-
Notifications
You must be signed in to change notification settings - Fork 97
/
Copy pathunitaryfactory.m
170 lines (141 loc) · 5.16 KB
/
unitaryfactory.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
function M = unitaryfactory(n, k)
% Returns a manifold structure to optimize over unitary matrices.
%
% function M = unitaryfactory(n)
% function M = unitaryfactory(n, k)
%
% Unitary group: deals with arrays U of size n x n x k (or n x n if k = 1,
% which is the default) such that each n x n matrix is unitary, that is,
% X'*X = eye(n) if k = 1, or
% X(:, :, i)' * X(:, :, i) = eye(n) for i = 1 : k if k > 1.
%
% This is a description of U(n)^k with the induced metric from the
% embedding space (C^nxn)^k, i.e., this manifold is a Riemannian
% submanifold of (C^nxn)^k endowed with the usual real inner product on
% C^nxn, namely, <A, B> = real(trace(A'*B)).
%
% This is important:
% Tangent vectors are represented in the Lie algebra, i.e., as
% skew-Hermitian matrices. Use the function M.tangent2ambient(X, H) to
% switch from the Lie algebra representation to the embedding space
% representation. This is often necessary to define problem.ehess(X, H),
% as the input H will then be a skew-Hermitian matrix (but the output must
% not be, as the output is the Hessian in the embedding Euclidean space.)
%
% By default, the retraction is only a first-order approximation of the
% exponential. To force the use of a second-order approximation, call
% M.retr = M.retr_polar after creating M. This switches from a QR-based
% computation to an SVD-based computation.
%
% By default, k = 1.
%
% See also: stiefelcomplexfactory rotationsgroup stiefelfactory
% This file is part of Manopt: www.manopt.org.
% Original author: Nicolas Boumal, June 18, 2019.
% Contributors:
% Change log:
if ~exist('k', 'var') || isempty(k)
k = 1;
end
if k == 1
M.name = @() sprintf('Unitary manifold U(%d)', n);
elseif k > 1
M.name = @() sprintf('Product unitary manifold U(%d)^%d', n, k);
else
error('k must be an integer no less than 1.');
end
M.dim = @() k*(n^2);
M.inner = @(x, d1, d2) real(d1(:)'*d2(:));
M.norm = @(x, d) norm(d(:));
M.typicaldist = @() pi*sqrt(n*k);
M.proj = @(X, H) multiskewh(multiprod(multihconj(X), H));
M.tangent = @(X, H) multiskewh(H);
M.tangent2ambient_is_identity = false;
M.tangent2ambient = @(X, U) multiprod(X, U);
M.egrad2rgrad = M.proj;
M.ehess2rhess = @ehess2rhess;
function Rhess = ehess2rhess(X, Egrad, Ehess, H)
% Reminder : H contains skew-Hermitian matrices. The actual
% direction that the point X is moved along is X*H.
Xt = multihconj(X);
XtEgrad = multiprod(Xt, Egrad);
symXtEgrad = multiherm(XtEgrad);
XtEhess = multiprod(Xt, Ehess);
Rhess = multiskewh( XtEhess - multiprod(H, symXtEgrad) );
end
% This QR-based retraction is only a first-order approximation
% of the exponential map, not a second-order one.
M.retr_qr = @retraction_qr;
function Y = retraction_qr(X, U, t)
% It is necessary to call qr_unique rather than simply qr to ensure
% this is a retraction, to avoid spurious column sign flips.
XU = multiprod(X, U);
if nargin < 3
Y = qr_unique(X + XU);
else
Y = qr_unique(X + t*XU);
end
end
% A second-order retraction is implemented here. To force its use,
% after creating the factory M, execute M.retr = M.retr_polar.
M.retr_polar = @retraction_polar;
function Y = retraction_polar(X, U, t)
if nargin == 3
tU = t*U;
else
tU = U;
end
Y = X + multiprod(X, tU);
for kk = 1 : k
[Uk, ~, Vk] = svd(Y(:, :, kk));
Y(:, :, kk) = Uk*Vk';
end
end
% By default, use QR retraction
M.retr = M.retr_qr;
M.exp = @exponential;
function Y = exponential(X, U, t)
if nargin == 3
exptU = t*U;
else
exptU = U;
end
for kk = 1 : k
exptU(:, :, kk) = expm(exptU(:, :, kk));
end
Y = multiprod(X, exptU);
end
M.log = @logarithm;
function U = logarithm(X, Y)
U = multiprod(multihconj(X), Y);
for kk = 1 : k
U(:, :, kk) = logm(U(:, :, kk));
end
% Ensure the tangent vector is in the Lie algebra.
U = multiskewh(U);
end
M.hash = @(X) ['z' hashmd5([real(X(:)) ; imag(X(:))])];
M.rand = @() randunitary(n, k);
M.randvec = @randomvec;
function U = randomvec(X) %#ok<INUSD>
U = randskewh(n, k);
nrmU = sqrt(U(:)'*U(:));
U = U / nrmU;
end
M.lincomb = @matrixlincomb;
M.zerovec = @(x) zeros(n, n, k);
M.transp = @(x1, x2, d) d;
M.isotransp = M.transp; % the transport is isometric
M.pairmean = @pairmean;
function Y = pairmean(X1, X2)
V = M.log(X1, X2);
Y = M.exp(X1, .5*V);
end
M.dist = @(x, y) M.norm(x, M.log(x, y));
sz = n*n*k;
M.vec = @(x, u_mat) [real(u_mat(:)) ; imag(u_mat(:))];
M.mat = @(x, u_vec) reshape(u_vec(1:sz), [n, n, k]) ...
+ 1i*reshape(u_vec((sz+1):end), [n, n, k]);
M.vecmatareisometries = @() true;
M.lie_identity = @() repmat(eye(n), [1, 1, k]);
end