Skip to content

Commit

Permalink
Merge pull request #245 from sathvikbhagavan/sb/order
Browse files Browse the repository at this point in the history
feat: add order arg for second order derivatives
  • Loading branch information
ChrisRackauckas committed May 6, 2024
2 parents 05b307c + a5953ae commit 19fae1f
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 30 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Expand Up @@ -4,21 +4,21 @@ version = "4.7.2"

[deps]
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

[extensions]
DataInterpolationsChainRulesCoreExt = "ChainRulesCore"
DataInterpolationsOptimExt = ["ForwardDiff", "Optim"]
DataInterpolationsOptimExt = "Optim"
DataInterpolationsRegularizationToolsExt = "RegularizationTools"
DataInterpolationsSymbolicsExt = "Symbolics"

Expand Down
8 changes: 5 additions & 3 deletions ext/DataInterpolationsOptimExt.jl
Expand Up @@ -4,7 +4,7 @@ using DataInterpolations
import DataInterpolations: munge_data,
Curvefit, CurvefitCache, _interpolate, get_show, derivative,
ExtrapolationError,
integral, IntegralNotFoundError
integral, IntegralNotFoundError, DerivativeNotFoundError

isdefined(Base, :get_extension) ? (using Optim, ForwardDiff) :
(using ..Optim, ..ForwardDiff)
Expand Down Expand Up @@ -49,9 +49,11 @@ function _interpolate(A::CurvefitCache{<:AbstractVector{<:Number}},
end

function derivative(A::CurvefitCache{<:AbstractVector{<:Number}},
t::Union{AbstractVector{<:Number}, Number})
t::Union{AbstractVector{<:Number}, Number}, order = 1)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
ForwardDiff.derivative(x -> A.m(x, A.pmin), t)
order > 2 && throw(DerivativeNotFoundError())
order == 1 && return ForwardDiff.derivative(x -> A.m(x, A.pmin), t)
return ForwardDiff.derivative(t -> ForwardDiff.derivative(x -> A.m(x, A.pmin), t), t)
end

function get_show(A::CurvefitCache)
Expand Down
4 changes: 2 additions & 2 deletions ext/DataInterpolationsRegularizationToolsExt.jl
Expand Up @@ -356,8 +356,8 @@ end
function derivative(A::RegularizationSmooth{
<:AbstractVector{<:Number},
},
t::Number)
derivative(A.Aitp, t)
t::Number, order = 1)
derivative(A.Aitp, t, order)
end

function get_show(A::RegularizationSmooth)
Expand Down
4 changes: 2 additions & 2 deletions ext/DataInterpolationsSymbolicsExt.jl
Expand Up @@ -16,8 +16,8 @@ end
SymbolicUtils.promote_symtype(t::AbstractInterpolation, _...) = Real
Base.nameof(interp::AbstractInterpolation) = :Interpolation

function derivative(interp::AbstractInterpolation, t::Num)
Symbolics.wrap(SymbolicUtils.term(derivative, interp, unwrap(t)))
function derivative(interp::AbstractInterpolation, t::Num, order = 1)
Symbolics.wrap(SymbolicUtils.term(derivative, interp, unwrap(t), order))
end
SymbolicUtils.promote_symtype(::typeof(derivative), _...) = Real

Expand Down
7 changes: 7 additions & 0 deletions src/DataInterpolations.jl
Expand Up @@ -17,6 +17,7 @@ end

using LinearAlgebra, RecipesBase
using PrettyTables
using ForwardDiff
import FindFirstFunctions: searchsortedfirstcorrelated, searchsortedlastcorrelated,
bracketstrictlymontonic

Expand Down Expand Up @@ -54,6 +55,12 @@ function Base.showerror(io::IO, e::IntegralNotFoundError)
print(io, INTEGRAL_NOT_FOUND_ERROR)
end

const DERIVATIVE_NOT_FOUND_ERROR = "Derivatives greater than second order is not supported."
struct DerivativeNotFoundError <: Exception end
function Base.showerror(io::IO, e::DerivativeNotFoundError)
print(io, DERIVATIVE_NOT_FOUND_ERROR)
end

export LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation,
AkimaInterpolation, ConstantInterpolation, QuadraticSpline, CubicSpline,
BSplineInterpolation, BSplineApprox
Expand Down
38 changes: 20 additions & 18 deletions src/derivatives.jl
@@ -1,41 +1,43 @@
function derivative(A, t)
function derivative(A, t, order = 1)
order > 2 && throw(DerivativeNotFoundError())
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
derivative(A, t, firstindex(A.t) - 1)[1]
order == 1 && return _derivative(A, t, firstindex(A.t) - 1)[1]
return ForwardDiff.derivative(t -> _derivative(A, t, firstindex(A.t) - 1)[1], t)
end

function derivative(A::LinearInterpolation{<:AbstractVector}, t::Number, iguess)
function _derivative(A::LinearInterpolation{<:AbstractVector}, t::Number, iguess)
idx = searchsortedfirstcorrelated(A.t, t, iguess)
idx > length(A.t) ? idx -= 1 : nothing
idx -= 1
idx == 0 ? idx += 1 : nothing
(A.u[idx + 1] - A.u[idx]) / (A.t[idx + 1] - A.t[idx]), idx
end

function derivative(A::LinearInterpolation{<:AbstractMatrix}, t::Number, iguess)
function _derivative(A::LinearInterpolation{<:AbstractMatrix}, t::Number, iguess)
idx = searchsortedfirstcorrelated(A.t, t, iguess)
idx > length(A.t) ? idx -= 1 : nothing
idx -= 1
idx == 0 ? idx += 1 : nothing
(@views @. (A.u[:, idx + 1] - A.u[:, idx]) / (A.t[idx + 1] - A.t[idx])), idx
end

function derivative(A::QuadraticInterpolation{<:AbstractVector}, t::Number, iguess)
function _derivative(A::QuadraticInterpolation{<:AbstractVector}, t::Number, iguess)
i₀, i₁, i₂ = _quad_interp_indices(A, t, iguess)
dl₀ = (2t - A.t[i₁] - A.t[i₂]) / ((A.t[i₀] - A.t[i₁]) * (A.t[i₀] - A.t[i₂]))
dl₁ = (2t - A.t[i₀] - A.t[i₂]) / ((A.t[i₁] - A.t[i₀]) * (A.t[i₁] - A.t[i₂]))
dl₂ = (2t - A.t[i₀] - A.t[i₁]) / ((A.t[i₂] - A.t[i₀]) * (A.t[i₂] - A.t[i₁]))
A.u[i₀] * dl₀ + A.u[i₁] * dl₁ + A.u[i₂] * dl₂, i₀
end

function derivative(A::QuadraticInterpolation{<:AbstractMatrix}, t::Number, iguess)
function _derivative(A::QuadraticInterpolation{<:AbstractMatrix}, t::Number, iguess)
i₀, i₁, i₂ = _quad_interp_indices(A, t, iguess)
dl₀ = (2t - A.t[i₁] - A.t[i₂]) / ((A.t[i₀] - A.t[i₁]) * (A.t[i₀] - A.t[i₂]))
dl₁ = (2t - A.t[i₀] - A.t[i₂]) / ((A.t[i₁] - A.t[i₀]) * (A.t[i₁] - A.t[i₂]))
dl₂ = (2t - A.t[i₀] - A.t[i₁]) / ((A.t[i₂] - A.t[i₀]) * (A.t[i₂] - A.t[i₁]))
(@views @. A.u[:, i₀] * dl₀ + A.u[:, i₁] * dl₁ + A.u[:, i₂] * dl₂), i₀
end

function derivative(A::LagrangeInterpolation{<:AbstractVector}, t::Number)
function _derivative(A::LagrangeInterpolation{<:AbstractVector}, t::Number)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
der = zero(A.u[1])
for j in eachindex(A.t)
Expand Down Expand Up @@ -69,7 +71,7 @@ function derivative(A::LagrangeInterpolation{<:AbstractVector}, t::Number)
der
end

function derivative(A::LagrangeInterpolation{<:AbstractMatrix}, t::Number)
function _derivative(A::LagrangeInterpolation{<:AbstractMatrix}, t::Number)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
der = zero(A.u[:, 1])
for j in eachindex(A.t)
Expand Down Expand Up @@ -98,15 +100,15 @@ function derivative(A::LagrangeInterpolation{<:AbstractMatrix}, t::Number)
tmp += k
end
end
@. der += A.u[:, j] * tmp
der += A.u[:, j] * tmp
end
der
end

derivative(A::LagrangeInterpolation{<:AbstractVector}, t::Number, i) = derivative(A, t), i
derivative(A::LagrangeInterpolation{<:AbstractMatrix}, t::Number, i) = derivative(A, t), i
_derivative(A::LagrangeInterpolation{<:AbstractVector}, t::Number, i) = _derivative(A, t), i
_derivative(A::LagrangeInterpolation{<:AbstractMatrix}, t::Number, i) = _derivative(A, t), i

function derivative(A::AkimaInterpolation{<:AbstractVector}, t::Number, iguess)
function _derivative(A::AkimaInterpolation{<:AbstractVector}, t::Number, iguess)
i = searchsortedfirstcorrelated(A.t, t, iguess)
i > length(A.t) ? i -= 1 : nothing
i -= 1
Expand All @@ -116,18 +118,18 @@ function derivative(A::AkimaInterpolation{<:AbstractVector}, t::Number, iguess)
(@evalpoly wj A.b[i] 2A.c[j] 3A.d[j]), i
end

function derivative(A::ConstantInterpolation{<:AbstractVector}, t::Number)
function _derivative(A::ConstantInterpolation{<:AbstractVector}, t::Number)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
return isempty(searchsorted(A.t, t)) ? zero(A.u[1]) : eltype(A.u)(NaN)
end

function derivative(A::ConstantInterpolation{<:AbstractMatrix}, t::Number)
function _derivative(A::ConstantInterpolation{<:AbstractMatrix}, t::Number)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
return isempty(searchsorted(A.t, t)) ? zero(A.u[:, 1]) : eltype(A.u)(NaN) .* A.u[:, 1]
end

# QuadraticSpline Interpolation
function derivative(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess)
function _derivative(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess)
idx = searchsortedfirstcorrelated(A.t, t, iguess)
idx > length(A.t) ? idx -= 1 : nothing
idx == 1 ? idx += 1 : nothing
Expand All @@ -136,7 +138,7 @@ function derivative(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess)
end

# CubicSpline Interpolation
function derivative(A::CubicSpline{<:AbstractVector}, t::Number, iguess)
function _derivative(A::CubicSpline{<:AbstractVector}, t::Number, iguess)
i = searchsortedfirstcorrelated(A.t, t, iguess)
i > length(A.t) ? i -= 1 : nothing
i -= 1
Expand All @@ -148,7 +150,7 @@ function derivative(A::CubicSpline{<:AbstractVector}, t::Number, iguess)
dI + dC + dD, i
end

function derivative(A::BSplineInterpolation{<:AbstractVector{<:Number}}, t::Number, iguess)
function _derivative(A::BSplineInterpolation{<:AbstractVector{<:Number}}, t::Number, iguess)
# change t into param [0 1]
t < A.t[1] && return zero(A.u[1]), 1
t > A.t[end] && return zero(A.u[end]), lastindex(t)
Expand All @@ -170,7 +172,7 @@ function derivative(A::BSplineInterpolation{<:AbstractVector{<:Number}}, t::Numb
end

# BSpline Curve Approx
function derivative(A::BSplineApprox{<:AbstractVector{<:Number}}, t::Number, iguess)
function _derivative(A::BSplineApprox{<:AbstractVector{<:Number}}, t::Number, iguess)
# change t into param [0 1]
t < A.t[1] && return zero(A.u[1]), 1
t > A.t[end] && return zero(A.u[end]), lastindex(t)
Expand Down
27 changes: 24 additions & 3 deletions test/derivative_tests.jl
Expand Up @@ -17,32 +17,53 @@ function test_derivatives(method, u, t; args = [], kwargs = [], name::String)
cdiff = central_fdm(5, 1; geom = true)(func, _t)
adiff = derivative(func, _t)
@test isapprox(cdiff, adiff, atol = 1e-8)
adiff2 = derivative(func, _t, 2)
cdiff2 = central_fdm(5, 1; geom = true)(t -> derivative(func, t), _t)
@test isapprox(cdiff2, adiff2, atol = 1e-8)
end

# Interpolation time points
for _t in t[2:(end - 1)]
fdiff = if func isa BSplineInterpolation || func isa BSplineApprox
forward_fdm(5, 1; geom = true)(func, _t)
if func isa BSplineInterpolation || func isa BSplineApprox
fdiff = forward_fdm(5, 1; geom = true)(func, _t)
fdiff2 = forward_fdm(5, 1; geom = true)(t -> derivative(func, t), _t)
else
backward_fdm(5, 1; geom = true)(func, _t)
fdiff = backward_fdm(5, 1; geom = true)(func, _t)
fdiff2 = backward_fdm(5, 1; geom = true)(t -> derivative(func, t), _t)
end
adiff = derivative(func, _t)
adiff2 = derivative(func, _t, 2)
@test isapprox(fdiff, adiff, atol = 1e-8)
@test isapprox(fdiff2, adiff2, atol = 1e-8)
end

# t = t0
fdiff = forward_fdm(5, 1; geom = true)(func, t[1])
adiff = derivative(func, t[1])
@test isapprox(fdiff, adiff, atol = 1e-8)
if !(func isa BSplineInterpolation || func isa BSplineApprox)
fdiff2 = forward_fdm(5, 1; geom = true)(t -> derivative(func, t), t[1])
adiff2 = derivative(func, t[1], 2)
@test isapprox(fdiff2, adiff2, atol = 1e-8)
end

# t = tend
fdiff = backward_fdm(5, 1; geom = true)(func, t[end])
adiff = derivative(func, t[end])
@test isapprox(fdiff, adiff, atol = 1e-8)
if !(func isa BSplineInterpolation || func isa BSplineApprox)
fdiff2 = backward_fdm(5, 1; geom = true)(t -> derivative(func, t), t[end])
adiff2 = derivative(func, t[end], 2)
@test isapprox(fdiff2, adiff2, atol = 1e-8)
end
end
@test_throws DataInterpolations.DerivativeNotFoundError derivative(
func, t[1], 3)
func = method(u, t, args...)
@test_throws DataInterpolations.ExtrapolationError derivative(func, t[1] - 1.0)
@test_throws DataInterpolations.ExtrapolationError derivative(func, t[end] + 1.0)
@test_throws DataInterpolations.DerivativeNotFoundError derivative(
func, t[1], 3)
end

@testset "Linear Interpolation" begin
Expand Down

0 comments on commit 19fae1f

Please sign in to comment.