-
Notifications
You must be signed in to change notification settings - Fork 100
/
Copy patharc_conjugate_gradient.m
209 lines (191 loc) · 8.14 KB
/
arc_conjugate_gradient.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
function [eta, Heta, hesscalls, stop_str, stats] = arc_conjugate_gradient(problem, x, grad, gradnorm, sigma, options, storedb, key)
% Subproblem solver for ARC based on a nonlinear conjugate gradient method.
%
% [eta, Heta, hesscalls, stop_str, stats] =
% arc_conjugate_gradient(problem, x, grad, gradnorm, sigma, options, storedb, key)
%
% This routine approximately solves the following problem:
%
% min_{eta in T_x M} m(eta), where
%
% m(eta) = <eta, g> + .5 <eta, H[eta]> + (sigma/3) ||eta||^3
%
% where eta is a tangent vector at x on the manifold given by problem.M,
% g = grad is a tangent vector at x, H[eta] is the result of applying the
% Hessian of the problem at x along eta, and the inner product and norm
% are those from the Riemannian structure on the tangent space T_x M.
%
% The solve is approximate in the sense that the returned eta only ought
% to satisfy the following conditions:
%
% ||gradient of m at eta|| <= theta*||eta||^2 and m(eta) <= m(0),
%
% where theta is specified in options.theta (see below for default value.)
% Since the gradient of the model at 0 is g, if it is zero, then eta = 0
% is returned. This is the only scenario where eta = 0 is returned.
%
% Numerical errors can perturb the described expected behavior.
%
% Inputs:
% problem: Manopt optimization problem structure
% x: point on the manifold problem.M
% grad: gradient of the cost function of the problem at x
% gradnorm: norm of the gradient, often available to the caller
% sigma: cubic regularization parameter (positive scalar)
% options: structure containing options for the subproblem solver
% storedb, key: caching data for problem at x
%
% Options specific to this subproblem solver:
% theta (0.25)
% Stopping criterion parameter for subproblem solver: the gradient of
% the model at the returned step should have norm no more than theta
% times the squared norm of the step.
% maxinner (the manifold's dimension)
% Maximum number of iterations of the conjugate gradient algorithm.
% beta_type ('P-R')
% The update rule for calculating beta:
% 'F-R' for Fletcher-Reeves, 'P-R' for Polak-Ribiere, and 'H-S' for
% Hestenes-Stiefel.
%
% Outputs:
% eta: approximate solution to the cubic regularized subproblem at x
% Heta: Hess f(x)[eta] -- this is necessary in the outer loop, and it
% is often naturally available to the subproblem solver at the
% end of execution, so that it may be cheaper to return it here.
% hesscalls: number of Hessian calls during execution
% stop_str: string describing why the subsolver stopped
% stats: a structure specifying some statistics about inner work -
% we record the model cost value and model gradient norm at each
% inner iteration.
% This file is part of Manopt: www.manopt.org.
% Original authors: May 2, 2019,
% Bryan Zhu, Nicolas Boumal.
% Contributors:
% Change log:
%
% Aug. 19, 2019 (NB):
% Option maxiter_cg renamed to maxinner to match trustregions.
% TODO: Support preconditioning through getPrecon().
% Some shortcuts
M = problem.M;
n = M.dim();
inner = @(u, v) M.inner(x, u, v);
rnorm = @(u) M.norm(x, u);
tangent = @(u) problem.M.tangent(x, u);
Hess = @(u) getHessian(problem, x, u, storedb, key);
% Counter for Hessian calls issued
hesscalls = 0;
% If the gradient has norm zero, return a zero step
if gradnorm == 0
eta = M.zerovec(x);
Heta = eta;
stop_str = 'Cost gradient is zero';
stats = struct('gradnorms', 0, 'func_values', 0);
return;
end
% Set local defaults here
localdefaults.theta = 0.25;
localdefaults.maxinner = n;
localdefaults.beta_type = 'P-R';
localdefaults.subproblemstop = 'sqrule';
% Merge local defaults with user options, if any
if ~exist('options', 'var') || isempty(options)
options = struct();
end
options = mergeOptions(localdefaults, options);
% Calculate the Cauchy point as our initial step
hess_grad = Hess(grad);
hesscalls = hesscalls + 1;
temp = inner(grad, hess_grad) / (2 * sigma * gradnorm * gradnorm);
R_c = -temp + sqrt(temp * temp + gradnorm / sigma);
eta = M.lincomb(x, -R_c / gradnorm, grad);
Heta = M.lincomb(x, -R_c / gradnorm, hess_grad);
% Initialize variables needed for calculation of conjugate direction
prev_grad = M.lincomb(x, -1, grad);
prev_conj = prev_grad;
Hp_conj = M.lincomb(x, -1, hess_grad);
% Main conjugate gradients iteration
maxiter = min(options.maxinner, 3*n);
gradnorms = zeros(maxiter, 1);
func_values = zeros(maxiter, 1);
gradnorm_reached = false;
j = 1;
while j < maxiter
% Calculate the gradient of the model
eta_norm = rnorm(eta);
new_grad = M.lincomb(x, 1, Heta, 1, grad);
new_grad = M.lincomb(x, -1, new_grad, -sigma * eta_norm, eta);
new_grad = tangent(new_grad);
% Compute some statistics
gradnorms(j) = rnorm(new_grad);
func_values(j) = inner(grad, eta) + 0.5 * inner(eta, Heta) + (sigma/3) * eta_norm^3;
if options.verbosity >= 4
fprintf('\nModel grad norm: %.16e, Iterate norm: %.16e', gradnorms(j), eta_norm);
end
% Check termination condition
% TODO -- factor this out, as it is the same for all subsolvers.
% TODO -- I haven't found a scenario where sqrule doens't win.
% TODO -- 1e-4 might be too small (g, s, ss seem equivalent).
switch lower(options.subproblemstop)
case 'sqrule'
stop = (gradnorms(j) <= options.theta * eta_norm^2);
case 'grule'
stop = (gradnorms(j) <= min(1e-4, sqrt(gradnorms(1)))*gradnorms(1));
case 'srule'
stop = (gradnorms(j) <= min(1e-4, eta_norm)*gradnorms(1));
case 'ssrule'
stop = (gradnorms(j) <= min(1e-4, eta_norm/max(1, sigma))*gradnorms(1));
otherwise
error(['Unknown value for options.subproblemstop.\n' ...
'Possible values: ''sqrule'', ''grule'', ' ...
'''srule'', ''ssrule''.\n']); % ...
% 'Current value: ', options.subproblemstop, '\n']);
end
if stop
stop_str = 'Model grad norm condition satisfied';
gradnorm_reached = true;
break;
end
% Calculate the conjugate direction using the selected beta rule
delta = M.lincomb(x, 1, new_grad, -1, prev_grad);
switch upper(options.beta_type)
case 'F-R'
beta = inner(new_grad, new_grad) / inner(prev_grad, prev_grad);
case 'P-R'
beta = max(0, inner(new_grad, delta) / inner(prev_grad, prev_grad));
case 'H-S'
beta = max(0, -inner(new_grad, delta) / inner(prev_conj, delta));
otherwise
error('Unknown options.beta_type. Should be F-R, P-R, or H-S.');
end
new_conj = M.lincomb(x, 1, new_grad, beta, prev_conj);
Hn_grad = Hess(new_grad);
hesscalls = hesscalls + 1;
Hn_conj = M.lincomb(x, 1, Hn_grad, beta, Hp_conj);
% Find the optimal step in the conjugate direction
alpha = solve_along_line(M, x, eta, new_conj, grad, Hn_conj, sigma);
if alpha == 0
stop_str = 'Unable to make further progress in search direction';
gradnorm_reached = true;
break;
end
eta = M.lincomb(x, 1, eta, alpha, new_conj);
Heta = M.lincomb(x, 1, Heta, alpha, Hn_conj);
prev_grad = new_grad;
prev_conj = new_conj;
Hp_conj = Hn_conj;
j = j + 1;
end
% Check why we stopped iterating
if ~gradnorm_reached
stop_str = sprintf(['Reached max number of conjugate gradient iterations ' ...
'(options.maxinner = %d)'], options.maxinner);
j = j - 1;
end
% Return the point we ended on
eta = tangent(eta);
stats = struct('gradnorms', gradnorms(1:j), 'func_values', func_values(1:j));
if options.verbosity >= 4
fprintf('\n');
end
end