Skip to content

Commit

Permalink
Upgrade to ProximalAlgorithms v0.5 and AbstractOperators v0.3 (#40)
Browse files Browse the repository at this point in the history
* drop support for ForwardBackward optimizer;
* add support for PANOCplus;
* stabilize unstable unit test;
* bump upper version limit on DSP.
  • Loading branch information
dhanak committed Jun 30, 2023
1 parent d3a2f64 commit f0a25b0
Show file tree
Hide file tree
Showing 13 changed files with 136 additions and 139 deletions.
10 changes: 5 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "StructuredOptimization"
uuid = "46cd3e9d-64ff-517d-a929-236bc1a1fc9d"
version = "0.3.0"
version = "0.4.0-ci+20230622"

[deps]
AbstractOperators = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c"
Expand All @@ -12,11 +12,11 @@ ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"

[compat]
AbstractOperators = "0.1 - 0.2"
DSP = "0.5.1 - 0.6"
AbstractOperators = "0.3"
DSP = "0.5.1 - 0.7"
FFTW = "1"
ProximalAlgorithms = "0.3 - 0.4"
ProximalOperators = "0.8 - 0.14"
ProximalAlgorithms = "0.5"
ProximalOperators = "0.15"
RecursiveArrayTools = "1 - 2"
julia = "1.4"

Expand Down
2 changes: 1 addition & 1 deletion docs/src/solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ You can pick the algorithm to use as `Solver` object from the
package. Currently, the following algorithms are supported.

```@docs
ForwardBackward
ZeroFPR
PANOC
PANOCplus
```


Expand Down
4 changes: 2 additions & 2 deletions src/StructuredOptimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ using AbstractOperators
using ProximalOperators
using ProximalAlgorithms

import ProximalAlgorithms:ForwardBackward, ZeroFPR, PANOC
export ForwardBackward, ZeroFPR, PANOC
import ProximalAlgorithms: ZeroFPR, PANOC, PANOCplus
export ZeroFPR, PANOC, PANOCplus

include("syntax/syntax.jl")
include("calculus/precomposeNonlinear.jl") # TODO move to ProximalOperators?
Expand Down
8 changes: 4 additions & 4 deletions src/arraypartition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import ProximalOperators
import RecursiveArrayTools

@inline function ProximalOperators.prox(
h::ProximalOperators.ProximableFunction,
h,
x::RecursiveArrayTools.ArrayPartition,
gamma...
)
Expand All @@ -13,7 +13,7 @@ import RecursiveArrayTools
end

@inline function ProximalOperators.gradient(
h::ProximalOperators.ProximableFunction,
h,
x::RecursiveArrayTools.ArrayPartition
)
# unwrap
Expand All @@ -24,13 +24,13 @@ end

@inline ProximalOperators.prox!(
y::RecursiveArrayTools.ArrayPartition,
h::ProximalOperators.ProximableFunction,
h,
x::RecursiveArrayTools.ArrayPartition,
gamma...
) = ProximalOperators.prox!(y.x, h, x.x, gamma...)

@inline ProximalOperators.gradient!(
y::RecursiveArrayTools.ArrayPartition,
h::ProximalOperators.ProximableFunction,
h,
x::RecursiveArrayTools.ArrayPartition
) = ProximalOperators.gradient!(y.x, h, x.x)
10 changes: 5 additions & 5 deletions src/calculus/precomposeNonlinear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@ import ProximalOperators: gradient!, gradient # this can be removed when moved t

export PrecomposeNonlinear

struct PrecomposeNonlinear{P <: ProximableFunction,
struct PrecomposeNonlinear{P,
T <: AbstractOperator,
D <: AbstractArray,
D <: AbstractArray,
C <: AbstractArray
} <: ProximableFunction
g::P
}
g::P
G::T
bufD::D
bufC::C
bufC2::C
end

function PrecomposeNonlinear(g::P, G::T) where {P, T}
function PrecomposeNonlinear(g::P, G::T) where {P, T}
t, s = domainType(G), size(G,2)
bufD = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s))
t, s = codomainType(G), size(G,1)
Expand Down
6 changes: 3 additions & 3 deletions src/solvers/build_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ julia> A, b = randn(10,4), randn(10);
julia> p = problem( ls(A*x - b ) , norm(x) <= 1 );
julia> StructuredOptimization.parse_problem(p, ForwardBackward());
julia> StructuredOptimization.parse_problem(p, PANOCplus());
```
"""
function parse_problem(terms::Tuple, solver::T) where T <: ForwardBackwardSolver
Expand Down Expand Up @@ -65,14 +65,14 @@ julia> A, b = randn(10,4), randn(10);
julia> p = problem(ls(A*x - b ), norm(x) <= 1);
julia> solve(p, ForwardBackward());
julia> solve(p, PANOCplus());
julia> ~x
```
"""
function solve(terms::Tuple, solver::ForwardBackwardSolver)
x, kwargs = parse_problem(terms, solver)
x_star, it = solver(~x; kwargs...)
x_star, it = solver(; x0 = ~x, kwargs...)
~x .= x_star
return x, it
end
32 changes: 16 additions & 16 deletions src/solvers/minimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ julia> @minimize ls(A*x-b) st x >= 0.;
julia> ~x # access array with solution
julia> @minimize ls(A*x-b) st norm(x) == 2.0 with ForwardBackward(fast=true);
julia> @minimize ls(A*x-b) st norm(x) == 2.0 with PANOCplus();
julia> ~x # access array with solution
```
Expand All @@ -29,28 +29,28 @@ Returns as output a tuple containing the optimization variables and the number
of iterations spent by the solver algorithm.
"""
macro minimize(cf::Union{Expr, Symbol})
cost = esc(cf)
cost = esc(cf)
return :(solve(problem($(cost)), default_solver()))
end

macro minimize(cf::Union{Expr, Symbol}, s::Symbol, cstr::Union{Expr, Symbol})
cost = esc(cf)
if s == :(st)
constraints = esc(cstr)
return :(solve(problem($(cost), $(constraints)), default_solver()))
elseif s == :(with)
solver = esc(cstr)
cost = esc(cf)
if s == :(st)
constraints = esc(cstr)
return :(solve(problem($(cost), $(constraints)), default_solver()))
elseif s == :(with)
solver = esc(cstr)
return :(solve(problem($(cost)), $(solver)))
else
error("wrong symbol after cost function! use `st` or `with`")
end
else
error("wrong symbol after cost function! use `st` or `with`")
end
end

macro minimize(cf::Union{Expr, Symbol}, s::Symbol, cstr::Union{Expr, Symbol}, w::Symbol, slv::Union{Expr, Symbol})
cost = esc(cf)
s != :(st) && error("wrong symbol after cost function! use `st`")
constraints = esc(cstr)
w != :(with) && error("wrong symbol after constraints! use `with`")
solver = esc(slv)
cost = esc(cf)
s != :(st) && error("wrong symbol after cost function! use `st`")
constraints = esc(cstr)
w != :(with) && error("wrong symbol after constraints! use `with`")
solver = esc(slv)
return :(solve(problem($(cost), $(constraints)), $(solver)))
end
6 changes: 1 addition & 5 deletions src/solvers/solvers_options.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
using ProximalAlgorithms

const ForwardBackwardSolver = Union{
ProximalAlgorithms.ForwardBackward,
ProximalAlgorithms.ZeroFPR,
ProximalAlgorithms.PANOC,
}
const ForwardBackwardSolver = ProximalAlgorithms.IterativeAlgorithm

const default_solver = ProximalAlgorithms.PANOC

0 comments on commit f0a25b0

Please sign in to comment.