Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ext/DataInterpolationsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
module DataInterpolationsChainRulesCoreExt

using DataInterpolations: _interpolate, derivative, AbstractInterpolation,
LinearInterpolation, QuadraticInterpolation,
LagrangeInterpolation, AkimaInterpolation,
BSplineInterpolation, BSplineApprox, get_idx, get_parameters,
munge_data
LinearInterpolation, QuadraticInterpolation,
LagrangeInterpolation, AkimaInterpolation,
BSplineInterpolation, BSplineApprox, get_idx, get_parameters,
munge_data
using ChainRulesCore

function ChainRulesCore.rrule(::typeof(munge_data), u, t)
Expand Down
29 changes: 17 additions & 12 deletions ext/DataInterpolationsMakieExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@ using Makie
Makie.plottype(::AbstractInterpolation) = Makie.ScatterLines

# Define the attributes that you want to use
Makie.used_attributes(::Makie.PointBased, ::AbstractInterpolation) = (:plotdensity, :denseplot)
Makie.used_attributes(::Type{<:Makie.ScatterLines}, ::AbstractInterpolation) = (:plotdensity, :denseplot)
function Makie.used_attributes(::Makie.PointBased, ::AbstractInterpolation)
(:plotdensity, :denseplot)
end
function Makie.used_attributes(::Type{<:Makie.ScatterLines}, ::AbstractInterpolation)
(:plotdensity, :denseplot)
end

# Define the conversion of the data to the plot
function Makie.convert_arguments(
::Makie.PointBased,
A::AbstractInterpolation;
plotdensity = 10_000,
denseplot = true,
::Makie.PointBased,
A::AbstractInterpolation;
plotdensity = 10_000,
denseplot = true
)
DataInterpolations.to_plottable(A; plotdensity = plotdensity, denseplot = denseplot)
end
Expand All @@ -26,16 +30,17 @@ end
# and should actually be handled by a plot! method,
# except that doesn't work anymore or does it?
function Makie.convert_arguments(
::Type{<:Makie.ScatterLines},
A::AbstractInterpolation;
plotdensity = 10_000,
denseplot = true,
::Type{<:Makie.ScatterLines},
A::AbstractInterpolation;
plotdensity = 10_000,
denseplot = true
)
densex, densey = convert_arguments(Makie.PointBased(), A; plotdensity = plotdensity, denseplot = denseplot)
densex,
densey = convert_arguments(Makie.PointBased(), A; plotdensity = plotdensity, denseplot = denseplot)
return [
Makie.SpecApi.Lines(densex, densey),
Makie.SpecApi.Scatter(A.t, A.u)
]
end

end # module
end # module
88 changes: 42 additions & 46 deletions ext/DataInterpolationsSparseConnectivityTracerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,22 @@ using SparseConnectivityTracer: GradientTracer, gradient_tracer_1_to_1
using SparseConnectivityTracer: HessianTracer, hessian_tracer_1_to_1
using FillArrays: Fill # from FillArrays.jl
using DataInterpolations:
AbstractInterpolation,
LinearInterpolation,
QuadraticInterpolation,
LagrangeInterpolation,
AkimaInterpolation,
ConstantInterpolation,
QuadraticSpline,
CubicSpline,
BSplineInterpolation,
BSplineApprox,
CubicHermiteSpline,
# PCHIPInterpolation,
QuinticHermiteSpline,
output_size
AbstractInterpolation,
LinearInterpolation,
QuadraticInterpolation,
LagrangeInterpolation,
AkimaInterpolation,
ConstantInterpolation,
QuadraticSpline,
CubicSpline,
BSplineInterpolation,
BSplineApprox,
CubicHermiteSpline,
# PCHIPInterpolation,
QuinticHermiteSpline,
output_size#===========##===========#

#===========#
# Utilities #
#===========#

# Limit support to `u` begin an AbstractVector{<:Number} or AbstractMatrix{<:Number},
# to avoid any cases where the output size is dependent on the input value.
Expand All @@ -33,26 +31,26 @@ function _sct_interpolate(
uType::Type{<:AbstractVector{<:Number}},
t::GradientTracer,
is_der_1_zero,
is_der_2_zero,
)
is_der_2_zero
)
return gradient_tracer_1_to_1(t, is_der_1_zero)
end
function _sct_interpolate(
::AbstractInterpolation,
uType::Type{<:AbstractVector{<:Number}},
t::HessianTracer,
is_der_1_zero,
is_der_2_zero,
)
is_der_2_zero
)
return hessian_tracer_1_to_1(t, is_der_1_zero, is_der_2_zero)
end
function _sct_interpolate(
interp::AbstractInterpolation,
uType::Type{<:AbstractMatrix{<:Number}},
t::GradientTracer,
is_der_1_zero,
is_der_2_zero,
)
is_der_2_zero
)
t = gradient_tracer_1_to_1(t, is_der_1_zero)
N = only(output_size(interp))
return Fill(t, N)
Expand All @@ -62,48 +60,46 @@ function _sct_interpolate(
uType::Type{<:AbstractMatrix{<:Number}},
t::HessianTracer,
is_der_1_zero,
is_der_2_zero,
)
is_der_2_zero
)
t = hessian_tracer_1_to_1(t, is_der_1_zero, is_der_2_zero)
N = only(output_size(interp))
return Fill(t, N)
end
end#===========##===========#

#===========#
# Overloads #
#===========#

# We assume that with the exception of ConstantInterpolation and LinearInterpolation,
# all interpolations have a non-zero second derivative at some point in the input domain.

for (I, is_der1_zero, is_der2_zero) in (
(:ConstantInterpolation, true, true),
(:LinearInterpolation, false, true),
(:QuadraticInterpolation, false, false),
(:LagrangeInterpolation, false, false),
(:AkimaInterpolation, false, false),
(:QuadraticSpline, false, false),
(:CubicSpline, false, false),
(:BSplineInterpolation, false, false),
(:BSplineApprox, false, false),
(:CubicHermiteSpline, false, false),
(:QuinticHermiteSpline, false, false),
)
(:ConstantInterpolation, true, true),
(:LinearInterpolation, false, true),
(:QuadraticInterpolation, false, false),
(:LagrangeInterpolation, false, false),
(:AkimaInterpolation, false, false),
(:QuadraticSpline, false, false),
(:CubicSpline, false, false),
(:BSplineInterpolation, false, false),
(:BSplineApprox, false, false),
(:CubicHermiteSpline, false, false),
(:QuinticHermiteSpline, false, false)
)
@eval function (interp::$(I){uType})(
t::AbstractTracer
) where {uType <: AbstractArray{<:Number}}
) where {uType <: AbstractArray{<:Number}}
return _sct_interpolate(interp, uType, t, $is_der1_zero, $is_der2_zero)
end
end

# Some Interpolations require custom overloads on `Dual` due to mutation of caches.
for I in (
:LagrangeInterpolation,
:BSplineInterpolation,
:BSplineApprox,
:CubicHermiteSpline,
:QuinticHermiteSpline,
)
:LagrangeInterpolation,
:BSplineInterpolation,
:BSplineApprox,
:CubicHermiteSpline,
:QuinticHermiteSpline
)
@eval function (interp::$(I){uType})(d::Dual) where {uType <: AbstractVector}
p = interp(primal(d))
t = interp(tracer(d))
Expand Down
4 changes: 2 additions & 2 deletions src/interpolation_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1220,8 +1220,8 @@ function BSplineApprox(
end
for k in 2:(n - 1)
q[ax_u...,
k] = u[ax_u..., k] - sc[k, 1] * u[ax_u..., 1] -
sc[k, h] * u[ax_u..., end]
k] = u[ax_u..., k] - sc[k, 1] * u[ax_u..., 1] -
sc[k, h] * u[ax_u..., end]
end
Q = Array{T, N}(undef, size(u)[1:(end - 1)]..., h - 2)
for i in 2:(h - 1)
Expand Down
3 changes: 2 additions & 1 deletion test/derivative_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ function test_derivatives(method; args = [], kwargs = [], name::String)

# Interpolation transition points
for _t in t[2:(end - 1)]
if func isa Union{SmoothedConstantInterpolation, BSplineInterpolation, BSplineApprox}
if func isa
Union{SmoothedConstantInterpolation, BSplineInterpolation, BSplineApprox}
# TODO fix interpolations
continue
else
Expand Down
2 changes: 1 addition & 1 deletion test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ end
@testset "Output Type" begin
# Test consistency between eltype(u) and type of the output
u = Float32[-0.676367f0, 0.8449812f0, 1.2366607f0, -0.13347931f0, 1.9928657f0,
-0.63596356f0, 0.76009744f0, -0.30632544f0, 0.34649512f0, -0.3846099f0]
-0.63596356f0, 0.76009744f0, -0.30632544f0, 0.34649512f0, -0.3846099f0]
t = 0.1f0:0.1f0:1.0f0
for extrapolation_flag in instances(ExtrapolationType.T)
(extrapolation_flag == ExtrapolationType.None) && continue
Expand Down
Loading