Skip to content

Commit

Permalink
Merge pull request #91 from SciML/least_squares
Browse files Browse the repository at this point in the history
Add a dispatch to SimpleNewtonRaphson for NNLS and SimpleGaussNewton
  • Loading branch information
ChrisRackauckas committed Nov 3, 2023
2 parents 0d73574 + cf03317 commit ea20ad6
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ PrecompileTools.@compile_workload begin
end

export Bisection, Brent, Broyden, LBroyden, SimpleDFSane, Falsi, SimpleHalley, Klement,
Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld, ITP
Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld, ITP, SimpleGaussNewton
export BatchedBroyden, BatchedSimpleNewtonRaphson, BatchedSimpleDFSane

end # module
16 changes: 14 additions & 2 deletions src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ function SimpleNewtonRaphson(; batched = false,
SciMLBase._unwrap_val(diff_type)}()
end

function SciMLBase.__solve(prob::NonlinearProblem,
const SimpleGaussNewton = SimpleNewtonRaphson

function SciMLBase.__solve(prob::Union{NonlinearProblem,NonlinearLeastSquaresProblem},
alg::SimpleNewtonRaphson, args...; abstol = nothing,
reltol = nothing,
maxiters = 1000, kwargs...)
Expand All @@ -74,6 +76,10 @@ function SciMLBase.__solve(prob::NonlinearProblem,
error("SimpleNewtonRaphson currently only supports out-of-place nonlinear problems")
end

if prob isa NonlinearLeastSquaresProblem && !(typeof(prob.u0) <: Union{Number, AbstractVector})
error("SimpleGaussNewton only supports Number and AbstactVector types. Please convert any problem of AbstractArray into one with u0 as AbstractVector")
end

atol = abstol !== nothing ? abstol :
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)
rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5)
Expand All @@ -100,7 +106,13 @@ function SciMLBase.__solve(prob::NonlinearProblem,
end
iszero(fx) &&
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
Δx = _restructure(fx, dfx \ _vec(fx))

if prob isa NonlinearProblem
Δx = _restructure(fx, dfx \ _vec(fx))
else
Δx = dfx \ fx
end

x -= Δx
if isapprox(x, xo, atol = atol, rtol = rtol)
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
Expand Down
19 changes: 19 additions & 0 deletions test/least_squares.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using SimpleNonlinearSolve, LinearAlgebra, Test

true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])

θ_true = [1.0, 0.1, 2.0, 0.5]
x = [-1.0, -0.5, 0.0, 0.5, 1.0]
y_target = true_function(x, θ_true)

function loss_function(θ, p)
= true_function(p, θ)
return abs2.(ŷ .- y_target)
end

θ_init = θ_true .+ 0.1
prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, x)
sol = solve(prob_oop, SimpleNewtonRaphson())
sol = solve(prob_oop, SimpleGaussNewton())

@test norm(sol.resid) < 1e-12
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ const GROUP = get(ENV, "GROUP", "All")
@time @safetestset "Basic Tests + Some AD" include("basictests.jl")
@time @safetestset "Inplace Tests" include("inplace.jl")
@time @safetestset "Matrix Resizing Tests" include("matrix_resizing_tests.jl")
@time @safetestset "Least Squares Tests" include("least_squares.jl")
end
end

0 comments on commit ea20ad6

Please sign in to comment.