Skip to content

Commit

Permalink
Merge pull request #148 from SciML/ap/di
Browse files Browse the repository at this point in the history
Use DifferentiationInterface
  • Loading branch information
avik-pal committed May 26, 2024
2 parents 0d300d3 + 29ae939 commit 6b682af
Show file tree
Hide file tree
Showing 13 changed files with 184 additions and 329 deletions.
18 changes: 9 additions & 9 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
name = "SimpleNonlinearSolve"
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
authors = ["SciML"]
version = "1.8.1"
version = "1.9.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -17,21 +18,18 @@ MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff"
SimpleNonlinearSolveReverseDiffExt = "ReverseDiff"
SimpleNonlinearSolveStaticArraysExt = "StaticArrays"
SimpleNonlinearSolveTrackerExt = "Tracker"
SimpleNonlinearSolveZygoteExt = "Zygote"

Expand All @@ -41,13 +39,14 @@ AllocCheck = "0.1.1"
Aqua = "0.8"
ArrayInterface = "7.9"
CUDA = "5.2"
ChainRulesCore = "1.22"
ChainRulesCore = "1.23"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.149"
DiffResults = "1.1"
DifferentiationInterface = "0.4"
ExplicitImports = "1.5.0"
FastClosures = "0.3.2"
FiniteDiff = "2.22"
FiniteDiff = "2.23.1"
ForwardDiff = "0.10.36"
LinearAlgebra = "1.10"
LinearSolve = "2.30"
Expand All @@ -59,13 +58,14 @@ PrecompileTools = "1.2"
Random = "1.10"
ReTestItems = "1.23"
Reexport = "1.2"
ReverseDiff = "1.15"
ReverseDiff = "1.15.3"
SciMLBase = "2.37.0"
SciMLSensitivity = "7.58"
Setfield = "1.1.1"
StaticArrays = "1.9"
StaticArraysCore = "1.4.2"
Test = "1.10"
Tracker = "0.2.32"
Tracker = "0.2.33"
Zygote = "0.6.69"
julia = "1.10"

Expand Down
20 changes: 0 additions & 20 deletions ext/SimpleNonlinearSolvePolyesterForwardDiffExt.jl

This file was deleted.

7 changes: 0 additions & 7 deletions ext/SimpleNonlinearSolveStaticArraysExt.jl

This file was deleted.

15 changes: 10 additions & 5 deletions src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ module SimpleNonlinearSolve
using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidations

@recompile_invalidations begin
using ADTypes: ADTypes, AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff
using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff,
AutoPolyesterForwardDiff
using ArrayInterface: ArrayInterface
using ConcreteStructs: @concrete
using DiffEqBase: DiffEqBase, AbstractNonlinearTerminationMode,
AbstractSafeNonlinearTerminationMode,
AbstractSafeBestNonlinearTerminationMode, AbsNormTerminationMode,
NONLINEARSOLVE_DEFAULT_NORM
using DifferentiationInterface: DifferentiationInterface
using DiffResults: DiffResults
using FastClosures: @closure
using FiniteDiff: FiniteDiff
Expand All @@ -18,13 +20,16 @@ using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidati
mul!, norm, transpose
using MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
using Reexport: @reexport
using SciMLBase: SciMLBase, IntervalNonlinearProblem, NonlinearFunction,
NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode, init,
remake, solve, AbstractNonlinearAlgorithm, build_solution, isinplace,
_unwrap_val
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem,
ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm,
build_solution, isinplace, _unwrap_val
using Setfield: @set!
using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size
end

const DI = DifferentiationInterface

@reexport using SciMLBase

abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
Expand Down
94 changes: 36 additions & 58 deletions src/ad.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,15 @@
function SciMLBase.solve(
prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...;
kwargs...) where {T, V, P, iip}
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
end

function SciMLBase.solve(
prob::NonlinearLeastSquaresProblem{
<:AbstractArray, iip, <:Union{<:AbstractArray{<:Dual{T, V, P}}}},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...;
kwargs...) where {T, V, P, iip}
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
@eval function SciMLBase.solve(
prob::$(pType){<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...;
kwargs...) where {T, V, P, iip}
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
end
end

for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
Expand Down Expand Up @@ -47,8 +37,7 @@ function __nlsolve_ad(
tspan = value.(prob.tspan)
newprob = IntervalNonlinearProblem(prob.f, tspan, p; prob.kwargs...)
else
u0 = value(prob.u0)
newprob = NonlinearProblem(prob.f, u0, p; prob.kwargs...)
newprob = remake(prob; p, u0 = value(prob.u0))
end

sol = solve(newprob, alg, args...; kwargs...)
Expand All @@ -73,20 +62,16 @@ function __nlsolve_ad(
end

function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs...)
p = value(prob.p)
u0 = value(prob.u0)
newprob = NonlinearLeastSquaresProblem(prob.f, u0, p; prob.kwargs...)

newprob = remake(prob; p = value(prob.p), u0 = value(prob.u0))
sol = solve(newprob, alg, args...; kwargs...)

uu = sol.u

# First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
# nested autodiff as the last resort
if SciMLBase.has_vjp(prob.f)
if isinplace(prob)
_F = @closure (du, u, p) -> begin
resid = similar(du, length(sol.resid))
resid = __similar(du, length(sol.resid))
prob.f(resid, u, p)
prob.f.vjp(du, resid, u, p)
du .*= 2
Expand All @@ -101,9 +86,9 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs..
elseif SciMLBase.has_jac(prob.f)
if isinplace(prob)
_F = @closure (du, u, p) -> begin
J = similar(du, length(sol.resid), length(u))
J = __similar(du, length(sol.resid), length(u))
prob.f.jac(J, u, p)
resid = similar(du, length(sol.resid))
resid = __similar(du, length(sol.resid))
prob.f(resid, u, p)
mul!(reshape(du, 1, :), vec(resid)', J, 2, false)
return nothing
Expand All @@ -116,43 +101,40 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs..
else
if isinplace(prob)
_F = @closure (du, u, p) -> begin
resid = similar(du, length(sol.resid))
res = DiffResults.DiffResult(
resid, similar(du, length(sol.resid), length(u)))
_f = @closure (du, u) -> prob.f(du, u, p)
ForwardDiff.jacobian!(res, _f, resid, u)
mul!(reshape(du, 1, :), vec(DiffResults.value(res))',
DiffResults.jacobian(res), 2, false)
resid = __similar(du, length(sol.resid))
v, J = DI.value_and_jacobian(_f, resid, AutoForwardDiff(), u)
mul!(reshape(du, 1, :), vec(v)', J, 2, false)
return nothing
end
else
# For small problems, nesting ForwardDiff is actually quite fast
if __is_extension_loaded(Val(:Zygote)) && (length(uu) + length(sol.resid) 50)
_F = @closure (u, p) -> __zygote_compute_nlls_vjp(prob.f, u, p)
# TODO: Remove once DI has the value_and_pullback_split defined
_F = @closure (u, p) -> begin
_f = Base.Fix2(prob.f, p)
return __zygote_compute_nlls_vjp(_f, u, p)
end
else
_F = @closure (u, p) -> begin
T = promote_type(eltype(u), eltype(p))
res = DiffResults.DiffResult(similar(u, T, size(sol.resid)),
similar(u, T, length(sol.resid), length(u)))
ForwardDiff.jacobian!(res, Base.Fix2(prob.f, p), u)
return reshape(
2 .* vec(DiffResults.value(res))' * DiffResults.jacobian(res),
size(u))
_f = Base.Fix2(prob.f, p)
v, J = DI.value_and_jacobian(_f, AutoForwardDiff(), u)
return reshape(2 .* vec(v)' * J, size(u))
end
end
end
end

f_p = __nlsolve_∂f_∂p(prob, _F, uu, p)
f_x = __nlsolve_∂f_∂u(prob, _F, uu, p)
f_p = __nlsolve_∂f_∂p(prob, _F, uu, newprob.p)
f_x = __nlsolve_∂f_∂u(prob, _F, uu, newprob.p)

z_arr = -f_x \ f_p

pp = prob.p
sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
if uu isa Number
partials = sum(sumfun, zip(z_arr, pp))
elseif p isa Number
elseif pp isa Number
partials = sumfun((z_arr, pp))
else
partials = sum(sumfun, zip(eachcol(z_arr), pp))
Expand All @@ -164,7 +146,7 @@ end
@inline function __nlsolve_∂f_∂p(prob, f::F, u, p) where {F}
if isinplace(prob)
__f = p -> begin
du = similar(u, promote_type(eltype(u), eltype(p)))
du = __similar(u, promote_type(eltype(u), eltype(p)))
f(du, u, p)
return du
end
Expand All @@ -182,16 +164,12 @@ end

@inline function __nlsolve_∂f_∂u(prob, f::F, u, p) where {F}
if isinplace(prob)
du = similar(u)
__f = (du, u) -> f(du, u, p)
ForwardDiff.jacobian(__f, du, u)
__f = @closure (du, u) -> f(du, u, p)
return ForwardDiff.jacobian(__f, __similar(u), u)
else
__f = Base.Fix2(f, p)
if u isa Number
return ForwardDiff.derivative(__f, u)
else
return ForwardDiff.jacobian(__f, u)
end
u isa Number && return ForwardDiff.derivative(__f, u)
return ForwardDiff.jacobian(__f, u)
end
end

Expand Down
15 changes: 9 additions & 6 deletions src/nlsolve/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...
α_1 = one(T)
f_1 = fx_norm

history_f_k = if x isa SArray ||
(x isa Number && __is_extension_loaded(Val(:StaticArrays)))
ones(SVector{M, T}) * fx_norm
else
fill(fx_norm, M)
end
history_f_k = x isa SArray ? ones(SVector{M, T}) * fx_norm :
__history_vec(fx_norm, Val(M))

# Generate the cache
@bb x_cache = similar(x)
Expand Down Expand Up @@ -150,6 +146,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...
# Store function value
if history_f_k isa SVector
history_f_k = Base.setindex(history_f_k, fx_norm_new, mod1(k, M))
elseif history_f_k isa NTuple
@set! history_f_k[mod1(k, M)] = fx_norm_new
else
history_f_k[mod1(k, M)] = fx_norm_new
end
Expand All @@ -158,3 +156,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...

return build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
end

@inline @generated function __history_vec(fx_norm, ::Val{M}) where {M}
M 11 && return :(fill(fx_norm, M)) # Julia can't specialize here
return :(ntuple(Returns(fx_norm), $(M)))
end
Loading

2 comments on commit 6b682af

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/107675

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.9.0 -m "<description of version>" 6b682af1e2155298064d2a035939e6aa373edbed
git push origin v1.9.0

Please sign in to comment.