Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions lib/NonlinearSolveBase/ext/NonlinearSolveBaseChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using SciMLBase
using SciMLBase: AbstractSensitivityAlgorithm

import ChainRulesCore
import ChainRulesCore: NoTangent
import ChainRulesCore: NoTangent, Tangent

function ChainRulesCore.frule(::typeof(NonlinearSolveBase.solve_up), prob,
sensealg::Union{Nothing, AbstractSensitivityAlgorithm},
Expand All @@ -19,13 +19,23 @@ function ChainRulesCore.frule(::typeof(NonlinearSolveBase.solve_up), prob,
end

function ChainRulesCore.rrule(::typeof(NonlinearSolveBase.solve_up), prob::AbstractNonlinearProblem,
sensealg::Union{Nothing, AbstractSensitivityAlgorithm},
u0, p, args...; originator = SciMLBase.ChainRulesOriginator(),
kwargs...)
NonlinearSolveBase._solve_adjoint(
sensealg::Union{Nothing, AbstractSensitivityAlgorithm},
u0, p, args...; originator = SciMLBase.ChainRulesOriginator(),
kwargs...)
primal, inner_thunking_pb = NonlinearSolveBase._solve_adjoint(
prob, sensealg, u0, p,
originator, args...;
kwargs...)

# when using mooncake ∂sol would be a NamedTuple Tangent with cotangents of all the solution struct's fields.
# However the pullback for this rule - "steadystatebackpass" as defined in SciMLSensitivity/src/concrete_solve.jl/
# handles AD only when ∂sol is a ChainRulesCore.AbstractThunk object or a sol.u vector and similar data structures (not namedtuples).
# When using Mooncake, we pass in sol.u to inner_thunking_pb directly as this is the only field relevant to the solution's cotangent (given solve_up, AbstractNonlinearProblem setting).

function solve_up_adjoint(∂sol)
return inner_thunking_pb(∂sol isa Tangent{Any,<:NamedTuple} ? ∂sol.u : ∂sol)
end
return primal, solve_up_adjoint
end

end
39 changes: 17 additions & 22 deletions lib/NonlinearSolveBase/ext/NonlinearSolveBaseMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,25 @@ module NonlinearSolveBaseMooncakeExt

using NonlinearSolveBase, Mooncake
using SciMLBase: SciMLBase
import Mooncake: rrule!!, CoDual, zero_fcodual, @is_primitive,
@from_rrule, @zero_adjoint, @mooncake_overlay, MinimalCtx,
NoPullback
using Mooncake: rrule!!, CoDual, zero_fcodual, @is_primitive,
@from_chainrules, @zero_adjoint, @mooncake_overlay, MinimalCtx,
NoPullback

@from_rrule(MinimalCtx,
Tuple{
typeof(NonlinearSolveBase.solve_up),
SciMLBase.AbstractNonlinearProblem,
Union{Nothing, SciMLBase.AbstractSensitivityAlgorithm},
Any,
Any,
Any
},
true,)
@from_chainrules MinimalCtx Tuple{typeof(NonlinearSolveBase.solve_up),
SciMLBase.AbstractNonlinearProblem,
Union{Nothing,SciMLBase.AbstractSensitivityAlgorithm},
Any,
Any,
Any
} true

# Dispatch for auto-alg
@from_rrule(MinimalCtx,
Tuple{
typeof(NonlinearSolveBase.solve_up),
SciMLBase.AbstractNonlinearProblem,
Union{Nothing, SciMLBase.AbstractSensitivityAlgorithm},
Any,
Any
},
true,)
@from_chainrules MinimalCtx Tuple{
typeof(NonlinearSolveBase.solve_up),
SciMLBase.AbstractNonlinearProblem,
Union{Nothing,SciMLBase.AbstractSensitivityAlgorithm},
Any,
Any
} true

end
2 changes: 1 addition & 1 deletion test/adjoint_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@

@test ∂p_zygote ≈ ∂p_tracker ≈ ∂p_reversediff ≈ ∂p_enzyme
@test ∂p_zygote ≈ ∂p_forwarddiff ≈ ∂p_tracker ≈ ∂p_reversediff ≈ ∂p_enzyme
@test_broken ∂p_forwarddiff ≈ ∂p_mooncake
@test ∂p_forwarddiff ≈ ∂p_mooncake
end
Loading