Skip to content
This repository has been archived by the owner on Oct 31, 2024. It is now read-only.

Commit

Permalink
Forward Mode overloads for Least Squares Problem
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 26, 2024
1 parent 8995a23 commit f3b73b8
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 2 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SimpleNonlinearSolve"
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
authors = ["SciML"]
version = "1.4.3"
version = "1.5.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -22,11 +22,13 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff"
SimpleNonlinearSolveStaticArraysExt = "StaticArrays"
SimpleNonlinearSolveZygoteExt = "Zygote"

[compat]
ADTypes = "0.2.6"
Expand Down
7 changes: 7 additions & 0 deletions ext/SimpleNonlinearSolveZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module SimpleNonlinearSolveZygoteExt

import SimpleNonlinearSolve

SimpleNonlinearSolve.__is_extension_loaded(::Val{:Zygote}) = true

Check warning on line 5 in ext/SimpleNonlinearSolveZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SimpleNonlinearSolveZygoteExt.jl#L5

Added line #L5 was not covered by tests

end
3 changes: 2 additions & 1 deletion src/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
end
end

function __nlsolve_ad(prob, alg, args...; kwargs...)
function __nlsolve_ad(
prob::Union{IntervalNonlinearProblem, NonlinearProblem}, alg, args...; kwargs...)
p = value(prob.p)
if prob isa IntervalNonlinearProblem
tspan = value.(prob.tspan)
Expand Down
16 changes: 16 additions & 0 deletions test/core/least_squares_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
return.- y_target
end

function loss_function!(resid, θ, p)
= true_function(p, θ)
@. resid =- y_target
return
end

θ_init = θ_true .+ 0.1
prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, x)

Expand All @@ -21,4 +27,14 @@
sol = solve(prob_oop, solver)
@test norm(sol.resid, Inf) < 1e-12
end

prob_iip = NonlinearLeastSquaresProblem(
NonlinearFunction{true}(loss_function!, resid_prototype = zeros(length(y_target))), θ_init, x)

@testset "Solver: $(nameof(typeof(solver)))" for solver in [
SimpleNewtonRaphson(AutoForwardDiff()), SimpleGaussNewton(AutoForwardDiff()),
SimpleNewtonRaphson(AutoFiniteDiff()), SimpleGaussNewton(AutoFiniteDiff())]
sol = solve(prob_iip, solver)
@test norm(sol.resid, Inf) < 1e-12
end
end

0 comments on commit f3b73b8

Please sign in to comment.