diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index fe49d9f6..aabda80f 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -1,10 +1,10 @@ module DataInterpolationsChainRulesCoreExt using DataInterpolations: _interpolate, derivative, AbstractInterpolation, - LinearInterpolation, QuadraticInterpolation, - LagrangeInterpolation, AkimaInterpolation, - BSplineInterpolation, BSplineApprox, get_idx, get_parameters, - munge_data + LinearInterpolation, QuadraticInterpolation, + LagrangeInterpolation, AkimaInterpolation, + BSplineInterpolation, BSplineApprox, get_idx, get_parameters, + munge_data using ChainRulesCore function ChainRulesCore.rrule(::typeof(munge_data), u, t) diff --git a/ext/DataInterpolationsMakieExt.jl b/ext/DataInterpolationsMakieExt.jl index 2581781b..1598cae0 100644 --- a/ext/DataInterpolationsMakieExt.jl +++ b/ext/DataInterpolationsMakieExt.jl @@ -8,15 +8,19 @@ using Makie Makie.plottype(::AbstractInterpolation) = Makie.ScatterLines # Define the attributes that you want to use -Makie.used_attributes(::Makie.PointBased, ::AbstractInterpolation) = (:plotdensity, :denseplot) -Makie.used_attributes(::Type{<:Makie.ScatterLines}, ::AbstractInterpolation) = (:plotdensity, :denseplot) +function Makie.used_attributes(::Makie.PointBased, ::AbstractInterpolation) + (:plotdensity, :denseplot) +end +function Makie.used_attributes(::Type{<:Makie.ScatterLines}, ::AbstractInterpolation) + (:plotdensity, :denseplot) +end # Define the conversion of the data to the plot function Makie.convert_arguments( - ::Makie.PointBased, - A::AbstractInterpolation; - plotdensity = 10_000, - denseplot = true, + ::Makie.PointBased, + A::AbstractInterpolation; + plotdensity = 10_000, + denseplot = true ) DataInterpolations.to_plottable(A; plotdensity = plotdensity, denseplot = denseplot) end @@ -26,16 +30,17 @@ end # and should actually be handled by a plot! method, # except that doesn't work anymore or does it? function Makie.convert_arguments( - ::Type{<:Makie.ScatterLines}, - A::AbstractInterpolation; - plotdensity = 10_000, - denseplot = true, + ::Type{<:Makie.ScatterLines}, + A::AbstractInterpolation; + plotdensity = 10_000, + denseplot = true ) - densex, densey = convert_arguments(Makie.PointBased(), A; plotdensity = plotdensity, denseplot = denseplot) + densex, + densey = convert_arguments(Makie.PointBased(), A; plotdensity = plotdensity, denseplot = denseplot) return [ Makie.SpecApi.Lines(densex, densey), Makie.SpecApi.Scatter(A.t, A.u) ] end -end # module \ No newline at end of file +end # module diff --git a/ext/DataInterpolationsSparseConnectivityTracerExt.jl b/ext/DataInterpolationsSparseConnectivityTracerExt.jl index 7f1df97b..d79e3741 100644 --- a/ext/DataInterpolationsSparseConnectivityTracerExt.jl +++ b/ext/DataInterpolationsSparseConnectivityTracerExt.jl @@ -5,24 +5,22 @@ using SparseConnectivityTracer: GradientTracer, gradient_tracer_1_to_1 using SparseConnectivityTracer: HessianTracer, hessian_tracer_1_to_1 using FillArrays: Fill # from FillArrays.jl using DataInterpolations: - AbstractInterpolation, - LinearInterpolation, - QuadraticInterpolation, - LagrangeInterpolation, - AkimaInterpolation, - ConstantInterpolation, - QuadraticSpline, - CubicSpline, - BSplineInterpolation, - BSplineApprox, - CubicHermiteSpline, - # PCHIPInterpolation, - QuinticHermiteSpline, - output_size + AbstractInterpolation, + LinearInterpolation, + QuadraticInterpolation, + LagrangeInterpolation, + AkimaInterpolation, + ConstantInterpolation, + QuadraticSpline, + CubicSpline, + BSplineInterpolation, + BSplineApprox, + CubicHermiteSpline, +# PCHIPInterpolation, + QuinticHermiteSpline, + output_size#===========##===========# -#===========# # Utilities # -#===========# # Limit support to `u` begin an AbstractVector{<:Number} or AbstractMatrix{<:Number}, # to avoid any cases where the output size is dependent on the input value. @@ -33,8 +31,8 @@ function _sct_interpolate( uType::Type{<:AbstractVector{<:Number}}, t::GradientTracer, is_der_1_zero, - is_der_2_zero, - ) + is_der_2_zero +) return gradient_tracer_1_to_1(t, is_der_1_zero) end function _sct_interpolate( @@ -42,8 +40,8 @@ function _sct_interpolate( uType::Type{<:AbstractVector{<:Number}}, t::HessianTracer, is_der_1_zero, - is_der_2_zero, - ) + is_der_2_zero +) return hessian_tracer_1_to_1(t, is_der_1_zero, is_der_2_zero) end function _sct_interpolate( @@ -51,8 +49,8 @@ function _sct_interpolate( uType::Type{<:AbstractMatrix{<:Number}}, t::GradientTracer, is_der_1_zero, - is_der_2_zero, - ) + is_der_2_zero +) t = gradient_tracer_1_to_1(t, is_der_1_zero) N = only(output_size(interp)) return Fill(t, N) @@ -62,48 +60,46 @@ function _sct_interpolate( uType::Type{<:AbstractMatrix{<:Number}}, t::HessianTracer, is_der_1_zero, - is_der_2_zero, - ) + is_der_2_zero +) t = hessian_tracer_1_to_1(t, is_der_1_zero, is_der_2_zero) N = only(output_size(interp)) return Fill(t, N) -end +end#===========##===========# -#===========# # Overloads # -#===========# # We assume that with the exception of ConstantInterpolation and LinearInterpolation, # all interpolations have a non-zero second derivative at some point in the input domain. for (I, is_der1_zero, is_der2_zero) in ( - (:ConstantInterpolation, true, true), - (:LinearInterpolation, false, true), - (:QuadraticInterpolation, false, false), - (:LagrangeInterpolation, false, false), - (:AkimaInterpolation, false, false), - (:QuadraticSpline, false, false), - (:CubicSpline, false, false), - (:BSplineInterpolation, false, false), - (:BSplineApprox, false, false), - (:CubicHermiteSpline, false, false), - (:QuinticHermiteSpline, false, false), - ) + (:ConstantInterpolation, true, true), + (:LinearInterpolation, false, true), + (:QuadraticInterpolation, false, false), + (:LagrangeInterpolation, false, false), + (:AkimaInterpolation, false, false), + (:QuadraticSpline, false, false), + (:CubicSpline, false, false), + (:BSplineInterpolation, false, false), + (:BSplineApprox, false, false), + (:CubicHermiteSpline, false, false), + (:QuinticHermiteSpline, false, false) +) @eval function (interp::$(I){uType})( t::AbstractTracer - ) where {uType <: AbstractArray{<:Number}} + ) where {uType <: AbstractArray{<:Number}} return _sct_interpolate(interp, uType, t, $is_der1_zero, $is_der2_zero) end end # Some Interpolations require custom overloads on `Dual` due to mutation of caches. for I in ( - :LagrangeInterpolation, - :BSplineInterpolation, - :BSplineApprox, - :CubicHermiteSpline, - :QuinticHermiteSpline, - ) + :LagrangeInterpolation, + :BSplineInterpolation, + :BSplineApprox, + :CubicHermiteSpline, + :QuinticHermiteSpline +) @eval function (interp::$(I){uType})(d::Dual) where {uType <: AbstractVector} p = interp(primal(d)) t = interp(tracer(d)) diff --git a/src/interpolation_caches.jl b/src/interpolation_caches.jl index 47c0773b..97ff494e 100644 --- a/src/interpolation_caches.jl +++ b/src/interpolation_caches.jl @@ -1220,8 +1220,8 @@ function BSplineApprox( end for k in 2:(n - 1) q[ax_u..., - k] = u[ax_u..., k] - sc[k, 1] * u[ax_u..., 1] - - sc[k, h] * u[ax_u..., end] + k] = u[ax_u..., k] - sc[k, 1] * u[ax_u..., 1] - + sc[k, h] * u[ax_u..., end] end Q = Array{T, N}(undef, size(u)[1:(end - 1)]..., h - 2) for i in 2:(h - 1) diff --git a/test/derivative_tests.jl b/test/derivative_tests.jl index 081de086..016ba8dc 100644 --- a/test/derivative_tests.jl +++ b/test/derivative_tests.jl @@ -35,7 +35,8 @@ function test_derivatives(method; args = [], kwargs = [], name::String) # Interpolation transition points for _t in t[2:(end - 1)] - if func isa Union{SmoothedConstantInterpolation, BSplineInterpolation, BSplineApprox} + if func isa + Union{SmoothedConstantInterpolation, BSplineInterpolation, BSplineApprox} # TODO fix interpolations continue else diff --git a/test/interface.jl b/test/interface.jl index a1dad63e..f7222f5b 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -59,7 +59,7 @@ end @testset "Output Type" begin # Test consistency between eltype(u) and type of the output u = Float32[-0.676367f0, 0.8449812f0, 1.2366607f0, -0.13347931f0, 1.9928657f0, - -0.63596356f0, 0.76009744f0, -0.30632544f0, 0.34649512f0, -0.3846099f0] + -0.63596356f0, 0.76009744f0, -0.30632544f0, 0.34649512f0, -0.3846099f0] t = 0.1f0:0.1f0:1.0f0 for extrapolation_flag in instances(ExtrapolationType.T) (extrapolation_flag == ExtrapolationType.None) && continue