From 52f2191a3671e29a82c7a7cdb4c24853eeaf375b Mon Sep 17 00:00:00 2001 From: termi-official Date: Thu, 13 Nov 2025 19:03:39 +0100 Subject: [PATCH 1/2] Add API to update iterative solver tolerances --- ext/LinearSolveIterativeSolversExt.jl | 27 ++++++++++++++++++++++++++- ext/LinearSolveKrylovKitExt.jl | 2 ++ src/common.jl | 19 +++++++++++++++++++ src/iterative_wrappers.jl | 2 ++ test/basictests.jl | 23 +++++++++++++++++++++++ 5 files changed, 72 insertions(+), 1 deletion(-) diff --git a/ext/LinearSolveIterativeSolversExt.jl b/ext/LinearSolveIterativeSolversExt.jl index e92d04ba2..24efae5cd 100644 --- a/ext/LinearSolveIterativeSolversExt.jl +++ b/ext/LinearSolveIterativeSolversExt.jl @@ -105,7 +105,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::IterativeSolversJL; kwargs... cache.Pr = Pr cache.precsisfresh = false end - if cache.isfresh || !(alg isa IterativeSolvers.GMRESIterable) + if cache.isfresh || !(cache.cacheval isa IterativeSolvers.GMRESIterable) solver = LinearSolve.init_cacheval(alg, cache.A, cache.b, cache.u, cache.Pl, cache.Pr, cache.maxiters, cache.abstol, cache.reltol, @@ -149,4 +149,29 @@ function purge_history!(iter::IterativeSolvers.GMRESIterable, x, b) nothing end +# The constructors above all set the tolerance as follows. +# tol = max(reltol * ||residual||, abstol) +# +# The iterable in turn is stored in `cache.cacheval`. +function update_tolerances_iterativesolversjl!(iter, atol, rtol) + Rnorm = norm(iter.r) + iter.tol = max(rtol * Rnorm, atol) +end +function update_tolerances_iterativesolversjl!(iter::IterativeSolvers.GMRESIterable, atol, rtol) + Rnorm = iter.residual.current + iter.tol = max(rtol * Rnorm, atol) +end +function update_tolerances_iterativesolversjl!(iter::IterativeSolvers.MINRESIterable, atol, rtol) + Rnorm = norm(iter.v_curr) + iter.tol = max(rtol * Rnorm, atol) +end +function update_tolerances_iterativesolversjl!(iter::IterativeSolvers.IDRSIterable, atol, rtol) + Rnorm = iter.normR + iter.tol = max(rtol * Rnorm, atol) +end + +function LinearSolve.update_tolerances_internal!(cache, alg::IterativeSolversJL, atol, rtol) + update_tolerances_iterativesolversjl!(cache.cacheval, atol, rtol) +end + end diff --git a/ext/LinearSolveKrylovKitExt.jl b/ext/LinearSolveKrylovKitExt.jl index b3a7a6f5f..61e42d5ee 100644 --- a/ext/LinearSolveKrylovKitExt.jl +++ b/ext/LinearSolveKrylovKitExt.jl @@ -48,4 +48,6 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovKitJL; kwargs...) iters = iters) end +LinearSolve.update_tolerances_internal!(cache, alg::KrylovKitJL, atol, rtol) = nothing + end diff --git a/src/common.jl b/src/common.jl index fe9471899..4141616f6 100644 --- a/src/common.jl +++ b/src/common.jl @@ -478,3 +478,22 @@ function SciMLBase.solve(prob::StaticLinearProblem, return SciMLBase.build_linear_solution( alg, u, nothing, prob; retcode = ReturnCode.Success) end + +function update_tolerances!(cache; abstol = nothing, reltol = nothing) + if abstol !== nothing + cache.abstol = abstol + end + if reltol !== nothing + cache.reltol = reltol + end + update_tolerances_internal!(cache, cache.alg, abstol, reltol) +end + + +function update_tolerances_internal!(cache, alg::AbstractFactorization, abstol, reltol) + error("Cannot update tolerances for factorization.") +end + +function update_tolerances_internal!(cache, alg::AbstractKrylovSubspaceMethod, abstol, reltol) + @warn "Tolerance update for Krylov subspace method '$typeof(alg)' not implemented." maxlog = 1 +end diff --git a/src/iterative_wrappers.jl b/src/iterative_wrappers.jl index af2193741..7f30e0e39 100644 --- a/src/iterative_wrappers.jl +++ b/src/iterative_wrappers.jl @@ -338,3 +338,5 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...) return SciMLBase.build_linear_solution(alg, cache.u, Ref(resid), cache; iters = stats.niter, retcode, stats) end + +update_tolerances_internal!(cache, alg::KrylovJL, atol, rtol) = nothing diff --git a/test/basictests.jl b/test/basictests.jl index eb41e6ae8..c01dd7406 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -33,10 +33,17 @@ A4 = A2 .|> ComplexF32 b4 = b2 .|> ComplexF32 x4 = x2 .|> ComplexF32 +A5_ = A - 0.01Tridiagonal(ones(n,n)) + sparse([1], [8], 0.5, n,n) +A5 = sparse(transpose(A5_) * A5_) +x5 = zeros(n) +u5 = ones(n) +b5 = A5*u5 + prob1 = LinearProblem(A1, b1; u0 = x1) prob2 = LinearProblem(A2, b2; u0 = x2) prob3 = LinearProblem(A3, b3; u0 = x3) prob4 = LinearProblem(A4, b4; u0 = x4) +prob5 = LinearProblem(A5, b5) cache_kwargs = (;abstol = 1e-8, reltol = 1e-8, maxiter = 30) @@ -69,6 +76,19 @@ function test_interface(alg, prob1, prob2) return end +function test_tolerance_update(alg, prob, u) + cache = init(prob, alg; verbose=LinearVerbosity(; error_control=SciMLLogging.WarnLevel(), numerical=SciMLLogging.WarnLevel())) + LinearSolve.update_tolerances!(cache; reltol = 1e-2, abstol=1e-8) + u1 = copy(solve!(cache).u) + + LinearSolve.update_tolerances!(cache; reltol = 1e-8, abstol=1e-8) + u2 = solve!(cache).u + + @test norm(u2 - u) < norm(u1 - u) + + return +end + @testset "LinearSolve" begin @testset "Default Linear Solver" begin test_interface(nothing, prob1, prob2) @@ -379,6 +399,7 @@ end @testset "$name" begin test_interface(algorithm, prob1, prob2) test_interface(algorithm, prob3, prob4) + test_tolerance_update(algorithm, prob5, u5) end end end @@ -418,6 +439,7 @@ end @testset "$(alg[1])" begin test_interface(alg[2], prob1, prob2) test_interface(alg[2], prob3, prob4) + test_tolerance_update(alg[2], prob5, u5) end end end @@ -432,6 +454,7 @@ end @testset "$(alg[1])" begin test_interface(alg[2], prob1, prob2) test_interface(alg[2], prob3, prob4) + test_tolerance_update(alg[2], prob5, u5) end @test alg[2] isa KrylovKitJL end From 5e825b676cf59eac838ec02d42a920a6d6f85263 Mon Sep 17 00:00:00 2001 From: termi-official <9196588+termi-official@users.noreply.github.com> Date: Sun, 16 Nov 2025 10:41:40 +0100 Subject: [PATCH 2/2] Oopsie --- test/basictests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/basictests.jl b/test/basictests.jl index c01dd7406..953e6a56e 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -77,7 +77,7 @@ function test_interface(alg, prob1, prob2) end function test_tolerance_update(alg, prob, u) - cache = init(prob, alg; verbose=LinearVerbosity(; error_control=SciMLLogging.WarnLevel(), numerical=SciMLLogging.WarnLevel())) + cache = init(prob, alg) LinearSolve.update_tolerances!(cache; reltol = 1e-2, abstol=1e-8) u1 = copy(solve!(cache).u)