From 6952cf7a3e5990f5960a41207a1a7e6b4fae6c54 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 28 Jan 2025 10:18:21 +0100 Subject: [PATCH 1/2] feat: support contexts with symbolic backends --- DifferentiationInterface/Project.toml | 4 +- .../docs/src/explanation/backends.md | 34 +- ...ntiationInterfaceFastDifferentiationExt.jl | 15 +- .../onearg.jl | 442 +++++++++++------- .../twoarg.jl | 229 +++++---- .../DifferentiationInterfaceSymbolicsExt.jl | 13 +- .../onearg.jl | 383 ++++++++++----- .../twoarg.jl | 175 ++++--- DifferentiationInterface/src/utils/context.jl | 6 + .../SymbolicBackends/fastdifferentiation.jl | 4 +- .../test/Back/SymbolicBackends/symbolics.jl | 11 +- 11 files changed, 848 insertions(+), 468 deletions(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 438b7312c..14ed008ff 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.33" +version = "0.6.34" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -55,7 +55,7 @@ Diffractor = "=0.2.6" Enzyme = "0.13.17" EnzymeCore = "0.8.8" ExplicitImports = "1.10.1" -FastDifferentiation = "0.4.1" +FastDifferentiation = "0.4.3" FiniteDiff = "2.23.1" FiniteDifferences = "0.12.31" ForwardDiff = "0.10.36" diff --git a/DifferentiationInterface/docs/src/explanation/backends.md b/DifferentiationInterface/docs/src/explanation/backends.md index f206f90d9..47461c402 100644 --- a/DifferentiationInterface/docs/src/explanation/backends.md +++ b/DifferentiationInterface/docs/src/explanation/backends.md @@ -56,23 +56,23 @@ In practice, many AD backends have custom implementations for high-level operato Moreover, each context type is supported by a specific subset of backends: -| | [`Constant`](@ref) | -| -------------------------- | ------------------ | -| `AutoChainRules` | ✅ | -| `AutoDiffractor` | ❌ | -| `AutoEnzyme` (forward) | ✅ | -| `AutoEnzyme` (reverse) | ✅ | -| `AutoFastDifferentiation` | ❌ | -| `AutoFiniteDiff` | ✅ | -| `AutoFiniteDifferences` | ✅ | -| `AutoForwardDiff` | ✅ | -| `AutoGTPSA` | ✅ | -| `AutoMooncake` | ✅ | -| `AutoPolyesterForwardDiff` | ✅ | -| `AutoReverseDiff` | ✅ | -| `AutoSymbolics` | ❌ | -| `AutoTracker` | ✅ | -| `AutoZygote` | ✅ | +| | [`Constant`](@ref) | [`Cache`](@ref) | +| -------------------------- | ------------------ | --------------- | +| `AutoChainRules` | ✅ | ❌ | +| `AutoDiffractor` | ❌ | ❌ | +| `AutoEnzyme` (forward) | ✅ | ✅ | +| `AutoEnzyme` (reverse) | ✅ | ✅ | +| `AutoFastDifferentiation` | ✅ | ✅ | +| `AutoFiniteDiff` | ✅ | ✅ | +| `AutoFiniteDifferences` | ✅ | ✅ | +| `AutoForwardDiff` | ✅ | ✅ | +| `AutoGTPSA` | ✅ | ❌ | +| `AutoMooncake` | ✅ | ❌ | +| `AutoPolyesterForwardDiff` | ✅ | ✅ | +| `AutoReverseDiff` | ✅ | ❌ | +| `AutoSymbolics` | ✅ | ❌ | +| `AutoTracker` | ✅ | ❌ | +| `AutoZygote` | ✅ | ❌ | ## Second order diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl index 6c27d109f..430cc95f9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl @@ -18,14 +18,23 @@ using FastDifferentiation.RuntimeGeneratedFunctions: RuntimeGeneratedFunction DI.check_available(::AutoFastDifferentiation) = true -monovec(x::Number) = [x] - -myvec(x::Number) = monovec(x) +myvec(x::Number) = [x] myvec(x::AbstractArray) = vec(x) +variablize(::Number, name::Symbol) = only(make_variables(name)) +variablize(x::AbstractArray, name::Symbol) = make_variables(name, size(x)...) + +function variablize(contexts::NTuple{C,DI.Context}) where {C} + map(enumerate(contexts)) do (k, c) + variablize(DI.unwrap(c), Symbol("context$k")) + end +end + dense_ad(backend::AutoFastDifferentiation) = backend dense_ad(backend::AutoSparse{<:AutoFastDifferentiation}) = ADTypes.dense_ad(backend) +myvec_unwrap(x) = myvec(DI.unwrap(x)) + include("onearg.jl") include("twoarg.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl index 3e6075369..5ec4cc644 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl @@ -6,20 +6,23 @@ struct FastDifferentiationOneArgPushforwardPrep{Y,E1,E1!} <: DI.PushforwardPrep jvp_exe!::E1! end -function DI.prepare_pushforward(f, ::AutoFastDifferentiation, x, tx::NTuple) - y_prototype = f(x) - x_var = if x isa Number - only(make_variables(:x)) - else - make_variables(:x, size(x)...) - end - y_var = f(x_var) - - x_vec_var = x_var isa Number ? monovec(x_var) : vec(x_var) - y_vec_var = y_var isa Number ? monovec(y_var) : vec(y_var) +function DI.prepare_pushforward( + f, ::AutoFastDifferentiation, x, tx::NTuple, contexts::Vararg{DI.Context,C} +) where {C} + y_prototype = f(x, map(DI.unwrap, contexts)...) + x_var = variablize(x, :x) + context_vars = variablize(contexts) + y_var = f(x_var, context_vars...) + x_vec_var = myvec(x_var) + context_vec_vars = map(myvec, context_vars) + y_vec_var = myvec(y_var) jv_vec_var, v_vec_var = jacobian_times_v(y_vec_var, x_vec_var) - jvp_exe = make_function(jv_vec_var, vcat(x_vec_var, v_vec_var); in_place=false) - jvp_exe! = make_function(jv_vec_var, vcat(x_vec_var, v_vec_var); in_place=true) + jvp_exe = make_function( + jv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=false + ) + jvp_exe! = make_function( + jv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true + ) return FastDifferentiationOneArgPushforwardPrep(y_prototype, jvp_exe, jvp_exe!) end @@ -29,13 +32,14 @@ function DI.pushforward( ::AutoFastDifferentiation, x, tx::NTuple, -) + contexts::Vararg{DI.Context,C}, +) where {C} ty = map(tx) do dx - v_vec = vcat(myvec(x), myvec(dx)) + result = prep.jvp_exe(myvec(x), myvec(dx), map(myvec_unwrap, contexts)...) if prep.y_prototype isa Number - return only(prep.jvp_exe(v_vec)) + return only(result) else - return reshape(prep.jvp_exe(v_vec), size(prep.y_prototype)) + return reshape(result, size(prep.y_prototype)) end end return ty @@ -48,11 +52,11 @@ function DI.pushforward!( ::AutoFastDifferentiation, x, tx::NTuple, -) + contexts::Vararg{DI.Context,C}, +) where {C} for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] - v_vec = vcat(myvec(x), myvec(dx)) - prep.jvp_exe!(vec(dy), v_vec) + prep.jvp_exe!(vec(dy), myvec(x), myvec(dx), map(myvec_unwrap, contexts)...) end return ty end @@ -63,8 +67,10 @@ function DI.value_and_pushforward( backend::AutoFastDifferentiation, x, tx::NTuple, -) - return f(x), DI.pushforward(f, prep, backend, x, tx) + contexts::Vararg{DI.Context,C}, +) where {C} + return f(x, map(DI.unwrap, contexts)...), + DI.pushforward(f, prep, backend, x, tx, contexts...) end function DI.value_and_pushforward!( @@ -74,8 +80,10 @@ function DI.value_and_pushforward!( backend::AutoFastDifferentiation, x, tx::NTuple, -) - return f(x), DI.pushforward!(f, ty, prep, backend, x, tx) + contexts::Vararg{DI.Context,C}, +) where {C} + return f(x, map(DI.unwrap, contexts)...), + DI.pushforward!(f, ty, prep, backend, x, tx, contexts...) end ## Pullback @@ -85,31 +93,40 @@ struct FastDifferentiationOneArgPullbackPrep{E1,E1!} <: DI.PullbackPrep vjp_exe!::E1! end -function DI.prepare_pullback(f, ::AutoFastDifferentiation, x, ty::NTuple) - x_var = if x isa Number - only(make_variables(:x)) - else - make_variables(:x, size(x)...) - end - y_var = f(x_var) +function DI.prepare_pullback( + f, ::AutoFastDifferentiation, x, ty::NTuple, contexts::Vararg{DI.Context,C} +) where {C} + x_var = variablize(x, :x) + context_vars = variablize(contexts) + y_var = f(x_var, context_vars...) - x_vec_var = x_var isa Number ? monovec(x_var) : vec(x_var) - y_vec_var = y_var isa Number ? monovec(y_var) : vec(y_var) + x_vec_var = myvec(x_var) + context_vec_vars = map(myvec, context_vars) + y_vec_var = myvec(y_var) vj_vec_var, v_vec_var = jacobian_transpose_v(y_vec_var, x_vec_var) - vjp_exe = make_function(vj_vec_var, vcat(x_vec_var, v_vec_var); in_place=false) - vjp_exe! = make_function(vj_vec_var, vcat(x_vec_var, v_vec_var); in_place=true) + vjp_exe = make_function( + vj_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=false + ) + vjp_exe! = make_function( + vj_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true + ) return FastDifferentiationOneArgPullbackPrep(vjp_exe, vjp_exe!) end function DI.pullback( - f, prep::FastDifferentiationOneArgPullbackPrep, ::AutoFastDifferentiation, x, ty::NTuple -) + f, + prep::FastDifferentiationOneArgPullbackPrep, + ::AutoFastDifferentiation, + x, + ty::NTuple, + contexts::Vararg{DI.Context,C}, +) where {C} tx = map(ty) do dy - v_vec = vcat(myvec(x), myvec(dy)) + result = prep.vjp_exe(myvec(x), myvec(dy), map(myvec_unwrap, contexts)...) if x isa Number - return only(prep.vjp_exe(v_vec)) + return only(result) else - return reshape(prep.vjp_exe(v_vec), size(x)) + return reshape(result, size(x)) end end return tx @@ -122,11 +139,11 @@ function DI.pullback!( ::AutoFastDifferentiation, x, ty::NTuple, -) + contexts::Vararg{DI.Context,C}, +) where {C} for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] - v_vec = vcat(myvec(x), myvec(dy)) - prep.vjp_exe!(vec(dx), v_vec) + prep.vjp_exe!(vec(dx), myvec(x), myvec(dy), map(myvec_unwrap, contexts)...) end return tx end @@ -137,8 +154,10 @@ function DI.value_and_pullback( backend::AutoFastDifferentiation, x, ty::NTuple, -) - return f(x), DI.pullback(f, prep, backend, x, ty) + contexts::Vararg{DI.Context,C}, +) where {C} + return f(x, map(DI.unwrap, contexts)...), + DI.pullback(f, prep, backend, x, ty, contexts...) end function DI.value_and_pullback!( @@ -148,8 +167,10 @@ function DI.value_and_pullback!( backend::AutoFastDifferentiation, x, ty::NTuple, -) - return f(x), DI.pullback!(f, tx, prep, backend, x, ty) + contexts::Vararg{DI.Context,C}, +) where {C} + return f(x, map(DI.unwrap, contexts)...), + DI.pullback!(f, tx, prep, backend, x, ty, contexts...) end ## Derivative @@ -160,40 +181,59 @@ struct FastDifferentiationOneArgDerivativePrep{Y,E1,E1!} <: DI.DerivativePrep der_exe!::E1! end -function DI.prepare_derivative(f, ::AutoFastDifferentiation, x) - y_prototype = f(x) - x_var = only(make_variables(:x)) - y_var = f(x_var) +function DI.prepare_derivative( + f, ::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C} +) where {C} + y_prototype = f(x, map(DI.unwrap, contexts)...) + x_var = variablize(x, :x) + context_vars = variablize(contexts) + y_var = f(x_var, context_vars...) - x_vec_var = monovec(x_var) - y_vec_var = y_var isa Number ? monovec(y_var) : vec(y_var) + x_vec_var = myvec(x_var) + context_vec_vars = map(myvec, context_vars) + y_vec_var = myvec(y_var) der_vec_var = derivative(y_vec_var, x_var) - der_exe = make_function(der_vec_var, x_vec_var; in_place=false) - der_exe! = make_function(der_vec_var, x_vec_var; in_place=true) + der_exe = make_function(der_vec_var, x_vec_var, context_vec_vars...; in_place=false) + der_exe! = make_function(der_vec_var, x_vec_var, context_vec_vars...; in_place=true) return FastDifferentiationOneArgDerivativePrep(y_prototype, der_exe, der_exe!) end function DI.derivative( - f, prep::FastDifferentiationOneArgDerivativePrep, ::AutoFastDifferentiation, x -) + f, + prep::FastDifferentiationOneArgDerivativePrep, + ::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + result = prep.der_exe(myvec(x), map(myvec_unwrap, contexts)...) if prep.y_prototype isa Number - return only(prep.der_exe(monovec(x))) + return only(result) else - return reshape(prep.der_exe(monovec(x)), size(prep.y_prototype)) + return reshape(result, size(prep.y_prototype)) end end function DI.derivative!( - f, der, prep::FastDifferentiationOneArgDerivativePrep, ::AutoFastDifferentiation, x -) - prep.der_exe!(vec(der), monovec(x)) + f, + der, + prep::FastDifferentiationOneArgDerivativePrep, + ::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + prep.der_exe!(vec(der), myvec(x), map(myvec_unwrap, contexts)...) return der end function DI.value_and_derivative( - f, prep::FastDifferentiationOneArgDerivativePrep, backend::AutoFastDifferentiation, x -) - return f(x), DI.derivative(f, prep, backend, x) + f, + prep::FastDifferentiationOneArgDerivativePrep, + backend::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + return f(x, map(DI.unwrap, contexts)...), + DI.derivative(f, prep, backend, x, contexts...) end function DI.value_and_derivative!( @@ -202,8 +242,10 @@ function DI.value_and_derivative!( prep::FastDifferentiationOneArgDerivativePrep, backend::AutoFastDifferentiation, x, -) - return f(x), DI.derivative!(f, der, prep, backend, x) + contexts::Vararg{DI.Context,C}, +) where {C} + return f(x, map(DI.unwrap, contexts)...), + DI.derivative!(f, der, prep, backend, x, contexts...) end ## Gradient @@ -213,37 +255,54 @@ struct FastDifferentiationOneArgGradientPrep{E1,E1!} <: DI.GradientPrep jac_exe!::E1! end -function DI.prepare_gradient(f, backend::AutoFastDifferentiation, x) - x_var = make_variables(:x, size(x)...) - y_var = f(x_var) +function DI.prepare_gradient( + f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C} +) where {C} + x_var = variablize(x, :x) + context_vars = variablize(contexts) + y_var = f(x_var, context_vars...) - x_vec_var = vec(x_var) - y_vec_var = monovec(y_var) + x_vec_var = myvec(x_var) + context_vec_vars = map(myvec, context_vars) + y_vec_var = myvec(y_var) jac_var = jacobian(y_vec_var, x_vec_var) - jac_exe = make_function(jac_var, x_vec_var; in_place=false) - jac_exe! = make_function(jac_var, x_vec_var; in_place=true) + jac_exe = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=false) + jac_exe! = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=true) return FastDifferentiationOneArgGradientPrep(jac_exe, jac_exe!) end function DI.gradient( - f, prep::FastDifferentiationOneArgGradientPrep, ::AutoFastDifferentiation, x -) - jac = prep.jac_exe(vec(x)) + f, + prep::FastDifferentiationOneArgGradientPrep, + ::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + jac = prep.jac_exe(myvec(x), map(myvec_unwrap, contexts)...) grad_vec = @view jac[1, :] return reshape(grad_vec, size(x)) end function DI.gradient!( - f, grad, prep::FastDifferentiationOneArgGradientPrep, ::AutoFastDifferentiation, x -) - prep.jac_exe!(reshape(grad, 1, length(grad)), vec(x)) + f, + grad, + prep::FastDifferentiationOneArgGradientPrep, + ::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + prep.jac_exe!(reshape(grad, 1, length(grad)), myvec(x), map(myvec_unwrap, contexts)...) return grad end function DI.value_and_gradient( - f, prep::FastDifferentiationOneArgGradientPrep, backend::AutoFastDifferentiation, x -) - return f(x), DI.gradient(f, prep, backend, x) + f, + prep::FastDifferentiationOneArgGradientPrep, + backend::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + return f(x, map(DI.unwrap, contexts)...), DI.gradient(f, prep, backend, x, contexts...) end function DI.value_and_gradient!( @@ -252,8 +311,10 @@ function DI.value_and_gradient!( prep::FastDifferentiationOneArgGradientPrep, backend::AutoFastDifferentiation, x, -) - return f(x), DI.gradient!(f, grad, prep, backend, x) + contexts::Vararg{DI.Context,C}, +) where {C} + return f(x, map(DI.unwrap, contexts)...), + DI.gradient!(f, grad, prep, backend, x, contexts...) end ## Jacobian @@ -265,21 +326,26 @@ struct FastDifferentiationOneArgJacobianPrep{Y,E1,E1!} <: DI.JacobianPrep end function DI.prepare_jacobian( - f, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x -) - y_prototype = f(x) - x_var = make_variables(:x, size(x)...) - y_var = f(x_var) - - x_vec_var = vec(x_var) - y_vec_var = vec(y_var) + f, + backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + y_prototype = f(x, map(DI.unwrap, contexts)...) + x_var = variablize(x, :x) + context_vars = variablize(contexts) + y_var = f(x_var, context_vars...) + + x_vec_var = myvec(x_var) + context_vec_vars = map(myvec, context_vars) + y_vec_var = myvec(y_var) jac_var = if backend isa AutoSparse sparse_jacobian(y_vec_var, x_vec_var) else jacobian(y_vec_var, x_vec_var) end - jac_exe = make_function(jac_var, x_vec_var; in_place=false) - jac_exe! = make_function(jac_var, x_vec_var; in_place=true) + jac_exe = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=false) + jac_exe! = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=true) return FastDifferentiationOneArgJacobianPrep(y_prototype, jac_exe, jac_exe!) end @@ -288,8 +354,9 @@ function DI.jacobian( prep::FastDifferentiationOneArgJacobianPrep, ::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, -) - return prep.jac_exe(vec(x)) + contexts::Vararg{DI.Context,C}, +) where {C} + return prep.jac_exe(myvec(x), map(myvec_unwrap, contexts)...) end function DI.jacobian!( @@ -298,8 +365,9 @@ function DI.jacobian!( prep::FastDifferentiationOneArgJacobianPrep, ::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, -) - prep.jac_exe!(jac, vec(x)) + contexts::Vararg{DI.Context,C}, +) where {C} + prep.jac_exe!(jac, myvec(x), map(myvec_unwrap, contexts)...) return jac end @@ -308,8 +376,9 @@ function DI.value_and_jacobian( prep::FastDifferentiationOneArgJacobianPrep, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, -) - return f(x), DI.jacobian(f, prep, backend, x) + contexts::Vararg{DI.Context,C}, +) where {C} + return f(x, map(DI.unwrap, contexts)...), DI.jacobian(f, prep, backend, x, contexts...) end function DI.value_and_jacobian!( @@ -318,8 +387,10 @@ function DI.value_and_jacobian!( prep::FastDifferentiationOneArgJacobianPrep, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, -) - return f(x), DI.jacobian!(f, jac, prep, backend, x) + contexts::Vararg{DI.Context,C}, +) where {C} + return f(x, map(DI.unwrap, contexts)...), + DI.jacobian!(f, jac, prep, backend, x, contexts...) end ## Second derivative @@ -332,31 +403,40 @@ struct FastDifferentiationAllocatingSecondDerivativePrep{Y,D,E2,E2!} <: der2_exe!::E2! end -function DI.prepare_second_derivative(f, backend::AutoFastDifferentiation, x) - y_prototype = f(x) - x_var = only(make_variables(:x)) - y_var = f(x_var) +function DI.prepare_second_derivative( + f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C} +) where {C} + y_prototype = f(x, map(DI.unwrap, contexts)...) + x_var = variablize(x, :x) + context_vars = variablize(contexts) + y_var = f(x_var, context_vars...) - x_vec_var = monovec(x_var) - y_vec_var = y_var isa Number ? monovec(y_var) : vec(y_var) + x_vec_var = myvec(x_var) + context_vec_vars = map(myvec, context_vars) + y_vec_var = myvec(y_var) der2_vec_var = derivative(y_vec_var, x_var, x_var) - der2_exe = make_function(der2_vec_var, x_vec_var; in_place=false) - der2_exe! = make_function(der2_vec_var, x_vec_var; in_place=true) + der2_exe = make_function(der2_vec_var, x_vec_var, context_vec_vars...; in_place=false) + der2_exe! = make_function(der2_vec_var, x_vec_var, context_vec_vars...; in_place=true) - derivative_prep = DI.prepare_derivative(f, backend, x) + derivative_prep = DI.prepare_derivative(f, backend, x, contexts...) return FastDifferentiationAllocatingSecondDerivativePrep( y_prototype, derivative_prep, der2_exe, der2_exe! ) end function DI.second_derivative( - f, prep::FastDifferentiationAllocatingSecondDerivativePrep, ::AutoFastDifferentiation, x -) + f, + prep::FastDifferentiationAllocatingSecondDerivativePrep, + ::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + result = prep.der2_exe(myvec(x), map(myvec_unwrap, contexts)...) if prep.y_prototype isa Number - return only(prep.der2_exe(monovec(x))) + return only(result) else - return reshape(prep.der2_exe(monovec(x)), size(prep.y_prototype)) + return reshape(result, size(prep.y_prototype)) end end @@ -364,10 +444,11 @@ function DI.second_derivative!( f, der2, prep::FastDifferentiationAllocatingSecondDerivativePrep, - backend::AutoFastDifferentiation, + ::AutoFastDifferentiation, x, -) - prep.der2_exe!(vec(der2), monovec(x)) + contexts::Vararg{DI.Context,C}, +) where {C} + prep.der2_exe!(vec(der2), myvec(x), map(myvec_unwrap, contexts)...) return der2 end @@ -376,9 +457,10 @@ function DI.value_derivative_and_second_derivative( prep::FastDifferentiationAllocatingSecondDerivativePrep, backend::AutoFastDifferentiation, x, -) - y, der = DI.value_and_derivative(f, prep.derivative_prep, backend, x) - der2 = DI.second_derivative(f, prep, backend, x) + contexts::Vararg{DI.Context,C}, +) where {C} + y, der = DI.value_and_derivative(f, prep.derivative_prep, backend, x, contexts...) + der2 = DI.second_derivative(f, prep, backend, x, contexts...) return y, der, der2 end @@ -389,9 +471,10 @@ function DI.value_derivative_and_second_derivative!( prep::FastDifferentiationAllocatingSecondDerivativePrep, backend::AutoFastDifferentiation, x, -) - y, _ = DI.value_and_derivative!(f, der, prep.derivative_prep, backend, x) - DI.second_derivative!(f, der2, prep, backend, x) + contexts::Vararg{DI.Context,C}, +) where {C} + y, _ = DI.value_and_derivative!(f, der, prep.derivative_prep, backend, x, contexts...) + DI.second_derivative!(f, der2, prep, backend, x, contexts...) return y, der, der2 end @@ -403,25 +486,37 @@ struct FastDifferentiationHVPPrep{E2,E2!,E1} <: DI.HVPPrep gradient_prep::E1 end -function DI.prepare_hvp(f, backend::AutoFastDifferentiation, x, tx::NTuple) - x_var = make_variables(:x, size(x)...) - y_var = f(x_var) +function DI.prepare_hvp( + f, backend::AutoFastDifferentiation, x, tx::NTuple, contexts::Vararg{DI.Context,C} +) where {C} + x_var = variablize(x, :x) + context_vars = variablize(contexts) + y_var = f(x_var, context_vars...) - x_vec_var = vec(x_var) + x_vec_var = myvec(x_var) + context_vec_vars = map(myvec, context_vars) hv_vec_var, v_vec_var = hessian_times_v(y_var, x_vec_var) - hvp_exe = make_function(hv_vec_var, vcat(x_vec_var, v_vec_var); in_place=false) - hvp_exe! = make_function(hv_vec_var, vcat(x_vec_var, v_vec_var); in_place=true) + hvp_exe = make_function( + hv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=false + ) + hvp_exe! = make_function( + hv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true + ) - gradient_prep = DI.prepare_gradient(f, backend, x) + gradient_prep = DI.prepare_gradient(f, backend, x, contexts...) return FastDifferentiationHVPPrep(hvp_exe, hvp_exe!, gradient_prep) end function DI.hvp( - f, prep::FastDifferentiationHVPPrep, ::AutoFastDifferentiation, x, tx::NTuple -) + f, + prep::FastDifferentiationHVPPrep, + ::AutoFastDifferentiation, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}, +) where {C} tg = map(tx) do dx - v_vec = vcat(vec(x), vec(dx)) - dg_vec = prep.hvp_exe(v_vec) + dg_vec = prep.hvp_exe(myvec(x), myvec(dx), map(myvec_unwrap, contexts)...) return reshape(dg_vec, size(x)) end return tg @@ -434,20 +529,25 @@ function DI.hvp!( ::AutoFastDifferentiation, x, tx::NTuple, -) + contexts::Vararg{DI.Context,C}, +) where {C} for b in eachindex(tx, tg) dx, dg = tx[b], tg[b] - v_vec = vcat(vec(x), vec(dx)) - prep.hvp_exe!(dg, v_vec) + prep.hvp_exe!(dg, myvec(x), myvec(dx), map(myvec_unwrap, contexts)...) end return tg end function DI.gradient_and_hvp( - f, prep::FastDifferentiationHVPPrep, backend::AutoFastDifferentiation, x, tx::NTuple -) - tg = DI.hvp(f, prep, backend, x, tx) - grad = DI.gradient(f, prep.gradient_prep, backend, x) + f, + prep::FastDifferentiationHVPPrep, + backend::AutoFastDifferentiation, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}, +) where {C} + tg = DI.hvp(f, prep, backend, x, tx, contexts...) + grad = DI.gradient(f, prep.gradient_prep, backend, x, contexts...) return grad, tg end @@ -459,9 +559,10 @@ function DI.gradient_and_hvp!( backend::AutoFastDifferentiation, x, tx::NTuple, -) - DI.hvp!(f, tg, prep, backend, x, tx) - DI.gradient!(f, grad, prep.gradient_prep, backend, x) + contexts::Vararg{DI.Context,C}, +) where {C} + DI.hvp!(f, tg, prep, backend, x, tx, contexts...) + DI.gradient!(f, grad, prep.gradient_prep, backend, x, contexts...) return grad, tg end @@ -474,22 +575,27 @@ struct FastDifferentiationHessianPrep{G,E2,E2!} <: DI.HessianPrep end function DI.prepare_hessian( - f, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x -) - x_var = make_variables(:x, size(x)...) - y_var = f(x_var) + f, + backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + x_var = variablize(x, :x) + context_vars = variablize(contexts) + y_var = f(x_var, context_vars...) - x_vec_var = vec(x_var) + x_vec_var = myvec(x_var) + context_vec_vars = map(myvec, context_vars) hess_var = if backend isa AutoSparse sparse_hessian(y_var, x_vec_var) else hessian(y_var, x_vec_var) end - hess_exe = make_function(hess_var, x_vec_var; in_place=false) - hess_exe! = make_function(hess_var, x_vec_var; in_place=true) + hess_exe = make_function(hess_var, x_vec_var, context_vec_vars...; in_place=false) + hess_exe! = make_function(hess_var, x_vec_var, context_vec_vars...; in_place=true) - gradient_prep = DI.prepare_gradient(f, dense_ad(backend), x) + gradient_prep = DI.prepare_gradient(f, dense_ad(backend), x, contexts...) return FastDifferentiationHessianPrep(gradient_prep, hess_exe, hess_exe!) end @@ -498,8 +604,9 @@ function DI.hessian( prep::FastDifferentiationHessianPrep, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, -) - return prep.hess_exe(vec(x)) + contexts::Vararg{DI.Context,C}, +) where {C} + return prep.hess_exe(myvec(x), map(myvec_unwrap, contexts)...) end function DI.hessian!( @@ -508,8 +615,9 @@ function DI.hessian!( prep::FastDifferentiationHessianPrep, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, -) - prep.hess_exe!(hess, vec(x)) + contexts::Vararg{DI.Context,C}, +) where {C} + prep.hess_exe!(hess, myvec(x), map(myvec_unwrap, contexts)...) return hess end @@ -518,9 +626,12 @@ function DI.value_gradient_and_hessian( prep::FastDifferentiationHessianPrep, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, -) - y, grad = DI.value_and_gradient(f, prep.gradient_prep, dense_ad(backend), x) - hess = DI.hessian(f, prep, backend, x) + contexts::Vararg{DI.Context,C}, +) where {C} + y, grad = DI.value_and_gradient( + f, prep.gradient_prep, dense_ad(backend), x, contexts... + ) + hess = DI.hessian(f, prep, backend, x, contexts...) return y, grad, hess end @@ -531,8 +642,11 @@ function DI.value_gradient_and_hessian!( prep::FastDifferentiationHessianPrep, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, -) - y, _ = DI.value_and_gradient!(f, grad, prep.gradient_prep, dense_ad(backend), x) - DI.hessian!(f, hess, prep, backend, x) + contexts::Vararg{DI.Context,C}, +) where {C} + y, _ = DI.value_and_gradient!( + f, grad, prep.gradient_prep, dense_ad(backend), x, contexts... + ) + DI.hessian!(f, hess, prep, backend, x, contexts...) return y, grad, hess end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl index 0e44f44cd..db77e3591 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl @@ -5,20 +5,24 @@ struct FastDifferentiationTwoArgPushforwardPrep{E1,E1!} <: DI.PushforwardPrep jvp_exe!::E1! end -function DI.prepare_pushforward(f!, y, ::AutoFastDifferentiation, x, tx::NTuple) - x_var = if x isa Number - only(make_variables(:x)) - else - make_variables(:x, size(x)...) - end - y_var = make_variables(:y, size(y)...) - f!(y_var, x_var) +function DI.prepare_pushforward( + f!, y, ::AutoFastDifferentiation, x, tx::NTuple, contexts::Vararg{DI.Context,C} +) where {C} + x_var = variablize(x, :x) + context_vars = variablize(contexts) + y_var = variablize(y, :y) + f!(y_var, x_var, context_vars...) - x_vec_var = x_var isa Number ? monovec(x_var) : vec(x_var) - y_vec_var = vec(y_var) + x_vec_var = myvec(x_var) + context_vec_vars = map(myvec, context_vars) + y_vec_var = myvec(y_var) jv_vec_var, v_vec_var = jacobian_times_v(y_vec_var, x_vec_var) - jvp_exe = make_function(jv_vec_var, vcat(x_vec_var, v_vec_var); in_place=false) - jvp_exe! = make_function(jv_vec_var, vcat(x_vec_var, v_vec_var); in_place=true) + jvp_exe = make_function( + jv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=false + ) + jvp_exe! = make_function( + jv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true + ) return FastDifferentiationTwoArgPushforwardPrep(jvp_exe, jvp_exe!) end @@ -29,10 +33,10 @@ function DI.pushforward( ::AutoFastDifferentiation, x, tx::NTuple, -) + contexts::Vararg{DI.Context,C}, +) where {C} ty = map(tx) do dx - v_vec = vcat(myvec(x), myvec(dx)) - reshape(prep.jvp_exe(v_vec), size(y)) + reshape(prep.jvp_exe(myvec(x), myvec(dx), map(myvec_unwrap, contexts)...), size(y)) end return ty end @@ -45,11 +49,11 @@ function DI.pushforward!( ::AutoFastDifferentiation, x, tx::NTuple, -) + contexts::Vararg{DI.Context,C}, +) where {C} for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] - v_vec = vcat(myvec(x), myvec(dx)) - prep.jvp_exe!(vec(dy), v_vec) + prep.jvp_exe!(myvec(dy), myvec(x), myvec(dx), map(myvec_unwrap, contexts)...) end return ty end @@ -61,9 +65,10 @@ function DI.value_and_pushforward( backend::AutoFastDifferentiation, x, tx::NTuple, -) - ty = DI.pushforward(f!, y, prep, backend, x, tx) - f!(y, x) + contexts::Vararg{DI.Context,C}, +) where {C} + ty = DI.pushforward(f!, y, prep, backend, x, tx, contexts...) + f!(y, x, map(DI.unwrap, contexts)...) return y, ty end @@ -75,9 +80,10 @@ function DI.value_and_pushforward!( backend::AutoFastDifferentiation, x, tx::NTuple, -) - DI.pushforward!(f!, y, ty, prep, backend, x, tx) - f!(y, x) + contexts::Vararg{DI.Context,C}, +) where {C} + DI.pushforward!(f!, y, ty, prep, backend, x, tx, contexts...) + f!(y, x, map(DI.unwrap, contexts)...) return y, ty end @@ -88,20 +94,24 @@ struct FastDifferentiationTwoArgPullbackPrep{E1,E1!} <: DI.PullbackPrep vjp_exe!::E1! end -function DI.prepare_pullback(f!, y, ::AutoFastDifferentiation, x, ty::NTuple) - x_var = if x isa Number - only(make_variables(:x)) - else - make_variables(:x, size(x)...) - end - y_var = make_variables(:y, size(y)...) - f!(y_var, x_var) +function DI.prepare_pullback( + f!, y, ::AutoFastDifferentiation, x, ty::NTuple, contexts::Vararg{DI.Context,C} +) where {C} + x_var = variablize(x, :x) + context_vars = variablize(contexts) + y_var = variablize(y, :y) + f!(y_var, x_var, context_vars...) - x_vec_var = x_var isa Number ? monovec(x_var) : vec(x_var) - y_vec_var = y_var isa Number ? monovec(y_var) : vec(y_var) + x_vec_var = myvec(x_var) + context_vec_vars = map(myvec, context_vars) + y_vec_var = myvec(y_var) vj_vec_var, v_vec_var = jacobian_transpose_v(y_vec_var, x_vec_var) - vjp_exe = make_function(vj_vec_var, vcat(x_vec_var, v_vec_var); in_place=false) - vjp_exe! = make_function(vj_vec_var, vcat(x_vec_var, v_vec_var); in_place=true) + vjp_exe = make_function( + vj_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=false + ) + vjp_exe! = make_function( + vj_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true + ) return FastDifferentiationTwoArgPullbackPrep(vjp_exe, vjp_exe!) end @@ -112,13 +122,14 @@ function DI.pullback( ::AutoFastDifferentiation, x, ty::NTuple, -) + contexts::Vararg{DI.Context,C}, +) where {C} tx = map(ty) do dy - v_vec = vcat(myvec(x), myvec(dy)) + result = prep.vjp_exe(myvec(x), myvec(dy), map(myvec_unwrap, contexts)...) if x isa Number - return only(prep.vjp_exe(v_vec)) + return only(result) else - return reshape(prep.vjp_exe(v_vec), size(x)) + return reshape(result, size(x)) end end return tx @@ -132,11 +143,11 @@ function DI.pullback!( ::AutoFastDifferentiation, x, ty::NTuple, -) + contexts::Vararg{DI.Context,C}, +) where {C} for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] - v_vec = vcat(myvec(x), myvec(dy)) - prep.vjp_exe!(vec(dx), v_vec) + prep.vjp_exe!(myvec(dx), myvec(x), myvec(dy), map(myvec_unwrap, contexts)...) end return tx end @@ -148,9 +159,10 @@ function DI.value_and_pullback( backend::AutoFastDifferentiation, x, ty::NTuple, -) - tx = DI.pullback(f!, y, prep, backend, x, ty) - f!(y, x) + contexts::Vararg{DI.Context,C}, +) where {C} + tx = DI.pullback(f!, y, prep, backend, x, ty, contexts...) + f!(y, x, map(DI.unwrap, contexts)...) return y, tx end @@ -162,9 +174,10 @@ function DI.value_and_pullback!( backend::AutoFastDifferentiation, x, ty::NTuple, -) - DI.pullback!(f!, y, tx, prep, backend, x, ty) - f!(y, x) + contexts::Vararg{DI.Context,C}, +) where {C} + DI.pullback!(f!, y, tx, prep, backend, x, ty, contexts...) + f!(y, x, map(DI.unwrap, contexts)...) return y, tx end @@ -175,46 +188,72 @@ struct FastDifferentiationTwoArgDerivativePrep{E1,E1!} <: DI.DerivativePrep der_exe!::E1! end -function DI.prepare_derivative(f!, y, ::AutoFastDifferentiation, x) - x_var = only(make_variables(:x)) - y_var = make_variables(:y, size(y)...) - f!(y_var, x_var) +function DI.prepare_derivative( + f!, y, ::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C} +) where {C} + x_var = variablize(x, :x) + context_vars = variablize(contexts) + y_var = variablize(y, :y) + f!(y_var, x_var, context_vars...) - x_vec_var = monovec(x_var) - y_vec_var = vec(y_var) + x_vec_var = myvec(x_var) + context_vec_vars = map(myvec, context_vars) + y_vec_var = myvec(y_var) der_vec_var = derivative(y_vec_var, x_var) - der_exe = make_function(der_vec_var, x_vec_var; in_place=false) - der_exe! = make_function(der_vec_var, x_vec_var; in_place=true) + der_exe = make_function(der_vec_var, x_vec_var, context_vec_vars...; in_place=false) + der_exe! = make_function(der_vec_var, x_vec_var, context_vec_vars...; in_place=true) return FastDifferentiationTwoArgDerivativePrep(der_exe, der_exe!) end function DI.value_and_derivative( - f!, y, prep::FastDifferentiationTwoArgDerivativePrep, ::AutoFastDifferentiation, x -) - f!(y, x) - der = reshape(prep.der_exe(monovec(x)), size(y)) + f!, + y, + prep::FastDifferentiationTwoArgDerivativePrep, + ::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + f!(y, x, map(DI.unwrap, contexts)...) + der = reshape(prep.der_exe(myvec(x), map(myvec_unwrap, contexts)...), size(y)) return y, der end function DI.value_and_derivative!( - f!, y, der, prep::FastDifferentiationTwoArgDerivativePrep, ::AutoFastDifferentiation, x -) - f!(y, x) - prep.der_exe!(der, monovec(x)) + f!, + y, + der, + prep::FastDifferentiationTwoArgDerivativePrep, + ::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + f!(y, x, map(DI.unwrap, contexts)...) + prep.der_exe!(der, myvec(x), map(myvec_unwrap, contexts)...) return y, der end function DI.derivative( - f!, y, prep::FastDifferentiationTwoArgDerivativePrep, ::AutoFastDifferentiation, x -) - der = reshape(prep.der_exe(monovec(x)), size(y)) + f!, + y, + prep::FastDifferentiationTwoArgDerivativePrep, + ::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + der = reshape(prep.der_exe(myvec(x), map(myvec_unwrap, contexts)...), size(y)) return der end function DI.derivative!( - f!, y, der, prep::FastDifferentiationTwoArgDerivativePrep, ::AutoFastDifferentiation, x -) - prep.der_exe!(der, monovec(x)) + f!, + y, + der, + prep::FastDifferentiationTwoArgDerivativePrep, + ::AutoFastDifferentiation, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + prep.der_exe!(der, myvec(x), map(myvec_unwrap, contexts)...) return der end @@ -226,21 +265,27 @@ struct FastDifferentiationTwoArgJacobianPrep{E1,E1!} <: DI.JacobianPrep end function DI.prepare_jacobian( - f!, y, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x -) - x_var = make_variables(:x, size(x)...) - y_var = make_variables(:y, size(y)...) - f!(y_var, x_var) + f!, + y, + backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + x_var = variablize(x, :x) + context_vars = variablize(contexts) + y_var = variablize(y, :y) + f!(y_var, x_var, context_vars...) - x_vec_var = vec(x_var) - y_vec_var = vec(y_var) + x_vec_var = myvec(x_var) + context_vec_vars = map(myvec, context_vars) + y_vec_var = myvec(y_var) jac_var = if backend isa AutoSparse sparse_jacobian(y_vec_var, x_vec_var) else jacobian(y_vec_var, x_vec_var) end - jac_exe = make_function(jac_var, x_vec_var; in_place=false) - jac_exe! = make_function(jac_var, x_vec_var; in_place=true) + jac_exe = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=false) + jac_exe! = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=true) return FastDifferentiationTwoArgJacobianPrep(jac_exe, jac_exe!) end @@ -250,9 +295,10 @@ function DI.value_and_jacobian( prep::FastDifferentiationTwoArgJacobianPrep, ::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, -) - f!(y, x) - jac = prep.jac_exe(vec(x)) + contexts::Vararg{DI.Context,C}, +) where {C} + f!(y, x, map(DI.unwrap, contexts)...) + jac = prep.jac_exe(myvec(x), map(myvec_unwrap, contexts)...) return y, jac end @@ -263,9 +309,10 @@ function DI.value_and_jacobian!( prep::FastDifferentiationTwoArgJacobianPrep, ::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, -) - f!(y, x) - prep.jac_exe!(jac, vec(x)) + contexts::Vararg{DI.Context,C}, +) where {C} + f!(y, x, map(DI.unwrap, contexts)...) + prep.jac_exe!(jac, myvec(x), map(myvec_unwrap, contexts)...) return y, jac end @@ -275,8 +322,9 @@ function DI.jacobian( prep::FastDifferentiationTwoArgJacobianPrep, ::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, -) - jac = prep.jac_exe(vec(x)) + contexts::Vararg{DI.Context,C}, +) where {C} + jac = prep.jac_exe(myvec(x), map(myvec_unwrap, contexts)...) return jac end @@ -287,7 +335,8 @@ function DI.jacobian!( prep::FastDifferentiationTwoArgJacobianPrep, ::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x, -) - prep.jac_exe!(jac, vec(x)) + contexts::Vararg{DI.Context,C}, +) where {C} + prep.jac_exe!(jac, myvec(x), map(myvec_unwrap, contexts)...) return jac end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl index 3b02560ae..4e9c79a35 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl @@ -21,14 +21,21 @@ using Symbolics.RuntimeGeneratedFunctions: RuntimeGeneratedFunction DI.check_available(::AutoSymbolics) = true DI.pullback_performance(::AutoSymbolics) = DI.PullbackSlow() -monovec(x::Number) = [x] - -myvec(x::Number) = monovec(x) +myvec(x::Number) = [x] myvec(x::AbstractArray) = vec(x) dense_ad(backend::AutoSymbolics) = backend dense_ad(backend::AutoSparse{<:AutoSymbolics}) = ADTypes.dense_ad(backend) +variablize(::Number, name::Symbol) = variable(name) +variablize(x::AbstractArray, name::Symbol) = variables(name, axes(x)...) + +function variablize(contexts::NTuple{C,DI.Context}) where {C} + map(enumerate(contexts)) do (k, c) + variablize(DI.unwrap(c), Symbol("context$k")) + end +end + include("onearg.jl") include("twoarg.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl index 32756b58e..9f0102a8d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl @@ -5,23 +5,18 @@ struct SymbolicsOneArgPushforwardPrep{E1,E1!} <: DI.PushforwardPrep pf_exe!::E1! end -function DI.prepare_pushforward(f, ::AutoSymbolics, x, tx::NTuple) +function DI.prepare_pushforward( + f, ::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C} +) where {C} dx = first(tx) - x_var = if x isa Number - variable(:x) - else - variables(:x, axes(x)...) - end - dx_var = if dx isa Number - variable(:dx) - else - variables(:dx, axes(dx)...) - end + x_var = variablize(x, :x) + dx_var = variablize(dx, :dx) t_var = variable(:t) - step_der_var = derivative(f(x_var + t_var * dx_var), t_var) + context_vars = variablize(contexts) + step_der_var = derivative(f(x_var + t_var * dx_var, context_vars...), t_var) pf_var = substitute(step_der_var, Dict(t_var => zero(eltype(x)))) - res = build_function(pf_var, vcat(myvec(x_var), myvec(dx_var)); expression=Val(false)) + res = build_function(pf_var, x_var, dx_var, context_vars...; expression=Val(false)) (pf_exe, pf_exe!) = if res isa Tuple res elseif res isa RuntimeGeneratedFunction @@ -31,30 +26,45 @@ function DI.prepare_pushforward(f, ::AutoSymbolics, x, tx::NTuple) end function DI.pushforward( - f, prep::SymbolicsOneArgPushforwardPrep, ::AutoSymbolics, x, tx::NTuple -) + f, + prep::SymbolicsOneArgPushforwardPrep, + ::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}, +) where {C} ty = map(tx) do dx - v_vec = vcat(myvec(x), myvec(dx)) - dy = prep.pf_exe(v_vec) + dy = prep.pf_exe(x, dx, map(DI.unwrap, contexts)...) end return ty end function DI.pushforward!( - f, ty::NTuple, prep::SymbolicsOneArgPushforwardPrep, ::AutoSymbolics, x, tx::NTuple -) + f, + ty::NTuple, + prep::SymbolicsOneArgPushforwardPrep, + ::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}, +) where {C} for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] - v_vec = vcat(myvec(x), myvec(dx)) - prep.pf_exe!(dy, v_vec) + prep.pf_exe!(dy, x, dx, map(DI.unwrap, contexts)...) end return ty end function DI.value_and_pushforward( - f, prep::SymbolicsOneArgPushforwardPrep, backend::AutoSymbolics, x, tx::NTuple -) - return f(x), DI.pushforward(f, prep, backend, x, tx) + f, + prep::SymbolicsOneArgPushforwardPrep, + backend::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}, +) where {C} + return f(x, map(DI.unwrap, contexts)...), + DI.pushforward(f, prep, backend, x, tx, contexts...) end function DI.value_and_pushforward!( @@ -64,8 +74,10 @@ function DI.value_and_pushforward!( backend::AutoSymbolics, x, tx::NTuple, -) - return f(x), DI.pushforward!(f, ty, prep, backend, x, tx) + contexts::Vararg{DI.Context,C}, +) where {C} + return f(x, map(DI.unwrap, contexts)...), + DI.pushforward!(f, ty, prep, backend, x, tx, contexts...) end ## Derivative @@ -75,11 +87,14 @@ struct SymbolicsOneArgDerivativePrep{E1,E1!} <: DI.DerivativePrep der_exe!::E1! end -function DI.prepare_derivative(f, ::AutoSymbolics, x) - x_var = variable(:x) - der_var = derivative(f(x_var), x_var) +function DI.prepare_derivative( + f, ::AutoSymbolics, x, contexts::Vararg{DI.Context,C} +) where {C} + x_var = variablize(x, :x) + context_vars = variablize(contexts) + der_var = derivative(f(x_var, context_vars...), x_var) - res = build_function(der_var, x_var; expression=Val(false)) + res = build_function(der_var, x_var, context_vars...; expression=Val(false)) (der_exe, der_exe!) = if res isa Tuple res elseif res isa RuntimeGeneratedFunction @@ -88,25 +103,49 @@ function DI.prepare_derivative(f, ::AutoSymbolics, x) return SymbolicsOneArgDerivativePrep(der_exe, der_exe!) end -function DI.derivative(f, prep::SymbolicsOneArgDerivativePrep, ::AutoSymbolics, x) - return prep.der_exe(x) +function DI.derivative( + f, + prep::SymbolicsOneArgDerivativePrep, + ::AutoSymbolics, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + return prep.der_exe(x, map(DI.unwrap, contexts)...) end -function DI.derivative!(f, der, prep::SymbolicsOneArgDerivativePrep, ::AutoSymbolics, x) - prep.der_exe!(der, x) +function DI.derivative!( + f, + der, + prep::SymbolicsOneArgDerivativePrep, + ::AutoSymbolics, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + prep.der_exe!(der, x, map(DI.unwrap, contexts)...) return der end function DI.value_and_derivative( - f, prep::SymbolicsOneArgDerivativePrep, backend::AutoSymbolics, x -) - return f(x), DI.derivative(f, prep, backend, x) + f, + prep::SymbolicsOneArgDerivativePrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + return f(x, map(DI.unwrap, contexts)...), + DI.derivative(f, prep, backend, x, contexts...) end function DI.value_and_derivative!( - f, der, prep::SymbolicsOneArgDerivativePrep, backend::AutoSymbolics, x -) - return f(x), DI.derivative!(f, der, prep, backend, x) + f, + der, + prep::SymbolicsOneArgDerivativePrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + return f(x, map(DI.unwrap, contexts)...), + DI.derivative!(f, der, prep, backend, x, contexts...) end ## Gradient @@ -116,35 +155,57 @@ struct SymbolicsOneArgGradientPrep{E1,E1!} <: DI.GradientPrep grad_exe!::E1! end -function DI.prepare_gradient(f, ::AutoSymbolics, x) - x_var = variables(:x, axes(x)...) +function DI.prepare_gradient( + f, ::AutoSymbolics, x, contexts::Vararg{DI.Context,C} +) where {C} + x_var = variablize(x, :x) + context_vars = variablize(contexts) # Symbolic.gradient only accepts vectors - grad_var = gradient(f(x_var), vec(x_var)) + grad_var = gradient(f(x_var, context_vars...), vec(x_var)) - res = build_function(grad_var, vec(x_var); expression=Val(false)) + res = build_function(grad_var, vec(x_var), context_vars...; expression=Val(false)) (grad_exe, grad_exe!) = res return SymbolicsOneArgGradientPrep(grad_exe, grad_exe!) end -function DI.gradient(f, prep::SymbolicsOneArgGradientPrep, ::AutoSymbolics, x) - return reshape(prep.grad_exe(vec(x)), size(x)) +function DI.gradient( + f, prep::SymbolicsOneArgGradientPrep, ::AutoSymbolics, x, contexts::Vararg{DI.Context,C} +) where {C} + return reshape(prep.grad_exe(vec(x), map(DI.unwrap, contexts)...), size(x)) end -function DI.gradient!(f, grad, prep::SymbolicsOneArgGradientPrep, ::AutoSymbolics, x) - prep.grad_exe!(vec(grad), vec(x)) +function DI.gradient!( + f, + grad, + prep::SymbolicsOneArgGradientPrep, + ::AutoSymbolics, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + prep.grad_exe!(vec(grad), vec(x), map(DI.unwrap, contexts)...) return grad end function DI.value_and_gradient( - f, prep::SymbolicsOneArgGradientPrep, backend::AutoSymbolics, x -) - return f(x), DI.gradient(f, prep, backend, x) + f, + prep::SymbolicsOneArgGradientPrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + return f(x, map(DI.unwrap, contexts)...), DI.gradient(f, prep, backend, x, contexts...) end function DI.value_and_gradient!( - f, grad, prep::SymbolicsOneArgGradientPrep, backend::AutoSymbolics, x -) - return f(x), DI.gradient!(f, grad, prep, backend, x) + f, + grad, + prep::SymbolicsOneArgGradientPrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + return f(x, map(DI.unwrap, contexts)...), + DI.gradient!(f, grad, prep, backend, x, contexts...) end ## Jacobian @@ -155,16 +216,20 @@ struct SymbolicsOneArgJacobianPrep{E1,E1!} <: DI.JacobianPrep end function DI.prepare_jacobian( - f, backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x -) - x_var = variables(:x, axes(x)...) + f, + backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + x_var = variablize(x, :x) + context_vars = variablize(contexts) jac_var = if backend isa AutoSparse - sparsejacobian(vec(f(x_var)), vec(x_var)) + sparsejacobian(vec(f(x_var, context_vars...)), vec(x_var)) else - jacobian(f(x_var), x_var) + jacobian(f(x_var, context_vars...), x_var) end - res = build_function(jac_var, x_var; expression=Val(false)) + res = build_function(jac_var, x_var, context_vars...; expression=Val(false)) (jac_exe, jac_exe!) = res return SymbolicsOneArgJacobianPrep(jac_exe, jac_exe!) end @@ -174,8 +239,9 @@ function DI.jacobian( prep::SymbolicsOneArgJacobianPrep, ::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, -) - return prep.jac_exe(x) + contexts::Vararg{DI.Context,C}, +) where {C} + return prep.jac_exe(x, map(DI.unwrap, contexts)...) end function DI.jacobian!( @@ -184,8 +250,9 @@ function DI.jacobian!( prep::SymbolicsOneArgJacobianPrep, ::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, -) - prep.jac_exe!(jac, x) + contexts::Vararg{DI.Context,C}, +) where {C} + prep.jac_exe!(jac, x, map(DI.unwrap, contexts)...) return jac end @@ -194,8 +261,9 @@ function DI.value_and_jacobian( prep::SymbolicsOneArgJacobianPrep, backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, -) - return f(x), DI.jacobian(f, prep, backend, x) + contexts::Vararg{DI.Context,C}, +) where {C} + return f(x, map(DI.unwrap, contexts)...), DI.jacobian(f, prep, backend, x, contexts...) end function DI.value_and_jacobian!( @@ -204,8 +272,10 @@ function DI.value_and_jacobian!( prep::SymbolicsOneArgJacobianPrep, backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, -) - return f(x), DI.jacobian!(f, jac, prep, backend, x) + contexts::Vararg{DI.Context,C}, +) where {C} + return f(x, map(DI.unwrap, contexts)...), + DI.jacobian!(f, jac, prep, backend, x, contexts...) end ## Hessian @@ -216,19 +286,25 @@ struct SymbolicsOneArgHessianPrep{G,E2,E2!} <: DI.HessianPrep hess_exe!::E2! end -function DI.prepare_hessian(f, backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x) - x_var = variables(:x, axes(x)...) +function DI.prepare_hessian( + f, + backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + x_var = variablize(x, :x) + context_vars = variablize(contexts) # Symbolic.hessian only accepts vectors hess_var = if backend isa AutoSparse - sparsehessian(f(x_var), vec(x_var)) + sparsehessian(f(x_var, context_vars...), vec(x_var)) else - hessian(f(x_var), vec(x_var)) + hessian(f(x_var, context_vars...), vec(x_var)) end - res = build_function(hess_var, vec(x_var); expression=Val(false)) + res = build_function(hess_var, vec(x_var), context_vars...; expression=Val(false)) (hess_exe, hess_exe!) = res - gradient_prep = DI.prepare_gradient(f, dense_ad(backend), x) + gradient_prep = DI.prepare_gradient(f, dense_ad(backend), x, contexts...) return SymbolicsOneArgHessianPrep(gradient_prep, hess_exe, hess_exe!) end @@ -237,8 +313,9 @@ function DI.hessian( prep::SymbolicsOneArgHessianPrep, ::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, -) - return prep.hess_exe(vec(x)) + contexts::Vararg{DI.Context,C}, +) where {C} + return prep.hess_exe(vec(x), map(DI.unwrap, contexts)...) end function DI.hessian!( @@ -247,8 +324,9 @@ function DI.hessian!( prep::SymbolicsOneArgHessianPrep, ::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, -) - prep.hess_exe!(hess, vec(x)) + contexts::Vararg{DI.Context,C}, +) where {C} + prep.hess_exe!(hess, vec(x), map(DI.unwrap, contexts)...) return hess end @@ -257,9 +335,12 @@ function DI.value_gradient_and_hessian( prep::SymbolicsOneArgHessianPrep, backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, -) - y, grad = DI.value_and_gradient(f, prep.gradient_prep, dense_ad(backend), x) - hess = DI.hessian(f, prep, backend, x) + contexts::Vararg{DI.Context,C}, +) where {C} + y, grad = DI.value_and_gradient( + f, prep.gradient_prep, dense_ad(backend), x, contexts... + ) + hess = DI.hessian(f, prep, backend, x, contexts...) return y, grad, hess end @@ -270,9 +351,12 @@ function DI.value_gradient_and_hessian!( prep::SymbolicsOneArgHessianPrep, backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, -) - y, _ = DI.value_and_gradient!(f, grad, prep.gradient_prep, dense_ad(backend), x) - DI.hessian!(f, hess, prep, backend, x) + contexts::Vararg{DI.Context,C}, +) where {C} + y, _ = DI.value_and_gradient!( + f, grad, prep.gradient_prep, dense_ad(backend), x, contexts... + ) + DI.hessian!(f, hess, prep, backend, x, contexts...) return y, grad, hess end @@ -284,52 +368,81 @@ struct SymbolicsOneArgHVPPrep{G,E2,E2!} <: DI.HVPPrep hvp_exe!::E2! end -function DI.prepare_hvp(f, backend::AutoSymbolics, x, tx::NTuple) - x_var = variables(:x, axes(x)...) - dx_var = variables(:dx, axes(x)...) +function DI.prepare_hvp( + f, backend::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C} +) where {C} + dx = first(tx) + x_var = variablize(x, :x) + dx_var = variablize(dx, :dx) + context_vars = variablize(contexts) # Symbolic.hessian only accepts vectors - hess_var = hessian(f(x_var), vec(x_var)) + hess_var = hessian(f(x_var, context_vars...), vec(x_var)) hvp_vec_var = hess_var * vec(dx_var) - res = build_function(hvp_vec_var, vcat(vec(x_var), vec(dx_var)); expression=Val(false)) + res = build_function( + hvp_vec_var, vec(x_var), vec(dx_var), context_vars...; expression=Val(false) + ) (hvp_exe, hvp_exe!) = res - gradient_prep = DI.prepare_gradient(f, backend, x) + gradient_prep = DI.prepare_gradient(f, backend, x, contexts...) return SymbolicsOneArgHVPPrep(gradient_prep, hvp_exe, hvp_exe!) end -function DI.hvp(f, prep::SymbolicsOneArgHVPPrep, ::AutoSymbolics, x, tx::NTuple) +function DI.hvp( + f, + prep::SymbolicsOneArgHVPPrep, + ::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}, +) where {C} return map(tx) do dx - v_vec = vcat(vec(x), vec(dx)) - dg_vec = prep.hvp_exe(v_vec) + dg_vec = prep.hvp_exe(vec(x), vec(dx), map(DI.unwrap, contexts)...) reshape(dg_vec, size(x)) end end function DI.hvp!( - f, tg::NTuple, prep::SymbolicsOneArgHVPPrep, ::AutoSymbolics, x, tx::NTuple -) + f, + tg::NTuple, + prep::SymbolicsOneArgHVPPrep, + ::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}, +) where {C} for b in eachindex(tx, tg) dx, dg = tx[b], tg[b] - v_vec = vcat(vec(x), vec(dx)) - prep.hvp_exe!(vec(dg), v_vec) + prep.hvp_exe!(vec(dg), vec(x), vec(dx), map(DI.unwrap, contexts)...) end return tg end function DI.gradient_and_hvp( - f, prep::SymbolicsOneArgHVPPrep, backend::AutoSymbolics, x, tx::NTuple -) - tg = DI.hvp(f, prep, backend, x, tx) - grad = DI.gradient(f, prep.gradient_prep, backend, x) + f, + prep::SymbolicsOneArgHVPPrep, + backend::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}, +) where {C} + tg = DI.hvp(f, prep, backend, x, tx, contexts...) + grad = DI.gradient(f, prep.gradient_prep, backend, x, contexts...) return grad, tg end function DI.gradient_and_hvp!( - f, grad, tg::NTuple, prep::SymbolicsOneArgHVPPrep, backend::AutoSymbolics, x, tx::NTuple -) - DI.hvp!(f, tg, prep, backend, x, tx) - DI.gradient!(f, grad, prep.gradient_prep, backend, x) + f, + grad, + tg::NTuple, + prep::SymbolicsOneArgHVPPrep, + backend::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}, +) where {C} + DI.hvp!(f, tg, prep, backend, x, tx, contexts...) + DI.gradient!(f, grad, prep.gradient_prep, backend, x, contexts...) return grad, tg end @@ -341,46 +454,68 @@ struct SymbolicsOneArgSecondDerivativePrep{D,E1,E1!} <: DI.SecondDerivativePrep der2_exe!::E1! end -function DI.prepare_second_derivative(f, backend::AutoSymbolics, x) - x_var = variable(:x) - der_var = derivative(f(x_var), x_var) +function DI.prepare_second_derivative( + f, backend::AutoSymbolics, x, contexts::Vararg{DI.Context,C} +) where {C} + x_var = variablize(x, :x) + context_vars = variablize(contexts) + der_var = derivative(f(x_var, context_vars...), x_var) der2_var = derivative(der_var, x_var) - res = build_function(der2_var, x_var; expression=Val(false)) + res = build_function(der2_var, x_var, context_vars...; expression=Val(false)) (der2_exe, der2_exe!) = if res isa Tuple res elseif res isa RuntimeGeneratedFunction res, nothing end - derivative_prep = DI.prepare_derivative(f, backend, x) + derivative_prep = DI.prepare_derivative(f, backend, x, contexts...) return SymbolicsOneArgSecondDerivativePrep(derivative_prep, der2_exe, der2_exe!) end function DI.second_derivative( - f, prep::SymbolicsOneArgSecondDerivativePrep, ::AutoSymbolics, x -) - return prep.der2_exe(x) + f, + prep::SymbolicsOneArgSecondDerivativePrep, + ::AutoSymbolics, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + return prep.der2_exe(x, map(DI.unwrap, contexts)...) end function DI.second_derivative!( - f, der2, prep::SymbolicsOneArgSecondDerivativePrep, ::AutoSymbolics, x -) - prep.der2_exe!(der2, x) + f, + der2, + prep::SymbolicsOneArgSecondDerivativePrep, + ::AutoSymbolics, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + prep.der2_exe!(der2, x, map(DI.unwrap, contexts)...) return der2 end function DI.value_derivative_and_second_derivative( - f, prep::SymbolicsOneArgSecondDerivativePrep, backend::AutoSymbolics, x -) - y, der = DI.value_and_derivative(f, prep.derivative_prep, backend, x) - der2 = DI.second_derivative(f, prep, backend, x) + f, + prep::SymbolicsOneArgSecondDerivativePrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + y, der = DI.value_and_derivative(f, prep.derivative_prep, backend, x, contexts...) + der2 = DI.second_derivative(f, prep, backend, x, contexts...) return y, der, der2 end function DI.value_derivative_and_second_derivative!( - f, der, der2, prep::SymbolicsOneArgSecondDerivativePrep, backend::AutoSymbolics, x -) - y, _ = DI.value_and_derivative!(f, der, prep.derivative_prep, backend, x) - DI.second_derivative!(f, der2, prep, backend, x) + f, + der, + der2, + prep::SymbolicsOneArgSecondDerivativePrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + y, _ = DI.value_and_derivative!(f, der, prep.derivative_prep, backend, x, contexts...) + DI.second_derivative!(f, der2, prep, backend, x, contexts...) return y, der, der2 end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl index 10cff3c37..ffe6ee0f4 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl @@ -5,55 +5,67 @@ struct SymbolicsTwoArgPushforwardPrep{E1,E1!} <: DI.PushforwardPrep pushforward_exe!::E1! end -function DI.prepare_pushforward(f!, y, ::AutoSymbolics, x, tx::NTuple) +function DI.prepare_pushforward( + f!, y, ::AutoSymbolics, x, tx::NTuple, contexts::Vararg{DI.Context,C} +) where {C} dx = first(tx) - x_var = if x isa Number - variable(:x) - else - variables(:x, axes(x)...) - end - dx_var = if dx isa Number - variable(:dx) - else - variables(:dx, axes(dx)...) - end - y_var = variables(:y, axes(y)...) + x_var = variablize(x, :x) + dx_var = variablize(dx, :dx) + context_vars = variablize(contexts) + y_var = variablize(y, :y) t_var = variable(:t) - f!(y_var, x_var + t_var * dx_var) + f!(y_var, x_var + t_var * dx_var, context_vars...) step_der_var = derivative(y_var, t_var) pf_var = substitute(step_der_var, Dict(t_var => zero(eltype(x)))) - res = build_function(pf_var, vcat(myvec(x_var), myvec(dx_var)); expression=Val(false)) + res = build_function(pf_var, x_var, dx_var, context_vars...; expression=Val(false)) (pushforward_exe, pushforward_exe!) = res return SymbolicsTwoArgPushforwardPrep(pushforward_exe, pushforward_exe!) end function DI.pushforward( - f!, y, prep::SymbolicsTwoArgPushforwardPrep, ::AutoSymbolics, x, tx::NTuple -) + f!, + y, + prep::SymbolicsTwoArgPushforwardPrep, + ::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}, +) where {C} ty = map(tx) do dx - v_vec = vcat(myvec(x), myvec(dx)) - dy = prep.pushforward_exe(v_vec) + dy = prep.pushforward_exe(x, dx, map(DI.unwrap, contexts)...) end return ty end function DI.pushforward!( - f!, y, ty::NTuple, prep::SymbolicsTwoArgPushforwardPrep, ::AutoSymbolics, x, tx::NTuple -) + f!, + y, + ty::NTuple, + prep::SymbolicsTwoArgPushforwardPrep, + ::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}, +) where {C} for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] - v_vec = vcat(myvec(x), myvec(dx)) - prep.pushforward_exe!(dy, v_vec) + prep.pushforward_exe!(dy, x, dx, map(DI.unwrap, contexts)...) end return ty end function DI.value_and_pushforward( - f!, y, prep::SymbolicsTwoArgPushforwardPrep, backend::AutoSymbolics, x, tx::NTuple -) - ty = DI.pushforward(f!, y, prep, backend, x, tx) - f!(y, x) + f!, + y, + prep::SymbolicsTwoArgPushforwardPrep, + backend::AutoSymbolics, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}, +) where {C} + ty = DI.pushforward(f!, y, prep, backend, x, tx, contexts...) + f!(y, x, map(DI.unwrap, contexts)...) return y, ty end @@ -65,9 +77,10 @@ function DI.value_and_pushforward!( backend::AutoSymbolics, x, tx::NTuple, -) - DI.pushforward!(f!, y, ty, prep, backend, x, tx) - f!(y, x) + contexts::Vararg{DI.Context,C}, +) where {C} + DI.pushforward!(f!, y, ty, prep, backend, x, tx, contexts...) + f!(y, x, map(DI.unwrap, contexts)...) return y, ty end @@ -78,39 +91,68 @@ struct SymbolicsTwoArgDerivativePrep{E1,E1!} <: DI.DerivativePrep der_exe!::E1! end -function DI.prepare_derivative(f!, y, ::AutoSymbolics, x) - x_var = variable(:x) - y_var = variables(:y, axes(y)...) - f!(y_var, x_var) +function DI.prepare_derivative( + f!, y, ::AutoSymbolics, x, contexts::Vararg{DI.Context,C} +) where {C} + x_var = variablize(x, :x) + y_var = variablize(y, :y) + context_vars = variablize(contexts) + f!(y_var, x_var, context_vars...) der_var = derivative(y_var, x_var) - res = build_function(der_var, x_var; expression=Val(false)) + res = build_function(der_var, x_var, context_vars...; expression=Val(false)) (der_exe, der_exe!) = res return SymbolicsTwoArgDerivativePrep(der_exe, der_exe!) end -function DI.derivative(f!, y, prep::SymbolicsTwoArgDerivativePrep, ::AutoSymbolics, x) - return prep.der_exe(x) +function DI.derivative( + f!, + y, + prep::SymbolicsTwoArgDerivativePrep, + ::AutoSymbolics, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + return prep.der_exe(x, map(DI.unwrap, contexts)...) end -function DI.derivative!(f!, y, der, prep::SymbolicsTwoArgDerivativePrep, ::AutoSymbolics, x) - prep.der_exe!(der, x) +function DI.derivative!( + f!, + y, + der, + prep::SymbolicsTwoArgDerivativePrep, + ::AutoSymbolics, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + prep.der_exe!(der, x, map(DI.unwrap, contexts)...) return der end function DI.value_and_derivative( - f!, y, prep::SymbolicsTwoArgDerivativePrep, backend::AutoSymbolics, x -) - der = DI.derivative(f!, y, prep, backend, x) - f!(y, x) + f!, + y, + prep::SymbolicsTwoArgDerivativePrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + der = DI.derivative(f!, y, prep, backend, x, contexts...) + f!(y, x, map(DI.unwrap, contexts)...) return y, der end function DI.value_and_derivative!( - f!, y, der, prep::SymbolicsTwoArgDerivativePrep, backend::AutoSymbolics, x -) - DI.derivative!(f!, y, der, prep, backend, x) - f!(y, x) + f!, + y, + der, + prep::SymbolicsTwoArgDerivativePrep, + backend::AutoSymbolics, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + DI.derivative!(f!, y, der, prep, backend, x, contexts...) + f!(y, x, map(DI.unwrap, contexts)...) return y, der end @@ -122,18 +164,23 @@ struct SymbolicsTwoArgJacobianPrep{E1,E1!} <: DI.JacobianPrep end function DI.prepare_jacobian( - f!, y, backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x -) - x_var = variables(:x, axes(x)...) - y_var = variables(:y, axes(y)...) - f!(y_var, x_var) + f!, + y, + backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, + x, + contexts::Vararg{DI.Context,C}, +) where {C} + x_var = variablize(x, :x) + y_var = variablize(y, :y) + context_vars = variablize(contexts) + f!(y_var, x_var, context_vars...) jac_var = if backend isa AutoSparse sparsejacobian(vec(y_var), vec(x_var)) else jacobian(y_var, x_var) end - res = build_function(jac_var, x_var; expression=Val(false)) + res = build_function(jac_var, x_var, context_vars...; expression=Val(false)) (jac_exe, jac_exe!) = res return SymbolicsTwoArgJacobianPrep(jac_exe, jac_exe!) end @@ -144,8 +191,9 @@ function DI.jacobian( prep::SymbolicsTwoArgJacobianPrep, ::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, -) - return prep.jac_exe(x) + contexts::Vararg{DI.Context,C}, +) where {C} + return prep.jac_exe(x, map(DI.unwrap, contexts)...) end function DI.jacobian!( @@ -155,8 +203,9 @@ function DI.jacobian!( prep::SymbolicsTwoArgJacobianPrep, ::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, -) - prep.jac_exe!(jac, x) + contexts::Vararg{DI.Context,C}, +) where {C} + prep.jac_exe!(jac, x, map(DI.unwrap, contexts)...) return jac end @@ -166,9 +215,10 @@ function DI.value_and_jacobian( prep::SymbolicsTwoArgJacobianPrep, backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, -) - jac = DI.jacobian(f!, y, prep, backend, x) - f!(y, x) + contexts::Vararg{DI.Context,C}, +) where {C} + jac = DI.jacobian(f!, y, prep, backend, x, contexts...) + f!(y, x, map(DI.unwrap, contexts)...) return y, jac end @@ -179,8 +229,9 @@ function DI.value_and_jacobian!( prep::SymbolicsTwoArgJacobianPrep, backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x, -) - DI.jacobian!(f!, y, jac, prep, backend, x) - f!(y, x) + contexts::Vararg{DI.Context,C}, +) where {C} + DI.jacobian!(f!, y, jac, prep, backend, x, contexts...) + f!(y, x, map(DI.unwrap, contexts)...) return y, jac end diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index 1abdbdd2e..15a9d4a0e 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -32,6 +32,9 @@ Concrete type of [`Context`](@ref) argument which is kept constant during differ Note that an operator can be prepared with an arbitrary value of the constant. However, same-point preparation must occur with the exact value that will be reused later. +!!! warning + Some backends require any `Constant` context to be a `Number` or an `AbstractArray`. + # Example ```jldoctest @@ -65,6 +68,9 @@ maker(::Constant) = constant_maker Concrete type of [`Context`](@ref) argument which can be mutated with active values during differentiation. The initial values present inside the cache do not matter. + +!!! warning + Most backends require any `Cache` context to be an `AbstractArray`. """ struct Cache{T} <: Context data::T diff --git a/DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl b/DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl index 1f049e161..191025ae2 100644 --- a/DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl +++ b/DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl @@ -17,8 +17,8 @@ end test_differentiation( AutoFastDifferentiation(), - filter(default_scenarios()) do s - !(s.x isa Matrix) && !(s.y isa Matrix) + filter(default_scenarios(; include_constantified=true, include_cachified=true)) do s + !(s.x isa AbstractMatrix) && !(s.y isa AbstractMatrix) end; logging=LOGGING, ); diff --git a/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl b/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl index 4eac97dab..31f8316a0 100644 --- a/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl +++ b/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl @@ -15,7 +15,16 @@ for backend in [AutoSymbolics(), AutoSparse(AutoSymbolics())] @test check_inplace(backend) end -test_differentiation(AutoSymbolics(); logging=LOGGING); +test_differentiation( + AutoSymbolics(), default_scenarios(; include_constantified=true); logging=LOGGING +); + +test_differentiation( + AutoSymbolics(), + default_scenarios(; include_normal=false, include_cachified=true); + excluded=[:jacobian], # TODO: figure out why this fails + logging=LOGGING, +); test_differentiation( AutoSparse(AutoSymbolics()), From 4d3ae60f491da9bedec948e4648ec35064d0da98 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 28 Jan 2025 12:35:38 +0100 Subject: [PATCH 2/2] No `myvec` --- .../DifferentiationInterfaceSymbolicsExt.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl index 4e9c79a35..2dc8d0018 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl @@ -21,9 +21,6 @@ using Symbolics.RuntimeGeneratedFunctions: RuntimeGeneratedFunction DI.check_available(::AutoSymbolics) = true DI.pullback_performance(::AutoSymbolics) = DI.PullbackSlow() -myvec(x::Number) = [x] -myvec(x::AbstractArray) = vec(x) - dense_ad(backend::AutoSymbolics) = backend dense_ad(backend::AutoSparse{<:AutoSymbolics}) = ADTypes.dense_ad(backend)