Permalink
Fetching contributors…
Cannot retrieve contributors at this time
286 lines (248 sloc) 9.74 KB
function [x, cost, info, options] = neldermead(problem, x, options)
% Nelder Mead optimization algorithm for derivative-free minimization.
%
% function [x, cost, info, options] = neldermead(problem)
% function [x, cost, info, options] = neldermead(problem, x0)
% function [x, cost, info, options] = neldermead(problem, x0, options)
% function [x, cost, info, options] = neldermead(problem, [], options)
%
% Apply a Nelder-Mead minimization algorithm to the problem defined in
% the problem structure, starting with the population x0 if it is provided
% (otherwise, a random population on the manifold is generated). A
% population is a cell containing points on the manifold. The number of
% elements in the cell must be dim+1, where dim is the dimension of the
% manifold: problem.M.dim().
%
% To specify options whilst not specifying an initial guess, give x0 as []
% (the empty matrix).
%
% This algorithm is a plain adaptation of the Euclidean Nelder-Mead method
% to the Riemannian setting. It comes with no convergence guarantees and
% there is room for improvement. In particular, we compute centroids as
% Karcher means, which seems overly expensive: cheaper forms of
% average-like quantities might work better.
% This solver is useful nonetheless for problems for which no derivatives
% are available, and it may constitute a starting point for the development
% of other Riemannian derivative-free methods.
%
% None of the options are mandatory. See in code for details.
%
% Requires problem.M.pairmean(x, y) to be defined (computes the average
% between two points, x and y).
%
% If options.statsfun is defined, it will receive a cell of points x (the
% current simplex being considered at that iteration), and, if required,
% one store structure corresponding to the best point, x{1}. The points are
% ordered by increasing cost: f(x{1}) <= f(x{2}) <= ... <= f(x{dim+1}),
% where dim = problem.M.dim().
%
% Based on http://www.optimization-online.org/DB_FILE/2007/08/1742.pdf.
%
% See also: manopt/solvers/pso/pso
% This file is part of Manopt: www.manopt.org.
% Original author: Nicolas Boumal, Dec. 30, 2012.
% Contributors:
% Change log:
%
% Apr. 4, 2015 (NB):
% Working with the new StoreDB class system.
% Clarified interactions with statsfun and store.
%
% Nov. 11, 2016 (NB):
% If options.verbosity is < 2, prints minimal output.
%
% Sep. 6, 2018 (NB):
% Using retraction instead of exponential.
% Verify that the problem description is sufficient for the solver.
if ~canGetCost(problem)
warning('manopt:getCost', ...
'No cost provided. The algorithm will likely abort.');
end
% Dimension of the manifold
dim = problem.M.dim();
% Set local defaults here
localdefaults.storedepth = 0; % no need for caching
localdefaults.maxiter = max(2000, 4*dim);
localdefaults.reflection = 1;
localdefaults.expansion = 2;
localdefaults.contraction = .5;
% forced to .5 to enable using pairmean functions in manifolds.
% localdefaults.shrinkage = .5;
% Merge global and local defaults, then merge w/ user options, if any.
localdefaults = mergeOptions(getGlobalDefaults(), localdefaults);
if ~exist('options', 'var') || isempty(options)
options = struct();
end
options = mergeOptions(localdefaults, options);
% Start timing for initialization.
timetic = tic();
% If no initial simplex x is given by the user, generate one at random.
if ~exist('x', 'var') || isempty(x)
x = cell(dim+1, 1);
for i = 1 : dim+1
x{i} = problem.M.rand();
end
end
% Create a store database and a key for each point.
storedb = StoreDB(options.storedepth);
key = cell(size(x));
for i = 1 : dim+1;
key{i} = storedb.getNewKey();
end
% Compute objective-related quantities for x, and setup a
% function evaluations counter.
costs = zeros(dim+1, 1);
for i = 1 : dim+1
costs(i) = getCost(problem, x{i}, storedb, key{i});
end
costevals = dim+1;
% Sort simplex points by cost.
[costs, order] = sort(costs);
x = x(order);
key = key(order);
% Iteration counter.
% At any point, iter is the number of fully executed iterations so far.
iter = 0;
% Save stats in a struct array info, and preallocate.
% savestats will be called twice for the initial iterate (number 0),
% which is unfortunate, but not problematic.
stats = savestats();
info(1) = stats;
info(min(10000, options.maxiter+1)).iter = [];
% Start iterating until stopping criterion triggers.
while true
% Make sure we don't use to much memory for the store database.
storedb.purge();
stats = savestats();
info(iter+1) = stats; %#ok<AGROW>
iter = iter + 1;
% Start timing this iteration.
timetic = tic();
% Sort simplex points by cost.
[costs, order] = sort(costs);
x = x(order);
key = key(order);
% Log / display iteration information here.
if options.verbosity >= 2
fprintf('Cost evals: %7d\tBest cost: %+.4e\t', ...
costevals, costs(1));
end
% Run standard stopping criterion checks.
[stop, reason] = stoppingcriterion(problem, x, options, info, iter);
if stop
if options.verbosity >= 1
fprintf([reason '\n']);
end
break;
end
% Compute a centroid for the dim best points.
xbar = centroid(problem.M, x(1:end-1));
% Compute the direction for moving along the axis xbar - worst x.
vec = problem.M.log(xbar, x{end});
% Reflection step
xr = problem.M.retr(xbar, vec, -options.reflection);
keyr = storedb.getNewKey();
costr = getCost(problem, xr, storedb, keyr);
costevals = costevals + 1;
% If the reflected point is honorable, drop the worst point,
% replace it by the reflected point and start new iteration.
if costr >= costs(1) && costr < costs(end-1)
if options.verbosity >= 2
fprintf('Reflection\n');
end
costs(end) = costr;
x{end} = xr;
key{end} = keyr;
continue;
end
% If the reflected point is better than the best point, expand.
if costr < costs(1)
xe = problem.M.retr(xbar, vec, -options.expansion);
keye = storedb.getNewKey();
coste = getCost(problem, xe, storedb, keye);
costevals = costevals + 1;
if coste < costr
if options.verbosity >= 2
fprintf('Expansion\n');
end
costs(end) = coste;
x{end} = xe;
key{end} = keye;
continue;
else
if options.verbosity >= 2
fprintf('Reflection (failed expansion)\n');
end
costs(end) = costr;
x{end} = xr;
key{end} = keyr;
continue;
end
end
% If the reflected point is worse than the second to worst point,
% contract.
if costr >= costs(end-1)
if costr < costs(end)
% do an outside contraction
xoc = problem.M.retr(xbar, vec, -options.contraction);
keyoc = storedb.getNewKey();
costoc = getCost(problem, xoc, storedb, keyoc);
costevals = costevals + 1;
if costoc <= costr
if options.verbosity >= 2
fprintf('Outside contraction\n');
end
costs(end) = costoc;
x{end} = xoc;
key{end} = keyoc;
continue;
end
else
% do an inside contraction
xic = problem.M.retr(xbar, vec, options.contraction);
keyic = storedb.getNewKey();
costic = getCost(problem, xic, storedb, keyic);
costevals = costevals + 1;
if costic <= costs(end)
if options.verbosity >= 2
fprintf('Inside contraction\n');
end
costs(end) = costic;
x{end} = xic;
key{end} = keyic;
continue;
end
end
end
% If we get here, shrink the simplex around x{1}.
if options.verbosity >= 2
fprintf('Shrinkage\n');
end
for i = 2 : dim+1
x{i} = problem.M.pairmean(x{1}, x{i});
key{i} = storedb.getNewKey();
costs(i) = getCost(problem, x{i}, storedb, key{i});
end
costevals = costevals + dim;
end
info = info(1:iter);
% Iteration done: return only the best point found.
cost = costs(1);
x = x{1};
key = key{1};
% Routine in charge of collecting the current iteration stats.
function stats = savestats()
stats.iter = iter;
stats.cost = costs(1);
stats.costevals = costevals;
if iter == 0
stats.time = toc(timetic);
else
stats.time = info(iter).time + toc(timetic);
end
% The statsfun can only possibly receive one store structure. We
% pass the key to the best point, so that the best point's store
% will be passed. But the whole cell x of points is passed through.
stats = applyStatsfun(problem, x, storedb, key{1}, options, stats);
end
end