-
Notifications
You must be signed in to change notification settings - Fork 97
/
Copy pathstochasticgradient.m
245 lines (198 loc) · 9.07 KB
/
stochasticgradient.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
function [x, info, options] = stochasticgradient(problem, x, options)
% Stochastic gradient (SG) minimization algorithm for Manopt.
%
% function [x, info, options] = stochasticgradient(problem)
% function [x, info, options] = stochasticgradient(problem, x0)
% function [x, info, options] = stochasticgradient(problem, x0, options)
% function [x, info, options] = stochasticgradient(problem, [], options)
%
% Apply the Riemannian stochastic gradient algorithm to the problem defined
% in the problem structure, starting at x0 if it is provided (otherwise, at
% a random point on the manifold). To specify options whilst not specifying
% an initial guess, give x0 as [] (the empty matrix).
%
% The problem structure must contain the following fields:
%
% problem.M:
% Defines the manifold to optimize over, given by a factory.
%
% problem.partialgrad or problem.partialegrad (or equivalent)
% Describes the partial gradients of the cost function. If the cost
% function is of the form f(x) = sum_{k=1}^N f_k(x),
% then partialegrad(x, K) = sum_{k \in K} grad f_k(x).
% As usual, partialgrad must define the Riemannian gradient, whereas
% partialegrad defines a Euclidean (classical) gradient which will be
% converted automatically to a Riemannian gradient. Use the tool
% checkgradient(problem) to check it. K is a /row/ vector, which
% makes it natural to write for k = K, ..., end.
%
% problem.ncostterms
% An integer specifying how many terms are in the cost function (in
% the example above, that would be N.)
%
% Importantly, the cost function itself needs not be specified.
%
% Some of the options of the solver are specific to this file. Please have
% a look inside the code.
%
% To record the value of the cost function or the norm of the gradient for
% example (which are statistics the algorithm does not require and hence
% does not compute by default), one can set the following options:
%
% metrics.cost = @(problem, x) getCost(problem, x);
% metrics.gradnorm = @(problem, x) problem.M.norm(x, getGradient(problem, x));
% options.statsfun = statsfunhelper(metrics);
%
% Important caveat: stochastic algorithms usually return an average of the
% last few iterates. Computing averages on manifolds can be expensive.
% Currently, this solver does not compute averages and simply returns the
% last iterate. Using options.statsfun, it is possible for the user to
% compute averages manually. If you have ideas on how to do this
% generically, we welcome feedback. In particular, approximate means could
% be computed with M.pairmean which is available in many geometries.
%
% See also: steepestdescent
% This file is part of Manopt: www.manopt.org.
% Original authors: Bamdev Mishra <bamdevm@gmail.com>,
% Hiroyuki Kasai <kasai@is.uec.ac.jp>, and
% Hiroyuki Sato <hsato@ms.kagu.tus.ac.jp>, 22 April 2016.
% Contributors: Nicolas Boumal
% Change log:
%
% 06 July 2019 (BM):
% Added preconditioner support. This allows to use adaptive algorithms.
% Verify that the problem description is sufficient for the solver.
if ~canGetPartialGradient(problem)
warning('manopt:getPartialGradient', ...
'No partial gradient provided. The algorithm will likely abort.');
end
% Set local default
localdefaults.maxiter = 1000; % Maximum number of iterations
localdefaults.batchsize = 1; % Batchsize (# cost terms per iter)
localdefaults.verbosity = 2; % Output verbosity (0, 1 or 2)
localdefaults.storedepth = 20; % Limit amount of caching
% Check stopping criteria and save stats every checkperiod iterations.
localdefaults.checkperiod = 100;
% stepsizefun is a function implementing a step size selection
% algorithm. See that function for help with options, which can be
% specified in the options structure passed to the solver directly.
localdefaults.stepsizefun = @stepsize_sg;
% Merge global and local defaults, then merge w/ user options, if any.
localdefaults = mergeOptions(getGlobalDefaults(), localdefaults);
if ~exist('options', 'var') || isempty(options)
options = struct();
end
options = mergeOptions(localdefaults, options);
assert(options.checkperiod >= 1, ...
'options.checkperiod must be a positive integer (>= 1).');
% If no initial point x is given by the user, generate one at random.
if ~exist('x', 'var') || isempty(x)
x = problem.M.rand();
end
% Create a store database and get a key for the current x
storedb = StoreDB(options.storedepth);
key = storedb.getNewKey();
% Elapsed time for the current set of iterations, where a set of
% iterations comprises options.checkperiod iterations. We do not
% count time spent for such things as logging statistics, as these are
% not relevant to the actual optimization process.
elapsed_time = 0;
% Total number of completed steps
iter = 0;
% Total number of saved stats at this point.
savedstats = 0;
% Collect and save stats in a struct array info, and preallocate.
stats = savestats();
info(1) = stats;
savedstats = savedstats + 1;
if isinf(options.maxiter)
% We trust that if the user set maxiter = inf, then they defined
% another stopping criterion.
preallocate = 1e5;
else
preallocate = ceil(options.maxiter / options.checkperiod) + 1;
end
info(preallocate).iter = [];
% Display information header for the user.
if options.verbosity >= 2
fprintf(' iter time [s] step size\n');
end
% Main loop.
stop = false;
while iter < options.maxiter
% Record start time.
start_time = tic();
% Draw the samples with replacement.
idx_batch = randi(problem.ncostterms, options.batchsize, 1);
% Compute partial gradient on this batch.
pgrad = getPartialGradient(problem, x, idx_batch, storedb, key);
% Apply preconditioner to the partial gradient.
Ppgrad = getPrecon(problem, x, pgrad, storedb, key);
% Compute a step size and the corresponding new point x.
[stepsize, newx, newkey, ssstats] = ...
options.stepsizefun(problem, x, Ppgrad, iter, ...
options, storedb, key);
% Make the step: transfer iterate, remove cache from previous x.
storedb.removefirstifdifferent(key, newkey);
x = newx;
key = newkey;
% Make sure we do not use too much memory for the store database.
storedb.purge();
% Total number of completed steps.
iter = iter + 1;
% Elapsed time doing actual optimization work so far in this
% set of options.checkperiod iterations.
elapsed_time = elapsed_time + toc(start_time);
% Check stopping criteria and save stats every checkperiod iters.
if mod(iter, options.checkperiod) == 0
% Log statistics for freshly executed iteration.
stats = savestats();
info(savedstats+1) = stats;
savedstats = savedstats + 1;
% Reset timer.
elapsed_time = 0;
% Print output.
if options.verbosity >= 2
fprintf('%8d %10.2f %.3e\n', ...
iter, stats.time, stepsize);
end
% Run standard stopping criterion checks.
[stop, reason] = stoppingcriterion(problem, x, ...
options, info, savedstats);
if stop
if options.verbosity >= 1
fprintf([reason '\n']);
end
break;
end
end
end
% Keep only the relevant portion of the info struct-array.
info = info(1:savedstats);
% Display a final information message.
if options.verbosity >= 1
if ~stop
% We stopped not because of stoppingcriterion but because the
% loop came to an end, which means maxiter triggered.
msg = 'Max iteration count reached; options.maxiter = %g.\n';
fprintf(msg, options.maxiter);
end
fprintf('Total time is %f [s] (excludes statsfun)\n', ...
info(end).time + elapsed_time);
end
% Helper function to collect statistics to be saved at
% index checkperiodcount+1 in info.
function stats = savestats()
stats.iter = iter;
if savedstats == 0
stats.time = 0;
stats.stepsize = NaN;
stats.stepsize_stats = [];
else
stats.time = info(savedstats).time + elapsed_time;
stats.stepsize = stepsize;
stats.stepsize_stats = ssstats;
end
stats = applyStatsfun(problem, x, storedb, key, options, stats);
end
end