Skip to content
Permalink
master
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
function [x, info, options] = stochasticgradient(problem, x, options)
% Stochastic gradient (SG) minimization algorithm for Manopt.
%
% function [x, info, options] = stochasticgradient(problem)
% function [x, info, options] = stochasticgradient(problem, x0)
% function [x, info, options] = stochasticgradient(problem, x0, options)
% function [x, info, options] = stochasticgradient(problem, [], options)
%
% Apply the Riemannian stochastic gradient algorithm to the problem defined
% in the problem structure, starting at x0 if it is provided (otherwise, at
% a random point on the manifold). To specify options whilst not specifying
% an initial guess, give x0 as [] (the empty matrix).
%
% The problem structure must contain the following fields:
%
% problem.M:
% Defines the manifold to optimize over, given by a factory.
%
% problem.partialgrad or problem.partialegrad (or equivalent)
% Describes the partial gradients of the cost function. If the cost
% function is of the form f(x) = sum_{k=1}^N f_k(x),
% then partialegrad(x, K) = sum_{k \in K} grad f_k(x).
% As usual, partialgrad must define the Riemannian gradient, whereas
% partialegrad defines a Euclidean (classical) gradient which will be
% converted automatically to a Riemannian gradient. Use the tool
% checkgradient(problem) to check it. K is a /row/ vector, which
% makes it natural to write for k = K, ..., end.
%
% problem.ncostterms
% An integer specifying how many terms are in the cost function (in
% the example above, that would be N.)
%
% Importantly, the cost function itself needs not be specified.
%
% Some of the options of the solver are specific to this file. Please have
% a look inside the code.
%
% To record the value of the cost function or the norm of the gradient for
% example (which are statistics the algorithm does not require and hence
% does not compute by default), one can set the following options:
%
% metrics.cost = @(problem, x) getCost(problem, x);
% metrics.gradnorm = @(problem, x) problem.M.norm(x, getGradient(problem, x));
% options.statsfun = statsfunhelper(metrics);
%
% Important caveat: stochastic algorithms usually return an average of the
% last few iterates. Computing averages on manifolds can be expensive.
% Currently, this solver does not compute averages and simply returns the
% last iterate. Using options.statsfun, it is possible for the user to
% compute averages manually. If you have ideas on how to do this
% generically, we welcome feedback. In particular, approximate means could
% be computed with M.pairmean which is available in many geometries.
%
% See also: steepestdescent
% This file is part of Manopt: www.manopt.org.
% Original authors: Bamdev Mishra <bamdevm@gmail.com>,
% Hiroyuki Kasai <kasai@is.uec.ac.jp>, and
% Hiroyuki Sato <hsato@ms.kagu.tus.ac.jp>, 22 April 2016.
% Contributors: Nicolas Boumal
% Change log:
%
% 06 July 2019 (BM):
% Added preconditioner support. This allows to use adaptive algorithms.
% Verify that the problem description is sufficient for the solver.
if ~canGetPartialGradient(problem)
warning('manopt:getPartialGradient', ...
'No partial gradient provided. The algorithm will likely abort.');
end
% Set local default
localdefaults.maxiter = 1000; % Maximum number of iterations
localdefaults.batchsize = 1; % Batchsize (# cost terms per iter)
localdefaults.verbosity = 2; % Output verbosity (0, 1 or 2)
localdefaults.storedepth = 20; % Limit amount of caching
% Check stopping criteria and save stats every checkperiod iterations.
localdefaults.checkperiod = 100;
% stepsizefun is a function implementing a step size selection
% algorithm. See that function for help with options, which can be
% specified in the options structure passed to the solver directly.
localdefaults.stepsizefun = @stepsize_sg;
% Merge global and local defaults, then merge w/ user options, if any.
localdefaults = mergeOptions(getGlobalDefaults(), localdefaults);
if ~exist('options', 'var') || isempty(options)
options = struct();
end
options = mergeOptions(localdefaults, options);
assert(options.checkperiod >= 1, ...
'options.checkperiod must be a positive integer (>= 1).');
% If no initial point x is given by the user, generate one at random.
if ~exist('x', 'var') || isempty(x)
x = problem.M.rand();
end
% Create a store database and get a key for the current x
storedb = StoreDB(options.storedepth);
key = storedb.getNewKey();
% Elapsed time for the current set of iterations, where a set of
% iterations comprises options.checkperiod iterations. We do not
% count time spent for such things as logging statistics, as these are
% not relevant to the actual optimization process.
elapsed_time = 0;
% Total number of completed steps
iter = 0;
% Total number of saved stats at this point.
savedstats = 0;
% Collect and save stats in a struct array info, and preallocate.
stats = savestats();
info(1) = stats;
savedstats = savedstats + 1;
if isinf(options.maxiter)
% We trust that if the user set maxiter = inf, then they defined
% another stopping criterion.
preallocate = 1e5;
else
preallocate = ceil(options.maxiter / options.checkperiod) + 1;
end
info(preallocate).iter = [];
% Display information header for the user.
if options.verbosity >= 2
fprintf(' iter time [s] step size\n');
end
% Main loop.
stop = false;
while iter < options.maxiter
% Record start time.
start_time = tic();
% Draw the samples with replacement.
idx_batch = randi(problem.ncostterms, options.batchsize, 1);
% Compute partial gradient on this batch.
pgrad = getPartialGradient(problem, x, idx_batch, storedb, key);
% Apply preconditioner to the partial gradient.
Ppgrad = getPrecon(problem, x, pgrad, storedb, key);
% Compute a step size and the corresponding new point x.
[stepsize, newx, newkey, ssstats] = ...
options.stepsizefun(problem, x, Ppgrad, iter, ...
options, storedb, key);
% Make the step: transfer iterate, remove cache from previous x.
storedb.removefirstifdifferent(key, newkey);
x = newx;
key = newkey;
% Make sure we do not use too much memory for the store database.
storedb.purge();
% Total number of completed steps.
iter = iter + 1;
% Elapsed time doing actual optimization work so far in this
% set of options.checkperiod iterations.
elapsed_time = elapsed_time + toc(start_time);
% Check stopping criteria and save stats every checkperiod iters.
if mod(iter, options.checkperiod) == 0
% Log statistics for freshly executed iteration.
stats = savestats();
info(savedstats+1) = stats;
savedstats = savedstats + 1;
% Reset timer.
elapsed_time = 0;
% Print output.
if options.verbosity >= 2
fprintf('%8d %10.2f %.3e\n', ...
iter, stats.time, stepsize);
end
% Run standard stopping criterion checks.
[stop, reason] = stoppingcriterion(problem, x, ...
options, info, savedstats);
if stop
if options.verbosity >= 1
fprintf([reason '\n']);
end
break;
end
end
end
% Keep only the relevant portion of the info struct-array.
info = info(1:savedstats);
% Display a final information message.
if options.verbosity >= 1
if ~stop
% We stopped not because of stoppingcriterion but because the
% loop came to an end, which means maxiter triggered.
msg = 'Max iteration count reached; options.maxiter = %g.\n';
fprintf(msg, options.maxiter);
end
fprintf('Total time is %f [s] (excludes statsfun)\n', ...
info(end).time + elapsed_time);
end
% Helper function to collect statistics to be saved at
% index checkperiodcount+1 in info.
function stats = savestats()
stats.iter = iter;
if savedstats == 0
stats.time = 0;
stats.stepsize = NaN;
stats.stepsize_stats = [];
else
stats.time = info(savedstats).time + elapsed_time;
stats.stepsize = stepsize;
stats.stepsize_stats = ssstats;
end
stats = applyStatsfun(problem, x, storedb, key, options, stats);
end
end