-
Notifications
You must be signed in to change notification settings - Fork 96
/
low_rank_matrix_completion.m
218 lines (188 loc) · 8.55 KB
/
low_rank_matrix_completion.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
function low_rank_matrix_completion()
% Given partial observation of a low rank matrix, attempts to complete it.
%
% function low_rank_matrix_completion()
%
% This example uses low-rank matrix completion to illustrate three
% different ways of optimizing under a rank constraint with Manopt.
%
% Method 1: via (L, R) -> L*R.' parameterization (rank <= k)
% with burermonteiroLRlift and manoptlift
% Method 2: with desingularizationfactory (rank <= k)
% Method 3: with fixedrankembeddedfactory (rank == k)
%
% Input: None. This example file generates random data.
%
% Output: None.
%
% See also: fixedrankembeddedfactory desingularizationfactory
% manoptlift burermonteiroLRlift euclideanlargefactory
% This file is part of Manopt: www.manopt.org.
% Original author: Nicolas Boumal, July 15, 2014
% Contributors: Bart Vandereycken
% Change log:
%
% Xiaowen Jiang, Aug. 20, 2021
% Added AD to compute the egrad and the ehess
% NB, June 26, 2024
% Rewrote in full to include desingularization, lifts, large matrices
%% Create a problem instance
% We optimize over matrices of size m-by-n with rank <= r (or = r)
m = 1000;
n = 1000;
r = 10; % try setting this to 11 for example
% The structure Rmn provides functions to operate on large matrices.
% We use this for all operations on matrices X and Xdot that exist
% (mathematically at least) in R^(m x n). Under the hood, we avoid ever
% creating these matrices explicitly: they are stored in one of various
% efficient formats (sparse; factored; as functions), and this is
% exploited for all computations.
Rmn = euclideanlargefactory(m, n);
% Select some of the m*n entries uniformly at random.
% The dimension of the set of matrices of size m-by-n with rank r is
% r(m+n-r). We aim to sample a multiple of that. This multiplier is
% the oversampling factor (osf). M is sparse with 0s and 1s.
desired_osf = 5;
[~, ~, M, actual_osf] = random_mask(m, n, r, desired_osf);
% Generate a ground-truth matrix of rank rtrue (potentially different
% from r) and compute the selected entries of that matrix, in atrue.
rtrue = 10;
Astar.U = stiefelfactory(m, rtrue).rand();
Astar.V = stiefelfactory(n, rtrue).rand();
Astar.S = diag(.5 + .5*rand(rtrue, 1)); % singular values of target A;
atrue = Rmn.sparseentries(M, Astar); % efficient sampling of entries
%% Create a problem structure defining the problem in R^(mxn)
% The problem structure 'downstairs' defines the problem over all mxn
% matrices, without rank constraint and with efficient use of sparsity.
downstairs.M = Rmn;
downstairs.cost = @cost;
downstairs.grad = @grad;
downstairs.hess = @hess;
function store = prepare(X, store)
if ~isfield(store, 'residue')
% Compute the possibly nonzero entries of M.*(X - A)
store.residue = Rmn.sparseentries(M, X) - atrue;
store = incrementcounter(store, 'sparseentries');
end
end
function [f, store] = cost(X, store)
store = prepare(X, store);
% f(X) = .5*norm(M.*(X - A), 'fro')^2
f = .5*norm(store.residue)^2;
end
function [g, store] = grad(X, store)
store = prepare(X, store);
% nabla f(X) = M.*(X - A)
g = replacesparseentries(M, store.residue);
end
function [h, store] = hess(X, Xdot, store) %#ok<INUSD>
% nabla^2 f(X)[Xdot] = M.*Xdot
MXdot = Rmn.sparseentries(M, Xdot);
h = replacesparseentries(M, MXdot);
% Increment by 2 because Xdot can have rank up to 2r
store = incrementcounter(store, 'sparseentries', 2);
end
% Whenever we change f, it is a good idea to check the derivatives.
% checkgradient(downstairs);
% checkhessian(downstairs);
%% Define options for the manopt solvers
options = struct();
stats = statscounters({'sparseentries'});
options.statsfun = statsfunhelper(stats);
options.theta = 0.4; % this is a parameter in trs_tCG_cached
options.maxtime = 30; % stop if we take more than xyz seconds
options.tolgradnorm = 0;
options.tolcost = 1e-12; % stop when f(X) is close to 0
%% Build an initial guess
X0_USV.U = stiefelfactory(m, r).rand();
X0_USV.V = stiefelfactory(n, r).rand();
X0_USV.S = diag(rand(r, 1))/1000;
X0_LR = Rmn.to_LR(X0_USV); % the same X0 in LR format
%% Lift the downtairs problem through the LR' parameterization
lift = burermonteiroLRlift(m, n, r);
upstairs = manoptlift(downstairs, lift);
[X_LR, ~, LR_info] = trustregions(upstairs, X0_LR, options); %#ok<ASGLU>
%% Lift the downstairs problem through the desingularization
desing.M = desingularizationfactory(m, n, r);
desing.cost = @cost;
desing.egrad = @grad;
desing.ehess = @ehess_xp;
function [h, store] = ehess_xp(X, Xdot, store)
% Here we compute the Euclidean Hessian.
% The inputs are a point X and a tangent vector Xdot.
% The latter is in tangent vector format, but we need it in ambient
% vector format. We need to do two things here:
% 1. Map Xdot (a tangent vector) to the ambient space,
% using M.tangent2ambient.
% 2. Extract the X component of the result, because ambient
% vectors in the desingularization have both an X and a P
% component, but only the X part is relevant to us.
Xdot_ambient = desing.M.tangent2ambient(X, Xdot).X;
[h, store] = hess(X, Xdot_ambient, store);
end
[X_XP, ~, XP_info] = trustregions(desing, X0_USV, options); %#ok<ASGLU>
%% Restrict the downstairs problem to matrices of rank exactly r
fixedrk.M = fixedrankembeddedfactory(m, n, r);
fixedrk.cost = @cost;
fixedrk.egrad = @grad;
fixedrk.ehess = @ehess_fr;
function [h, store] = ehess_fr(X, Xdot, store)
% Same as in ehess_xp, we must first convert Xdot
Xdot_ambient = fixedrk.M.tangent2ambient(X, Xdot);
[h, store] = hess(X, Xdot_ambient, store);
end
[X_FR, ~, FR_info] = trustregions(fixedrk, X0_USV, options); %#ok<ASGLU>
%% Plot some statistics to compare the various approaches
figure(1);
clf;
subplot(2, 1, 1);
hplt = ...
semilogy([LR_info.time], [LR_info.cost], '.-', ...
[XP_info.time], [XP_info.cost], '.-', ...
[FR_info.time], [FR_info.cost], '.-');
set(hplt, 'LineWidth', 2);
set(hplt, 'MarkerSize', 20);
legend('LR', 'desingularisation', 'fixed rank', 'Location', 'northeast');
xlabel('Time [s]');
ylabel('Cost function value (training loss)');
title(sprintf('Oversampling factor: %.4g, Observed fraction: %.4g', ...
actual_osf, nnz(M)/numel(M)));
grid on;
subplot(2, 1, 2);
hplt = ...
plot([LR_info.time], [LR_info.sparseentries], '.-', ...
[XP_info.time], [XP_info.sparseentries], '.-', ...
[FR_info.time], [FR_info.sparseentries], '.-');
set(hplt, 'LineWidth', 2);
set(hplt, 'MarkerSize', 20);
legend('LR', 'desingularisation', 'fixed rank', 'Location', 'northeast');
xlabel('Time [s]');
ylabel('Equivalent-calls to sampling sparse entries of low rank matrix');
grid on;
end
% This function aims to select osf*r*(m+n-r) entries uniformly at random
% out of a matrix of size m-by-n, where osf is the oversampling factor.
% The output is a sparse matrix M of size m-by-n with
function [I, J, M, osf] = random_mask(m, n, r, osf)
% Expected number of collisions in a random sample of b numbers among a
% numbers with replacement.
expected_redundant = @(a, b) ceil(b + a*(((a-1)/a)^b - 1));
% We'll select a few more to make up for collisions.
desired_sample_size = round(osf*r*(m+n-r));
initial_sample_size = desired_sample_size + ...
4*expected_redundant(m*n, desired_sample_size);
sample = unique(randi(m*n, initial_sample_size, 1));
% With high probability, we have too many in our sample.
% Trim uniformly at random until we have just the right number.
while numel(sample) > desired_sample_size
to_delete = unique(randi(numel(sample), ...
numel(sample) - desired_sample_size, 1));
sample(to_delete) = [];
end
% Update the actual sample size and oversampling factor.
sample_size = length(sample);
osf = sample_size/(r*(m+n-r));
% Convert the sampled linear indices into (i, j) pairs and a mask M.
[I, J] = ind2sub([m, n], sample);
M = sparse(I, J, ones(sample_size, 1));
end