Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward Mode overloads for Least Squares Problem #131

Merged
merged 4 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
name = "SimpleNonlinearSolve"
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
authors = ["SciML"]
version = "1.4.3"
version = "1.5.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -22,34 +23,51 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff"
SimpleNonlinearSolveStaticArraysExt = "StaticArrays"
SimpleNonlinearSolveZygoteExt = "Zygote"

[compat]
ADTypes = "0.2.6"
AllocCheck = "0.1.1"
ArrayInterface = "7.7"
ChainRulesCore = "1.21"
Aqua = "0.8"
CUDA = "5.2"
ChainRulesCore = "1.22"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.146"
DiffResults = "1.1"
FastClosures = "0.3"
FiniteDiff = "2.22"
ForwardDiff = "0.10.36"
LinearAlgebra = "1.10"
LinearSolve = "2.25"
MaybeInplace = "0.1.1"
NonlinearProblemLibrary = "0.1.2"
Pkg = "1.10"
PolyesterForwardDiff = "0.1.1"
PrecompileTools = "1.2"
Random = "1.10"
ReTestItems = "1.23"
Reexport = "1.2"
SciMLBase = "2.23"
SciMLBase = "2.26.3"
SciMLSensitivity = "7.56"
StaticArrays = "1.9"
StaticArraysCore = "1.4.2"
Test = "1.10"
Zygote = "0.6.69"
julia = "1.10"

[extras]
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Expand All @@ -65,4 +83,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["AllocCheck", "DiffEqBase", "ForwardDiff", "LinearAlgebra", "LinearSolve", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SciMLSensitivity", "StaticArrays", "Zygote", "CUDA", "PolyesterForwardDiff", "Reexport", "Test"]
test = ["Aqua", "AllocCheck", "DiffEqBase", "ForwardDiff", "LinearAlgebra", "LinearSolve", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SciMLSensitivity", "StaticArrays", "Zygote", "CUDA", "PolyesterForwardDiff", "Reexport", "Test", "FiniteDiff"]
12 changes: 12 additions & 0 deletions ext/SimpleNonlinearSolveZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module SimpleNonlinearSolveZygoteExt

import SimpleNonlinearSolve, Zygote

SimpleNonlinearSolve.__is_extension_loaded(::Val{:Zygote}) = true

function SimpleNonlinearSolve.__zygote_compute_nlls_vjp(f::F, u, p) where {F}
y, pb = Zygote.pullback(Base.Fix2(f, p), u)
return 2 .* only(pb(y))

Check warning on line 9 in ext/SimpleNonlinearSolveZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveZygoteExt.jl#L7-L9

Added lines #L7 - L9 were not covered by tests
end

end
1 change: 1 addition & 0 deletions src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidat
AbstractSafeBestNonlinearTerminationMode,
NonlinearSafeTerminationReturnCode, get_termination_mode,
NONLINEARSOLVE_DEFAULT_NORM
import DiffResults
import ForwardDiff: Dual
import MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
import SciMLBase: AbstractNonlinearAlgorithm, build_solution, isinplace, _unwrap_val
Expand Down
106 changes: 103 additions & 3 deletions src/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,17 @@
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats,
sol.original)
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
end

function SciMLBase.solve(
prob::NonlinearLeastSquaresProblem{<:AbstractArray,
iip, <:Union{<:AbstractArray{<:Dual{T, V, P}}}},
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) where {T, V, P, iip}
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
end

for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
Expand All @@ -24,7 +33,8 @@
end
end

function __nlsolve_ad(prob, alg, args...; kwargs...)
function __nlsolve_ad(
prob::Union{IntervalNonlinearProblem, NonlinearProblem}, alg, args...; kwargs...)
p = value(prob.p)
if prob isa IntervalNonlinearProblem
tspan = value.(prob.tspan)
Expand Down Expand Up @@ -55,6 +65,96 @@
return sol, partials
end

function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs...)
p = value(prob.p)
u0 = value(prob.u0)
newprob = NonlinearLeastSquaresProblem(prob.f, u0, p; prob.kwargs...)

sol = solve(newprob, alg, args...; kwargs...)

uu = sol.u

# First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
# nested autodiff as the last resort
if SciMLBase.has_vjp(prob.f)
if isinplace(prob)
_F = @closure (du, u, p) -> begin
resid = similar(du, length(sol.resid))
prob.f(resid, u, p)
prob.f.vjp(du, resid, u, p)
du .*= 2
return nothing
end
else
_F = @closure (u, p) -> begin
resid = prob.f(u, p)
return reshape(2 .* prob.f.vjp(resid, u, p), size(u))
end
end
elseif SciMLBase.has_jac(prob.f)
if isinplace(prob)
_F = @closure (du, u, p) -> begin
J = similar(du, length(sol.resid), length(u))
prob.f.jac(J, u, p)
resid = similar(du, length(sol.resid))
prob.f(resid, u, p)
mul!(reshape(du, 1, :), vec(resid)', J, 2, false)
return nothing
end
else
_F = @closure (u, p) -> begin
return reshape(2 .* vec(prob.f(u, p))' * prob.f.jac(u, p), size(u))
end
end
else
if isinplace(prob)
_F = @closure (du, u, p) -> begin
resid = similar(du, length(sol.resid))
res = DiffResults.DiffResult(
resid, similar(du, length(sol.resid), length(u)))
_f = @closure (du, u) -> prob.f(du, u, p)
ForwardDiff.jacobian!(res, _f, resid, u)
mul!(reshape(du, 1, :), vec(DiffResults.value(res))',
DiffResults.jacobian(res), 2, false)
return nothing
end
else
# For small problems, nesting ForwardDiff is actually quite fast
if __is_extension_loaded(Val(:Zygote)) && (length(uu) + length(sol.resid) ≥ 50)
_F = @closure (u, p) -> __zygote_compute_nlls_vjp(prob.f, u, p)

Check warning on line 124 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L124

Added line #L124 was not covered by tests
else
_F = @closure (u, p) -> begin
T = promote_type(eltype(u), eltype(p))
res = DiffResults.DiffResult(
similar(u, T, size(sol.resid)), similar(
u, T, length(sol.resid), length(u)))
ForwardDiff.jacobian!(res, Base.Fix2(prob.f, p), u)
return reshape(
2 .* vec(DiffResults.value(res))' * DiffResults.jacobian(res),
size(u))
end
end
end
end

f_p = __nlsolve_∂f_∂p(prob, _F, uu, p)
f_x = __nlsolve_∂f_∂u(prob, _F, uu, p)

z_arr = -f_x \ f_p

pp = prob.p
sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
if uu isa Number
partials = sum(sumfun, zip(z_arr, pp))

Check warning on line 148 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L148

Added line #L148 was not covered by tests
elseif p isa Number
partials = sumfun((z_arr, pp))

Check warning on line 150 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L150

Added line #L150 was not covered by tests
else
partials = sum(sumfun, zip(eachcol(z_arr), pp))
end

return sol, partials
end

@inline function __nlsolve_∂f_∂p(prob, f::F, u, p) where {F}
if isinplace(prob)
__f = p -> begin
Expand Down
3 changes: 3 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,6 @@ function __get_tolerance(x::Union{SArray, Number}, ::Nothing, ::Type{T}) where {
η = real(oneunit(T)) * (eps(real(one(T))))^(real(T)(0.8))
return T(η)
end

# Extension
function __zygote_compute_nlls_vjp end
9 changes: 9 additions & 0 deletions test/core/aqua_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
@testitem "Aqua" begin
using Aqua

Aqua.test_all(SimpleNonlinearSolve; piracies = false, ambiguities = false)
Aqua.test_piracies(SimpleNonlinearSolve;
treat_as_own = [
NonlinearProblem, NonlinearLeastSquaresProblem, IntervalNonlinearProblem])
Aqua.test_ambiguities(SimpleNonlinearSolve; recursive = false)
end
122 changes: 120 additions & 2 deletions test/core/forward_ad_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testsetup module ForwardADTesting
@testsetup module ForwardADRootfindingTesting
using Reexport
@reexport using ForwardDiff, SimpleNonlinearSolve, StaticArrays, LinearAlgebra
import SimpleNonlinearSolve: AbstractSimpleNonlinearSolveAlgorithm
Expand Down Expand Up @@ -40,7 +40,7 @@ __compatible(::SimpleHalley, ::Val{:iip}) = false
export test_f, test_f!, jacobian_f, solve_with, __compatible
end

@testitem "ForwardDiff.jl Integration" setup=[ForwardADTesting] begin
@testitem "ForwardDiff.jl Integration: Rootfinding" setup=[ForwardADRootfindingTesting] begin
@testset "$(nameof(typeof(alg)))" for alg in (SimpleNewtonRaphson(),
SimpleTrustRegion(), SimpleTrustRegion(; nlsolve_update_rule = Val(true)),
SimpleHalley(), SimpleBroyden(), SimpleKlement(), SimpleDFSane())
Expand Down Expand Up @@ -88,3 +88,121 @@ end
end
end
end

@testsetup module ForwardADNLLSTesting
using Reexport
@reexport using ForwardDiff, FiniteDiff, SimpleNonlinearSolve, StaticArrays, LinearAlgebra,
Zygote

true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])

const θ_true = [1.0, 0.1, 2.0, 0.5]
const x = [-1.0, -0.5, 0.0, 0.5, 1.0]
const y_target = true_function(x, θ_true)

function loss_function(θ, p)
ŷ = true_function(p, θ)
return ŷ .- y_target
end

function loss_function_jac(θ, p)
return ForwardDiff.jacobian(θ -> loss_function(θ, p), θ)
end

loss_function_vjp(v, θ, p) = reshape(vec(v)' * loss_function_jac(θ, p), size(θ))

function loss_function!(resid, θ, p)
ŷ = true_function(p, θ)
@. resid = ŷ - y_target
return
end

function loss_function_jac!(J, θ, p)
J .= ForwardDiff.jacobian(θ -> loss_function(θ, p), θ)
return
end

function loss_function_vjp!(vJ, v, θ, p)
vec(vJ) .= reshape(vec(v)' * loss_function_jac(θ, p), size(θ))
return
end

θ_init = θ_true .+ 0.1

export loss_function, loss_function!, loss_function_jac, loss_function_vjp,
loss_function_jac!, loss_function_vjp!, θ_init, x, y_target
end

@testitem "ForwardDiff.jl Integration: NLLS" setup=[ForwardADNLLSTesting] begin
@testset "$(nameof(typeof(alg)))" for alg in (
SimpleNewtonRaphson(), SimpleGaussNewton(),
SimpleNewtonRaphson(AutoFiniteDiff()), SimpleGaussNewton(AutoFiniteDiff()))
function obj_1(p)
prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, p)
sol = solve(prob_oop, alg)
return sum(abs2, sol.u)
end

function obj_2(p)
ff = NonlinearFunction{false}(loss_function; jac = loss_function_jac)
prob_oop = NonlinearLeastSquaresProblem{false}(ff, θ_init, p)
sol = solve(prob_oop, alg)
return sum(abs2, sol.u)
end

function obj_3(p)
ff = NonlinearFunction{false}(loss_function; vjp = loss_function_vjp)
prob_oop = NonlinearLeastSquaresProblem{false}(ff, θ_init, p)
sol = solve(prob_oop, alg)
return sum(abs2, sol.u)
end

finitediff = FiniteDiff.finite_difference_gradient(obj_1, x)

fdiff1 = ForwardDiff.gradient(obj_1, x)
fdiff2 = ForwardDiff.gradient(obj_2, x)
fdiff3 = ForwardDiff.gradient(obj_3, x)

@test finitediff≈fdiff1 atol=1e-5
@test finitediff≈fdiff2 atol=1e-5
@test finitediff≈fdiff3 atol=1e-5
@test fdiff1 ≈ fdiff2 ≈ fdiff3

function obj_4(p)
prob_iip = NonlinearLeastSquaresProblem(
NonlinearFunction{true}(
loss_function!; resid_prototype = zeros(length(y_target))), θ_init, p)
sol = solve(prob_iip, alg)
return sum(abs2, sol.u)
end

function obj_5(p)
ff = NonlinearFunction{true}(
loss_function!; resid_prototype = zeros(length(y_target)), jac = loss_function_jac!)
prob_iip = NonlinearLeastSquaresProblem(
ff, θ_init, p)
sol = solve(prob_iip, alg)
return sum(abs2, sol.u)
end

function obj_6(p)
ff = NonlinearFunction{true}(
loss_function!; resid_prototype = zeros(length(y_target)), vjp = loss_function_vjp!)
prob_iip = NonlinearLeastSquaresProblem(
ff, θ_init, p)
sol = solve(prob_iip, alg)
return sum(abs2, sol.u)
end

finitediff = FiniteDiff.finite_difference_gradient(obj_4, x)

fdiff4 = ForwardDiff.gradient(obj_4, x)
fdiff5 = ForwardDiff.gradient(obj_5, x)
fdiff6 = ForwardDiff.gradient(obj_6, x)

@test finitediff≈fdiff4 atol=1e-5
@test finitediff≈fdiff5 atol=1e-5
@test finitediff≈fdiff6 atol=1e-5
@test fdiff4 ≈ fdiff5 ≈ fdiff6
end
end
Loading
Loading