From 29456a009c26ec721e273bfd3f846c853c0b2526 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 17 Nov 2025 11:55:34 +0530 Subject: [PATCH 1/2] feat: update to Symbolics@7 --- Project.toml | 2 +- ext/DataInterpolationsSymbolicsExt.jl | 41 +++++++++++++++++++++------ 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index 7f38079d..efaea1d5 100644 --- a/Project.toml +++ b/Project.toml @@ -51,7 +51,7 @@ SafeTestsets = "0.1" SparseConnectivityTracer = "1" StableRNGs = "1" StaticArrays = "1.9" -Symbolics = "6.46" +Symbolics = "6.46, 7" Test = "1.10" Unitful = "1.21.1" Zygote = "0.6.77, 0.7" diff --git a/ext/DataInterpolationsSymbolicsExt.jl b/ext/DataInterpolationsSymbolicsExt.jl index 317856b7..04e91a68 100644 --- a/ext/DataInterpolationsSymbolicsExt.jl +++ b/ext/DataInterpolationsSymbolicsExt.jl @@ -8,17 +8,40 @@ using Symbolics: Num, unwrap, SymbolicUtils @register_symbolic (interp::AbstractInterpolation)(t) Base.nameof(interp::AbstractInterpolation) = :Interpolation -function derivative(interp::AbstractInterpolation, t::Num, order = 1) - Symbolics.wrap(SymbolicUtils.term(derivative, interp, unwrap(t), order)) -end -SymbolicUtils.promote_symtype(::typeof(derivative), _...) = Real +@static if pkgversion(Symbolics) >= v"7" + @register_symbolic derivative(interp::AbstractInterpolation, t, order::Integer) false + function SymbolicUtils.promote_symtype(::typeof(derivative), Ti::SymbolicUtils.TypeT, + Tt::SymbolicUtils.TypeT, + To::SymbolicUtils.TypeT) + @assert Ti <: AbstractInterpolation + @assert Tt <: Real + @assert To <: Integer + Real + end + function SymbolicUtils.promote_shape(::typeof(derivative), + @nospecialize(shi::SymbolicUtils.ShapeT), + @nospecialize(sht::SymbolicUtils.ShapeT), + @nospecialize(sho::SymbolicUtils.ShapeT)) + @assert !SymbolicUtils.is_array_shape(shi) + @assert !SymbolicUtils.is_array_shape(sht) + @assert !SymbolicUtils.is_array_shape(sho) + return SymbolicUtils.ShapeVecT() + end -function Symbolics.derivative(::typeof(derivative), args::NTuple{3, Any}, ::Val{2}) - Symbolics.unwrap(derivative(args[1], Symbolics.wrap(args[2]), args[3] + 1)) -end + @register_derivative derivative(interp, t, ord) 2 derivative(interp, t, ord + 1) + @register_derivative (interp::AbstractInterpolation)(t) 1 derivative(interp, t, 1) +else + function derivative(interp::AbstractInterpolation, t::Num, order = 1) + Symbolics.wrap(SymbolicUtils.term(derivative, interp, unwrap(t), order)) + end + SymbolicUtils.promote_symtype(::typeof(derivative), _...) = Real + function Symbolics.derivative(::typeof(derivative), args::NTuple{3, Any}, ::Val{2}) + Symbolics.unwrap(derivative(args[1], Symbolics.wrap(args[2]), args[3] + 1)) + end -function Symbolics.derivative(interp::AbstractInterpolation, args::NTuple{1, Any}, ::Val{1}) - Symbolics.unwrap(derivative(interp, Symbolics.wrap(args[1]))) + function Symbolics.derivative(interp::AbstractInterpolation, args::NTuple{1, Any}, ::Val{1}) + Symbolics.unwrap(derivative(interp, Symbolics.wrap(args[1]))) + end end end # module From 0e9903bb3fdd68f06ff4b1bcad162c42a4a0f79f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 17 Nov 2025 12:16:21 +0530 Subject: [PATCH 2/2] test: update tests to Symbolics@7 --- test/derivative_tests.jl | 4 ++-- test/interface.jl | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/derivative_tests.jl b/test/derivative_tests.jl index 9b35bf89..a5257ce7 100644 --- a/test/derivative_tests.jl +++ b/test/derivative_tests.jl @@ -335,8 +335,8 @@ end expr = A(ω) @test isequal(Symbolics.derivative(expr, τ), D(ω) * DataInterpolations.derivative(A, ω)) - derivexpr1 = expand_derivatives(substitute(D(A(ω)), Dict(ω => 0.5τ))) - derivexpr2 = expand_derivatives(substitute(D2(A(ω)), Dict(ω => 0.5τ))) + derivexpr1 = expand_derivatives(substitute(D(A(ω)), Dict(ω => 0.5τ); filterer = Returns(true))) + derivexpr2 = expand_derivatives(substitute(D2(A(ω)), Dict(ω => 0.5τ); filterer = Returns(true))) symfunc1 = Symbolics.build_function(derivexpr1, τ; expression = Val{false}) symfunc2 = Symbolics.build_function(derivexpr2, τ; expression = Val{false}) @test symfunc1(0.5) == 1.5 diff --git a/test/interface.jl b/test/interface.jl index f7222f5b..fa8f635b 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -23,9 +23,9 @@ end @variables t x(t) substitute(A(t), Dict(t => x)) t_val = 2.7 - @test substitute(A(t), Dict(t => t_val)) == A(t_val) - @test substitute(B(A(t)), Dict(t => t_val)) == B(A(t_val)) - @test substitute(A(B(A(t))), Dict(t => t_val)) == A(B(A(t_val))) + @test substitute(A(t), Dict(t => t_val); fold = Val(true)) == A(t_val) + @test substitute(B(A(t)), Dict(t => t_val); fold = Val(true)) == B(A(t_val)) + @test substitute(A(B(A(t))), Dict(t => t_val); fold = Val(true)) == A(B(A(t_val))) end @testset "Type Inference" begin