/
ogm_ls.jl
154 lines (123 loc) · 4.34 KB
/
ogm_ls.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#=
ogm_ls.jl
OGM with a MM line search
2019-03-16, Jeff Fessler, University of Michigan
=#
export ogm_ls
using LinearAlgebra: I, norm, dot
"""
(x,out) = ogm_ls(B, gradf, curvf, x0; niter=?, ninner=?, fun=?)
OGM with a line search
[Drori&Taylor](http://doi.org/10.1007/s10107-019-01410-2)
to minimize a general "inverse problem" cost function of the form
``\\Psi(x) = \\sum_{j=1}^J f_j(B_j x)``
where each function ``f_j(v)`` has a quadratic majorizer of the form
```math
q_j(v;u) = f_j(u) + \\nabla f_j(u) (v - u) + 1/2 \\|v - u\\|^2_{C_j(u)}
```
where ``C_j(u)`` is diagonal matrix of curvatures.
(It suffices for each ``f_j`` to have a Lipschitz smooth gradient.)
This OGM method uses a majorize-minimize (MM) line search.
# in
- `B` vector of ``J`` blocks ``B_1,…,B_J``
- `gradf` vector of ``J`` functions return gradients of ``f_1,…,f_J``
- `curvf` vector of ``J`` functions `z -> curv(z)` that return a scalar
or a vector of curvature values for each element of ``z``
- `x0` initial guess; need `length(x) == size(B[j],2)` for ``j=1,…,J``
# option
- `niter` # number of outer iterations; default 50
- `ninner` # number of inner iterations of MM line search; default 5
- `fun` User-defined function to be evaluated with two arguments (x,iter).
* It is evaluated at (x0,0) and then after each iteration.
# output
- `x` final iterate
- `out (niter+1) (fun(x0,0), fun(x1,1), ..., fun(x_niter,niter))`
* (all 0 by default). This is a vector of length `niter+1`.
"""
function ogm_ls(
B::AbstractVector{<:Any},
gradf::AbstractVector{<:Function},
curvf::AbstractVector{<:Function},
x0::AbstractArray{<:Number} ; # usually Vector
niter::Int = 50,
ninner::Int = 5,
fun::Function = (x,iter) -> 0,
)
Base.require_one_based_indexing(B, gradf, curvf)
out = Array{Any}(undef, niter+1)
out[1] = fun(x0, 0)
J = length(B)
x = x0
dir = []
grad_old = []
grad_new = []
grad_sum = zeros(size(x0))
ti = 1
thetai = 1
B0 = [B[j] * x for j in 1:J]
Bx = copy(B0)
By = copy(B0)
grad = (Bx) -> sum([B[j]' * gradf[j](Bx[j]) for j in 1:J])
for iter in 1:niter
grad_new = grad(Bx) # gradient of x_{iter-1}
grad_sum += ti * grad_new # sum_{j=0}^{iter-1} t_j * gradient_j
thetai = (1 + sqrt(8*ti^2 + 1)) / 2 # theta_{i+1}
ti = (1 + sqrt(4*ti^2 + 1)) / 2 # t_{i+1}
tt = (iter < niter) ? ti : thetai # use theta_i factor for last iteration
yi = (1 - 1/tt) * x + (1/tt) * x0
for j in 1:J # update Bj * yi
By[j] = (1 - 1/tt) * Bx[j] + (1/tt) * B0[j]
end
dir = -(1 - 1/tt) * grad_new - (2/tt) * grad_sum # -d_i
# MM-based line search for step size alpha
# using h(a) = sum_j f_j(By_j + a * Bd_j)
Bd = [B[j] * dir for j in 1:J]
alf = 0
for ii in 1:ninner
derh = 0 # derivative of h(a)
curv = 0
for j in 1:J
tmp = By[j] + alf * Bd[j]
derh += real(dot(Bd[j], gradf[j](tmp)))
curv += sum(curvf[j](tmp) .* abs2.(Bd[j]))
end
curv < 0 && throw("curv < 0")
if curv > 0
alf = alf - derh / curv
end
iszero(alf) && break
end
# # derivative of h(a) = cost(x + a * dir) where \alpha is real
# dh = alf -> real(sum([Bd[j]' * gradf[j](By[j] + alf * Bd[j]) for j in 1:J]))
# Ldh = sum([Lgf[j] * norm(Bd[j])^2 for j in 1:J]) # Lipschitz constant for dh
# (alf, ) = gd(dh, Ldh, 0, niter=ninner) # GD-based line search
# todo
x = yi + alf * dir
if iter < niter
for j in 1:J # update Bj * x
Bx[j] = By[j] + alf * Bd[j]
end
end
# for j in 1:J # recursive update Bj * yi ???
# By[j] = (1 - 1/ti) * (By[j] + alf * Bd[j]) + (1/ti) * B0[j]
# end
out[iter+1] = fun(x, iter)
end
return x, out
end
"""
(x,out) = ogm_ls(grad, curv, x0, ...)
Special case of `ogm_ls` (OGM with line search)
for minimizing a cost function
whose gradient is `grad(x)`
and that has a quadratic majorizer with diagonal Hessian given by `curv(x)`.
Typically `curv = (x) -> L` where `L` is the Lipschitz constant of `grad`.
"""
function ogm_ls(
grad::Function,
curv::Function,
x0::AbstractVector{<:Number} ;
kwargs...,
)
return ogm_ls([I], [grad], [curv], x0; kwargs...)
end