-
Notifications
You must be signed in to change notification settings - Fork 96
/
sympositivedefinitesimplexcomplexfactory.m
328 lines (268 loc) · 11.6 KB
/
sympositivedefinitesimplexcomplexfactory.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
function M = sympositivedefinitesimplexcomplexfactory(n, k)
% Manifold of k product of n-by-n Hermitian positive definite matrices
% with the bi-invariant geometry such that the sum is the identity matrix.
%
% function M = sympositivedefinitesimplexcomplexfactory(n, k)
%
% Given X1, X2, ... Xk Hermitian positive definite matrices, the constraint
% tackled is
% X1 + X2 + ... = I.
%
% The Riemannian structure enforced on the manifold
% M:={(X1, X2,...) : X1 + X2 + ... = I } is a submanifold structure of the
% total space defined as the k Cartesian product of Hermitian positive
% definite Riemannian manifold (of n-by-n matrices) endowed with the bi-invariant metric.
%
% A point X on the manifold is represented as multidimensional array
% of size n-by-n-by-k. Each n-by-n matrix is Hermitian positive definite.
% Tangent vectors are represented as n-by-n-by-k multidimensional arrays, where
% each n-by-n matrix is Hermitian.
%
% The embedding space is the k Cartesian product of complex matrices of size
% n-by-n (Hermitian not required). The Euclidean gradient and Hessian expressions
% needed for egrad2rgrad and ehess2rhess are in the embedding space endowed with the
% usual metric for the complex plane identified with R^2.
%
% E = (C^(nxn))^k is the embedding space: we have the obvious representation of points
% there as 3D arrays of size nxnxk. It is equipped with the standard Euclidean metric.
%
% P = {X in C^(nxn) : X = X' and X positive definite} is a submanifold of C^(nxn).
% We turn it into a Riemannian manifold (but not a Riemannian submanifold) by equipping
% it with the bi-invariant metric.
%
% M = {X in P^k : X_1 + ... + X_k = I} is the manifold we care about here: it is
% a Riemannian submanifold of P^k, hence it is also a submanifold (but not a Riemannian
% submanifold) of E -- our embedding space.
%
%
% Please cite the Manopt paper as well as the research paper:
%
% @techreport{mishra2019riemannian,
% title={Riemannian optimization on the simplex of positive definite matrices},
% author={Mishra, B. and Kasai, H. and Jawanpuria, P.},
% institution={arXiv preprint arXiv:1906.10436},
% year={2019}
% }
%
% See also sympositivedefinitesimplexcomplexfactory multinomialfactory sympositivedefinitefactory
% This file is part of Manopt: www.manopt.org.
% Original author: Bamdev Mishra, September 18, 2019.
% Contributors: NB
% Change log:
% Dec. 16, 2019 (BM): Comments updated
% Nov. 1, 2021 (BM): Removed typos in Hessian expression
% June 26, 2024 (NB): Removed M.exp() as it was not implemented.
symm = @(X) .5*(X+X');
M.name = @() sprintf('%d complex hemitian positive definite matrices of size %dx%d such that their sum is the identiy matrix.', k, n, n);
M.dim = @() (k-1)*n*(n+1);
% Helpers to avoid computing full matrices simply to extract their trace
vec = @(A) A(:);
trinner = @(A, B) real(vec(A')'*vec(B)); % = trace(A*B)
trnorm = @(A) sqrt((trinner(A, A))); % = sqrt(trace(A^2))
% Choice of the metric on the orthonormal space is motivated by the
% symmetry present in the space. The metric on the positive definite
% cone is its natural bi-invariant metric.
% The result is equal to: trace( (X\eta) * (X\zeta) )
M.inner = @innerproduct;
function iproduct = innerproduct(X, eta, zeta)
iproduct = 0;
for kk = 1 : k
iproduct = iproduct + (trinner(X(:,:,kk)\eta(:,:,kk), X(:,:,kk)\zeta(:,:,kk))); % BM okay
end
end
% Notice that X\eta is *not* symmetric in general.
% The result is equal to: sqrt(trace((X\eta)^2))
% There should be no need to take the real part, but rounding errors
% may cause a small imaginary part to appear, so we discard it.
M.norm = @innernorm;
function inorm = innernorm(X, eta)
inorm = 0;
for kk = 1:k
inorm = inorm + (trnorm(X(:,:,kk)\eta(:,:,kk)))^2; % BM okay
end
inorm = sqrt(inorm);
end
% % Same here: X\Y is not symmetric in general.
% % Same remark about taking the real part.
% M.dist = @innerdistance;
% function idistance = innerdistance(X, Y)
% idistance = 0;
% for kk = 1:k
% idistance = idistance + real(trnorm(real(logm(X(:,:,kk)\Y(:,:,kk))))); % BM okay, but need not be correct.
% end
% end
M.typicaldist = @() sqrt(k*n*(n+1)); % BM: to be looked into.
M.egrad2rgrad = @egrad2rgrad;
function rgrad = egrad2rgrad(X, egrad)
egradscaled = nan(size(egrad));
for kk = 1:k
egradscaled(:,:,kk) = X(:,:,kk)*symm(egrad(:,:,kk))*X(:,:,kk);
end
% Project onto the set X1dot + X2dot + ... = 0.
% That is rgrad = Xk*egradk*Xk + Xk*Lambdasol*Xk
rgrad = M.proj(X, egradscaled);
% % Debug
% norm(sum(rgrad,3), 'fro') % BM: this should be zero.
end
M.ehess2rhess = @ehess2rhess;
function Hess = ehess2rhess(X, egrad, ehess, eta)
Hess = nan(size(X));
egradscaled = nan(size(egrad));
egradscaleddot = nan(size(egrad));
for kk = 1:k
egradk = symm(egrad(:,:,kk));
ehessk = symm(ehess(:,:,kk));
Xk = X(:,:,kk);
etak = eta(:,:,kk);
egradscaled(:,:,kk) = Xk*egradk*Xk;
egradscaleddot(:,:,kk) = Xk*ehessk*Xk + 2*symm(etak*egradk*Xk);
end
% Compute Lambdasol
RHS = - sum(egradscaled,3);
[Lambdasol] = mylinearsolve(X, RHS);
% Compute Lambdasoldot
temp = nan(size(egrad));;
for kk = 1:k
Xk = X(:,:,kk);
etak = eta(:,:,kk);
temp(:,:,kk) = 2*symm(etak*Lambdasol*Xk);
end
RHSdot = - sum(egradscaleddot,3) - sum(temp,3);
[Lambdasoldot] = mylinearsolve(X, RHSdot);
for kk = 1:k
egradk = symm(egrad(:,:,kk));
ehessk = symm(ehess(:,:,kk));
Xk = X(:,:,kk);
etak = eta(:,:,kk);
% Directional derivatives of the Riemannian gradient
% Note that Riemannian grdient is Xk*egradk*Xk + Xk*Lambdasol*Xk.
% rhessk = Xk*(ehessk + Lambdasoldot)*Xk + 2*symm(etak*(egradk + Lambdasol)*Xk);
% rhessk = rhessk - symm(etak*(egradk + Lambdasol)*Xk);
rhessk = Xk*(ehessk + Lambdasoldot)*Xk + symm(etak*(egradk + Lambdasol)*Xk);
Hess(:,:,kk) = rhessk;
end
% Project onto the set X1dot + X2dot + ... = 0.
Hess = M.proj(X, Hess);
% Hess = nan(size(X));
% for kk = 1 : k
% % % Directional derivatives of the Riemannian gradient
% % Hess(:,:,kk) = symm(X(:,:,kk)*symm(ehess(:,:,kk))*X(:,:,kk)) + 2*symm(eta(:,:,kk)*symm(egrad(:,:,kk))*X(:,:,kk));
% % % Correction factor for the non-constant metric
% % Hess(:,:,kk) = Hess(:,:,kk) - symm(eta(:,:,kk)*symm(egrad(:,:,kk))*X(:,:,kk));
% Hess(:,:,kk) = symm(X(:,:,kk)*symm(ehess(:,:,kk))*X(:,:,kk)) + symm(eta(:,:,kk)*symm(egrad(:,:,kk))*X(:,:,kk));
% end
% % Project onto the set X1dot + X2dot + ... = 0.
% Hess = M.proj(X, Hess);
end
% Project onto the set X1dot + X2dot + ... = 0.
M.proj = @innerprojection;
function zeta = innerprojection(X, eta)
% etareal = real(eta);
% etaimag = imag(eta);
% sumetareal = sum(etareal,3);
% sumetaimag = sum(etaimag,3);
RHS = -sum(eta,3);
Lambdasol = mylinearsolve(X, RHS);
zeta = zeros(size(eta));
for jj = 1 : k
zeta(:,:,jj) = eta(:,:,jj) + (X(:,:,jj)*Lambdasol*X(:,:,jj));
end
% % Debug
% eta;
% sum(real(zeta),3)
% sum(imag(zeta),3)
% neta = eta - zeta;
% innerproduct(X, zeta, neta) % This should be zero
end
function Lambdasol = mylinearsolve(X, RHS)
% Solve the linear system.
tol_omegax_pcg = 1e-8;
max_iterations_pcg = 100;
sumetareal = real(RHS);
sumetaimag = imag(RHS);
rhs = [sumetareal(:); sumetaimag(:)];
[lambdasol, ~, ~, ~] = pcg(@compute_matrix_system, rhs, tol_omegax_pcg, max_iterations_pcg);
lambdasolreal = lambdasol(1:n^2);
lambdasolimag = lambdasol(n^2 + 1 : end);
Lambdasol = symm(reshape(lambdasolreal, [n n])) + 1i*reshape(lambdasolimag,n,n);
function lhslambda = compute_matrix_system(lambda)
lambdareal = lambda(1:n^2);
lambdaimag = lambda(n^2 + 1 : end);
Lambda = symm(reshape(lambdareal, [n n])) + 1i*reshape(lambdaimag, n, n);
lhsLambda = zeros(n,n);
for kk = 1 : k
lhsLambda = lhsLambda + ((X(:,:,kk)*Lambda*X(:,:,kk)));
end
lhsLambdareal = real(lhsLambda);
lhsLambdaimag = imag(lhsLambda);
lhslambda = [lhsLambdareal(:); lhsLambdaimag(:)];
end
end
M.tangent = M.proj;
M.tangent2ambient = @(X, eta) eta;
myeps = eps;
M.retr = @retraction;
function Y = retraction(X, eta, t) % BM okay
if nargin < 3
teta = eta;
else
teta = t*eta;
end
% The symm() call is mathematically unnecessary but numerically
% necessary.
Y = zeros(size(X));
for kk=1:k
% Second-order approximation of expm
Y(:,:,kk) = symm(X(:,:,kk) + teta(:,:,kk) + .5*teta(:,:,kk)*((X(:,:,kk) + myeps*eye(n) )\teta(:,:,kk)));
end
Ysum = sum(Y, 3);
Ysumsqrt = sqrtm(Ysum);
for kk=1:kk
Y(:,:,kk) = symm((Ysumsqrt\Y(:,:,kk))/Ysumsqrt);
end
% % Debug
% norm(sum(Y, 3) - eye(n), 'fro') % This should be zero
end
M.hash = @(X) ['z' hashmd5([real(X(:)); imag(X(:))])];% BM okay
% Generate a random symmetric positive definite matrix following a
% certain distribution. The particular choice of a distribution is of
% course arbitrary, and specific applications might require different
% ones.
M.rand = @random;
function X = random()
X = nan(n,n,k);
for kk = 1:k
D = diag(1+rand(n, 1));
[Q, R] = qr(randn(n) +1i*randn(n)); % BM okay
X(:,:,kk) = Q*D*Q';
end
Xsum = sum(X, 3);
Xsumsqrt = sqrtm(Xsum);
for kk = 1 : k
X(:,:,kk) = symm((Xsumsqrt\X(:,:,kk))/Xsumsqrt); % To do
end
end
% Generate a uniformly random unit-norm tangent vector at X.
M.randvec = @randomvec;
function eta = randomvec(X)
eta = nan(size(X));
for kk = 1:k
eta(:,:,kk) = symm(randn(n,n) + 1i*randn(n, n)); % BM okay
end
eta = M.proj(X, eta); % To do
nrm = M.norm(X, eta);
eta = eta / nrm;
end
M.lincomb = @matrixlincomb; % BM okay
M.zerovec = @(X) zeros(n,n,k); % BM okay
% Poor man's transporter: exploit the fact that all tangent spaces
% are the set of symmetric matrices, so that the identity is a sort of
% transporter. It may perform poorly if the origin and target (X1
% and X2) are far apart though. This should not be the case for typical
% optimization algorithms, which perform small steps.
M.transp = @(X1, X2, eta) M.proj(X2, eta);% To do
% vec and mat are not isometries, because of the unusual inner metric.
M.vec = @(X, U) [real(U(:)); image(U(:))] ; % BM okay
M.mat = @(X, u) reshape(u(1:(n*n*k)) + 1i*u((n*n*k+1):end), n, n, k); % BM okay
M.vecmatareisometries = @() false;
end