Skip to content

Commit

Permalink
Merge pull request #123 from avik-pal/ap/fix_adjoint
Browse files Browse the repository at this point in the history
Patch Adjoint Sensitivity for Simple Nonlinear Solve Algorithms
  • Loading branch information
ChrisRackauckas committed Feb 6, 2024
2 parents f8408d4 + 17220bb commit ba3b5a4
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 9 deletions.
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SimpleNonlinearSolve"
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
authors = ["SciML"]
version = "1.3.1"
version = "1.3.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -19,16 +19,19 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[extensions]
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff"
SimpleNonlinearSolveStaticArraysExt = "StaticArrays"

[compat]
ADTypes = "0.2.6"
ArrayInterface = "7"
ChainRulesCore = "1"
ConcreteStructs = "0.2"
DiffEqBase = "6.126"
FastClosures = "0.3"
Expand All @@ -39,6 +42,6 @@ MaybeInplace = "0.1"
PrecompileTools = "1"
Reexport = "1"
SciMLBase = "2.7"
StaticArraysCore = "1.4"
StaticArrays = "1"
StaticArraysCore = "1.4"
julia = "1.9"
21 changes: 21 additions & 0 deletions ext/SimpleNonlinearSolveChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module SimpleNonlinearSolveChainRulesCoreExt

using ChainRulesCore, DiffEqBase, SciMLBase, SimpleNonlinearSolve

# The expectation here is that no-one is using this directly inside a GPU kernel. We can
# eventually lift this requirement using a custom adjoint
function ChainRulesCore.rrule(::typeof(SimpleNonlinearSolve.__internal_solve_up),
prob::NonlinearProblem,
sensealg::Union{Nothing, DiffEqBase.AbstractSensitivityAlgorithm}, u0, u0_changed,
p, p_changed, alg, args...; kwargs...)
out, ∇internal = DiffEqBase._solve_adjoint(prob, sensealg, u0, p,
SciMLBase.ChainRulesOriginator(), alg, args...; kwargs...)
function ∇__internal_solve_up(Δ)
∂f, ∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(Δ)
return (∂f, ∂prob, ∂sensealg, ∂u0, NoTangent(), ∂p, NoTangent(), ∂originator,
∂args...)
end
return out, ∇__internal_solve_up
end

end
22 changes: 15 additions & 7 deletions src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,26 @@ include("ad.jl")
## Default algorithm

# Set the default bracketing method to ITP
function SciMLBase.solve(prob::IntervalNonlinearProblem; kwargs...)
return solve(prob, ITP(); kwargs...)
end

function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Nothing,
args...; kwargs...)
SciMLBase.solve(prob::IntervalNonlinearProblem; kwargs...) = solve(prob, ITP(); kwargs...)
function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Nothing, args...; kwargs...)
return solve(prob, ITP(), args...; kwargs...)
end

# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
args...; kwargs...)
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
sensealg = prob.kwargs[:sensealg]
end
new_u0 = u0 !== nothing ? u0 : prob.u0
new_p = p !== nothing ? p : prob.p
return __internal_solve_up(prob, sensealg, new_u0, u0 === nothing, new_p, p === nothing,
alg, args...; kwargs...)
end

function __internal_solve_up(_prob::NonlinearProblem, sensealg, u0, u0_changed, p,
p_changed, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...)
prob = u0_changed || p_changed ? remake(_prob; u0, p) : _prob
return SciMLBase.__solve(prob, alg, args...; kwargs...)
end

Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
NonlinearProblemLibrary = "0.1.2"
14 changes: 14 additions & 0 deletions test/adjoint.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using ForwardDiff, SciMLSensitivity, SimpleNonlinearSolve, Test, Zygote

@testset "Simple Adjoint Test" begin
ff(u, p) = u .^ 2 .- p

function solve_nlprob(p)
prob = NonlinearProblem{false}(ff, [1.0, 2.0], p)
return sum(abs2, solve(prob, SimpleNewtonRaphson()).u)
end

p = [3.0, 2.0]

@test only(Zygote.gradient(solve_nlprob, p)) ForwardDiff.gradient(solve_nlprob, p)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ end
@time @safetestset "Matrix Resizing Tests" include("matrix_resizing_tests.jl")
@time @safetestset "Least Squares Tests" include("least_squares.jl")
@time @safetestset "23 Test Problems" include("23_test_problems.jl")
@time @safetestset "Simple Adjoint Tests" include("adjoint.jl")
end

if GROUP == "CUDA"
Expand Down

0 comments on commit ba3b5a4

Please sign in to comment.