Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient evaluation #18

Merged
merged 4 commits into from
Jan 2, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 28 additions & 1 deletion src/Interpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@ module Interpolations
using Base.Cartesian
using Compat

import Base: size, eltype, getindex, ndims
import Base:
eltype,
gradient,
getindex,
ndims,
size

export
Interpolation,
Expand Down Expand Up @@ -152,9 +157,31 @@ for IT in (
ret
end
))

eval(ngenerate(
:N,
:(Array{promote_type(T,typeof(x)...),1}),
:(gradient!{T,N}(g::Array{T,1}, itp::Interpolation{T,N,$IT,$EB}, x::NTuple{N,Real}...)),
N->quote
$(extrap_transform_x(gr,eb,N))
$(define_indices(it,N))
@nexprs $N dim->begin
@nexprs $N d->begin
(d==dim
? $(gradient_coefficients(it,N,:d))
: $(coefficients(it,N,:d)))
end

@inbounds g[dim] = $(index_gen(degree(it),N))
end
g
end
))
end
end

gradient{T}(itp::Interpolation{T}, x...) = gradient!(Array(T,ndims(itp)), itp, x...)

# This creates prefilter specializations for all interpolation types that need them
for IT in (
Quadratic{Flat,OnCell},
Expand Down
14 changes: 12 additions & 2 deletions src/constant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,18 @@ function define_indices(::Constant, N)
:(@nexprs $N d->(ix_d = clamp(round(Int,x_d), 1, size(itp,d))))
end

function coefficients(::Constant, N)
:(@nexprs $N d->(c_d = one(typeof(x_d))))
function coefficients(c::Constant, N)
:(@nexprs $N d->($(coefficients(c, N, :d))))
end

function coefficients(::Constant, N, d)
sym, symx = symbol(string("c_",d)), symbol(string("x_",d))
:($sym = one(typeof($symx)))
end

function gradient_coefficients(::Constant, N, d)
sym, symx = symbol(string("c_",d)), symbol(string("x_",d))
:($sym = zero(typeof($symx)))
end

function index_gen(degree::ConstantDegree, N::Integer, offsets...)
Expand Down
21 changes: 16 additions & 5 deletions src/linear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,23 @@ function define_indices(::Linear, N)
end
end

function coefficients(::Linear, N)
function coefficients(l::Linear, N)
:(@nexprs $N d->($(coefficients(l, N, :d))))
end

function coefficients(::Linear, N, d)
sym, symp, symfx = symbol(string("c_",d)), symbol(string("cp_",d)), symbol(string("fx_",d))
quote
@nexprs $N d->begin
c_d = one(typeof(fx_d)) - fx_d
cp_d = fx_d
end
$sym = one(typeof($symfx)) - $symfx
$symp = $symfx
end
end

function gradient_coefficients(::Linear,N,d)
sym, symp, symfx = symbol(string("c_",d)), symbol(string("cp_",d)), symbol(string("fx_",d))
quote
$sym = -one(typeof($symfx))
$symp = one(typeof($symfx))
end
end

Expand Down
26 changes: 20 additions & 6 deletions src/quadratic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,27 @@ function define_indices(q::Quadratic{Periodic}, N)
end
end

function coefficients(::Quadratic, N)
function coefficients(q::Quadratic, N)
:(@nexprs $N d->($(coefficients(q, N, :d))))
end

function coefficients(q::Quadratic, N, d)
symm, sym, symp = symbol(string("cm_",d)), symbol(string("c_",d)), symbol(string("cp_",d))
symfx = symbol(string("fx_",d))
quote
@nexprs $N d->begin
cm_d = .5 * (fx_d-.5)^2
c_d = .75 - fx_d^2
cp_d = .5 * (fx_d+.5)^2
end
$symm = .5 * ($symfx - .5)^2
$sym = .75 - $symfx^2
$symp = .5 * ($symfx + .5)^2
end
end

function gradient_coefficients(q::Quadratic, N, d)
symm, sym, symp = symbol(string("cm_",d)), symbol(string("c_",d)), symbol(string("cp_",d))
symfx = symbol(string("fx_",d))
quote
$symm = $symfx-.5
$sym = -2*$symfx
$symp = $symfx+.5
end
end

Expand Down
30 changes: 30 additions & 0 deletions test/gradient.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
module GradientTests
println("Testing gradient evaluation")
using Base.Test, Interpolations

nx = 10
f1(x) = sin((x-3)*2pi/(nx-1) - 1)
g1(x) = 2pi/(nx-1) * cos((x-3)*2pi/(nx-1) - 1)

# Gradient of Constant should always be 0
itp1 = Interpolation(Float64[f1(x) for x in 1:nx-1],
Constant(OnGrid()), ExtrapPeriodic())
for x in 1:nx
@test gradient(itp1, x)[1] == 0
end

# Since Linear is OnGrid in the domain, check the gradients between grid points
itp1 = Interpolation(Float64[f1(x) for x in 1:nx-1],
Linear(OnGrid()), ExtrapPeriodic())
for x in 2.5:nx-1.5
@test_approx_eq_eps g1(x) gradient(itp1, x)[1] abs(.1*g1(x))
end

# Since Quadratic is OnCell in the domain, check gradients at grid points
itp1 = Interpolation(Float64[f1(x) for x in 1:nx-1],
Quadratic(Periodic(),OnGrid()), ExtrapPeriodic())
for x in 2:nx-1
@test_approx_eq_eps g1(x) gradient(itp1, x)[1] abs(.05*g1(x))
end

end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ include("quadratic.jl")
# indices inbounds in A.
include("on-grid.jl")

# test gradient evaluation
include("gradient.jl")

# Tests copied from Grid.jl's old test suite
#include("grid.jl")

Expand Down