-
Notifications
You must be signed in to change notification settings - Fork 28
/
pogm_restart.jl
266 lines (221 loc) · 8.52 KB
/
pogm_restart.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
#=
pogm_restart.jl
2017-03-31, Donghwan Kim and Jeff Fessler, University of Michigan
=#
export pogm_restart
using LinearAlgebra: norm
function gr_restart(Fgrad, ynew_yold, restart_cutoff)
return sum(Float64, real(-Fgrad .* ynew_yold)) <=
restart_cutoff * norm(Fgrad) * norm(ynew_yold)
end
"""
x, out = pogm_restart(x0, Fcost, f_grad, f_L ;
f_mu=0, mom=:pogm, restart=:gr, restart_cutoff=0.,
bsig=1, niter=10, g_prox=(z,c)->z, fun=...)
Iterative proximal algorithms (PGM=ISTA, FPGM=FISTA, POGM) with restart.
# in
- `x0` initial guess
- `Fcost` function for computing the cost function value ``F(x)``
- (needed only if `restart === :fr`)
- `f_grad` function for computing the gradient of ``f(x)``
- `f_L` Lipschitz constant of the gradient of ``f(x)``
# option
- `f_mu` strong convexity parameter of ``f(x)``; default 0.
- if `f_mu > 0`, ``(\\alpha, \\beta_k, \\gamma_k)`` is chosen by Table 1 in [KF18]
- `g_prox` function `g_prox(z,c)` for the proximal operator for ``g(x)``
- `g_prox(z,c)` computes ``argmin_x 1/2 \\|z-x\\|^2 + c \\, g(x)``
- `mom` momentum option
- `:pogm` POGM (fastest); default!
- `:fpgm` (FISTA), ``\\gamma_k = 0``
- `:pgm` PGM (ISTA), ``\\beta_k = \\gamma_k = 0``
- `restart` restart option
- `:gr` gradient restart; default!
- `:fr` function restart
- `:none` no restart
- `restart_cutoff` for `:gr` restart if cos(angle) < this; default 0.
- `bsig` gradient "gamma" decrease option (value within [0 1]); default 1
- see ``\\bar{\\sigma}`` in [KF18]
- `niter` number of iterations; default 10
- `fun` function`(iter, xk, yk, is_restart)`
user-defined function evaluated each `iter`
with secondary `xk`, primary `yk`,
and boolean `is_restart` indicating whether this iteration was a restart
# out
- `x` final iterate
- for PGM (ISTA): ``x_N = y_N``
- for FPGM (FISTA): primary iterate ``y_N``
- for POGM: secondary iterate ``x_N``, see [KF18]
- `out [fun(0, x0, x0, false), fun(1, x1, y1, is_restart), ...]` array of length `[niter+1]`
Optimization Problem: Nonsmooth Composite Convex Minimization
* ``argmin_x F(x), F(x) := f(x) + g(x))``
- ``f(x)`` smooth convex function
- ``g(x)`` convex function, possibly nonsmooth and "proximal-friendly" [CP11]
# Optimization Algorithms:
Accelerated First-order Algorithms when ``g(x) = 0`` [KF18]
iterate as below for given coefficients ``(\\alpha, \\beta_k, \\gamma_k)``
* For k = 0,1,...
- ``y_{k+1} = x_k - \\alpha f'(x_k)`` : gradient update
- ``x_{k+1} = y_{k+1} + \\beta_k (y_{k+1} - y_k) + \\gamma_k (y_{k+1} - x_k)`` : momentum update
Proximal versions of the above for ``g(x) \\neq 0`` are in the below references,
and use the proximal operator
``prox_g(z) = argmin_x {1/2\\|z-x\\|^2 + g(x)}``.
- Proximal Gradient method (PGM or ISTA) - ``\\beta_k = \\gamma_k = 0``. [BT09]
- Fast Proximal Gradient Method (FPGM or FISTA) - ``\\gamma_k = 0``. [BT09]
- Proximal Optimized Gradient Method (POGM) - [THG15]
- FPGM(FISTA) with Restart - [OC15]
- POGM with Restart - [KF18]
# references
- [CP11] P. L. Combettes, J. C. Pesquet,
"Proximal splitting methods in signal processing,"
Fixed-Point Algorithms for Inverse Problems in Science and Engineering,
Springer, Optimization and Its Applications, 2011.
- [KF18] D. Kim, J.A. Fessler,
"Adaptive restart of the optimized gradient method for convex optimization," 2018
Arxiv:1703.04641,
[http://doi.org/10.1007/s10957-018-1287-4]
- [BT09] A. Beck, M. Teboulle:
"A fast iterative shrinkage-thresholding algorithm for linear inverse problems,"
SIAM J. Imaging Sci., 2009.
- [THG15] A.B. Taylor, J.M. Hendrickx, F. Glineur,
"Exact worst-case performance of first-order algorithms
for composite convex optimization," Arxiv:1512.07516, 2015,
SIAM J. Opt. 2017
[http://doi.org/10.1137/16m108104x]
Copyright 2017-3-31, Donghwan Kim and Jeff Fessler, University of Michigan
2018-08-13 Julia 0.7.0
2019-02-24 interface redesign
"""
function pogm_restart(
x0,
Fcost::Function,
f_grad::Function,
f_L::Real ;
f_mu::Real = 0.,
mom::Symbol = :pogm, # :ogm :gm
restart::Symbol = :gr, # :fr :none
restart_cutoff::Real = 0.,
bsig::Real = 1,
niter::Int = 10,
g_prox::Function = (z, c::Real) -> z,
fun::Function = (iter::Int, xk, yk, is_restart::Bool) -> undef,
)
!in(mom, (:pgm, :fpgm, :pogm)) && throw(ArgumentError("mom $mom"))
!in(restart, (:none, :gr, :fr)) && throw(ArgumentError("restart $restart"))
f_L < 0 && throw(ArgumentError("f_L=$f_L < 0"))
f_mu < 0 && throw(ArgumentError("f_mu=$f_mu < 0"))
bsig < 0 && throw(ArgumentError("bsig=$bsig < 0"))
!((-1 < restart_cutoff) && (restart_cutoff < 1)) &&
throw(ArgumentError("restart_cutoff=$restart_cutoff"))
L = f_L
mu = f_mu
q = mu/L
# initialize parameters
told = 1
sig = 1
zetaold = 1 # dummy
# initialize x
xold = x0
yold = x0
uold = x0
zold = x0
Fcostold = Fcost(x0)
Fgradold = zeros(size(x0)) # dummy
# save initial
out = Array{Any}(undef, niter+1)
out[1] = fun(0, x0, x0, false)
xnew = []
ynew = []
# iterations
for iter in 1:niter
# proximal gradient (PGM) update
if mom === :pgm && mu != 0
alpha = 2. / (L+mu)
else
alpha = 1. / L
end
fgrad = f_grad(xold)
is_restart = false
if mom === :pgm || mom === :fpgm
ynew = g_prox(xold - alpha * fgrad, alpha) # standard PG update
Fgrad = -(1. / alpha) * (ynew - xold) # standard composite gradient mapping
Fcostnew = Fcost(ynew)
# restart condition
if restart != :none
# function/gradient restart
if ((restart === :fr && Fcostnew > Fcostold)
|| (restart === :gr && gr_restart(Fgrad, ynew-yold, restart_cutoff)))
told = 1
is_restart = true
end
Fcostold = Fcostnew
end
elseif mom === :pogm # POGM
# gradient update for POGM [see KF18]
unew = xold - alpha * fgrad
# restart + "gamma" decrease conditions checked later for POGM,
# unlike PGM, FPGM above
# else
# throw("bad mom $mom")
end
# momentum coefficient "beta"
if mom === :fpgm && mu != 0 # known μ > 0
beta = (1 - sqrt(q)) / (1 + sqrt(q))
elseif mom === :pogm && mu != 0
beta = (2 + q - sqrt(q^2+8*q))^2 / 4. / (1-q)
# for "mu" = 0 or for unknown "mu"
elseif mom != :pgm
if mom === :pogm && iter == niter # && iszero(restart)
tnew = 0.5 * (1 + sqrt(1 + 8 * told^2))
else
tnew = 0.5 * (1 + sqrt(1 + 4 * told^2))
end
beta = (told - 1) / tnew
end
# momentum update
if mom === :pgm
xnew = ynew
elseif mom === :fpgm
xnew = ynew + beta * (ynew - yold)
elseif mom === :pogm # see [KF18]
# momentum coefficient "gamma"
if mu != 0
gamma = (2 + q - sqrt(q^2+8*q)) / 2.
else
gamma = sig * told / tnew
end
znew = (unew + beta * (unew - uold) + gamma * (unew - xold)
- beta * alpha / zetaold * (xold - zold))
zetanew = alpha * (1 + beta + gamma)
xnew = g_prox(znew, zetanew) # non-standard PG update for POGM
# non-standard composite gradient mapping for POGM:
Fgrad = fgrad - 1/zetanew * (xnew - znew)
ynew = xold - alpha * Fgrad
Fcostnew = Fcost(xnew)
# restart + "gamma" decrease conditions for POGM
if restart != :none
# function/gradient restart
if ((restart === :fr && Fcostnew > Fcostold)
|| (restart === :gr && gr_restart(Fgrad, ynew-yold, restart_cutoff)))
tnew = 1
sig = 1
is_restart = true
# gradient "gamma" decrease
elseif sum(Float64, real(Fgrad .* Fgradold)) < 0
sig = bsig * sig
end
Fcostold = Fcostnew
Fgradold = Fgrad
end
uold = unew
zold = znew
zetaold = zetanew
end
out[iter+1] = fun(iter, xnew, ynew, is_restart) # save
xold = xnew
yold = ynew
if mom != :pgm && iszero(mu)
told = tnew
end
end # for iter
return ((mom === :pogm) ? xnew : ynew), out
end # pogm_restart()