diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index 96d76d9..0713ea2 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -89,7 +89,7 @@ PrecompileTools.@compile_workload begin end export Bisection, Brent, Broyden, LBroyden, SimpleDFSane, Falsi, SimpleHalley, Klement, - Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld, ITP + Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld, ITP, SimpleGaussNewton export BatchedBroyden, BatchedSimpleNewtonRaphson, BatchedSimpleDFSane end # module diff --git a/src/raphson.jl b/src/raphson.jl index 6af24b7..d5bad8d 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -61,7 +61,9 @@ function SimpleNewtonRaphson(; batched = false, SciMLBase._unwrap_val(diff_type)}() end -function SciMLBase.__solve(prob::NonlinearProblem, +const SimpleGaussNewton = SimpleNewtonRaphson + +function SciMLBase.__solve(prob::Union{NonlinearProblem,NonlinearLeastSquaresProblem}, alg::SimpleNewtonRaphson, args...; abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...) @@ -74,6 +76,10 @@ function SciMLBase.__solve(prob::NonlinearProblem, error("SimpleNewtonRaphson currently only supports out-of-place nonlinear problems") end + if prob isa NonlinearLeastSquaresProblem && !(typeof(prob.u0) <: Union{Number, AbstractVector}) + error("SimpleGaussNewton only supports Number and AbstactVector types. Please convert any problem of AbstractArray into one with u0 as AbstractVector") + end + atol = abstol !== nothing ? abstol : real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5) rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5) @@ -100,7 +106,13 @@ function SciMLBase.__solve(prob::NonlinearProblem, end iszero(fx) && return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success) - Δx = _restructure(fx, dfx \ _vec(fx)) + + if prob isa NonlinearProblem + Δx = _restructure(fx, dfx \ _vec(fx)) + else + Δx = dfx \ fx + end + x -= Δx if isapprox(x, xo, atol = atol, rtol = rtol) return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success) diff --git a/test/least_squares.jl b/test/least_squares.jl new file mode 100644 index 0000000..0b66473 --- /dev/null +++ b/test/least_squares.jl @@ -0,0 +1,19 @@ +using SimpleNonlinearSolve, LinearAlgebra, Test + +true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4]) + +θ_true = [1.0, 0.1, 2.0, 0.5] +x = [-1.0, -0.5, 0.0, 0.5, 1.0] +y_target = true_function(x, θ_true) + +function loss_function(θ, p) + ŷ = true_function(p, θ) + return abs2.(ŷ .- y_target) +end + +θ_init = θ_true .+ 0.1 +prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, x) +sol = solve(prob_oop, SimpleNewtonRaphson()) +sol = solve(prob_oop, SimpleGaussNewton()) + +@test norm(sol.resid) < 1e-12 diff --git a/test/runtests.jl b/test/runtests.jl index 98a01bd..d0fd1ff 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,5 +7,6 @@ const GROUP = get(ENV, "GROUP", "All") @time @safetestset "Basic Tests + Some AD" include("basictests.jl") @time @safetestset "Inplace Tests" include("inplace.jl") @time @safetestset "Matrix Resizing Tests" include("matrix_resizing_tests.jl") + @time @safetestset "Least Squares Tests" include("least_squares.jl") end end