Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 149 additions & 42 deletions src/R2_alg.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,56 @@
export R2

mutable struct R2Solver{R, S <: AbstractVector{R}} <: AbstractOptimizationSolver
xk::S
∇fk::S
mν∇fk::S
xkn::S
s::S
has_bnds::Bool
l_bound::S
u_bound::S
l_bound_m_x::S
u_bound_m_x::S
Fobj_hist::Vector{R}
Hobj_hist::Vector{R}
Complex_hist::Vector{Int}
end

function R2Solver(x0::S, options::ROSolverOptions, l_bound::S, u_bound::S) where {R <: Real, S <: AbstractVector{R}}
maxIter = options.maxIter
xk = similar(x0)
∇fk = similar(x0)
mν∇fk = similar(x0)
xkn = similar(x0)
s = zero(x0)
has_bnds = any(l_bound .!= R(-Inf)) || any(u_bound .!= R(Inf))
if has_bnds
l_bound_m_x = similar(xk)
u_bound_m_x = similar(xk)
else
l_bound_m_x = similar(xk, 0)
u_bound_m_x = similar(xk, 0)
end
Fobj_hist = zeros(R, maxIter)
Hobj_hist = zeros(R, maxIter)
Complex_hist = zeros(Int, maxIter)
return R2Solver(
xk,
∇fk,
mν∇fk,
xkn,
s,
has_bnds,
l_bound,
u_bound,
l_bound_m_x,
u_bound_m_x,
Fobj_hist,
Hobj_hist,
Complex_hist,
)
end

"""
R2(nlp, h, options)
R2(f, ∇f!, h, options, x0)
Expand Down Expand Up @@ -51,9 +102,9 @@ function R2(nlp::AbstractNLPModel, args...; kwargs...)
x -> obj(nlp, x),
(g, x) -> grad!(nlp, x, g),
args...,
x0;
l_bound = nlp.meta.lvar,
u_bound = nlp.meta.uvar,
x0,
nlp.meta.lvar,
nlp.meta.uvar;
kwargs_dict...,
)
ξ = outdict[:ξ]
Expand All @@ -71,6 +122,7 @@ function R2(nlp::AbstractNLPModel, args...; kwargs...)
return stats
end

# method without bounds
function R2(
f::F,
∇f!::G,
Expand All @@ -79,6 +131,64 @@ function R2(
x0::AbstractVector{R};
selected::AbstractVector{<:Integer} = 1:length(x0),
kwargs...,
) where {F <: Function, G <: Function, H, R <: Real}
start_time = time()
elapsed_time = 0.0
solver = R2Solver(x0, options, similar(x0, 0), similar(x0, 0))
k, status, fk, hk, ξ = R2!(solver, f, ∇f!, h, options, x0; selected = selected)
elapsed_time = time() - start_time
outdict = Dict(
:Fhist => solver.Fobj_hist[1:k],
:Hhist => solver.Hobj_hist[1:k],
:Chist => solver.Complex_hist[1:k],
:NonSmooth => h,
:status => status,
:fk => fk,
:hk => hk,
:ξ => ξ,
:elapsed_time => elapsed_time,
)
return solver.xk, k, outdict
end

function R2(
f::F,
∇f!::G,
h::H,
options::ROSolverOptions{R},
x0::AbstractVector{R},
l_bound::AbstractVector{R},
u_bound::AbstractVector{R};
selected::AbstractVector{<:Integer} = 1:length(x0),
kwargs...,
) where {F <: Function, G <: Function, H, R <: Real}
start_time = time()
elapsed_time = 0.0
solver = R2Solver(x0, options, l_bound, u_bound)
k, status, fk, hk, ξ = R2!(solver, f, ∇f!, h, options, x0; selected = selected)
elapsed_time = time() - start_time
outdict = Dict(
:Fhist => solver.Fobj_hist[1:k],
:Hhist => solver.Hobj_hist[1:k],
:Chist => solver.Complex_hist[1:k],
:NonSmooth => h,
:status => status,
:fk => fk,
:hk => hk,
:ξ => ξ,
:elapsed_time => elapsed_time,
)
return solver.xk, k, outdict
end

function R2!(
solver::R2Solver{R},
f::F,
∇f!::G,
h::H,
options::ROSolverOptions{R},
x0::AbstractVector{R};
selected::AbstractVector{<:Integer} = 1:length(x0),
) where {F <: Function, G <: Function, H, R <: Real}
start_time = time()
elapsed_time = 0.0
Expand All @@ -94,17 +204,23 @@ function R2(
ν = options.ν
γ = options.γ

local l_bound, u_bound
has_bnds = false
for (key, val) in kwargs
if key == :l_bound
l_bound = val
has_bnds = has_bnds || any(l_bound .!= R(-Inf))
elseif key == :u_bound
u_bound = val
has_bnds = has_bnds || any(u_bound .!= R(Inf))
end
# retrieve workspace
xk = solver.xk
xk .= x0
∇fk = solver.∇fk
mν∇fk = solver.mν∇fk
xkn = solver.xkn
s = solver.s
has_bnds = solver.has_bnds
if has_bnds
l_bound = solver.l_bound
u_bound = solver.u_bound
l_bound_m_x = solver.l_bound_m_x
u_bound_m_x = solver.u_bound_m_x
end
Fobj_hist = solver.Fobj_hist
Hobj_hist = solver.Hobj_hist
Complex_hist = solver.Complex_hist

if verbose == 0
ptf = Inf
Expand All @@ -117,39 +233,38 @@ function R2(
end

# initialize parameters
xk = copy(x0)
hk = h(xk[selected])
hk = @views h(xk[selected])
if hk == Inf
verbose > 0 && @info "R2: finding initial guess where nonsmooth term is finite"
prox!(xk, h, x0, one(eltype(x0)))
hk = h(xk[selected])
hk = @views h(xk[selected])
hk < Inf || error("prox computation must be erroneous")
verbose > 0 && @debug "R2: found point where h has value" hk
end
hk == -Inf && error("nonsmooth term is not proper")

xkn = similar(xk)
s = zero(xk)
ψ = has_bnds ? shifted(h, xk, l_bound - xk, u_bound - xk, selected) : shifted(h, xk)
if has_bnds
@. l_bound_m_x = l_bound - xk
@. u_bound_m_x = u_bound - xk
ψ = shifted(h, xk, l_bound_m_x, u_bound_m_x, selected)
else
ψ = shifted(h, xk)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are also allocations here, right? We can think about it in a separate PR, but we should have an issue to track them.

Copy link
Member Author

@geoffroyleconte geoffroyleconte Jun 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I forgot this. We leave it as it is for this PR as you propose, and maybe create a function shifted! in ShiftedProximalOperators to allow us to preallocate the ShiftedProximalOperator structure?

end

Fobj_hist = zeros(maxIter)
Hobj_hist = zeros(maxIter)
Complex_hist = zeros(Int, maxIter)
if verbose > 0
#! format: off
@info @sprintf "%6s %8s %8s %7s %8s %7s %7s %7s %1s" "iter" "f(x)" "h(x)" "√ξ" "ρ" "σ" "‖x‖" "‖s‖" ""
#! format: off
end

local ξ
local ξ::R
k = 0
σk = max(1 / ν, σmin)
ν = 1 / σk

fk = f(xk)
∇fk = similar(xk)
∇f!(∇fk, xk)
mν∇fk = -ν * ∇fk
@. mν∇fk = -ν * ∇fk

optimal = false
tired = maxIter > 0 && k ≥ maxIter || elapsed_time > maxTime
Expand All @@ -162,7 +277,7 @@ function R2(

# define model
φk(d) = dot(∇fk, d)
mk(d) = φk(d) + ψ(d)
mk(d)::R = φk(d) + ψ(d)::R
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still some allocations here but I could not fix it, maybe it is more related to ProximalOperators.


prox!(s, ψ, mν∇fk, ν)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some allocations here as well, maybe also related to ProximalOperators.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably ShiftedProximalOperators?!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the allocations gone now?

Copy link
Member Author

@geoffroyleconte geoffroyleconte Jul 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are still a few (16 per iteration on line 280 and 282), I could not get lower on my PRs in ShiftedProximalOperators.
The biggest number of allocations comes from ψ = shifted(h, xk, l_bound_m_x, u_bound_m_x, selected) (as you mentionned below) at the begging of R2, but it is not inside a loop. We should think about a way of preallocating the ShiftedProximalOperators.

Complex_hist[k] += 1
Expand All @@ -181,16 +296,15 @@ function R2(
ξ > 0 || error("R2: prox-gradient step should produce a decrease but ξ = $(ξ)")
xkn .= xk .+ s
fkn = f(xkn)
hkn = h(xkn[selected])
hkn = @views h(xkn[selected])
hkn == -Inf && error("nonsmooth term is not proper")

Δobj = (fk + hk) - (fkn + hkn) + max(1, abs(fk + hk)) * 10 * eps()
ρk = Δobj / ξ

σ_stat = (η2 ≤ ρk < Inf) ? "↘" : (ρk < η1 ? "↗" : "=")

if (verbose > 0) && (k % ptf == 0)
#! format: off
σ_stat = (η2 ≤ ρk < Inf) ? "↘" : (ρk < η1 ? "↗" : "=")
@info @sprintf "%6d %8.1e %8.1e %7.1e %8.1e %7.1e %7.1e %7.1e %1s" k fk hk sqrt(ξ) ρk σk norm(xk) norm(s) σ_stat
#! format: on
end
Expand All @@ -201,7 +315,11 @@ function R2(

if η1 ≤ ρk < Inf
xk .= xkn
has_bnds && set_bounds!(ψ, l_bound - xk, u_bound - xk)
if has_bnds
@. l_bound_m_x = l_bound - xk
@. u_bound_m_x = u_bound - xk
set_bounds!(ψ, l_bound_m_x, u_bound_m_x)
end
fk = fkn
hk = hkn
∇f!(∇fk, xk)
Expand Down Expand Up @@ -239,17 +357,6 @@ function R2(
else
:exception
end
outdict = Dict(
:Fhist => Fobj_hist[1:k],
:Hhist => Hobj_hist[1:k],
:Chist => Complex_hist[1:k],
:NonSmooth => h,
:status => status,
:fk => fk,
:hk => hk,
:ξ => ξ,
:elapsed_time => elapsed_time,
)

return xk, k, outdict
return k, status, fk, hk, ξ
end