Skip to content
Permalink
master
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
function B = doubly_stochastic(A, maxiter, mode, checkperiod)
% Project a matrix to the doubly stochastic matrices (Sinkhorn's algorithm)
%
% function B = doubly_stochastic(A)
% function B = doubly_stochastic(A, maxiter)
% function B = doubly_stochastic(A, [], mode)
% function B = doubly_stochastic(A, maxiter, mode)
% function B = doubly_stochastic(A, maxiter, mode, checkperiod)
%
% Given an element-wise non-negative matrix A of size nxn, returns a
% doubly-stochastic matrix B of size nxn by applying Sinkhorn's algorithm
% to A.
%
% maxiter (optional):
% Strictly positive integer representing the maximum
% number of iterations of Sinkhorn's algorithm.
% The default value of maxiter is n^2.
% mode (optional):
% Setting mode = 1 changes the behavior of the algorithm
% such that the input A is an n x p matrix with AA' having
% element-wise non-negative entries and the output B is also n x p
% such that BB' is a doubly-stochastic matrix. The default value is 0.
% checkperiod (optional):
% Only check stopping criteria every checkperiod iterations,
% to reduce computational burden.
% The file is based on developments in the research paper
% Philip A. Knight, "The Sinkhorn–Knopp Algorithm: Convergence and
% Applications" in SIAM Journal on Matrix Analysis and Applications 30(1),
% 261-275, 2008.
%
% Please cite the Manopt paper as well as the research paper.
% This file is part of Manopt: www.manopt.org.
% Original author: David Young, September 10, 2015.
% Contributors: Ahmed Douik, March 15, 2018.
% Pratik Jawanpuria and Bamdev Mishra, Sep 10, 2019.
% Change log:
% Sep. 10, 2019 (PJ, BM)
% Added the checkperiod parameter.
n = size(A, 1);
tol = eps(n);
if ~exist('maxiter', 'var') || isempty(maxiter)
maxiter = n^2;
end
if ~exist('mode', 'var') || isempty(mode)
mode = 0;
end
if ~exist('checkperiod', 'var') || isempty(checkperiod)
checkperiod = 100;
end
if mode == 1
C = A*A';
else
C = A;
end
iter = 0;
d_1 = 1./sum(C);
d_2 = 1./(C * d_1.');
while iter < maxiter
iter = iter + 1;
row = d_2.' * C;
% Check gap condition only at checkperiod intervals.
% It saves computations for large-scale scenarios.
if mod(iter, checkperiod) == 0
gap = max(abs(row .* d_1 - 1));
if isnan(gap)
break;
end
if gap <= tol
break;
end
end
d_1_prev = d_1;
d_2_prev = d_2;
d_1 = 1./row;
d_2 = 1./(C * d_1.');
if any(isinf(d_2)) || any(isnan(d_2)) || any(isinf(d_1)) || any(isnan(d_1))
warning('DoublyStochasticProjection:NanInfEncountered', ...
'Nan or Inf occured at iter %d. \n', iter);
d_1 = d_1_prev;
d_2 = d_2_prev;
break;
end
end
if mode == 1
v = sqrt(d_2(1)/d_1(1))*d_1;
B = diag(v)*A;
else
B = C .* (d_2 * d_1);
end
end