Skip to content

Commit

Permalink
Merge pull request #131 from avik-pal/ap/nlls_adjoint
Browse files Browse the repository at this point in the history
Forward Mode overloads for Least Squares Problem
  • Loading branch information
avik-pal committed Feb 26, 2024
2 parents 8fbfcbd + fe8f562 commit fd7d216
Show file tree
Hide file tree
Showing 9 changed files with 311 additions and 32 deletions.
24 changes: 21 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -22,34 +23,51 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff"
SimpleNonlinearSolveStaticArraysExt = "StaticArrays"
SimpleNonlinearSolveZygoteExt = "Zygote"

[compat]
ADTypes = "0.2.6"
AllocCheck = "0.1.1"
ArrayInterface = "7.7"
ChainRulesCore = "1.21"
Aqua = "0.8"
CUDA = "5.2"
ChainRulesCore = "1.22"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.146"
DiffResults = "1.1"
FastClosures = "0.3"
FiniteDiff = "2.22"
ForwardDiff = "0.10.36"
LinearAlgebra = "1.10"
LinearSolve = "2.25"
MaybeInplace = "0.1.1"
NonlinearProblemLibrary = "0.1.2"
Pkg = "1.10"
PolyesterForwardDiff = "0.1.1"
PrecompileTools = "1.2"
Random = "1.10"
ReTestItems = "1.23"
Reexport = "1.2"
SciMLBase = "2.23"
SciMLBase = "2.26.3"
SciMLSensitivity = "7.56"
StaticArrays = "1.9"
StaticArraysCore = "1.4.2"
Test = "1.10"
Zygote = "0.6.69"
julia = "1.10"

[extras]
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Expand All @@ -65,4 +83,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["AllocCheck", "DiffEqBase", "ForwardDiff", "LinearAlgebra", "LinearSolve", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SciMLSensitivity", "StaticArrays", "Zygote", "CUDA", "PolyesterForwardDiff", "Reexport", "Test"]
test = ["Aqua", "AllocCheck", "DiffEqBase", "ForwardDiff", "LinearAlgebra", "LinearSolve", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SciMLSensitivity", "StaticArrays", "Zygote", "CUDA", "PolyesterForwardDiff", "Reexport", "Test", "FiniteDiff"]
12 changes: 12 additions & 0 deletions ext/SimpleNonlinearSolveZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module SimpleNonlinearSolveZygoteExt

import SimpleNonlinearSolve, Zygote

SimpleNonlinearSolve.__is_extension_loaded(::Val{:Zygote}) = true

function SimpleNonlinearSolve.__zygote_compute_nlls_vjp(f::F, u, p) where {F}
y, pb = Zygote.pullback(Base.Fix2(f, p), u)
return 2 .* only(pb(y))
end

end
1 change: 1 addition & 0 deletions src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidat
AbstractSafeBestNonlinearTerminationMode,
NonlinearSafeTerminationReturnCode, get_termination_mode,
NONLINEARSOLVE_DEFAULT_NORM
import DiffResults
import ForwardDiff: Dual
import MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
import SciMLBase: AbstractNonlinearAlgorithm, build_solution, isinplace, _unwrap_val
Expand Down
106 changes: 103 additions & 3 deletions src/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,17 @@ function SciMLBase.solve(
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats,
sol.original)
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
end

function SciMLBase.solve(
prob::NonlinearLeastSquaresProblem{<:AbstractArray,
iip, <:Union{<:AbstractArray{<:Dual{T, V, P}}}},
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) where {T, V, P, iip}
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
end

for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
Expand All @@ -24,7 +33,8 @@ for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
end
end

function __nlsolve_ad(prob, alg, args...; kwargs...)
function __nlsolve_ad(
prob::Union{IntervalNonlinearProblem, NonlinearProblem}, alg, args...; kwargs...)
p = value(prob.p)
if prob isa IntervalNonlinearProblem
tspan = value.(prob.tspan)
Expand Down Expand Up @@ -55,6 +65,96 @@ function __nlsolve_ad(prob, alg, args...; kwargs...)
return sol, partials
end

function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs...)
p = value(prob.p)
u0 = value(prob.u0)
newprob = NonlinearLeastSquaresProblem(prob.f, u0, p; prob.kwargs...)

sol = solve(newprob, alg, args...; kwargs...)

uu = sol.u

# First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
# nested autodiff as the last resort
if SciMLBase.has_vjp(prob.f)
if isinplace(prob)
_F = @closure (du, u, p) -> begin
resid = similar(du, length(sol.resid))
prob.f(resid, u, p)
prob.f.vjp(du, resid, u, p)
du .*= 2
return nothing
end
else
_F = @closure (u, p) -> begin
resid = prob.f(u, p)
return reshape(2 .* prob.f.vjp(resid, u, p), size(u))
end
end
elseif SciMLBase.has_jac(prob.f)
if isinplace(prob)
_F = @closure (du, u, p) -> begin
J = similar(du, length(sol.resid), length(u))
prob.f.jac(J, u, p)
resid = similar(du, length(sol.resid))
prob.f(resid, u, p)
mul!(reshape(du, 1, :), vec(resid)', J, 2, false)
return nothing
end
else
_F = @closure (u, p) -> begin
return reshape(2 .* vec(prob.f(u, p))' * prob.f.jac(u, p), size(u))
end
end
else
if isinplace(prob)
_F = @closure (du, u, p) -> begin
resid = similar(du, length(sol.resid))
res = DiffResults.DiffResult(
resid, similar(du, length(sol.resid), length(u)))
_f = @closure (du, u) -> prob.f(du, u, p)
ForwardDiff.jacobian!(res, _f, resid, u)
mul!(reshape(du, 1, :), vec(DiffResults.value(res))',
DiffResults.jacobian(res), 2, false)
return nothing
end
else
# For small problems, nesting ForwardDiff is actually quite fast
if __is_extension_loaded(Val(:Zygote)) && (length(uu) + length(sol.resid) 50)
_F = @closure (u, p) -> __zygote_compute_nlls_vjp(prob.f, u, p)
else
_F = @closure (u, p) -> begin
T = promote_type(eltype(u), eltype(p))
res = DiffResults.DiffResult(
similar(u, T, size(sol.resid)), similar(
u, T, length(sol.resid), length(u)))
ForwardDiff.jacobian!(res, Base.Fix2(prob.f, p), u)
return reshape(
2 .* vec(DiffResults.value(res))' * DiffResults.jacobian(res),
size(u))
end
end
end
end

f_p = __nlsolve_∂f_∂p(prob, _F, uu, p)
f_x = __nlsolve_∂f_∂u(prob, _F, uu, p)

z_arr = -f_x \ f_p

pp = prob.p
sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
if uu isa Number
partials = sum(sumfun, zip(z_arr, pp))
elseif p isa Number
partials = sumfun((z_arr, pp))
else
partials = sum(sumfun, zip(eachcol(z_arr), pp))
end

return sol, partials
end

@inline function __nlsolve_∂f_∂p(prob, f::F, u, p) where {F}
if isinplace(prob)
__f = p -> begin
Expand Down
3 changes: 3 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,6 @@ function __get_tolerance(x::Union{SArray, Number}, ::Nothing, ::Type{T}) where {
η = real(oneunit(T)) * (eps(real(one(T))))^(real(T)(0.8))
return T(η)
end

# Extension
function __zygote_compute_nlls_vjp end
9 changes: 9 additions & 0 deletions test/core/aqua_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
@testitem "Aqua" begin
using Aqua

Aqua.test_all(SimpleNonlinearSolve; piracies = false, ambiguities = false)
Aqua.test_piracies(SimpleNonlinearSolve;
treat_as_own = [
NonlinearProblem, NonlinearLeastSquaresProblem, IntervalNonlinearProblem])
Aqua.test_ambiguities(SimpleNonlinearSolve; recursive = false)
end
122 changes: 120 additions & 2 deletions test/core/forward_ad_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testsetup module ForwardADTesting
@testsetup module ForwardADRootfindingTesting
using Reexport
@reexport using ForwardDiff, SimpleNonlinearSolve, StaticArrays, LinearAlgebra
import SimpleNonlinearSolve: AbstractSimpleNonlinearSolveAlgorithm
Expand Down Expand Up @@ -40,7 +40,7 @@ __compatible(::SimpleHalley, ::Val{:iip}) = false
export test_f, test_f!, jacobian_f, solve_with, __compatible
end

@testitem "ForwardDiff.jl Integration" setup=[ForwardADTesting] begin
@testitem "ForwardDiff.jl Integration: Rootfinding" setup=[ForwardADRootfindingTesting] begin
@testset "$(nameof(typeof(alg)))" for alg in (SimpleNewtonRaphson(),
SimpleTrustRegion(), SimpleTrustRegion(; nlsolve_update_rule = Val(true)),
SimpleHalley(), SimpleBroyden(), SimpleKlement(), SimpleDFSane())
Expand Down Expand Up @@ -88,3 +88,121 @@ end
end
end
end

@testsetup module ForwardADNLLSTesting
using Reexport
@reexport using ForwardDiff, FiniteDiff, SimpleNonlinearSolve, StaticArrays, LinearAlgebra,
Zygote

true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])

const θ_true = [1.0, 0.1, 2.0, 0.5]
const x = [-1.0, -0.5, 0.0, 0.5, 1.0]
const y_target = true_function(x, θ_true)

function loss_function(θ, p)
= true_function(p, θ)
return.- y_target
end

function loss_function_jac(θ, p)
return ForwardDiff.jacobian-> loss_function(θ, p), θ)
end

loss_function_vjp(v, θ, p) = reshape(vec(v)' * loss_function_jac(θ, p), size(θ))

function loss_function!(resid, θ, p)
= true_function(p, θ)
@. resid =- y_target
return
end

function loss_function_jac!(J, θ, p)
J .= ForwardDiff.jacobian-> loss_function(θ, p), θ)
return
end

function loss_function_vjp!(vJ, v, θ, p)
vec(vJ) .= reshape(vec(v)' * loss_function_jac(θ, p), size(θ))
return
end

θ_init = θ_true .+ 0.1

export loss_function, loss_function!, loss_function_jac, loss_function_vjp,
loss_function_jac!, loss_function_vjp!, θ_init, x, y_target
end

@testitem "ForwardDiff.jl Integration: NLLS" setup=[ForwardADNLLSTesting] begin
@testset "$(nameof(typeof(alg)))" for alg in (
SimpleNewtonRaphson(), SimpleGaussNewton(),
SimpleNewtonRaphson(AutoFiniteDiff()), SimpleGaussNewton(AutoFiniteDiff()))
function obj_1(p)
prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, p)
sol = solve(prob_oop, alg)
return sum(abs2, sol.u)
end

function obj_2(p)
ff = NonlinearFunction{false}(loss_function; jac = loss_function_jac)
prob_oop = NonlinearLeastSquaresProblem{false}(ff, θ_init, p)
sol = solve(prob_oop, alg)
return sum(abs2, sol.u)
end

function obj_3(p)
ff = NonlinearFunction{false}(loss_function; vjp = loss_function_vjp)
prob_oop = NonlinearLeastSquaresProblem{false}(ff, θ_init, p)
sol = solve(prob_oop, alg)
return sum(abs2, sol.u)
end

finitediff = FiniteDiff.finite_difference_gradient(obj_1, x)

fdiff1 = ForwardDiff.gradient(obj_1, x)
fdiff2 = ForwardDiff.gradient(obj_2, x)
fdiff3 = ForwardDiff.gradient(obj_3, x)

@test finitedifffdiff1 atol=1e-5
@test finitedifffdiff2 atol=1e-5
@test finitedifffdiff3 atol=1e-5
@test fdiff1 fdiff2 fdiff3

function obj_4(p)
prob_iip = NonlinearLeastSquaresProblem(
NonlinearFunction{true}(
loss_function!; resid_prototype = zeros(length(y_target))), θ_init, p)
sol = solve(prob_iip, alg)
return sum(abs2, sol.u)
end

function obj_5(p)
ff = NonlinearFunction{true}(
loss_function!; resid_prototype = zeros(length(y_target)), jac = loss_function_jac!)
prob_iip = NonlinearLeastSquaresProblem(
ff, θ_init, p)
sol = solve(prob_iip, alg)
return sum(abs2, sol.u)
end

function obj_6(p)
ff = NonlinearFunction{true}(
loss_function!; resid_prototype = zeros(length(y_target)), vjp = loss_function_vjp!)
prob_iip = NonlinearLeastSquaresProblem(
ff, θ_init, p)
sol = solve(prob_iip, alg)
return sum(abs2, sol.u)
end

finitediff = FiniteDiff.finite_difference_gradient(obj_4, x)

fdiff4 = ForwardDiff.gradient(obj_4, x)
fdiff5 = ForwardDiff.gradient(obj_5, x)
fdiff6 = ForwardDiff.gradient(obj_6, x)

@test finitedifffdiff4 atol=1e-5
@test finitedifffdiff5 atol=1e-5
@test finitedifffdiff6 atol=1e-5
@test fdiff4 fdiff5 fdiff6
end
end
16 changes: 16 additions & 0 deletions test/core/least_squares_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
return.- y_target
end

function loss_function!(resid, θ, p)
= true_function(p, θ)
@. resid =- y_target
return
end

θ_init = θ_true .+ 0.1
prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, x)

Expand All @@ -21,4 +27,14 @@
sol = solve(prob_oop, solver)
@test norm(sol.resid, Inf) < 1e-12
end

prob_iip = NonlinearLeastSquaresProblem(
NonlinearFunction{true}(loss_function!, resid_prototype = zeros(length(y_target))), θ_init, x)

@testset "Solver: $(nameof(typeof(solver)))" for solver in [
SimpleNewtonRaphson(AutoForwardDiff()), SimpleGaussNewton(AutoForwardDiff()),
SimpleNewtonRaphson(AutoFiniteDiff()), SimpleGaussNewton(AutoFiniteDiff())]
sol = solve(prob_iip, solver)
@test norm(sol.resid, Inf) < 1e-12
end
end
Loading

2 comments on commit fd7d216

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: Version 1.5.0 already exists

Please sign in to comment.