Permalink
Fetching contributors…
Cannot retrieve contributors at this time
240 lines (193 sloc) 8.85 KB
function [x, info, options] = stochasticgradient(problem, x, options)
% Stochastic gradient (SG) minimization algorithm for Manopt.
%
% function [x, info, options] = stochasticgradient(problem)
% function [x, info, options] = stochasticgradient(problem, x0)
% function [x, info, options] = stochasticgradient(problem, x0, options)
% function [x, info, options] = stochasticgradient(problem, [], options)
%
% Apply the Riemannian stochastic gradient algorithm to the problem defined
% in the problem structure, starting at x0 if it is provided (otherwise, at
% a random point on the manifold). To specify options whilst not specifying
% an initial guess, give x0 as [] (the empty matrix).
%
% The problem structure must contain the following fields:
%
% problem.M:
% Defines the manifold to optimize over, given by a factory.
%
% problem.partialgrad or problem.partialegrad (or equivalent)
% Describes the partial gradients of the cost function. If the cost
% function is of the form f(x) = sum_{k=1}^N f_k(x),
% then partialegrad(x, K) = sum_{k \in K} grad f_k(x).
% As usual, partialgrad must define the Riemannian gradient, whereas
% partialegrad defines a Euclidean (classical) gradient which will be
% converted automatically to a Riemannian gradient. Use the tool
% checkgradient(problem) to check it. K is a /row/ vector, which
% makes it natural to write for k = K, ..., end.
%
% problem.ncostterms
% An integer specifying how many terms are in the cost function (in
% the example above, that would be N.)
%
% Importantly, the cost function itself needs not be specified.
%
% Some of the options of the solver are specific to this file. Please have
% a look inside the code.
%
% To record the value of the cost function or the norm of the gradient for
% example (which are statistics the algorithm does not require and hence
% does not compute by default), one can set the following options:
%
% metrics.cost = @(problem, x) getCost(problem, x);
% metrics.gradnorm = @(problem, x) problem.M.norm(x, getGradient(problem, x));
% options.statsfun = statsfunhelper(metrics);
%
% Important caveat: stochastic algorithms usually return an average of the
% last few iterates. Computing averages on manifolds can be expensive.
% Currently, this solver does not compute averages and simply returns the
% last iterate. Using options.statsfun, it is possible for the user to
% compute averages manually. If you have ideas on how to do this
% generically, we welcome feedback. In particular, approximate means could
% be computed with M.pairmean which is available in many geometries.
%
% See also: steepestdescent
% This file is part of Manopt: www.manopt.org.
% Original authors: Bamdev Mishra <bamdevm@gmail.com>,
% Hiroyuki Kasai <kasai@is.uec.ac.jp>, and
% Hiroyuki Sato <hsato@ms.kagu.tus.ac.jp>, 22 April 2016.
% Contributors: Nicolas Boumal
% Change log:
% Verify that the problem description is sufficient for the solver.
if ~canGetPartialGradient(problem)
warning('manopt:getPartialGradient', ...
'No partial gradient provided. The algorithm will likely abort.');
end
% Set local default
localdefaults.maxiter = 1000; % Maximum number of iterations
localdefaults.batchsize = 1; % Batchsize (# cost terms per iter)
localdefaults.verbosity = 2; % Output verbosity (0, 1 or 2)
localdefaults.storedepth = 20; % Limit amount of caching
% Check stopping criteria and save stats every checkperiod iterations.
localdefaults.checkperiod = 100;
% stepsizefun is a function implementing a step size selection
% algorithm. See that function for help with options, which can be
% specified in the options structure passed to the solver directly.
localdefaults.stepsizefun = @stepsize_sg;
% Merge global and local defaults, then merge w/ user options, if any.
localdefaults = mergeOptions(getGlobalDefaults(), localdefaults);
if ~exist('options', 'var') || isempty(options)
options = struct();
end
options = mergeOptions(localdefaults, options);
assert(options.checkperiod >= 1, ...
'options.checkperiod must be a positive integer (>= 1).');
% If no initial point x is given by the user, generate one at random.
if ~exist('x', 'var') || isempty(x)
x = problem.M.rand();
end
% Create a store database and get a key for the current x
storedb = StoreDB(options.storedepth);
key = storedb.getNewKey();
% Elapsed time for the current set of iterations, where a set of
% iterations comprises options.checkperiod iterations. We do not
% count time spent for such things as logging statistics, as these are
% not relevant to the actual optimization process.
elapsed_time = 0;
% Total number of completed steps
iter = 0;
% Total number of saved stats at this point.
savedstats = 0;
% Collect and save stats in a struct array info, and preallocate.
stats = savestats();
info(1) = stats;
savedstats = savedstats + 1;
if isinf(options.maxiter)
% We trust that if the user set maxiter = inf, then they defined
% another stopping criterion.
preallocate = 1e5;
else
preallocate = ceil(options.maxiter / options.checkperiod) + 1;
end
info(preallocate).iter = [];
% Display information header for the user.
if options.verbosity >= 2
fprintf(' iter time [s] step size\n');
end
% Main loop.
stop = false;
while iter < options.maxiter
% Record start time.
start_time = tic();
% Draw the samples with replacement.
idx_batch = randi(problem.ncostterms, options.batchsize, 1);
% Compute partial gradient on this batch.
pgrad = getPartialGradient(problem, x, idx_batch, storedb, key);
% Compute a step size and the corresponding new point x.
[stepsize, newx, newkey, ssstats] = ...
options.stepsizefun(problem, x, pgrad, iter, ...
options, storedb, key);
% Make the step: transfer iterate, remove cache from previous x.
storedb.removefirstifdifferent(key, newkey);
x = newx;
key = newkey;
% Make sure we do not use too much memory for the store database.
storedb.purge();
% Total number of completed steps.
iter = iter + 1;
% Elapsed time doing actual optimization work so far in this
% set of options.checkperiod iterations.
elapsed_time = elapsed_time + toc(start_time);
% Check stopping criteria and save stats every checkperiod iters.
if mod(iter, options.checkperiod) == 0
% Log statistics for freshly executed iteration.
stats = savestats();
info(savedstats+1) = stats;
savedstats = savedstats + 1;
% Reset timer.
elapsed_time = 0;
% Print output.
if options.verbosity >= 2
fprintf('%8d %10.2f %.3e\n', ...
iter, stats.time, stepsize);
end
% Run standard stopping criterion checks.
[stop, reason] = stoppingcriterion(problem, x, ...
options, info, savedstats);
if stop
if options.verbosity >= 1
fprintf([reason '\n']);
end
break;
end
end
end
% Keep only the relevant portion of the info struct-array.
info = info(1:savedstats);
% Display a final information message.
if options.verbosity >= 1
if ~stop
% We stopped not because of stoppingcriterion but because the
% loop came to an end, which means maxiter triggered.
msg = 'Max iteration count reached; options.maxiter = %g.\n';
fprintf(msg, options.maxiter);
end
fprintf('Total time is %f [s] (excludes statsfun)\n', ...
info(end).time + elapsed_time);
end
% Helper function to collect statistics to be saved at
% index checkperiodcount+1 in info.
function stats = savestats()
stats.iter = iter;
if savedstats == 0
stats.time = 0;
stats.stepsize = NaN;
stats.stepsize_stats = [];
else
stats.time = info(savedstats).time + elapsed_time;
stats.stepsize = stepsize;
stats.stepsize_stats = ssstats;
end
stats = applyStatsfun(problem, x, storedb, key, options, stats);
end
end