diff --git a/Project.toml b/Project.toml index 42d2d3f9d..f46dd330e 100644 --- a/Project.toml +++ b/Project.toml @@ -71,7 +71,7 @@ ChainRulesCore = "1" ConcreteStructs = "0.2.3" Distributions = "0.25" DocStringExtensions = "0.9" -Enzyme = "0.13" +Enzyme = "0.13.100" EnzymeCore = "0.7, 0.8" FastBroadcast = "0.3.5" FastClosures = "0.3.2" diff --git a/ext/DiffEqBaseEnzymeExt.jl b/ext/DiffEqBaseEnzymeExt.jl index ae5fd322b..c3d31c163 100644 --- a/ext/DiffEqBaseEnzymeExt.jl +++ b/ext/DiffEqBaseEnzymeExt.jl @@ -7,33 +7,58 @@ module DiffEqBaseEnzymeExt import Enzyme: Const using ChainRulesCore + + @inline function copy_or_reuse(config, val, idx) + if Enzyme.EnzymeRules.overwritten(config)[idx] && ismutable(val) + return deepcopy(val) + else + return val + end + end + + @inline function arg_copy(data, i) + config, args = data + copy_or_reuse(config, args[i].val, i + 5) + end + + # Note these following functions are generally not considered user facing from within Enzyme. + # They enable additional performance/usability here (e.g. inactive kwargs). + # Contact wsmoses@ before modifying (and beware their semantics may change without semver). + + Enzyme.EnzymeRules.inactive_kwarg(::typeof(DiffEqBase.solve_up), prob, sensalg::Union{Nothing, DiffEqBase.AbstractSensitivityAlgorithm}, u0, p, args...; kwargs...) = nothing + + Enzyme.EnzymeRules.has_easy_rule(::typeof(DiffEqBase.solve_up), prob, sensalg::Union{Nothing, DiffEqBase.AbstractSensitivityAlgorithm}, u0, p, args...; kwargs...) = nothing + function Enzyme.EnzymeRules.augmented_primal( config::Enzyme.EnzymeRules.RevConfigWidth{1}, - func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob, + func::Const{typeof(DiffEqBase.solve_up)}, RTA::Type{Duplicated{RT}}, prob, sensealg::Union{ Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, u0, p, args...; kwargs...) where {RT} - @inline function copy_or_reuse(val, idx) - if Enzyme.EnzymeRules.overwritten(config)[idx] && ismutable(val) - return deepcopy(val) - else - return val - end - end - - @inline function arg_copy(i) - copy_or_reuse(args[i].val, i + 5) - end res = DiffEqBase._solve_adjoint( - copy_or_reuse(prob.val, 2), copy_or_reuse(sensealg.val, 3), - copy_or_reuse(u0.val, 4), copy_or_reuse(p.val, 5), - SciMLBase.EnzymeOriginator(), ntuple(arg_copy, Val(length(args)))...; + copy_or_reuse(config, prob.val, 2), copy_or_reuse(config, sensealg.val, 3), + copy_or_reuse(config, u0.val, 4), copy_or_reuse(config, p.val, 5), + SciMLBase.EnzymeOriginator(), ntuple(Base.Fix1(arg_copy, (config, args)), Val(length(args)))...; kwargs...) - dres = Enzyme.make_zero(res[1])::RT - tup = (dres, res[2]) - return Enzyme.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any) + primal = if Enzyme.EnzymeRules.needs_primal(config) + res[1] + else + nothing + end + + shadow = if Enzyme.EnzymeRules.needs_shadow(config) + Enzyme.make_zero(res[1])::RT + else + nothing + end + tup = if Enzyme.EnzymeRules.needs_shadow(config) + (shadow, res[2]) + else + nothing + end + return Enzyme.EnzymeRules.augmented_rule_return_type(config, RTA)(primal, shadow, tup) end function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1}, @@ -41,20 +66,24 @@ module DiffEqBaseEnzymeExt sensealg::Union{ Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, u0, p, args...; kwargs...) where {RT} - dres, clos = tape - dres = dres::RT - dargs = clos(dres) - for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...)) - if ptr isa Enzyme.Const - continue - end - if darg == ChainRulesCore.NoTangent() - continue + + if Enzyme.EnzymeRules.needs_shadow(config) + dres, clos = tape + dres = dres::RT + dargs = clos(dres) + for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...)) + if ptr isa Enzyme.Const + continue + end + if darg == ChainRulesCore.NoTangent() + continue + end + ptr.dval .+= darg end - ptr.dval .+= darg + Enzyme.make_zero!(dres.u) end - Enzyme.make_zero!(dres.u) - return ntuple(_ -> nothing, Val(length(args) + 4)) + + return ntuple(Returns(nothing), Val(length(args) + 4)) end end