Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidat

@recompile_invalidations begin
using ADTypes, ArrayInterface, ConcreteStructs, DiffEqBase, FastClosures, FiniteDiff,
ForwardDiff, Reexport, LinearAlgebra, SciMLBase
ForwardDiff, Reexport, LinearAlgebra, SciMLBase

import DiffEqBase: AbstractNonlinearTerminationMode,
AbstractSafeNonlinearTerminationMode, AbstractSafeBestNonlinearTerminationMode,
NonlinearSafeTerminationReturnCode, get_termination_mode,
NONLINEARSOLVE_DEFAULT_NORM
AbstractSafeNonlinearTerminationMode,
AbstractSafeBestNonlinearTerminationMode,
NonlinearSafeTerminationReturnCode, get_termination_mode,
NONLINEARSOLVE_DEFAULT_NORM
import ForwardDiff: Dual
import MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
import SciMLBase: AbstractNonlinearAlgorithm, build_solution, isinplace, _unwrap_val
Expand Down Expand Up @@ -56,14 +57,16 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Nothing, args...;
end

# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
function SciMLBase.solve(
prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
sensealg = prob.kwargs[:sensealg]
end
new_u0 = u0 !== nothing ? u0 : prob.u0
new_p = p !== nothing ? p : prob.p
return __internal_solve_up(prob, sensealg, new_u0, u0 === nothing, new_p, p === nothing,
return __internal_solve_up(
prob, sensealg, new_u0, u0 === nothing, new_p, p === nothing,
alg, args...; kwargs...)
end

Expand Down Expand Up @@ -111,7 +114,7 @@ end
end

export SimpleBroyden, SimpleDFSane, SimpleGaussNewton, SimpleHalley, SimpleKlement,
SimpleLimitedMemoryBroyden, SimpleNewtonRaphson, SimpleTrustRegion
SimpleLimitedMemoryBroyden, SimpleNewtonRaphson, SimpleTrustRegion
export Alefeld, Bisection, Brent, Falsi, ITP, Ridder

end # module
9 changes: 6 additions & 3 deletions src/ad.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, <:AbstractArray},
function SciMLBase.solve(
prob::NonlinearProblem{<:Union{Number, <:AbstractArray},
iip, <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) where {T, V, P, iip}
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats,
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats,
sol.original)
end

for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
@eval begin
function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
function SciMLBase.solve(
prob::IntervalNonlinearProblem{uType, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
alg::$(algType), args...; kwargs...) where {uType, T, V, P, iip}
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
Expand Down
9 changes: 6 additions & 3 deletions src/bracketing/alefeld.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Alefeld, args...;
end
ē, fc = d, f(c)
(a == c || b == c) &&
return build_solution(prob, alg, c, fc; retcode = ReturnCode.FloatingPointLimit,
return build_solution(
prob, alg, c, fc; retcode = ReturnCode.FloatingPointLimit,
left = a, right = b)
iszero(fc) &&
return build_solution(prob, alg, c, fc; retcode = ReturnCode.Success,
Expand All @@ -57,7 +58,8 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Alefeld, args...;
end
fc = f(c)
(ā == c || b̄ == c) &&
return build_solution(prob, alg, c, fc; retcode = ReturnCode.FloatingPointLimit,
return build_solution(
prob, alg, c, fc; retcode = ReturnCode.FloatingPointLimit,
left = ā, right = b̄)
iszero(fc) &&
return build_solution(prob, alg, c, fc; retcode = ReturnCode.Success,
Expand All @@ -76,7 +78,8 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Alefeld, args...;
end
fc = f(c)
(ā == c || b̄ == c) &&
return build_solution(prob, alg, c, fc; retcode = ReturnCode.FloatingPointLimit,
return build_solution(
prob, alg, c, fc; retcode = ReturnCode.FloatingPointLimit,
left = ā, right = b̄)
iszero(fc) &&
return build_solution(prob, alg, c, fc; retcode = ReturnCode.Success,
Expand Down
3 changes: 2 additions & 1 deletion src/bracketing/bisection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...
end

if iszero(fr)
return build_solution(prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight,
return build_solution(
prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight,
left, right)
end

Expand Down
3 changes: 2 additions & 1 deletion src/bracketing/brent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;
end

if iszero(fr)
return build_solution(prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight,
return build_solution(
prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight,
left, right)
end

Expand Down
3 changes: 2 additions & 1 deletion src/bracketing/falsi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...;
end

if iszero(fr)
return build_solution(prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight,
return build_solution(
prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight,
left, right)
end

Expand Down
3 changes: 2 additions & 1 deletion src/bracketing/itp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::ITP, args...;
end

if iszero(fr)
return build_solution(prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight,
return build_solution(
prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight,
left, right)
end
ϵ = abstol
Expand Down
3 changes: 2 additions & 1 deletion src/bracketing/ridder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...;
end

if iszero(fr)
return build_solution(prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight,
return build_solution(
prob, alg, right, fr; retcode = ReturnCode.ExactSolutionRight,
left, right)
end

Expand Down
3 changes: 2 additions & 1 deletion src/nlsolve/lbroyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ function __static_solve(prob::NonlinearProblem{<:SArray}, alg::SimpleLimitedMemo
init_α = inv(alg.alpha)
end

converged, res = __unrolled_lbroyden_initial_iterations(prob, xo, fo, δx, abstol, U, Vᵀ,
converged, res = __unrolled_lbroyden_initial_iterations(
prob, xo, fo, δx, abstol, U, Vᵀ,
threshold, ls_cache, init_α)

converged &&
Expand Down
9 changes: 5 additions & 4 deletions test/core/rootfind_tests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@testsetup module RootfindingTesting
using Reexport
@reexport using AllocCheck,
LinearSolve, StaticArrays, Random, LinearAlgebra, ForwardDiff, DiffEqBase
LinearSolve, StaticArrays, Random, LinearAlgebra, ForwardDiff, DiffEqBase
import PolyesterForwardDiff

quadratic_f(u, p) = u .* u .- p
Expand All @@ -22,7 +22,7 @@ end
const TERMINATION_CONDITIONS = [
NormTerminationMode(), RelTerminationMode(), RelNormTerminationMode(),
AbsTerminationMode(), AbsNormTerminationMode(), RelSafeTerminationMode(),
AbsSafeTerminationMode(), RelSafeBestTerminationMode(), AbsSafeBestTerminationMode(),
AbsSafeTerminationMode(), RelSafeBestTerminationMode(), AbsSafeBestTerminationMode()
]

function benchmark_nlsolve_oop(f::F, u0, p = 2.0; solver) where {F}
Expand All @@ -35,15 +35,16 @@ function benchmark_nlsolve_iip(f!::F, u0, p = 2.0; solver) where {F}
end

export quadratic_f, quadratic_f!, quadratic_f2, newton_fails, TERMINATION_CONDITIONS,
benchmark_nlsolve_oop, benchmark_nlsolve_iip
benchmark_nlsolve_oop, benchmark_nlsolve_iip

end

@testitem "First Order Methods" setup=[RootfindingTesting] begin
@testset "$(alg)" for alg in (SimpleNewtonRaphson, SimpleTrustRegion,
(args...; kwargs...) -> SimpleTrustRegion(args...; nlsolve_update_rule = Val(true),
kwargs...))
@testset "AutoDiff: $(nameof(typeof(autodiff))))" for autodiff in (AutoFiniteDiff(),
@testset "AutoDiff: $(nameof(typeof(autodiff))))" for autodiff in (
AutoFiniteDiff(),
AutoForwardDiff(), AutoPolyesterForwardDiff())
@testset "[OOP] u0: $(typeof(u0))" for u0 in ([1.0, 1.0],
@SVector[1.0, 1.0], 1.0)
Expand Down