-
Notifications
You must be signed in to change notification settings - Fork 97
/
Copy pathdoubly_stochastic.m
102 lines (90 loc) · 3.05 KB
/
doubly_stochastic.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
function B = doubly_stochastic(A, maxiter, mode, checkperiod)
% Project a matrix to the doubly stochastic matrices (Sinkhorn's algorithm)
%
% function B = doubly_stochastic(A)
% function B = doubly_stochastic(A, maxiter)
% function B = doubly_stochastic(A, [], mode)
% function B = doubly_stochastic(A, maxiter, mode)
% function B = doubly_stochastic(A, maxiter, mode, checkperiod)
%
% Given an element-wise non-negative matrix A of size nxn, returns a
% doubly-stochastic matrix B of size nxn by applying Sinkhorn's algorithm
% to A.
%
% maxiter (optional):
% Strictly positive integer representing the maximum
% number of iterations of Sinkhorn's algorithm.
% The default value of maxiter is n^2.
% mode (optional):
% Setting mode = 1 changes the behavior of the algorithm
% such that the input A is an n x p matrix with AA' having
% element-wise non-negative entries and the output B is also n x p
% such that BB' is a doubly-stochastic matrix. The default value is 0.
% checkperiod (optional):
% Only check stopping criteria every checkperiod iterations,
% to reduce computational burden.
% The file is based on developments in the research paper
% Philip A. Knight, "The Sinkhorn–Knopp Algorithm: Convergence and
% Applications" in SIAM Journal on Matrix Analysis and Applications 30(1),
% 261-275, 2008.
%
% Please cite the Manopt paper as well as the research paper.
% This file is part of Manopt: www.manopt.org.
% Original author: David Young, September 10, 2015.
% Contributors: Ahmed Douik, March 15, 2018.
% Pratik Jawanpuria and Bamdev Mishra, Sep 10, 2019.
% Change log:
% Sep. 10, 2019 (PJ, BM)
% Added the checkperiod parameter.
n = size(A, 1);
tol = eps(n);
if ~exist('maxiter', 'var') || isempty(maxiter)
maxiter = n^2;
end
if ~exist('mode', 'var') || isempty(mode)
mode = 0;
end
if ~exist('checkperiod', 'var') || isempty(checkperiod)
checkperiod = 100;
end
if mode == 1
C = A*A';
else
C = A;
end
iter = 0;
d_1 = 1./sum(C);
d_2 = 1./(C * d_1.');
while iter < maxiter
iter = iter + 1;
row = d_2.' * C;
% Check gap condition only at checkperiod intervals.
% It saves computations for large-scale scenarios.
if mod(iter, checkperiod) == 0
gap = max(abs(row .* d_1 - 1));
if isnan(gap)
break;
end
if gap <= tol
break;
end
end
d_1_prev = d_1;
d_2_prev = d_2;
d_1 = 1./row;
d_2 = 1./(C * d_1.');
if any(isinf(d_2)) || any(isnan(d_2)) || any(isinf(d_1)) || any(isnan(d_1))
warning('DoublyStochasticProjection:NanInfEncountered', ...
'Nan or Inf occured at iter %d. \n', iter);
d_1 = d_1_prev;
d_2 = d_2_prev;
break;
end
end
if mode == 1
v = sqrt(d_2(1)/d_1(1))*d_1;
B = diag(v)*A;
else
B = C .* (d_2 * d_1);
end
end