-
Notifications
You must be signed in to change notification settings - Fork 108
Expand file tree
/
Copy pathhyperbolicfactory.m
More file actions
288 lines (252 loc) · 11.2 KB
/
hyperbolicfactory.m
File metadata and controls
288 lines (252 loc) · 11.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
function M = hyperbolicfactory(n, m, transposed)
% Factory for matrices whose columns live on the hyperbolic manifold
%
% function M = hyperbolicfactory(n)
% function M = hyperbolicfactory(n, m)
% function M = hyperbolicfactory(n, m, transposed)
%
% Returns a structure M which describes the hyperbolic manifold in Manopt.
% A point on the manifold is a matrix X of size (n+1)-by-m whose columns
% live on the hyperbolic manifold, that is, for each column x of X, we have
%
% -x(1)^2 + x(2)^2 + x(3)^2 + ... + x(n+1)^2 = -1.
%
% The positive branch is selected by M.rand(), that is, x(1) > 0, but all
% tools work on the negative branch as well.
%
% Equivalently, defining the Minkowski (semi) inner product
%
% <x, y> = -x(1)y(1) + x(2)y(2) + x(3)y(3) + ... + x(n+1)y(n+1)
%
% and the induced Minkowski (semi) norm ||x||^2 = <x, x>, we can write
% compactly that each column of X has squared Minkowski norm equal to -1.
%
% The set of matrices X that satisfy this constraint is a smooth manifold.
% Tangent vectors at X are matrices U of the same size as X. If x and u are
% the kth columns of X and U respectively, then <x, u> = 0.
%
% This manifold is turned into a Riemannian manifold by restricting the
% Minkowski inner product to each tangent space (a simple calculation
% confirms that this metric is indeed Riemannian and not just semi
% Riemannian, that is, it is positive definite when restricted to each
% tangent space). This is the hyperbolic manifold: for m = 1, all of its
% sectional curvatures are equal to -1. This is called the hyperboloid or
% the Lorentz geometry.
%
% This manifold is an embedded submanifold of Euclidean space (the set of
% matrices of size (n+1)-by-m equipped with the usual trace inner product).
% Thus, when defining the Euclidean gradient for example (problem.egrad),
% it should be specified as if the function were defined in Euclidean space
% directly. The tool M.egrad2rgrad will automatically convert that gradient
% to the correct Riemannian gradient, as needed to satisfy the metric. The
% same is true for the Euclidean Hessian and other tools that manipulate
% elements in the embedding space.
%
% Importantly, the resulting manifold is /not/ a Riemannian submanifold of
% Euclidean space, because its metric is not obtained simply by restricting
% the Euclidean metric to the tangent spaces. However, it is a
% semi-Riemannian submanifold of Minkowski space, that is, the set of
% matrices of size (n+1)-by-m equipped with the Minkowski inner product.
% Minkowski space itself can be seen as a (linear) semi-Riemannian manifold
% embedded in Euclidean space. This view is entirely equivalent to the one
% described above (the Riemannian structure of the resulting manifold is
% exactly the same), and it is useful to derive some of the tools this
% factory provides.
%
% If transposed is set to true (it is false by default), then the matrices
% are transposed: a point X on the manifold is a matrix of size m-by-(n+1)
% and each row is an element in hyperbolic space. It is the same geometry,
% just a different representation.
%
%
% Resources:
%
% 1. Nickel and Kiela, "Learning Continuous Hierarchies in the Lorentz
% Model of Hyperbolic Geometry", ICML, 2018.
%
% 2. Wilson and Leimeister, "Gradient descent in hyperbolic space",
% arXiv preprint arXiv:1805.08207 (2018).
%
% 3. Pennec, "Hessian of the Riemannian squared distance", HAL INRIA, 2017.
%
% Ported primarily from the McTorch toolbox at
% https://github.com/mctorch/mctorch.
%
% See also: poincareballfactory spherefactory obliquefactory obliquecomplexfactory
% This file is part of Manopt: www.manopt.org.
% Original authors: Bamdev Mishra <bamdevm@gmail.com>, Mayank Meghwanshi,
% Pratik Jawanpuria, Anoop Kunchukuttan, and Hiroyuki Kasai Oct 28, 2018.
% Contributors: Nicolas Boumal
% Change log:
% May 14, 2020 (NB):
% Clarified comments about distance computation.
% July 13, 2020 (NB):
% Added pairmean function.
% Sep. 24, 2023 (NB):
% Edited out bsxfun() for improved speed.
% July 2, 2024 (NB):
% Made M.paralleltransp = M.isotransp and M.retr2 = M.exp available.
% Added offtangent to assess whether a given v is tangent at x.
% Design note: all functions that are defined here but not exposed
% outside work for non-transposed representations. Only the wrappers
% that eventually expose functionalities handle transposition. This
% makes it easier to compose functions internally.
%
% July 2024: This transposition should be edited out as it was in
% obliquefactory.
if ~exist('m', 'var') || isempty(m)
m = 1;
end
if ~exist('transposed', 'var') || isempty(transposed)
transposed = false;
end
if transposed
trnsp = @(X) X';
trnspstr = ', transposed';
else
trnsp = @(X) X;
trnspstr = '';
end
M.name = @() sprintf('Hyperbolic manifold H(%d, %d)%s', n, m, trnspstr);
M.dim = @() n*m;
M.typicaldist = @() sqrt(n*m);
% Returns a row vector q such that q(k) is the Minkowski inner product
% of columns U(:, k) and V(:, k). This is defined in all of Minkowski
% space, not only on tangent spaces. In particular, if X is a point on
% the manifold, then inner_minkowski_columns(X, X) should return a
% vector of all -1's.
function q = inner_minkowski_columns(U, V)
q = -U(1, :).*V(1, :) + sum(U(2:end, :).*V(2:end, :), 1);
end
% Riemannian metric: we sum over the m copies of the hyperbolic
% manifold, each equipped with a restriction of the Minkowski metric.
M.inner = @(X, U, V) sum(inner_minkowski_columns(trnsp(U), trnsp(V)));
% Mathematically, the Riemannian metric is positive definite, hence
% M.inner always returns a nonnegative number when U is tangent at X.
% Numerically, because the inner product involves a difference of
% positive numbers, round-off may result in a small negative number.
% Taking the max against 0 avoids imaginary results.
M.norm = @(X, U) sqrt(max(M.inner(X, U, U), 0));
M.dist = @(X, Y) norm(dists(trnsp(X), trnsp(Y)));
% This function returns a row vector of length m such that d(k) is the
% geodesic distance between X(:, k) and Y(:, k).
function d = dists(X, Y)
% Mathematically, each column of U = X-Y has nonnegative squared
% Minkowski norm. To avoid potentially imaginary results due to
% round-off errors, we take the max against 0.
U = X-Y;
mink_sqnorms = max(0, inner_minkowski_columns(U, U));
mink_norms = sqrt(mink_sqnorms);
d = 2*asinh(.5*mink_norms);
% The formula above is equivalent to
% d = max(0, real(acosh(-inner_minkowski_columns(X, Y))));
% but is numerically more accurate when distances are small.
% When distances are large, it is better to use the acosh formula.
end
M.proj = @(X, U) trnsp(projection(trnsp(X), trnsp(U)));
function PU = projection(X, U)
inners = inner_minkowski_columns(X, U);
PU = U + X .* inners;
end
M.tangent = M.proj;
% Look inside the code of the tool offtangent: you can check that the
% fallback code there does not work for this factory.
% Indeed, run
% M = hyperbolicfactory(5);
% x = M.rand();
% v = randn(size(x));
% w = M.lincomb(x, 1, M.tangent(x, v), -1, v);
% M.norm(x, w)
% The result is zero, even though v was random: there is no chance that
% it would happen to be tangent. The reason is that M.inner(x, w, w) is
% a negative number, and hence M.norm (due to max(0, ..)) returns 0.
% Therefore, we implement a dedicated function here to quantify how
% farr off a given vector v is from being tangent.
M.offtangent = @(X, V) offtangent(trnsp(X), trnsp(V));
function val = offtangent(X, V)
if isequal(size(X), size(V)) && isequal(size(V), [n+1, m])
val = norm(V - projection(X, V), 'fro');
else
val = Inf;
end
end
% For Riemannian submanifolds, converting the Euclidean gradient into
% the Riemannian gradient amounts to an orthogonal projection. Here
% however, the manifold is not a Riemannian submanifold of Euclidean
% space, hence extra corrections are required to account for the change
% of metric.
M.egrad2rgrad = @(X, egrad) trnsp(egrad2rgrad(trnsp(X), trnsp(egrad)));
function rgrad = egrad2rgrad(X, egrad)
egrad(1, :) = -egrad(1, :);
rgrad = projection(X, egrad);
end
M.ehess2rhess = @(X, egrad, ehess, U) ...
trnsp(ehess2rhess(trnsp(X), trnsp(egrad), trnsp(ehess), trnsp(U)));
function rhess = ehess2rhess(X, egrad, ehess, U)
egrad(1, :) = -egrad(1, :);
ehess(1, :) = -ehess(1, :);
inners = inner_minkowski_columns(X, egrad);
rhess = projection(X, ehess + U .* inners);
end
% For the exponential, we cannot separate trnsp() nicely from the main
% function because the third input, t, is optional.
M.exp = @exponential;
function Y = exponential(X, U, t)
X = trnsp(X);
U = trnsp(U);
if nargin < 3
tU = U; % corresponds to t = 1
else
tU = t*U;
end
% Compute the individual Minkowski norms of the columns of U.
mink_inners = inner_minkowski_columns(tU, tU);
mink_norms = sqrt(max(0, mink_inners));
% Coefficients for the exponential. For b, note that NaN's appear
% when an element of mink_norms is zero, in which case the correct
% convention is to define sinh(0)/0 = 1.
a = cosh(mink_norms);
b = sinh(mink_norms)./mink_norms;
b(isnan(b)) = 1;
Y = X .* a + tU .* b;
Y = trnsp(Y);
end
M.retr = M.exp;
M.retr2 = M.exp;
M.log = @(X, Y) trnsp(logarithm(trnsp(X), trnsp(Y)));
function U = logarithm(X, Y)
d = dists(X, Y);
a = d./sinh(d);
a(isnan(a)) = 1;
U = projection(X, Y .* a);
end
M.hash = @(X) ['z' hashmd5(X(:))];
M.rand = @() trnsp(myrand());
function X = myrand()
X1 = randn(n, m);
x0 = sqrt(1 + sum(X1.^2, 1)); % selects positive branch
X = [x0; X1];
end
M.normalize = @(X, U) U / M.norm(X, U);
M.randvec = @(X) M.normalize(X, M.proj(X, randn(size(X))));
M.lincomb = @matrixlincomb;
M.zerovec = @(X) zeros(size(X));
M.transp = @(X1, X2, U) M.proj(X2, U);
M.paralleltransp = @(X1, X2, U) ...
trnsp(parallel_transport(trnsp(X1), trnsp(X2), trnsp(U)));
M.isotransp = M.paralleltransp;
function V = parallel_transport(X1, X2, U)
V = inner_minkowski_columns(X2, U);
V = V ./ (1 - inner_minkowski_columns(X1, X2)) .* (X1 + X2);
V = U + V;
end
M.pairmean = @(x1, x2) M.exp(x1, M.log(x1, x2), .5);
% vec returns a vector representation of an input tangent vector which
% is represented as a matrix; mat returns the original matrix
% representation of the input vector representation of a tangent
% vector; vec and mat are thus inverse of each other.
vect = @(X) X(:);
M.vec = @(X, U_mat) vect(trnsp(U_mat));
M.mat = @(X, U_vec) trnsp(reshape(U_vec, [n+1, m]));
M.vecmatareisometries = @() false;
end