Skip to content

Commit

Permalink
Merge pull request #47 from tlycken/scaling
Browse files Browse the repository at this point in the history
RFC: Scaling of interpolation objects (fixes #25)
  • Loading branch information
Tomas Lycken committed Sep 21, 2015
2 parents 3d73d1d + d3c7a31 commit 7646fe9
Show file tree
Hide file tree
Showing 20 changed files with 324 additions and 46 deletions.
4 changes: 2 additions & 2 deletions perf/run_shootout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ make_knots(A) = ntuple(d->collect(linspace(1,size(A,d),size(A,d))), ndims(A))
make_xi(A) = ntuple(d->collect(linspace(2,size(A,d)-1,size(A,d)-2)), ndims(A))

## Interpolations and Grid
function evaluate_grid(itp::Union(Array,Interpolations.AbstractInterpolation,Grid.InterpGrid), A)
function evaluate_grid(itp::Union{Array,Interpolations.AbstractInterpolation,Grid.InterpGrid), A}
s = zero(eltype(itp)) + zero(eltype(itp))
for I in iterrange(itp)
s += itp[I]
Expand Down Expand Up @@ -58,7 +58,7 @@ function evaluate_grid(itp::Dierckx.Spline2D, A)
end

# Slow approach for Dierckx
function evaluate_grid_scalar(itp::Union(Dierckx.Spline1D,Dierckx.Spline2D), A)
function evaluate_grid_scalar(itp::Union{Dierckx.Spline1D,Dierckx.Spline2D), A}
T = eltype(A)
s = zero(T) + zero(T)
for I in iterrange(A)
Expand Down
21 changes: 17 additions & 4 deletions src/Interpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@ export
interpolate,
interpolate!,
extrapolate,
scale,

gradient!,

AbstractInterpolation,
AbstractExtrapolation,

OnCell,
OnGrid,

Expand All @@ -22,21 +26,23 @@ export
# see the following files for further exports:
# b-splines/b-splines.jl
# extrapolation/extrapolation.jl
# scaling/scaling.jl

using WoodburyMatrices, Ratios, AxisAlgorithms

import Base: convert, size, getindex, gradient, promote_rule
import Base: convert, size, getindex, gradient, scale, promote_rule

abstract InterpolationType
immutable NoInterp <: InterpolationType end
abstract GridType
immutable OnGrid <: GridType end
immutable OnCell <: GridType end

typealias DimSpec{T} Union(T,Tuple{Vararg{Union(T,NoInterp)}},NoInterp)
typealias DimSpec{T} Union{T,Tuple{Vararg{Union{T,NoInterp}}},NoInterp}

abstract AbstractInterpolation{T,N,IT<:DimSpec{InterpolationType},GT<:DimSpec{GridType}} <: AbstractArray{T,N}
abstract AbstractExtrapolation{T,N,ITPT,IT,GT} <: AbstractInterpolation{T,N,IT,GT}
abstract AbstractInterpolationWrapper{T,N,ITPT,IT,GT} <: AbstractInterpolation{T,N,IT,GT}
abstract AbstractExtrapolation{T,N,ITPT,IT,GT} <: AbstractInterpolationWrapper{T,N,ITPT,IT,GT}

abstract BoundaryCondition
immutable Flat <: BoundaryCondition end
Expand All @@ -53,7 +59,13 @@ typealias Natural Line
# TODO: size might have to be faster?
size{T,N}(itp::AbstractInterpolation{T,N}) = ntuple(i->size(itp,i), N)::NTuple{N,Int}
size(exp::AbstractExtrapolation, d) = size(exp.itp, d)
itptype{T,N,IT,GT}(itp::AbstractInterpolation{T,N,IT,GT}) = IT
bounds{T,N}(itp::AbstractInterpolation{T,N}) = tuple(zip(lbounds(itp), ubounds(itp))...)
bounds{T,N}(itp::AbstractInterpolation{T,N}, d) = (lbound(itp,d),ubound(itp,d))
lbounds{T,N}(itp::AbstractInterpolation{T,N}) = ntuple(i->lbound(itp,i), N)::NTuple{N,T}
ubounds{T,N}(itp::AbstractInterpolation{T,N}) = ntuple(i->ubound(itp,i), N)::NTuple{N,T}
lbound{T,N}(itp::AbstractInterpolation{T,N}, d) = convert(T, 1)
ubound{T,N}(itp::AbstractInterpolation{T,N}, d) = convert(T, size(itp, d))
itptype{T,N,IT,GT}(itp::AbstractInterpolation{T,N,IT,GT}) = IT
gridtype{T,N,IT,GT}(itp::AbstractInterpolation{T,N,IT,GT}) = GT

@inline gradient{T,N}(itp::AbstractInterpolation{T,N}, xs...) = gradient!(Array(T,N), itp, xs...)
Expand All @@ -62,5 +74,6 @@ include("nointerp/nointerp.jl")
include("b-splines/b-splines.jl")
include("gridded/gridded.jl")
include("extrapolation/extrapolation.jl")
include("scaling/scaling.jl")

end # module
5 changes: 5 additions & 0 deletions src/b-splines/b-splines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ iextract(t, d) = t.parameters[d]
padextract(pad::Integer, d) = pad
padextract(pad::Tuple{Vararg{Integer}}, d) = pad[d]

lbound{T,N,TCoefs,IT}(itp::BSplineInterpolation{T,N,TCoefs,IT,OnGrid}, d) = one(T)
ubound{T,N,TCoefs,IT}(itp::BSplineInterpolation{T,N,TCoefs,IT,OnGrid}, d) = convert(T, size(itp, d))
lbound{T,N,TCoefs,IT}(itp::BSplineInterpolation{T,N,TCoefs,IT,OnCell}, d) = convert(T, .5)
ubound{T,N,TCoefs,IT}(itp::BSplineInterpolation{T,N,TCoefs,IT,OnCell}, d) = convert(T, size(itp, d) + .5)

@generated function size{T,N,TCoefs,IT,GT,pad}(itp::BSplineInterpolation{T,N,TCoefs,IT,GT,pad}, d)
quote
d <= $N ? size(itp.coefs, d) - 2*padextract($pad, d) : 1
Expand Down
2 changes: 1 addition & 1 deletion src/b-splines/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ function gradient_impl{T,N,TCoefs,IT<:DimSpec{BSpline},GT<:DimSpec{GridType},Pad
gradient_exprs = Expr(:block, exs...)
quote
$meta
length(g) == $n || throw(DimensionMismatch("Gradient has wrong number of components"))
length(g) == $n || throw(ArgumentError(string("The length of the provided gradient vector (", length(g), ") did not match the number of interpolating dimensions (", n, ")")))
@nexprs $N d->(x_d = xs[d])

# Calculate the indices of all coefficients that will be used
Expand Down
4 changes: 2 additions & 2 deletions src/b-splines/prefiltering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ function prefilter{TWeights,TCoefs,TSrc,N,IT<:Quadratic,GT<:GridType}(
prefilter!(TWeights, ret, BSpline{IT}, GT), Pad
end

function prefilter{TWeights,TCoefs,TSrc,N,IT<:Tuple{Vararg{Union(BSpline,NoInterp)}},GT<:DimSpec{GridType}}(
function prefilter{TWeights,TCoefs,TSrc,N,IT<:Tuple{Vararg{Union{BSpline,NoInterp}}},GT<:DimSpec{GridType}}(
::Type{TWeights}, ::Type{TCoefs}, A::Array{TSrc,N}, ::Type{IT}, ::Type{GT}
)
ret, Pad = copy_with_padding(TCoefs,A, IT)
Expand All @@ -59,7 +59,7 @@ function prefilter!{TWeights,TCoefs,N,IT<:Quadratic,GT<:GridType}(
ret
end

function prefilter!{TWeights,TCoefs,N,IT<:Tuple{Vararg{Union(BSpline,NoInterp)}},GT<:DimSpec{GridType}}(
function prefilter!{TWeights,TCoefs,N,IT<:Tuple{Vararg{Union{BSpline,NoInterp}}},GT<:DimSpec{GridType}}(
::Type{TWeights}, ret::Array{TCoefs,N}, ::Type{IT}, ::Type{GT}
)
local buf, shape, retrs
Expand Down
6 changes: 3 additions & 3 deletions src/b-splines/quadratic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ function define_indices_d(::Type{BSpline{Quadratic{Periodic}}}, d, pad)
$symixm = mod1($symix - 1, size(itp,$d))
end
end
function define_indices_d{BC<:Union(InPlace,InPlaceQ)}(::Type{BSpline{Quadratic{BC}}}, d, pad)
function define_indices_d{BC<:Union{InPlace,InPlaceQ}}(::Type{BSpline{Quadratic{BC}}}, d, pad)
symix, symixm, symixp = symbol("ix_",d), symbol("ixm_",d), symbol("ixp_",d)
symx, symfx = symbol("x_",d), symbol("fx_",d)
pad == 0 || error("Use $BC only with interpolate!")
Expand Down Expand Up @@ -83,7 +83,7 @@ function inner_system_diags{T,Q<:Quadratic}(::Type{T}, n::Int, ::Type{Q})
(dl,d,du)
end

function prefiltering_system{T,TCoefs,BC<:Union(Flat,Reflect)}(::Type{T}, ::Type{TCoefs}, n::Int, ::Type{Quadratic{BC}}, ::Type{OnCell})
function prefiltering_system{T,TCoefs,BC<:Union{Flat,Reflect}}(::Type{T}, ::Type{TCoefs}, n::Int, ::Type{Quadratic{BC}}, ::Type{OnCell})
dl,d,du = inner_system_diags(T,n,Quadratic{BC})
d[1] = d[end] = -1
du[1] = dl[end] = 1
Expand Down Expand Up @@ -111,7 +111,7 @@ function prefiltering_system{T,TCoefs}(::Type{T}, ::Type{TCoefs}, n::Int, ::Type
Woodbury(lufact!(Tridiagonal(dl, d, du), Val{false}), rowspec, valspec, colspec), zeros(TCoefs, n)
end

function prefiltering_system{T,TCoefs,BC<:Union(Flat,Reflect)}(::Type{T}, ::Type{TCoefs}, n::Int, ::Type{Quadratic{BC}}, ::Type{OnGrid})
function prefiltering_system{T,TCoefs,BC<:Union{Flat,Reflect}}(::Type{T}, ::Type{TCoefs}, n::Int, ::Type{Quadratic{BC}}, ::Type{OnGrid})
dl,d,du = inner_system_diags(T,n,Quadratic{BC})
d[1] = d[end] = -1
du[1] = dl[end] = 0
Expand Down
17 changes: 7 additions & 10 deletions src/extrapolation/constant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,12 @@ ConstantExtrapolation{T,ITP,IT,GT}(::Type{T}, N, itp::ITP, ::Type{IT}, ::Type{GT
extrapolate{T,N,IT,GT}(itp::AbstractInterpolation{T,N,IT,GT}, ::Type{Flat}) =
ConstantExtrapolation(T,N,itp,IT,GT)

function extrap_prep{T,ITP,IT}(exp::Type{ConstantExtrapolation{T,1,ITP,IT,OnGrid}}, x)
:(x = clamp(x, 1, size(exp,1)))
function extrap_prep{T,ITP,IT,GT}(etp::Type{ConstantExtrapolation{T,1,ITP,IT,GT}}, x)
:(x = clamp(x, lbound(etp,1), ubound(etp,1)))
end
function extrap_prep{T,ITP,IT}(exp::Type{ConstantExtrapolation{T,1,ITP,IT,OnCell}}, x)
:(x = clamp(x, .5, size(exp,1)+.5))
end
function extrap_prep{T,N,ITP,IT}(exp::Type{ConstantExtrapolation{T,N,ITP,IT,OnGrid}}, xs...)
:(@nexprs $N d->(xs[d] = clamp(xs[d], 1, size(exp,d))))
end
function extrap_prep{T,N,ITP,IT}(exp::Type{ConstantExtrapolation{T,N,ITP,IT,OnCell}}, xs...)
:(@nexprs $N d->(xs[d] = clamp(xs[d], .5, size(exp,d)+.5)))
function extrap_prep{T,N,ITP,IT,GT}(etp::Type{ConstantExtrapolation{T,N,ITP,IT,GT}}, xs...)
:(@nexprs $N d->(xs[d] = clamp(xs[d], lbound(etp,d), ubound(etp,d))))
end

lbound(etp::ConstantExtrapolation, d) = lbound(etp.itp, d)
ubound(etp::ConstantExtrapolation, d) = ubound(etp.itp, d)
8 changes: 5 additions & 3 deletions src/extrapolation/error.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ ErrorExtrapolation{T,ITPT,IT,GT}(::Type{T}, N, itp::ITPT, ::Type{IT}, ::Type{GT}
extrapolate{T,N,IT,GT}(itp::AbstractInterpolation{T,N,IT,GT}, ::Type{Throw}) =
ErrorExtrapolation(T,N,itp,IT,GT)


function extrap_prep{T,N,ITPT,IT}(exp::Type{ErrorExtrapolation{T,N,ITPT,IT,OnGrid}}, xs...)
:(@nexprs $N d->(@show 1 <= xs[d] <= size(exp,d) || throw(BoundsError())))
function extrap_prep{T,N,ITPT,IT,GT}(etp::Type{ErrorExtrapolation{T,N,ITPT,IT,GT}}, xs...)
:(@nexprs $N d->(lbound(etp,d) <= xs[d] <= ubound(etp,d) || throw(BoundsError())))
end

lbound(etp::ErrorExtrapolation, d) = lbound(etp.itp, d)
ubound(etp::ErrorExtrapolation, d) = ubound(etp.itp, d)
2 changes: 1 addition & 1 deletion src/extrapolation/extrapolation.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
export Throw,
FilledInterpolation # for direct control over typeof(fillvalue)
FilledExtrapolation # for direct control over typeof(fillvalue)

include("error.jl")

Expand Down
27 changes: 15 additions & 12 deletions src/extrapolation/filled.jl
Original file line number Diff line number Diff line change
@@ -1,36 +1,39 @@
nindexes(N::Int) = N == 1 ? "1 index" : "$N indexes"


type FilledInterpolation{T,N,ITP<:AbstractInterpolation,IT,GT,FT} <: AbstractExtrapolation{T,N,ITP,IT,GT}
type FilledExtrapolation{T,N,ITP<:AbstractInterpolation,IT,GT,FT} <: AbstractExtrapolation{T,N,ITP,IT,GT}
itp::ITP
fillvalue::FT
end
@doc """
`FilledInterpolation(itp, fillvalue)` creates an extrapolation object that returns the `fillvalue` any time the indexes in `itp[x1,x2,...]` are out-of-bounds.
"""
`FilledExtrapolation(itp, fillvalue)` creates an extrapolation object that returns the `fillvalue` any time the indexes in `itp[x1,x2,...]` are out-of-bounds.
By comparison with `extrapolate`, this version lets you control the `fillvalue`'s type directly. It's important for the `fillvalue` to be of the same type as returned by `itp[x1,x2,...]` for in-bounds regions for the index types you are using; otherwise, indexing will be type-unstable (and slow).
""" ->
function FilledInterpolation{T,N,IT,GT}(itp::AbstractInterpolation{T,N,IT,GT}, fillvalue)
FilledInterpolation{T,N,typeof(itp),IT,GT,typeof(fillvalue)}(itp, fillvalue)
"""
function FilledExtrapolation{T,N,IT,GT}(itp::AbstractInterpolation{T,N,IT,GT}, fillvalue)
FilledExtrapolation{T,N,typeof(itp),IT,GT,typeof(fillvalue)}(itp, fillvalue)
end

@doc """
"""
`extrapolate(itp, fillvalue)` creates an extrapolation object that returns the `fillvalue` any time the indexes in `itp[x1,x2,...]` are out-of-bounds.
""" ->
extrapolate{T,N,IT,GT}(itp::AbstractInterpolation{T,N,IT,GT}, fillvalue) = FilledInterpolation(itp, convert(eltype(itp), fillvalue))
"""
extrapolate{T,N,IT,GT}(itp::AbstractInterpolation{T,N,IT,GT}, fillvalue) = FilledExtrapolation(itp, convert(eltype(itp), fillvalue))

@generated function getindex{T,N}(fitp::FilledInterpolation{T,N}, args::Number...)
@generated function getindex{T,N}(fitp::FilledExtrapolation{T,N}, args::Number...)
n = length(args)
n == N || return error("Must index $(N)-dimensional interpolation objects with $(nindexes(N))")
meta = Expr(:meta, :inline)
quote
$meta
# Check to see if we're in the extrapolation region, i.e.,
# out-of-bounds in an index
@nexprs $N d->((args[d] < 1 || args[d] > size(fitp.itp, d)) && return fitp.fillvalue)
@nexprs $N d->((args[d] < lbound(fitp,d) || args[d] > ubound(fitp, d)) && return fitp.fillvalue)
# In the interpolation region
return getindex(fitp.itp,args...)
end
end

getindex{T}(fitp::FilledInterpolation{T,1}, x::Number, y::Int) = y == 1 ? fitp[x] : throw(BoundsError())
getindex{T}(fitp::FilledExtrapolation{T,1}, x::Number, y::Int) = y == 1 ? fitp[x] : throw(BoundsError())

lbound(etp::FilledExtrapolation, d) = lbound(etp.itp, d)
ubound(etp::FilledExtrapolation, d) = ubound(etp.itp, d)
12 changes: 6 additions & 6 deletions src/extrapolation/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
@generated function getindex{T}(exp::AbstractExtrapolation{T,1}, x)
@generated function getindex{T}(etp::AbstractExtrapolation{T,1}, x)
quote
$(extrap_prep(exp, x))
exp.itp[x]
$(extrap_prep(etp, x))
etp.itp[x]
end
end

@generated function getindex{T,N,ITP,GT}(exp::AbstractExtrapolation{T,N,ITP,GT}, xs...)
@generated function getindex{T,N,ITP,GT}(etp::AbstractExtrapolation{T,N,ITP,GT}, xs...)
quote
$(extrap_prep(exp, xs...))
exp.itp[xs...]
$(extrap_prep(etp, xs...))
etp.itp[xs...]
end
end
2 changes: 1 addition & 1 deletion src/gridded/gridded.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Gridded{D<:Degree}(::Type{D}) = Gridded{D}

griddedtype{D<:Degree}(::Type{Gridded{D}}) = D

typealias GridIndex{T} Union(AbstractVector{T}, Tuple)
typealias GridIndex{T} Union{AbstractVector{T}, Tuple}

# Because Ranges check bounds on getindex, it's actually faster to convert the
# knots to Vectors. It's also good to take a copy, so it doesn't get modified later.
Expand Down
90 changes: 90 additions & 0 deletions src/scaling/scaling.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
export ScaledInterpolation

type ScaledInterpolation{T,N,ITPT,IT,GT,RT} <: AbstractInterpolationWrapper{T,N,ITPT,IT,GT}
itp::ITPT
ranges::RT
end
ScaledInterpolation{T,ITPT,IT,GT,RT}(::Type{T}, N, itp::ITPT, ::Type{IT}, ::Type{GT}, ranges::RT) =
ScaledInterpolation{T,N,ITPT,IT,GT,RT}(itp, ranges)
"""
`scale(itp, xs, ys, ...)` scales an existing interpolation object to allow for indexing using other coordinate axes than unit ranges, by wrapping the interpolation object and transforming the indices from the provided axes onto unit ranges upon indexing.
The parameters `xs` etc must be either ranges or linspaces, and there must be one coordinate range/linspace for each dimension of the interpolation object.
For every `NoInterp` dimension of the interpolation object, the range must be exactly `1:size(itp, d)`.
"""
function scale{T,N,IT,GT}(itp::AbstractInterpolation{T,N,IT,GT}, ranges::Range...)
length(ranges) == N || throw(ArgumentError("Must scale $N-dimensional interpolation object with exactly $N ranges (you used $(length(ranges)))"))
for d in 1:N
if iextract(IT,d) != NoInterp
length(ranges[d]) == size(itp,d) || throw(ArgumentError("The length of the range in dimension $d ($(length(ranges[d]))) did not equal the size of the interpolation object in that direction ($(size(itp,d)))"))
elseif ranges[d] != 1:size(itp,d)
throw(ArgumentError("NoInterp dimension $d must be scaled with unit range 1:$(size(itp,d))"))
end
end

ScaledInterpolation(T,N,itp,IT,GT,ranges)
end

@generated function getindex{T,N,ITPT,IT<:DimSpec}(sitp::ScaledInterpolation{T,N,ITPT,IT}, xs::Number...)
length(xs) == N || throw(ArgumentError("Must index into $N-dimensional scaled interpolation object with exactly $N indices (you used $(length(xs)))"))
interp_types = length(IT.parameters) == N ? IT.parameters : tuple([IT.parameters[1] for _ in 1:N]...)
interp_dimens = map(it -> interp_types[it] != NoInterp, 1:N)
interp_indices = map(i -> interp_dimens[i] ? :(coordlookup(sitp.ranges[$i], xs[$i])) : :(xs[$i]), 1:N)
return :(getindex(sitp.itp, $(interp_indices...)))
end

getindex{T}(sitp::ScaledInterpolation{T,1}, x::Number, y::Int) = y == 1 ? sitp[x] : throw(BoundsError())

size(sitp::ScaledInterpolation, d) = size(sitp.itp, d)
lbound{T,N,ITPT,IT}(sitp::ScaledInterpolation{T,N,ITPT,IT,OnGrid}, d) = 1 <= d <= N ? sitp.ranges[d][1] : throw(BoundsError())
lbound{T,N,ITPT,IT}(sitp::ScaledInterpolation{T,N,ITPT,IT,OnCell}, d) = 1 <= d <= N ? sitp.ranges[d][1] - boundstep(sitp.ranges[d]) : throw(BoundsError())
ubound{T,N,ITPT,IT}(sitp::ScaledInterpolation{T,N,ITPT,IT,OnGrid}, d) = 1 <= d <= N ? sitp.ranges[d][end] : throw(BoundsError())
ubound{T,N,ITPT,IT}(sitp::ScaledInterpolation{T,N,ITPT,IT,OnCell}, d) = 1 <= d <= N ? sitp.ranges[d][end] + boundstep(sitp.ranges[d]) : throw(BoundsError())

boundstep(r::LinSpace) = ((r.stop - r.start) / r.divisor) / 2
boundstep(r::FloatRange) = r.step / 2
boundstep(r::StepRange) = r.step / 2
boundstep(r::UnitRange) = 1//2

"""
Returns *half* the width of one step of the range.
This function is used to calculate the upper and lower bounds of `OnCell` interpolation objects.
""" boundstep

coordlookup(r::LinSpace, x) = (r.divisor * x + r.stop - r.len * r.start) / (r.stop - r.start)
coordlookup(r::FloatRange, x) = (r.divisor * x - r.start) / r.step + one(eltype(r))
coordlookup(r::StepRange, x) = (x - r.start) / r.step + one(eltype(r))
coordlookup(r::UnitRange, x) = x - r.start + one(eltype(r))
coordlookup(i::Bool, r::Range, x) = i ? coordlookup(r, x) : convert(typeof(coordlookup(r,x)), x)

gradient{T,N,ITPT,IT<:DimSpec}(sitp::ScaledInterpolation{T,N,ITPT,IT}, xs::Number...) = gradient!(Array(T,count_interp_dims(IT,N)), sitp, xs...)
@generated function gradient!{T,N,ITPT,IT}(g, sitp::ScaledInterpolation{T,N,ITPT,IT}, xs::Number...)
ndims(g) == 1 || throw(DimensionMismatch("g must be a vector (but had $(ndims(g)) dimensions)"))
length(xs) == N || throw(DimensionMismatch("Must index into $N-dimensional scaled interpolation object with exactly $N indices (you used $(length(xs)))"))

interp_types = length(IT.parameters) == N ? IT.parameters : tuple([IT.parameters[1] for _ in 1:N]...)
interp_dimens = map(it -> interp_types[it] != NoInterp, 1:N)
interp_indices = map(i -> interp_dimens[i] ? :(coordlookup(sitp.ranges[$i], xs[$i])) : :(xs[$i]), 1:N)

quote
length(g) == $(count_interp_dims(IT, N)) || throw(ArgumentError(string("The length of the provided gradient vector (", length(g), ") did not match the number of interpolating dimensions (", $(count_interp_dims(IT, N)), ")")))
gradient!(g, sitp.itp, $(interp_indices...))
for i in eachindex(g)
g[i] = rescale_gradient(sitp.ranges[i], g[i])
end
g
end
end

rescale_gradient(r::LinSpace, g) = g * r.divisor / (r.stop - r.start)
rescale_gradient(r::FloatRange, g) = g * r.divisor / r.step
rescale_gradient(r::StepRange, g) = g / r.step
rescale_gradient(r::UnitRange, g) = g

"""
`rescale_gradient(r::Range)`
Implements the chain rule dy/dx = dy/du * du/dx for use when calculating gradients with scaled interpolation objects.
""" rescale_gradient
2 changes: 1 addition & 1 deletion test/extrapolation/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ etpf = @inferred(extrapolate(itpg, NaN))
@test_throws BoundsError etpf[2.5,2]
@test_throws ErrorException etpf[2.5,2,1] # this will probably become a BoundsError someday

etpf = @inferred(FilledInterpolation(itpg, 'x'))
etpf = @inferred(FilledExtrapolation(itpg, 'x'))
@test_approx_eq etpf[2] f(2)
@test etpf[-1.5] == 'x'

Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ include("b-splines/runtests.jl")
# extrapolation tests
include("extrapolation/runtests.jl")

# scaling tests
include("scaling/runtests.jl")

# # test gradient evaluation
include("gradient.jl")

Expand Down

0 comments on commit 7646fe9

Please sign in to comment.