From ccdd7b3a49a6d29603afa86ed63963f425938271 Mon Sep 17 00:00:00 2001 From: Maxence Gollier Date: Tue, 30 Sep 2025 14:49:36 -0400 Subject: [PATCH] add computation function for bounds and initialize bounds across solvers --- src/LMTR_alg.jl | 15 +++------------ src/LM_alg.jl | 11 +++++------ src/R2DH.jl | 11 +++++------ src/R2N.jl | 11 +++++------ src/R2_alg.jl | 11 +++++------ src/RegularizedOptimization.jl | 11 ----------- src/TRDH_alg.jl | 23 +++++++++++------------ src/utils.jl | 15 +++++++++++++++ 8 files changed, 49 insertions(+), 59 deletions(-) diff --git a/src/LMTR_alg.jl b/src/LMTR_alg.jl index 693e23a1..f71a9003 100644 --- a/src/LMTR_alg.jl +++ b/src/LMTR_alg.jl @@ -328,10 +328,7 @@ function SolverCore.solve!( ∆_effective = min(β * χ(s), Δk) if has_bnds - @. l_bound_m_x = l_bound - xk - @. u_bound_m_x = u_bound - xk - @. l_bound_m_x .= max.(l_bound_m_x, -∆_effective) - @. u_bound_m_x .= min.(u_bound_m_x, ∆_effective) + update_bounds!(l_bound_m_x, u_bound_m_x, false, l_bound, u_bound, xk, ∆_effective) set_bounds!(ψ, l_bound_m_x, u_bound_m_x) set_bounds!(solver.subsolver.ψ, l_bound_m_x, u_bound_m_x) else @@ -399,10 +396,7 @@ function SolverCore.solve!( if η1 ≤ ρk < Inf xk .= xkn if has_bnds - @. l_bound_m_x = l_bound - xk - @. u_bound_m_x = u_bound - xk - @. l_bound_m_x .= max.(l_bound_m_x, -Δk) - @. u_bound_m_x .= min.(u_bound_m_x, Δk) + update_bounds!(l_bound_m_x, u_bound_m_x, false, l_bound, u_bound, xk, Δk) set_bounds!(ψ, l_bound_m_x, u_bound_m_x) set_bounds!(solver.subsolver.ψ, l_bound_m_x, u_bound_m_x) end @@ -430,10 +424,7 @@ function SolverCore.solve!( if ρk < η1 || ρk == Inf Δk = Δk / 2 if has_bnds - @. l_bound_m_x = l_bound - xk - @. u_bound_m_x = u_bound - xk - @. l_bound_m_x .= max.(l_bound_m_x, -Δk) - @. u_bound_m_x .= min.(u_bound_m_x, Δk) + update_bounds!(l_bound_m_x, u_bound_m_x, false, l_bound, u_bound, xk, ∆k) set_bounds!(ψ, l_bound_m_x, u_bound_m_x) set_bounds!(solver.subsolver.ψ, l_bound_m_x, u_bound_m_x) else diff --git a/src/LM_alg.jl b/src/LM_alg.jl index aefbce48..9c43e8b3 100644 --- a/src/LM_alg.jl +++ b/src/LM_alg.jl @@ -230,10 +230,10 @@ function SolverCore.solve!( m_monotone = length(m_fh_hist) + 1 if has_bnds - l_bound = solver.l_bound - u_bound = solver.u_bound - l_bound_m_x = solver.l_bound_m_x - u_bound_m_x = solver.u_bound_m_x + l_bound, u_bound = solver.l_bound, solver.u_bound + l_bound_m_x, u_bound_m_x = solver.l_bound_m_x, solver.u_bound_m_x + update_bounds!(l_bound_m_x, u_bound_m_x, l_bound, u_bound, xk) + set_bounds!(ψ, l_bound_m_x, u_bound_m_x) end # initialize parameters @@ -387,8 +387,7 @@ function SolverCore.solve!( xk .= xkn if has_bnds - @. l_bound_m_x = l_bound - xk - @. u_bound_m_x = u_bound - xk + update_bounds!(l_bound_m_x, u_bound_m_x, l_bound, u_bound, xk) set_bounds!(ψ, l_bound_m_x, u_bound_m_x) end diff --git a/src/R2DH.jl b/src/R2DH.jl index 8f924ab0..5c2f8f1e 100644 --- a/src/R2DH.jl +++ b/src/R2DH.jl @@ -253,10 +253,10 @@ function SolverCore.solve!( has_bnds = solver.has_bnds if has_bnds - l_bound_m_x = solver.l_bound_m_x - u_bound_m_x = solver.u_bound_m_x - l_bound = solver.l_bound - u_bound = solver.u_bound + l_bound, u_bound = solver.l_bound, solver.u_bound + l_bound_m_x, u_bound_m_x = solver.l_bound_m_x, solver.u_bound_m_x + update_bounds!(l_bound_m_x, u_bound_m_x, l_bound, u_bound, xk) + set_bounds!(ψ, l_bound_m_x, u_bound_m_x) end m_monotone = length(m_fh_hist) + 1 @@ -388,8 +388,7 @@ function SolverCore.solve!( if η1 ≤ ρk < Inf xk .= xkn if has_bnds - @. l_bound_m_x = l_bound - xk - @. u_bound_m_x = u_bound - xk + update_bounds!(l_bound_m_x, u_bound_m_x, l_bound, u_bound, xk) set_bounds!(ψ, l_bound_m_x, u_bound_m_x) end fk = fkn diff --git a/src/R2N.jl b/src/R2N.jl index 26a8eea0..bebe2553 100644 --- a/src/R2N.jl +++ b/src/R2N.jl @@ -243,10 +243,10 @@ function SolverCore.solve!( has_bnds = solver.has_bnds if has_bnds - l_bound_m_x = solver.l_bound_m_x - u_bound_m_x = solver.u_bound_m_x - l_bound = solver.l_bound - u_bound = solver.u_bound + l_bound, u_bound = solver.l_bound, solver.u_bound + l_bound_m_x, u_bound_m_x = solver.l_bound_m_x, solver.u_bound_m_x + update_bounds!(l_bound_m_x, u_bound_m_x, l_bound, u_bound, xk) + set_bounds!(ψ, l_bound_m_x, u_bound_m_x) end m_monotone = length(m_fh_hist) + 1 @@ -430,8 +430,7 @@ function SolverCore.solve!( if η1 ≤ ρk < Inf xk .= xkn if has_bnds - @. l_bound_m_x = l_bound - xk - @. u_bound_m_x = u_bound - xk + update_bounds!(l_bound_m_x, u_bound_m_x, l_bound, u_bound, xk) set_bounds!(ψ, l_bound_m_x, u_bound_m_x) end #update functions diff --git a/src/R2_alg.jl b/src/R2_alg.jl index b215101b..c952c768 100644 --- a/src/R2_alg.jl +++ b/src/R2_alg.jl @@ -343,10 +343,10 @@ function SolverCore.solve!( s = solver.s has_bnds = solver.has_bnds if has_bnds - l_bound = solver.l_bound - u_bound = solver.u_bound - l_bound_m_x = solver.l_bound_m_x - u_bound_m_x = solver.u_bound_m_x + l_bound, u_bound = solver.l_bound, solver.u_bound + l_bound_m_x, u_bound_m_x = solver.l_bound_m_x, solver.u_bound_m_x + update_bounds!(l_bound_m_x, u_bound_m_x, l_bound, u_bound, xk) + set_bounds!(ψ, l_bound_m_x, u_bound_m_x) end # initialize parameters @@ -462,8 +462,7 @@ function SolverCore.solve!( if η1 ≤ ρk < Inf xk .= xkn if has_bnds - @. l_bound_m_x = l_bound - xk - @. u_bound_m_x = u_bound - xk + update_bounds!(l_bound_m_x, u_bound_m_x, l_bound, u_bound, xk) set_bounds!(ψ, l_bound_m_x, u_bound_m_x) end fk = fkn diff --git a/src/RegularizedOptimization.jl b/src/RegularizedOptimization.jl index 57f9d13e..acc9406d 100644 --- a/src/RegularizedOptimization.jl +++ b/src/RegularizedOptimization.jl @@ -34,17 +34,6 @@ Notably, you can access, and modify, the following: - `stats.elapsed_time`: elapsed time in seconds. " -# update l_bound_k and u_bound_k -function update_bounds!(l_bound_k, u_bound_k, is_subsolver, l_bound, u_bound, xk, Δ) - if is_subsolver - @. l_bound_k = max(xk - Δ, l_bound) - @. u_bound_k = min(xk + Δ, u_bound) - else - @. l_bound_k = max(-Δ, l_bound - xk) - @. u_bound_k = min(Δ, u_bound - xk) - end -end - include("utils.jl") include("input_struct.jl") include("TR_alg.jl") diff --git a/src/TRDH_alg.jl b/src/TRDH_alg.jl index f91319f5..e53ba5da 100644 --- a/src/TRDH_alg.jl +++ b/src/TRDH_alg.jl @@ -265,11 +265,17 @@ function SolverCore.solve!( χ = solver.χ has_bnds = solver.has_bnds + is_subsolver = h isa ShiftedProximableFunction # case TRDH is used as a subsolver + if has_bnds - l_bound_m_x = solver.l_bound_m_x - u_bound_m_x = solver.u_bound_m_x - l_bound = solver.l_bound - u_bound = solver.u_bound + l_bound_m_x, u_bound_m_x = solver.l_bound_m_x, solver.u_bound_m_x + l_bound, u_bound = solver.l_bound, solver.u_bound + if is_subsolver + l_bound .= ψ.l + u_bound .= ψ.u + end + update_bounds!(l_bound_m_x, u_bound_m_x, is_subsolver, l_bound, u_bound, xk, Δk) + set_bounds!(ψ, l_bound_m_x, u_bound_m_x) end # initialize parameters @@ -286,13 +292,6 @@ function SolverCore.solve!( improper == true && @warn "TRDH: Improper term detected" improper == true && return stats - is_subsolver = h isa ShiftedProximableFunction # case TRDH is used as a subsolver - - if is_subsolver - l_bound .= ψ.l - u_bound .= ψ.u - end - if verbose > 0 @info log_header( [:iter, :fx, :hx, :xi, :ρ, :Δ, :normx, :norms, :normD, :arrow], @@ -439,7 +438,7 @@ function SolverCore.solve!( xk .= xkn if has_bnds update_bounds!(l_bound_m_x, u_bound_m_x, is_subsolver, l_bound, u_bound, xk, Δk) - has_bnds && set_bounds!(ψ, l_bound_m_x, u_bound_m_x) + set_bounds!(ψ, l_bound_m_x, u_bound_m_x) end fk = fkn hk = hkn diff --git a/src/utils.jl b/src/utils.jl index 912b186d..b0586372 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -155,3 +155,18 @@ function get_status( :unknown end end + +function update_bounds!(l_bound_m_x::V, u_bound_m_x::V, l_bound::V, u_bound::V, xk::V,) where {V <: AbstractVector} + @. l_bound_m_x = l_bound - xk + @. u_bound_m_x = u_bound - xk +end + +function update_bounds!(l_bound_m_x::V, u_bound_m_x::V, is_subsolver::Bool, l_bound::V, u_bound::V, xk::V, Δ::T) where {T <: Real, V <: AbstractVector{T}} + if is_subsolver + @. l_bound_m_x = max(xk - Δ, l_bound) + @. u_bound_m_x = min(xk + Δ, u_bound) + else + @. l_bound_m_x = max(-Δ, l_bound - xk) + @. u_bound_m_x = min(Δ, u_bound - xk) + end +end \ No newline at end of file