diff --git a/Project.toml b/Project.toml index 90c87d393..7772ea7d5 100644 --- a/Project.toml +++ b/Project.toml @@ -9,7 +9,6 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" BracketingNonlinearSolve = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" @@ -44,24 +43,6 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412" Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" -[sources.BracketingNonlinearSolve] -path = "lib/BracketingNonlinearSolve" - -[sources.NonlinearSolveBase] -path = "lib/NonlinearSolveBase" - -[sources.NonlinearSolveFirstOrder] -path = "lib/NonlinearSolveFirstOrder" - -[sources.NonlinearSolveQuasiNewton] -path = "lib/NonlinearSolveQuasiNewton" - -[sources.NonlinearSolveSpectralMethods] -path = "lib/NonlinearSolveSpectralMethods" - -[sources.SimpleNonlinearSolve] -path = "lib/SimpleNonlinearSolve" - [extensions] NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt" NonlinearSolveFixedPointAccelerationExt = "FixedPointAcceleration" @@ -83,7 +64,6 @@ BenchmarkTools = "1.4" BracketingNonlinearSolve = "1" CommonSolve = "0.2.4" ConcreteStructs = "0.2.3" -DiffEqBase = "6.188" DifferentiationInterface = "0.7.3" ExplicitImports = "1.5" FastClosures = "0.3.2" @@ -104,7 +84,7 @@ NLSolvers = "0.5" NLsolve = "4.5" NaNMath = "1" NonlinearProblemLibrary = "0.1.2" -NonlinearSolveBase = "1.14" +NonlinearSolveBase = "1.15" NonlinearSolveFirstOrder = "1.2" NonlinearSolveQuasiNewton = "1.8" NonlinearSolveSpectralMethods = "1.1" @@ -117,6 +97,7 @@ Preferences = "1.4.3" Random = "1.10" ReTestItems = "1.24" Reexport = "1.2.2" +ReverseDiff = "1.15" SIAMFANLEquations = "1.0.1" SciMLBase = "2.116" SimpleNonlinearSolve = "2.1" @@ -155,6 +136,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" @@ -163,7 +145,26 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +[sources.BracketingNonlinearSolve] +path = "lib/BracketingNonlinearSolve" + +[sources.NonlinearSolveBase] +path = "lib/NonlinearSolveBase" + +[sources.NonlinearSolveFirstOrder] +path = "lib/NonlinearSolveFirstOrder" + +[sources.NonlinearSolveQuasiNewton] +path = "lib/NonlinearSolveQuasiNewton" + +[sources.NonlinearSolveSpectralMethods] +path = "lib/NonlinearSolveSpectralMethods" + +[sources.SimpleNonlinearSolve] +path = "lib/SimpleNonlinearSolve" + [targets] -test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SparseMatrixColorings", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote"] +test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SparseMatrixColorings", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote", "ReverseDiff", "Tracker"] diff --git a/docs/Project.toml b/docs/Project.toml index 2ecbbf3cc..bd67a099f 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,7 +4,6 @@ AlgebraicMultigrid = "2169fc97-5a83-5252-b627-83903c6c433c" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" BracketingNonlinearSolve = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" @@ -69,7 +68,6 @@ AlgebraicMultigrid = "0.5, 0.6, 1" ArrayInterface = "6, 7" BenchmarkTools = "1" BracketingNonlinearSolve = "1" -DiffEqBase = "6.188" DifferentiationInterface = "0.7.3" Documenter = "1" DocumenterCitations = "1" @@ -79,7 +77,7 @@ InteractiveUtils = "<0.0.1, 1" LineSearch = "0.1" LinearSolve = "2, 3" NonlinearSolve = "4" -NonlinearSolveBase = "1" +NonlinearSolveBase = "1.15" NonlinearSolveFirstOrder = "1" NonlinearSolveHomotopyContinuation = "0.1" NonlinearSolveQuasiNewton = "1" diff --git a/docs/make.jl b/docs/make.jl index cce131634..42cbb4b56 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,8 +1,7 @@ using Documenter, DocumenterCitations, DocumenterInterLinks -import DiffEqBase using Sundials -using NonlinearSolveBase, SciMLBase, DiffEqBase +using NonlinearSolveBase, SciMLBase using SimpleNonlinearSolve, BracketingNonlinearSolve using NonlinearSolveFirstOrder, NonlinearSolveQuasiNewton, NonlinearSolveSpectralMethods using NonlinearSolveHomotopyContinuation, NonlinearSolveSciPy @@ -33,7 +32,7 @@ makedocs(; sitename = "NonlinearSolve.jl", authors = "SciML", modules = [ - NonlinearSolveBase, SciMLBase, DiffEqBase, + NonlinearSolveBase, SciMLBase, SimpleNonlinearSolve, BracketingNonlinearSolve, NonlinearSolveFirstOrder, NonlinearSolveQuasiNewton, NonlinearSolveSpectralMethods, NonlinearSolveHomotopyContinuation, diff --git a/lib/BracketingNonlinearSolve/Project.toml b/lib/BracketingNonlinearSolve/Project.toml index a4976d168..bac48ac19 100644 --- a/lib/BracketingNonlinearSolve/Project.toml +++ b/lib/BracketingNonlinearSolve/Project.toml @@ -30,7 +30,7 @@ ConcreteStructs = "0.2.3" ExplicitImports = "1.10.1" ForwardDiff = "0.10.36, 1" InteractiveUtils = "<0.0.1, 1" -NonlinearSolveBase = "1.1" +NonlinearSolveBase = "1.15" PrecompileTools = "1.2" Reexport = "1.2.2" SciMLBase = "2.116" diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index da80f2ec4..fe0f84721 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolveBase" uuid = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" authors = ["Avik Pal and contributors"] -version = "1.14.1" +version = "1.15.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -22,27 +22,38 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLJacobianOperators = "19f34311-ddf3-4b8b-af20-060888a46c0e" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" +SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" [weakdeps] BandedMatrices = "aae01518-5342-5314-be14-df237901396f" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [extensions] NonlinearSolveBaseBandedMatricesExt = "BandedMatrices" -NonlinearSolveBaseDiffEqBaseExt = "DiffEqBase" +NonlinearSolveBaseChainRulesCoreExt = "ChainRulesCore" +NonlinearSolveBaseEnzymeExt = ["ChainRulesCore", "Enzyme"] NonlinearSolveBaseForwardDiffExt = "ForwardDiff" NonlinearSolveBaseLineSearchExt = "LineSearch" NonlinearSolveBaseLinearSolveExt = "LinearSolve" +NonlinearSolveBaseMooncakeExt = "Mooncake" +NonlinearSolveBaseReverseDiffExt = "ReverseDiff" NonlinearSolveBaseSparseArraysExt = "SparseArrays" NonlinearSolveBaseSparseMatrixColoringsExt = "SparseMatrixColorings" +NonlinearSolveBaseTrackerExt = "Tracker" + [compat] ADTypes = "1.9" @@ -50,12 +61,13 @@ Adapt = "4.1.0" Aqua = "0.8.7" ArrayInterface = "7.9" BandedMatrices = "1.5" +ChainRulesCore = "1" CommonSolve = "0.2.4" Compat = "4.15" ConcreteStructs = "0.2.3" -DiffEqBase = "6.188" DifferentiationInterface = "0.7.3" EnzymeCore = "0.8" +Enzyme = "0.13.12" ExplicitImports = "1.10.1" FastClosures = "0.3" ForwardDiff = "0.10.36, 1" @@ -65,24 +77,28 @@ LinearAlgebra = "1.10" LinearSolve = "3.15" Markdown = "1.10" MaybeInplace = "0.1.4" +Mooncake = "0.4" Preferences = "1.4" Printf = "1.10" RecursiveArrayTools = "3" +ReverseDiff = "1.15" SciMLBase = "2.116" SciMLJacobianOperators = "0.1.1" SciMLOperators = "1.7" +SciMLStructures = "1.5" +Setfield = "1.1.2" SparseArrays = "1.10" SparseMatrixColorings = "0.4.5" StaticArraysCore = "1.4" SymbolicIndexingInterface = "0.3.43" Test = "1.10" +Tracker = "0.2.35" TimerOutputs = "0.5.23" julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -91,4 +107,4 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "BandedMatrices", "DiffEqBase", "ExplicitImports", "ForwardDiff", "InteractiveUtils", "LinearAlgebra", "SparseArrays", "Test"] +test = ["Aqua", "BandedMatrices", "ExplicitImports", "ForwardDiff", "InteractiveUtils", "LinearAlgebra", "SparseArrays", "Test"] diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseChainRulesCoreExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseChainRulesCoreExt.jl new file mode 100644 index 000000000..d60be6211 --- /dev/null +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseChainRulesCoreExt.jl @@ -0,0 +1,31 @@ +module NonlinearSolveBaseChainRulesCoreExt + +using NonlinearSolveBase +using NonlinearSolveBase: AbstractNonlinearProblem +using SciMLBase +using SciMLBase: AbstractSensitivityAlgorithm + +import ChainRulesCore +import ChainRulesCore: NoTangent + +function ChainRulesCore.frule(::typeof(NonlinearSolveBase.solve_up), prob, + sensealg::Union{Nothing, AbstractSensitivityAlgorithm}, + u0, p, args...; originator = SciMLBase.ChainRulesOriginator(), + kwargs...) + NonlinearSolveBase._solve_forward( + prob, sensealg, u0, p, + originator, args...; + kwargs...) +end + +function ChainRulesCore.rrule(::typeof(NonlinearSolveBase.solve_up), prob::AbstractNonlinearProblem, + sensealg::Union{Nothing, AbstractSensitivityAlgorithm}, + u0, p, args...; originator = SciMLBase.ChainRulesOriginator(), + kwargs...) + NonlinearSolveBase._solve_adjoint( + prob, sensealg, u0, p, + originator, args...; + kwargs...) +end + +end \ No newline at end of file diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseDiffEqBaseExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseDiffEqBaseExt.jl deleted file mode 100644 index a1d6c44ce..000000000 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseDiffEqBaseExt.jl +++ /dev/null @@ -1,16 +0,0 @@ -module NonlinearSolveBaseDiffEqBaseExt - -using DiffEqBase: DiffEqBase -using SciMLBase: SciMLBase, remake - -using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem - -function DiffEqBase.get_concrete_problem( - prob::ImmutableNonlinearProblem, isadapt; kwargs...) - u0 = SciMLBase.get_concrete_u0(prob, isadapt, nothing, kwargs) - u0 = SciMLBase.promote_u0(u0, prob.p, nothing) - p = SciMLBase.get_concrete_p(prob, kwargs) - return remake(prob; u0 = u0, p = p) -end - -end diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl new file mode 100644 index 000000000..3bdc7f42e --- /dev/null +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl @@ -0,0 +1,61 @@ +module NonlinearSolveBaseEnzymeExt + +@static if isempty(VERSION.prerelease) + using NonlinearSolveBase + import SciMLBase: SciMLBase, value + using Enzyme + import Enzyme: Const + using ChainRulesCore + + function Enzyme.EnzymeRules.augmented_primal( + config::Enzyme.EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(NonlinearSolveBase.solve_up)}, ::Type{Duplicated{RT}}, prob, + sensealg::Union{ + Const{Nothing}, Const{<:SciMLBase.AbstractSensitivityAlgorithm}}, + u0, p, args...; kwargs...) where {RT} + @inline function copy_or_reuse(val, idx) + if Enzyme.EnzymeRules.overwritten(config)[idx] && ismutable(val) + return deepcopy(val) + else + return val + end + end + + @inline function arg_copy(i) + copy_or_reuse(args[i].val, i + 5) + end + + res = NonlinearSolveBase._solve_adjoint( + copy_or_reuse(prob.val, 2), copy_or_reuse(sensealg.val, 3), + copy_or_reuse(u0.val, 4), copy_or_reuse(p.val, 5), + SciMLBase.EnzymeOriginator(), ntuple(arg_copy, Val(length(args)))...; + kwargs...) + + dres = Enzyme.make_zero(res[1])::RT + tup = (dres, res[2]) + return Enzyme.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any) + end + + function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(NonlinearSolveBase.solve_up)}, ::Type{Duplicated{RT}}, tape, prob, + sensealg::Union{ + Const{Nothing}, Const{<:SciMLBase.AbstractSensitivityAlgorithm}}, + u0, p, args...; kwargs...) where {RT} + dres, clos = tape + dres = dres::RT + dargs = clos(dres) + for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...)) + if ptr isa Enzyme.Const + continue + end + if darg == ChainRulesCore.NoTangent() + continue + end + ptr.dval .+= darg + end + Enzyme.make_zero!(dres.u) + return ntuple(_ -> nothing, Val(length(args) + 4)) + end +end + +end diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseMooncakeExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseMooncakeExt.jl new file mode 100644 index 000000000..91b901099 --- /dev/null +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseMooncakeExt.jl @@ -0,0 +1,31 @@ +module NonlinearSolveBaseMooncakeExt + +using NonlinearSolveBase, Mooncake +using SciMLBase: SciMLBase +import Mooncake: rrule!!, CoDual, zero_fcodual, @is_primitive, + @from_rrule, @zero_adjoint, @mooncake_overlay, MinimalCtx, + NoPullback + +@from_rrule(MinimalCtx, + Tuple{ + typeof(NonlinearSolveBase.solve_up), + SciMLBase.AbstractDEProblem, + Union{Nothing, SciMLBase.AbstractSensitivityAlgorithm}, + Any, + Any, + Any + }, + true,) + +# Dispatch for auto-alg +@from_rrule(MinimalCtx, + Tuple{ + typeof(NonlinearSolveBase.solve_up), + SciMLBase.AbstractDEProblem, + Union{Nothing, SciMLBase.AbstractSensitivityAlgorithm}, + Any, + Any + }, + true,) + +end diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseReverseDiffExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseReverseDiffExt.jl new file mode 100644 index 000000000..fdfe774db --- /dev/null +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseReverseDiffExt.jl @@ -0,0 +1,103 @@ +module NonlinearSolveBaseReverseDiffExt + +using NonlinearSolveBase +import SciMLBase: SciMLBase, value +import ReverseDiff +import ArrayInterface + +# `ReverseDiff.TrackedArray` +function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem, + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0::ReverseDiff.TrackedArray, + p::ReverseDiff.TrackedArray, args...; kwargs...) + ReverseDiff.track(NonlinearSolveBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) +end + +function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem, + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0, p::ReverseDiff.TrackedArray, + args...; kwargs...) + ReverseDiff.track(NonlinearSolveBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) +end + +function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem, + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0::ReverseDiff.TrackedArray, p, + args...; kwargs...) + ReverseDiff.track(NonlinearSolveBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) +end + +# `AbstractArray{<:ReverseDiff.TrackedReal}` +function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem, + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, + u0::AbstractArray{<:ReverseDiff.TrackedReal}, + p::AbstractArray{<:ReverseDiff.TrackedReal}, args...; + kwargs...) + NonlinearSolveBase.solve_up(prob, sensealg, ArrayInterface.aos_to_soa(u0), + ArrayInterface.aos_to_soa(p), args...; + kwargs...) +end + +function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem, + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0, + p::AbstractArray{<:ReverseDiff.TrackedReal}, + args...; kwargs...) + NonlinearSolveBase.solve_up( + prob, sensealg, u0, ArrayInterface.aos_to_soa(p), args...; kwargs...) +end + +function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem, + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0::ReverseDiff.TrackedArray, + p::AbstractArray{<:ReverseDiff.TrackedReal}, + args...; kwargs...) + NonlinearSolveBase.solve_up( + prob, sensealg, u0, ArrayInterface.aos_to_soa(p), args...; kwargs...) +end + +# function NonlinearSolveBase.solve_up(prob::SciMLBase.DEProblem, +# sensealg::Union{ +# SciMLBase.AbstractOverloadingSensitivityAlgorithm, +# Nothing}, +# u0::AbstractArray{<:ReverseDiff.TrackedReal}, p, +# args...; kwargs...) +# NonlinearSolveBase.solve_up( +# prob, sensealg, ArrayInterface.aos_to_soa(u0), p, args...; kwargs...) +# end + +# function NonlinearSolveBase.solve_up(prob::SciMLBase.DEProblem, +# sensealg::Union{ +# SciMLBase.AbstractOverloadingSensitivityAlgorithm, +# Nothing}, +# u0::AbstractArray{<:ReverseDiff.TrackedReal}, p::ReverseDiff.TrackedArray, +# args...; kwargs...) +# NonlinearSolveBase.solve_up( +# prob, sensealg, ArrayInterface.aos_to_soa(u0), p, args...; kwargs...) +# end + +# Required becase ReverseDiff.@grad function SciMLBase.solve_up is not supported! +import NonlinearSolveBase: solve_up +ReverseDiff.@grad function solve_up(prob, sensealg, u0, p, args...; kwargs...) + out = NonlinearSolveBase._solve_adjoint(prob, sensealg, ReverseDiff.value(u0), + ReverseDiff.value(p), + SciMLBase.ReverseDiffOriginator(), args...; kwargs...) + function actual_adjoint(_args...) + original_adjoint = out[2](_args...) + if isempty(args) # alg is missing + tuple(original_adjoint[1:4]..., original_adjoint[6:end]...) + else + original_adjoint + end + end + Array(out[1]), actual_adjoint +end + +end diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseTrackerExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseTrackerExt.jl new file mode 100644 index 000000000..dd73531c9 --- /dev/null +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseTrackerExt.jl @@ -0,0 +1,49 @@ +module NonlinearSolveBaseTrackerExt + +using NonlinearSolveBase +import SciMLBase: SciMLBase, value +import Tracker + +function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem, + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0::Tracker.TrackedArray, + p::Tracker.TrackedArray, args...; kwargs...) + Tracker.track(NonlinearSolveBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) +end + +function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem, + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0::Tracker.TrackedArray, p, args...; + kwargs...) + Tracker.track(NonlinearSolveBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) +end + +function NonlinearSolveBase.solve_up(prob::SciMLBase.NonlinearProblem, + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0, p::Tracker.TrackedArray, args...; + kwargs...) + Tracker.track(NonlinearSolveBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) +end + +Tracker.@grad function NonlinearSolveBase.solve_up(prob, + sensealg::Union{Nothing, + SciMLBase.AbstractOverloadingSensitivityAlgorithm + }, + u0, p, args...; + kwargs...) + sol, + pb_f = NonlinearSolveBase._solve_adjoint( + prob, sensealg, Tracker.data(u0), Tracker.data(p), + SciMLBase.TrackerOriginator(), args...; kwargs...) + + if sol isa AbstractArray + !hasfield(typeof(sol), :u) && return sol, pb_f # being safe here + return sol.u, pb_f # AbstractNoTimeSolution isa AbstractArray + end + return convert(AbstractArray, sol), pb_f +end + +end diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index e3efc9ef5..eac8dc1e9 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -15,15 +15,22 @@ using StaticArraysCore: StaticArray, SMatrix, SArray, MArray using CommonSolve: CommonSolve, init using EnzymeCore: EnzymeCore using MaybeInplace: @bb -using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition +using RecursiveArrayTools: RecursiveArrayTools, AbstractVectorOfArray, ArrayPartition using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinearProblem, - AbstractNonlinearAlgorithm, + AbstractNonlinearAlgorithm, _concrete_solve_adjoint, _concrete_solve_forward, NonlinearProblem, NonlinearLeastSquaresProblem, NonlinearFunction, NLStats, LinearProblem, - LinearAliasSpecifier, ImmutableNonlinearProblem + LinearAliasSpecifier, ImmutableNonlinearProblem, NonlinearAliasSpecifier, + promote_u0, get_concrete_u0, get_concrete_p, + has_kwargs, extract_alg, promote_u0, checkkwargs, SteadyStateProblem, + NoDefaultAlgorithmError, NonSolverError, KeywordArgError, AbstractDEAlgorithm +import SciMLBase: solve, init, __init, __solve, wrap_sol, get_root_indp, isinplace, remake + using SciMLJacobianOperators: JacobianOperator, StatefulJacobianOperator using SciMLOperators: AbstractSciMLOperator, IdentityOperator using SymbolicIndexingInterface: SymbolicIndexingInterface +import SciMLStructures +using Setfield: @set! using LinearAlgebra: LinearAlgebra, Diagonal, norm, ldiv!, diagind, mul! using Markdown: @doc_str diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index 91b7a6aa6..0fe813dbd 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -1,3 +1,225 @@ +struct EvalFunc{F} <: Function + f::F +end +(f::EvalFunc)(args...) = f.f(args...) + +""" +```julia +solve(prob::NonlinearProblem, alg::Union{AbstractNonlinearAlgorithm,Nothing}; kwargs...) +``` + +## Arguments + +The only positional argument is `alg` which is optional. By default, `alg = nothing`. +If `alg = nothing`, then `solve` dispatches to the NonlinearSolve.jl automated +algorithm selection (if `using NonlinearSolve` was done, otherwise it will +error with a `MethodError`). + +## Keyword Arguments + +The NonlinearSolve.jl universe has a large set of common arguments available +for the `solve` function. These arguments apply to `solve` on any problem type and +are only limited by limitations of the specific implementations. + +Many of the defaults depend on the algorithm or the package the algorithm derives +from. Not all of the interface is provided by every algorithm. +For more detailed information on the defaults and the available options +for specific algorithms / packages, see the manual pages for the solvers of specific +problems. + +#### Error Control + +* `abstol`: Absolute tolerance. +* `reltol`: Relative tolerance. + +### Miscellaneous + +* `maxiters`: Maximum number of iterations before stopping. Defaults to 1e5. +* `verbose`: Toggles whether warnings are thrown when the solver exits early. + Defaults to true. + +### Sensitivity Algorithms (`sensealg`) + +`sensealg` is used for choosing the way the automatic differentiation is performed. + For more information, see the documentation for SciMLSensitivity: + https://docs.sciml.ai/SciMLSensitivity/stable/ +""" +function solve(prob::AbstractNonlinearProblem, args...; sensealg = nothing, + u0 = nothing, p = nothing, wrap = Val(true), kwargs...) + if sensealg === nothing && haskey(prob.kwargs, :sensealg) + sensealg = prob.kwargs[:sensealg] + end + + if haskey(prob.kwargs, :alias_u0) + @warn "The `alias_u0` keyword argument is deprecated. Please use a NonlinearAliasSpecifier, e.g. `alias = NonlinearAliasSpecifier(alias_u0 = true)`." + alias_spec = NonlinearAliasSpecifier(alias_u0 = prob.kwargs[:alias_u0]) + elseif haskey(kwargs, :alias_u0) + @warn "The `alias_u0` keyword argument is deprecated. Please use a NonlinearAliasSpecifier, e.g. `alias = NonlinearAliasSpecifier(alias_u0 = true)`." + alias_spec = NonlinearAliasSpecifier(alias_u0 = kwargs[:alias_u0]) + end + + if haskey(prob.kwargs, :alias) && prob.kwargs[:alias] isa Bool + alias_spec = NonlinearAliasSpecifier(alias = prob.kwargs[:alias]) + elseif haskey(kwargs, :alias) && kwargs[:alias] isa Bool + alias_spec = NonlinearAliasSpecifier(alias = kwargs[:alias]) + end + + if haskey(prob.kwargs, :alias) && prob.kwargs[:alias] isa NonlinearAliasSpecifier + alias_spec = prob.kwargs[:alias] + elseif haskey(kwargs, :alias) && kwargs[:alias] isa NonlinearAliasSpecifier + alias_spec = kwargs[:alias] + else + alias_spec = NonlinearAliasSpecifier(alias_u0 = false) + end + + alias_u0 = alias_spec.alias_u0 + + u0 = u0 !== nothing ? u0 : prob.u0 + p = p !== nothing ? p : prob.p + + if wrap isa Val{true} + wrap_sol(solve_up(prob, + sensealg, + u0, + p, + args...; + alias_u0 = alias_u0, + originator = SciMLBase.ChainRulesOriginator(), + kwargs...)) + else + solve_up(prob, + sensealg, + u0, + p, + args...; + alias_u0 = alias_u0, + originator = SciMLBase.ChainRulesOriginator(), + kwargs...) + end +end + +function solve_up(prob::AbstractNonlinearProblem, sensealg, u0, p, + args...; originator = SciMLBase.ChainRulesOriginator(), + kwargs...) + alg = extract_alg(args, kwargs, has_kwargs(prob) ? prob.kwargs : kwargs) + if isnothing(alg) || !(alg isa AbstractNonlinearSolveAlgorithm) # Default algorithm handling + _prob = get_concrete_problem(prob, true; u0 = u0, + p = p, kwargs...) + solve_call(_prob, args...; kwargs...) + else + _prob = get_concrete_problem(prob, true; u0 = u0, p = p, kwargs...) + #check_prob_alg_pairing(_prob, alg) # use alg for improved inference + if length(args) > 1 + solve_call(_prob, alg, Base.tail(args)...; kwargs...) + else + solve_call(_prob, alg; kwargs...) + end + end +end + +function solve_call(_prob, args...; merge_callbacks = true, kwargshandle = nothing, + kwargs...) + kwargshandle = kwargshandle === nothing ? KeywordArgError : kwargshandle + kwargshandle = has_kwargs(_prob) && haskey(_prob.kwargs, :kwargshandle) ? + _prob.kwargs[:kwargshandle] : kwargshandle + + if has_kwargs(_prob) + kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) + end + + checkkwargs(kwargshandle; kwargs...) + if isdefined(_prob, :u0) + if _prob.u0 isa Array + if !isconcretetype(RecursiveArrayTools.recursive_unitless_eltype(_prob.u0)) + throw(NonConcreteEltypeError(RecursiveArrayTools.recursive_unitless_eltype(_prob.u0))) + end + + if !(eltype(_prob.u0) <: Number) && !(eltype(_prob.u0) <: Enum) && + !(_prob.u0 isa AbstractVector{<:AbstractArray} && _prob isa BVProblem) + # Allow Enums for FunctionMaps, make into a trait in the future + # BVPs use Vector of Arrays for initial guesses + throw(NonNumberEltypeError(eltype(_prob.u0))) + end + end + + if _prob.u0 === nothing + return build_null_solution(_prob, args...; kwargs...) + end + end + + if hasfield(typeof(_prob), :f) && hasfield(typeof(_prob.f), :f) && + _prob.f.f isa EvalFunc + Base.invokelatest(__solve, _prob, args...; kwargs...)#::T + else + __solve(_prob, args...; kwargs...)#::T + end +end + +function solve_call(prob::SteadyStateProblem, + alg::AbstractNonlinearAlgorithm, args...; + kwargs...) + solve_call(NonlinearProblem(prob), + alg, args...; + kwargs...) +end + +function init( + prob::AbstractNonlinearProblem, args...; sensealg = nothing, + u0 = nothing, p = nothing, kwargs...) + if sensealg === nothing && has_kwargs(prob) && haskey(prob.kwargs, :sensealg) + sensealg = prob.kwargs[:sensealg] + end + + u0 = u0 !== nothing ? u0 : prob.u0 + p = p !== nothing ? p : prob.p + + init_up(prob, sensealg, u0, p, args...; kwargs...) +end + +function init_up(prob::AbstractNonlinearProblem, + sensealg, u0, p, args...; kwargs...) + alg = extract_alg(args, kwargs, has_kwargs(prob) ? prob.kwargs : kwargs) + if isnothing(alg) || !(alg isa AbstractNonlinearAlgorithm) # Default algorithm handling + _prob = get_concrete_problem(prob, true; u0 = u0, + p = p, kwargs...) + init_call(_prob, args...; kwargs...) + else + tstops = get(kwargs, :tstops, nothing) + if tstops === nothing && has_kwargs(prob) + tstops = get(prob.kwargs, :tstops, nothing) + end + if !(tstops isa Union{Nothing, AbstractArray, Tuple, Real}) && + !SciMLBase.allows_late_binding_tstops(alg) + throw(LateBindingTstopsNotSupportedError()) + end + _prob = get_concrete_problem(prob, true; u0 = u0, p = p, kwargs...) + #check_prob_alg_pairing(_prob, alg) # alg for improved inference + if length(args) > 1 + init_call(_prob, alg, Base.tail(args)...; kwargs...) + else + init_call(_prob, alg; kwargs...) + end + end +end + +function init_call(_prob, args...; merge_callbacks=true, kwargshandle=nothing, + kwargs...) + kwargshandle = kwargshandle === nothing ? KeywordArgError : kwargshandle + kwargshandle = has_kwargs(_prob) && haskey(_prob.kwargs, :kwargshandle) ? + _prob.kwargs[:kwargshandle] : kwargshandle + if has_kwargs(_prob) + kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) + end + + checkkwargs(kwargshandle; kwargs...) + if hasfield(typeof(_prob), :f) && hasfield(typeof(_prob.f), :f) && + _prob.f.f isa EvalFunc + Base.invokelatest(__init, _prob, args...; kwargs...)#::T + else + __init(_prob, args...; kwargs...)#::T + end +end + function SciMLBase.__solve( prob::AbstractNonlinearProblem, alg::AbstractNonlinearSolveAlgorithm, args...; kwargs... @@ -127,6 +349,30 @@ function SciMLBase.__solve( __generated_polysolve(prob, alg, args...; kwargs...) end +function SciMLBase.__solve( + prob::AbstractNonlinearProblem, args...; default_set = false, second_time = false, + kwargs...) + if second_time + throw(NoDefaultAlgorithmError()) + elseif length(args) > 0 && !(first(args) isa AbstractNonlinearAlgorithm) + throw(NonSolverError()) + else + __solve(prob, nothing, args...; default_set = false, second_time = true, kwargs...) + end +end + +function __init(prob::AbstractNonlinearProblem, args...; default_set = false, second_time = false, + kwargs...) + if second_time + throw(NoDefaultAlgorithmError()) + elseif length(args) > 0 && !(first(args) isa + Union{Nothing, AbstractDEAlgorithm, AbstractNonlinearAlgorithm}) + throw(NonSolverError()) + else + __init(prob, nothing, args...; default_set = false, second_time = true, kwargs...) + end +end + @generated function __generated_polysolve( prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm{Val{N}}, args...; stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, verbose = true, @@ -297,6 +543,10 @@ SII.state_values(cache::NonlinearSolveNoInitCache) = SII.state_values(cache.prob get_u(cache::NonlinearSolveNoInitCache) = SII.state_values(cache.prob) +# has_kwargs(_prob::AbstractNonlinearProblem) = has_kwargs(typeof(_prob)) +# Base.@pure __has_kwargs(::Type{T}) where {T} = :kwargs ∈ fieldnames(T) +# has_kwargs(::Type{T}) where {T} = __has_kwargs(T) + function SciMLBase.reinit!( cache::NonlinearSolveNoInitCache, u0 = cache.prob.u0; p = cache.prob.p, kwargs... ) @@ -328,3 +578,159 @@ function CommonSolve.solve!(cache::NonlinearSolveNoInitCache) end return CommonSolve.solve(cache.prob, cache.alg, cache.args...; cache.kwargs...) end + +function _solve_adjoint(prob, sensealg, u0, p, originator, args...; merge_callbacks = true, + kwargs...) + alg = extract_alg(args, kwargs, prob.kwargs) + if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling + _prob = get_concrete_problem(prob, true; u0 = u0, + p = p, kwargs...) + else + _prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) + end + + if has_kwargs(_prob) + kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) + end + + if length(args) > 1 + _concrete_solve_adjoint(_prob, alg, sensealg, u0, p, originator, + Base.tail(args)...; kwargs...) + else + _concrete_solve_adjoint(_prob, alg, sensealg, u0, p, originator; kwargs...) + end +end + +function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callbacks = true, + kwargs...) + alg = extract_alg(args, kwargs, prob.kwargs) + if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling + _prob = get_concrete_problem(prob, true; u0 = u0, + p = p, kwargs...) + else + _prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...) + end + + if has_kwargs(_prob) + kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs) + end + + if length(args) > 1 + _concrete_solve_forward(_prob, alg, sensealg, u0, p, originator, + Base.tail(args)...; kwargs...) + else + _concrete_solve_forward(_prob, alg, sensealg, u0, p, originator; kwargs...) + end +end + +function get_concrete_problem(prob::NonlinearProblem, isadapt; kwargs...) + oldprob = prob + prob = get_updated_symbolic_problem(get_root_indp(prob), prob; kwargs...) + if prob !== oldprob + kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob)) + end + p = get_concrete_p(prob, kwargs) + u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) + u0 = promote_u0(u0, p, nothing) + remake(prob; u0 = u0, p = p) +end + +function get_concrete_problem(prob::NonlinearLeastSquaresProblem, isadapt; kwargs...) + oldprob = prob + prob = get_updated_symbolic_problem(get_root_indp(prob), prob; kwargs...) + if prob !== oldprob + kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob)) + end + p = get_concrete_p(prob, kwargs) + u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) + u0 = promote_u0(u0, p, nothing) + remake(prob; u0 = u0, p = p) +end + +function get_concrete_problem( + prob::ImmutableNonlinearProblem, isadapt; kwargs...) + u0 = get_concrete_u0(prob, isadapt, nothing, kwargs) + u0 = promote_u0(u0, prob.p, nothing) + p = get_concrete_p(prob, kwargs) + return remake(prob; u0 = u0, p = p) +end + +function get_concrete_problem(prob::SteadyStateProblem, isadapt; kwargs...) + oldprob = prob + prob = get_updated_symbolic_problem(SciMLBase.get_root_indp(prob), prob; kwargs...) + if prob !== oldprob + kwargs = (; kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob)) + end + p = get_concrete_p(prob, kwargs) + u0 = get_concrete_u0(prob, isadapt, Inf, kwargs) + u0 = promote_u0(u0, p, nothing) + remake(prob; u0 = u0, p = p) +end + + +""" +Given the index provider `indp` used to construct the problem `prob` being solved, return +an updated `prob` to be used for solving. All implementations should accept arbitrary +keyword arguments. + +Should be called before the problem is solved, after performing type-promotion on the +problem. If the returned problem is not `===` the provided `prob`, it is assumed to +contain the `u0` and `p` passed as keyword arguments. + +# Keyword Arguments + +- `u0`, `p`: Override values for `state_values(prob)` and `parameter_values(prob)` which + should be used instead of the ones in `prob`. +""" +function get_updated_symbolic_problem(indp, prob; kw...) + return prob +end + +function build_null_solution( + prob::Union{NonlinearProblem, SteadyStateProblem}, + args...; + saveat = (), + save_everystep = true, + save_on = true, + save_start = save_everystep || isempty(saveat) || + saveat isa Number || prob.tspan[1] in saveat, + save_end = true, + kwargs...) + prob, success = hack_null_solution_init(prob) + retcode = success ? ReturnCode.Success : ReturnCode.InitialFailure + SciMLBase.build_solution(prob, nothing, Float64[], nothing; retcode) +end + +function build_null_solution( + prob::NonlinearLeastSquaresProblem, + args...; abstol = 1e-6, kwargs...) + prob, success = hack_null_solution_init(prob) + retcode = success ? ReturnCode.Success : ReturnCode.InitialFailure + + if isinplace(prob) + resid = isnothing(prob.f.resid_prototype) ? Float64[] : copy(prob.f.resid_prototype) + prob.f(resid, prob.u0, prob.p) + else + resid = prob.f(prob.f.resid_prototype, prob.p) + end + + if success + retcode = norm(resid) < abstol ? ReturnCode.Success : ReturnCode.Failure + end + + SciMLBase.build_solution(prob, nothing, Float64[], resid; retcode) +end + +function hack_null_solution_init(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem, SteadyStateProblem}) + if SciMLBase.has_initialization_data(prob.f) + initializeprob = prob.f.initialization_data.initializeprob + nlsol = solve(initializeprob) + success = SciMLBase.successful_retcode(nlsol) + if prob.f.initialization_data.initializeprobpmap !== nothing + @set! prob.p = prob.f.initializeprobpmap(prob, nlsol) + end + else + success = true + end + return prob, success +end diff --git a/lib/NonlinearSolveBase/src/termination_conditions.jl b/lib/NonlinearSolveBase/src/termination_conditions.jl index 771c2517f..5442f22dd 100644 --- a/lib/NonlinearSolveBase/src/termination_conditions.jl +++ b/lib/NonlinearSolveBase/src/termination_conditions.jl @@ -82,7 +82,6 @@ function CommonSolve.init( length(saved_value_prototype) == 0 && (saved_value_prototype = nothing) leastsq = typeof(prob) <: NonlinearLeastSquaresProblem - return NonlinearTerminationModeCache( u_unaliased, ReturnCode.Default, abstol, reltol, best_value, mode, initial_objective, objectives_trace, 0, saved_value_prototype, diff --git a/lib/NonlinearSolveBase/src/utils.jl b/lib/NonlinearSolveBase/src/utils.jl index 05bc71158..18d78c451 100644 --- a/lib/NonlinearSolveBase/src/utils.jl +++ b/lib/NonlinearSolveBase/src/utils.jl @@ -320,4 +320,6 @@ function clean_sprint_struct(x, indent::Int) return "$(name)(\n$(spacing)$(join(modifiers, ",\n$(spacing)"))\n$(spacing_last))" end +set_mooncakeoriginator_if_mooncake(x::SciMLBase.ADOriginator) = x + end diff --git a/lib/NonlinearSolveBase/test/runtests.jl b/lib/NonlinearSolveBase/test/runtests.jl index 95ae283cc..86eb95730 100644 --- a/lib/NonlinearSolveBase/test/runtests.jl +++ b/lib/NonlinearSolveBase/test/runtests.jl @@ -7,17 +7,18 @@ using InteractiveUtils, Test @testset "NonlinearSolveBase.jl" begin @testset "Aqua" begin using Aqua, NonlinearSolveBase + using NonlinearSolveBase: AbstractNonlinearProblem, NonlinearProblem Aqua.test_all( NonlinearSolveBase; piracies = false, ambiguities = false, stale_deps = false ) Aqua.test_stale_deps(NonlinearSolveBase; ignore = [:TimerOutputs]) - Aqua.test_piracies(NonlinearSolveBase) + Aqua.test_piracies(NonlinearSolveBase, treat_as_own = [AbstractNonlinearProblem, NonlinearProblem]) Aqua.test_ambiguities(NonlinearSolveBase; recursive = false) end @testset "Explicit Imports" begin - import ForwardDiff, SparseArrays, DiffEqBase + import ForwardDiff, SparseArrays using ExplicitImports, NonlinearSolveBase @test check_no_implicit_imports(NonlinearSolveBase; skip = (Base, Core)) === nothing diff --git a/lib/NonlinearSolveFirstOrder/Project.toml b/lib/NonlinearSolveFirstOrder/Project.toml index 37e5bd1ba..72059a11b 100644 --- a/lib/NonlinearSolveFirstOrder/Project.toml +++ b/lib/NonlinearSolveFirstOrder/Project.toml @@ -8,12 +8,11 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" -LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b" MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb" NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" @@ -36,7 +35,6 @@ BenchmarkTools = "1.5.0" CommonSolve = "0.2.4" ConcreteStructs = "0.2.3" DifferentiationInterface = "0.7.3" -DiffEqBase = "6.188" Enzyme = "0.13.12" ExplicitImports = "1.5" FiniteDiff = "2.24" @@ -49,7 +47,7 @@ LinearAlgebra = "1.10" LinearSolve = "2.36.1, 3" MaybeInplace = "0.1.4" NonlinearProblemLibrary = "0.1.2" -NonlinearSolveBase = "1.14" +NonlinearSolveBase = "1.15" Pkg = "1.10" PrecompileTools = "1.2" Random = "1.10" @@ -59,7 +57,7 @@ SciMLBase = "2.116" SciMLJacobianOperators = "0.1.0" Setfield = "1.1.1" SparseArrays = "1.10" -SparseConnectivityTracer = "1" +SparseConnectivityTracer = "1, 1" SparseMatrixColorings = "0.4.5" StableRNGs = "1" StaticArrays = "1.9.8" @@ -73,9 +71,9 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b" diff --git a/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl b/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl index 57a1f0105..0a6c009c6 100644 --- a/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl +++ b/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl @@ -12,7 +12,6 @@ using LineSearch: BackTracking using StaticArraysCore: SArray using CommonSolve: CommonSolve -using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches using LinearSolve: LinearSolve # Trigger Linear Solve extension in NonlinearSolveBase using MaybeInplace: @bb using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm, diff --git a/lib/NonlinearSolveHomotopyContinuation/Project.toml b/lib/NonlinearSolveHomotopyContinuation/Project.toml index f011bb485..21a839d47 100644 --- a/lib/NonlinearSolveHomotopyContinuation/Project.toml +++ b/lib/NonlinearSolveHomotopyContinuation/Project.toml @@ -31,7 +31,7 @@ HomotopyContinuation = "2.12.0" LinearAlgebra = "1.10" NaNMath = "1.1" NonlinearSolve = "4.10" -NonlinearSolveBase = "1.14" +NonlinearSolveBase = "1.15" SciMLBase = "2.116" SymbolicIndexingInterface = "0.3.43" TaylorDiff = "0.3.1" diff --git a/lib/NonlinearSolveQuasiNewton/Project.toml b/lib/NonlinearSolveQuasiNewton/Project.toml index 7a6194560..cb2eaa8f8 100644 --- a/lib/NonlinearSolveQuasiNewton/Project.toml +++ b/lib/NonlinearSolveQuasiNewton/Project.toml @@ -7,7 +7,6 @@ version = "1.8.1" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb" @@ -21,8 +20,8 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [weakdeps] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -[sources.NonlinearSolveBase] -path = "../NonlinearSolveBase" +[sources] +NonlinearSolveBase = {path = "../NonlinearSolveBase"} [extensions] NonlinearSolveQuasiNewtonForwardDiffExt = "ForwardDiff" @@ -34,7 +33,6 @@ ArrayInterface = "7.16.0" BenchmarkTools = "1.5.0" CommonSolve = "0.2.4" ConcreteStructs = "0.2.3" -DiffEqBase = "6.188" Enzyme = "0.13.12" ExplicitImports = "1.5" FiniteDiff = "2.24" @@ -47,7 +45,7 @@ LinearAlgebra = "1.10" LinearSolve = "2.36.1, 3" MaybeInplace = "0.1.4" NonlinearProblemLibrary = "0.1.2" -NonlinearSolveBase = "1.14" +NonlinearSolveBase = "1.15" Pkg = "1.10" PrecompileTools = "1.2" ReTestItems = "1.24" diff --git a/lib/NonlinearSolveQuasiNewton/src/NonlinearSolveQuasiNewton.jl b/lib/NonlinearSolveQuasiNewton/src/NonlinearSolveQuasiNewton.jl index 167f1fa85..a7f3d93c3 100644 --- a/lib/NonlinearSolveQuasiNewton/src/NonlinearSolveQuasiNewton.jl +++ b/lib/NonlinearSolveQuasiNewton/src/NonlinearSolveQuasiNewton.jl @@ -8,7 +8,6 @@ using ArrayInterface: ArrayInterface using StaticArraysCore: StaticArray, Size, MArray using CommonSolve: CommonSolve -using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches using LinearAlgebra: LinearAlgebra, Diagonal, dot, diag using LinearSolve: LinearSolve # Trigger Linear Solve extension in NonlinearSolveBase using MaybeInplace: @bb diff --git a/lib/NonlinearSolveSciPy/Project.toml b/lib/NonlinearSolveSciPy/Project.toml index 2745b8006..4726d30d7 100644 --- a/lib/NonlinearSolveSciPy/Project.toml +++ b/lib/NonlinearSolveSciPy/Project.toml @@ -18,7 +18,7 @@ path = "../NonlinearSolveBase" ConcreteStructs = "0.2.3" Hwloc = "3" InteractiveUtils = "<0.0.1, 1" -NonlinearSolveBase = "1.14" +NonlinearSolveBase = "1.15" PrecompileTools = "1.2" PythonCall = "0.9" ReTestItems = "1.24" diff --git a/lib/NonlinearSolveSpectralMethods/Project.toml b/lib/NonlinearSolveSpectralMethods/Project.toml index f4e17f1bc..bc58d7030 100644 --- a/lib/NonlinearSolveSpectralMethods/Project.toml +++ b/lib/NonlinearSolveSpectralMethods/Project.toml @@ -6,7 +6,6 @@ version = "1.3.1" [deps] CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b" MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb" NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" @@ -28,7 +27,6 @@ Aqua = "0.8" BenchmarkTools = "1.5.0" CommonSolve = "0.2.4" ConcreteStructs = "0.2.3" -DiffEqBase = "6.188" ExplicitImports = "1.5" ForwardDiff = "0.10.36, 1" Hwloc = "3" @@ -36,7 +34,7 @@ InteractiveUtils = "<0.0.1, 1" LineSearch = "0.1.4" MaybeInplace = "0.1.4" NonlinearProblemLibrary = "0.1.2" -NonlinearSolveBase = "1.14" +NonlinearSolveBase = "1.15" Pkg = "1.10" PrecompileTools = "1.2" ReTestItems = "1.24" diff --git a/lib/NonlinearSolveSpectralMethods/src/NonlinearSolveSpectralMethods.jl b/lib/NonlinearSolveSpectralMethods/src/NonlinearSolveSpectralMethods.jl index 93a620761..2706d5670 100644 --- a/lib/NonlinearSolveSpectralMethods/src/NonlinearSolveSpectralMethods.jl +++ b/lib/NonlinearSolveSpectralMethods/src/NonlinearSolveSpectralMethods.jl @@ -5,7 +5,6 @@ using Reexport: @reexport using PrecompileTools: @compile_workload, @setup_workload using CommonSolve: CommonSolve -using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches using LineSearch: RobustNonMonotoneLineSearch using MaybeInplace: @bb using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm, diff --git a/lib/SCCNonlinearSolve/Project.toml b/lib/SCCNonlinearSolve/Project.toml index 1d39e5fed..fa3e615f1 100644 --- a/lib/SCCNonlinearSolve/Project.toml +++ b/lib/SCCNonlinearSolve/Project.toml @@ -19,7 +19,7 @@ Hwloc = "3" InteractiveUtils = "<0.0.1, 1" NonlinearProblemLibrary = "0.1.2" NonlinearSolve = "4.8" -NonlinearSolveBase = "1.5.1" +NonlinearSolveBase = "1.15" NonlinearSolveFirstOrder = "1" Pkg = "1.10" PrecompileTools = "1.2" diff --git a/lib/SimpleNonlinearSolve/Project.toml b/lib/SimpleNonlinearSolve/Project.toml index e73fbfc42..f0cbbc99c 100644 --- a/lib/SimpleNonlinearSolve/Project.toml +++ b/lib/SimpleNonlinearSolve/Project.toml @@ -25,7 +25,6 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" @@ -37,7 +36,6 @@ path = "../NonlinearSolveBase" [extensions] SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore" -SimpleNonlinearSolveDiffEqBaseExt = "DiffEqBase" SimpleNonlinearSolveReverseDiffExt = "ReverseDiff" SimpleNonlinearSolveTrackerExt = "Tracker" @@ -49,7 +47,6 @@ BracketingNonlinearSolve = "1.1" ChainRulesCore = "1.24" CommonSolve = "0.2.4" ConcreteStructs = "0.2.3" -DiffEqBase = "6.188" DifferentiationInterface = "0.7.3" Enzyme = "0.13.11" ExplicitImports = "1.9" @@ -61,7 +58,7 @@ LineSearch = "0.1.3" LinearAlgebra = "1.10" MaybeInplace = "0.1.4" NonlinearProblemLibrary = "0.1.2" -NonlinearSolveBase = "1.14" +NonlinearSolveBase = "1.15" Pkg = "1.10" PolyesterForwardDiff = "0.1.3" PrecompileTools = "1.2" @@ -80,7 +77,6 @@ julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -96,4 +92,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "DiffEqBase", "Enzyme", "ExplicitImports", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReverseDiff", "StaticArrays", "Test", "TestItemRunner", "Tracker", "Zygote"] +test = ["Aqua", "Enzyme", "ExplicitImports", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReverseDiff", "StaticArrays", "Test", "TestItemRunner", "Tracker", "Zygote"] diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl index a9d86ea84..efb4cace7 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl @@ -2,11 +2,10 @@ module SimpleNonlinearSolveChainRulesCoreExt using ChainRulesCore: ChainRulesCore, NoTangent -using NonlinearSolveBase: ImmutableNonlinearProblem +using NonlinearSolveBase: ImmutableNonlinearProblem, _solve_adjoint using SciMLBase: ChainRulesOriginator, NonlinearLeastSquaresProblem -using SimpleNonlinearSolve: SimpleNonlinearSolve, simplenonlinearsolve_solve_up, - solve_adjoint +using SimpleNonlinearSolve: SimpleNonlinearSolve, simplenonlinearsolve_solve_up function ChainRulesCore.rrule( ::typeof(simplenonlinearsolve_solve_up), @@ -14,7 +13,7 @@ function ChainRulesCore.rrule( sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs... ) out, - ∇internal = solve_adjoint( + ∇internal = _solve_adjoint( prob, sensealg, u0, p, ChainRulesOriginator(), alg, args...; kwargs... ) function ∇simplenonlinearsolve_solve_up(Δ) diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl deleted file mode 100644 index 4954ffb26..000000000 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl +++ /dev/null @@ -1,13 +0,0 @@ -module SimpleNonlinearSolveDiffEqBaseExt - -using DiffEqBase: DiffEqBase - -using SimpleNonlinearSolve: SimpleNonlinearSolve - -SimpleNonlinearSolve.is_extension_loaded(::Val{:DiffEqBase}) = true - -function SimpleNonlinearSolve.solve_adjoint_internal(args...; kwargs...) - return DiffEqBase._solve_adjoint(args...; kwargs...) -end - -end diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl index dca55621a..27e1cc1ac 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl @@ -1,12 +1,12 @@ module SimpleNonlinearSolveReverseDiffExt -using NonlinearSolveBase: ImmutableNonlinearProblem +using NonlinearSolveBase: ImmutableNonlinearProblem, _solve_adjoint using SciMLBase: ReverseDiffOriginator, NonlinearLeastSquaresProblem, remake using ArrayInterface: ArrayInterface using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal -using SimpleNonlinearSolve: SimpleNonlinearSolve, solve_adjoint +using SimpleNonlinearSolve: SimpleNonlinearSolve import SimpleNonlinearSolve: simplenonlinearsolve_solve_up for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem) @@ -27,7 +27,7 @@ for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem) u0, p = ReverseDiff.value(tu0), ReverseDiff.value(tp) prob = remake(tprob; u0, p) out, - ∇internal = solve_adjoint( + ∇internal = _solve_adjoint( prob, sensealg, u0, p, ReverseDiffOriginator(), alg, args...; kwargs...) function ∇simplenonlinearsolve_solve_up(Δ...) diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl index 551d7080a..9f71c4f55 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl @@ -1,12 +1,12 @@ module SimpleNonlinearSolveTrackerExt -using NonlinearSolveBase: ImmutableNonlinearProblem +using NonlinearSolveBase: ImmutableNonlinearProblem, _solve_adjoint using SciMLBase: TrackerOriginator, NonlinearLeastSquaresProblem, remake using ArrayInterface: ArrayInterface using Tracker: Tracker, TrackedArray, TrackedReal -using SimpleNonlinearSolve: SimpleNonlinearSolve, solve_adjoint +using SimpleNonlinearSolve: SimpleNonlinearSolve for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem) aTypes = (TrackedArray, AbstractArray{<:TrackedReal}, Any) @@ -26,7 +26,7 @@ for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem) u0, p = Tracker.data(tu0), Tracker.data(tp) prob = remake(tprob; u0, p) out, - ∇internal = solve_adjoint( + ∇internal = _solve_adjoint( prob, sensealg, u0, p, TrackerOriginator(), alg, args...; kwargs...) function ∇simplenonlinearsolve_solve_up(Δ) diff --git a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl index 7a8c5a308..782de6468 100644 --- a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl +++ b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl @@ -124,15 +124,6 @@ function simplenonlinearsolve_solve_up( return SciMLBase.__solve(prob, alg, args...; kwargs...) end -# NOTE: This is defined like this so that we don't have to keep have 2 args for the -# extensions -function solve_adjoint(args...; kws...) - is_extension_loaded(Val(:DiffEqBase)) && return solve_adjoint_internal(args...; kws...) - error("Adjoint sensitivity analysis requires `DiffEqBase.jl` to be explicitly loaded.") -end - -function solve_adjoint_internal end - @setup_workload begin for T in (Float64,) prob_scalar = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2)) diff --git a/lib/SimpleNonlinearSolve/test/core/adjoint_tests.jl b/lib/SimpleNonlinearSolve/test/core/adjoint_tests.jl index 1580ade60..c56850eb5 100644 --- a/lib/SimpleNonlinearSolve/test/core/adjoint_tests.jl +++ b/lib/SimpleNonlinearSolve/test/core/adjoint_tests.jl @@ -1,5 +1,5 @@ @testitem "Simple Adjoint Test" tags=[:adjoint] begin - using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, DiffEqBase + using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote ff(u, p) = u .^ 2 .- p diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 8eddf0712..d11f99749 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -8,7 +8,6 @@ using FastClosures: @closure using ADTypes: ADTypes using ArrayInterface: ArrayInterface using CommonSolve: CommonSolve, init, solve, solve! -using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches using LinearAlgebra: LinearAlgebra using LineSearch: BackTracking using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm, diff --git a/test/adjoint_tests.jl b/test/adjoint_tests.jl new file mode 100644 index 000000000..8882c1916 --- /dev/null +++ b/test/adjoint_tests.jl @@ -0,0 +1,27 @@ +@testitem "Adjoint Tests" tags = [:nopre] begin + using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, Enzyme, Mooncake + + ff(u, p) = u .^ 2 .- p + + function solve_nlprob(p) + prob = NonlinearProblem{false}(ff, [1.0, 2.0], p) + sol = solve(prob, NewtonRaphson()) + res = sol isa AbstractArray ? sol : sol.u + return sum(abs2, res) + end + + p = [3.0, 2.0] + + ∂p_zygote = only(Zygote.gradient(solve_nlprob, p)) + ∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p) + ∂p_tracker = Tracker.data(only(Tracker.gradient(solve_nlprob, p))) + ∂p_reversediff = ReverseDiff.gradient(solve_nlprob, p) + ∂p_enzyme = Enzyme.gradient(Enzyme.set_runtime_activity(Enzyme.Reverse), solve_nlprob, p)[1] + + cache = Mooncake.prepare_gradient_cache(solve_nlprob, p) + ∂p_mooncake = Mooncake.value_and_gradient!!(cache, solve_nlprob, p)[2][2] + + @test ∂p_zygote ≈ ∂p_tracker ≈ ∂p_reversediff ≈ ∂p_enzyme + @test ∂p_zygote ≈ ∂p_forwarddiff ≈ ∂p_tracker ≈ ∂p_reversediff ≈ ∂p_enzyme + @test_broken ∂p_forwarddiff ≈ ∂p_mooncake +end diff --git a/test/runtests.jl b/test/runtests.jl index c946e93ef..3ed6985dc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,6 +33,8 @@ if GROUP == "all" || GROUP == "nopre" # Only add Enzyme for nopre group if not on prerelease Julia if isempty(VERSION.prerelease) push!(EXTRA_PKGS, Pkg.PackageSpec("Enzyme")) + push!(EXTRA_PKGS, Pkg.PackageSpec("Mooncake")) + push!(EXTRA_PKGS, Pkg.PackageSpec("SciMLSensitivity")) end end if GROUP == "all" || GROUP == "cuda" @@ -41,6 +43,7 @@ if GROUP == "all" || GROUP == "cuda" push!(EXTRA_PKGS, Pkg.PackageSpec("CUDA")) end end + length(EXTRA_PKGS) ≥ 1 && Pkg.add(EXTRA_PKGS) # Use sequential execution for wrapper tests to avoid parallel initialization issues