diff --git a/Project.toml b/Project.toml index a7aa31f..1d9cc13 100644 --- a/Project.toml +++ b/Project.toml @@ -18,19 +18,18 @@ MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore" SimpleNonlinearSolveReverseDiffExt = "ReverseDiff" -SimpleNonlinearSolveStaticArraysExt = "StaticArrays" SimpleNonlinearSolveTrackerExt = "Tracker" SimpleNonlinearSolveZygoteExt = "Zygote" @@ -40,7 +39,7 @@ AllocCheck = "0.1.1" Aqua = "0.8" ArrayInterface = "7.9" CUDA = "5.2" -ChainRulesCore = "1.22" +ChainRulesCore = "1.23" ConcreteStructs = "0.2.3" DiffEqBase = "6.149" DiffResults = "1.1" @@ -59,13 +58,14 @@ PrecompileTools = "1.2" Random = "1.10" ReTestItems = "1.23" Reexport = "1.2" -ReverseDiff = "1.15" +ReverseDiff = "1.15.3" SciMLBase = "2.37.0" SciMLSensitivity = "7.58" +Setfield = "1.1.1" StaticArrays = "1.9" StaticArraysCore = "1.4.2" Test = "1.10" -Tracker = "0.2.32" +Tracker = "0.2.33" Zygote = "0.6.69" julia = "1.10" diff --git a/ext/SimpleNonlinearSolveStaticArraysExt.jl b/ext/SimpleNonlinearSolveStaticArraysExt.jl deleted file mode 100644 index c865084..0000000 --- a/ext/SimpleNonlinearSolveStaticArraysExt.jl +++ /dev/null @@ -1,7 +0,0 @@ -module SimpleNonlinearSolveStaticArraysExt - -using SimpleNonlinearSolve: SimpleNonlinearSolve - -@inline SimpleNonlinearSolve.__is_extension_loaded(::Val{:StaticArrays}) = true - -end diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index 04a32c9..eba6d99 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -24,6 +24,7 @@ using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidati NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm, build_solution, isinplace, _unwrap_val + using Setfield: @set! using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size end diff --git a/src/ad.jl b/src/ad.jl index ffae7cc..bb5afea 100644 --- a/src/ad.jl +++ b/src/ad.jl @@ -109,10 +109,12 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs.. end else # For small problems, nesting ForwardDiff is actually quite fast - _f = Base.Fix2(prob.f, newprob.p) if __is_extension_loaded(Val(:Zygote)) && (length(uu) + length(sol.resid) ≥ 50) # TODO: Remove once DI has the value_and_pullback_split defined - _F = @closure (u, p) -> __zygote_compute_nlls_vjp(_f, u, p) + _F = @closure (u, p) -> begin + _f = Base.Fix2(prob.f, p) + return __zygote_compute_nlls_vjp(_f, u, p) + end else _F = @closure (u, p) -> begin _f = Base.Fix2(prob.f, p) diff --git a/src/nlsolve/dfsane.jl b/src/nlsolve/dfsane.jl index 7dd1522..835ee4b 100644 --- a/src/nlsolve/dfsane.jl +++ b/src/nlsolve/dfsane.jl @@ -77,12 +77,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args... α_1 = one(T) f_1 = fx_norm - history_f_k = if x isa SArray || - (x isa Number && __is_extension_loaded(Val(:StaticArrays))) - ones(SVector{M, T}) * fx_norm - else - fill(fx_norm, M) - end + history_f_k = x isa SArray ? ones(SVector{M, T}) * fx_norm : + __history_vec(fx_norm, Val(M)) # Generate the cache @bb x_cache = similar(x) @@ -150,6 +146,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args... # Store function value if history_f_k isa SVector history_f_k = Base.setindex(history_f_k, fx_norm_new, mod1(k, M)) + elseif history_f_k isa NTuple + @set! history_f_k[mod1(k, M)] = fx_norm_new else history_f_k[mod1(k, M)] = fx_norm_new end @@ -158,3 +156,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args... return build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters) end + +@inline @generated function __history_vec(fx_norm, ::Val{M}) where {M} + M ≥ 11 && return :(fill(fx_norm, M)) # Julia can't specialize here + return :(ntuple(Returns(fx_norm), $(M))) +end diff --git a/src/utils.jl b/src/utils.jl index 0bf027f..2e76a4b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -32,16 +32,15 @@ function value_and_jacobian( if isinplace(prob) if cache isa HasAnalyticJacobian - prob.f.jac(J, x, p) + prob.f.jac(J, x, prob.p) f(y, x) - else - DI.jacobian!(f, y, J, ad, x, cache) + return y, J end - return y, J + return DI.value_and_jacobian!(f, y, J, ad, x, cache) else cache isa HasAnalyticJacobian && return f(x), prob.f.jac(x, prob.p) J === nothing && return DI.value_and_jacobian(f, ad, x, cache) - y, _ = DI.value_and_jacobian!(f, J, ad, x, cache) + y, J = DI.value_and_jacobian!(f, J, ad, x, cache) return y, J end end @@ -63,8 +62,9 @@ end function compute_jacobian_and_hessian( ad, prob::AbstractNonlinearProblem, f::F, y, x) where {F} if x isa Number - df = @closure x -> DI.derivative(f, ad, x) - return f(x), df(x), DI.derivative(df, ad, x) + H = DI.second_derivative(f, ad, x) + v, J = DI.value_and_derivative(f, ad, x) + return v, J, H end if isinplace(prob) diff --git a/test/core/rootfind_tests.jl b/test/core/rootfind_tests.jl index 1ef0757..6165314 100644 --- a/test/core/rootfind_tests.jl +++ b/test/core/rootfind_tests.jl @@ -34,7 +34,6 @@ end export quadratic_f, quadratic_f!, quadratic_f2, newton_fails, TERMINATION_CONDITIONS, benchmark_nlsolve_oop, benchmark_nlsolve_iip - end @testitem "First Order Methods" setup=[RootfindingTesting] tags=[:core] begin @@ -42,7 +41,7 @@ end SimpleTrustRegion, (args...; kwargs...) -> SimpleTrustRegion( args...; nlsolve_update_rule = Val(true), kwargs...)) - @testset "AutoDiff: $(nameof(typeof(autodiff))))" for autodiff in ( + @testset "AutoDiff: $(nameof(typeof(autodiff)))" for autodiff in ( AutoFiniteDiff(), AutoForwardDiff(), AutoPolyesterForwardDiff()) @testset "[OOP] u0: $(typeof(u0))" for u0 in ( [1.0, 1.0], @SVector[1.0, 1.0], 1.0) @@ -59,7 +58,7 @@ end end end - @testset "Termination condition: $(termination_condition) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS, + @testset "Termination condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS, u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0]) probN = NonlinearProblem(quadratic_f, u0, 2.0) @@ -79,7 +78,7 @@ end end end - @testset "Termination condition: $(termination_condition) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS, + @testset "Termination condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS, u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0]) probN = NonlinearProblem(quadratic_f, u0, 2.0) @@ -104,7 +103,7 @@ end @test all(abs.(sol.u .* sol.u .- 2) .< 1e-9) end - @testset "Termination condition: $(termination_condition) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS, + @testset "Termination condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS, u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0]) probN = NonlinearProblem(quadratic_f, u0, 2.0)