{{ message }}
/ manopt Public
Switch branches/tags
Could not load branches
Nothing to show

Cannot retrieve contributors at this time
168 lines (140 sloc) 5.1 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 function M = unitaryfactory(n, k) % Returns a manifold structure to optimize over unitary matrices. % % function M = unitaryfactory(n) % function M = unitaryfactory(n, k) % % Unitary group: deals with arrays U of size n x n x k (or n x n if k = 1, % which is the default) such that each n x n matrix is unitary, that is, % X'*X = eye(n) if k = 1, or % X(:, :, i)' * X(:, :, i) = eye(n) for i = 1 : k if k > 1. % % This is a description of U(n)^k with the induced metric from the % embedding space (C^nxn)^k, i.e., this manifold is a Riemannian % submanifold of (C^nxn)^k endowed with the usual real inner product on % C^nxn, namely, = real(trace(A'*B)). % % This is important: % Tangent vectors are represented in the Lie algebra, i.e., as % skew-Hermitian matrices. Use the function M.tangent2ambient(X, H) to % switch from the Lie algebra representation to the embedding space % representation. This is often necessary to define problem.ehess(X, H), % as the input H will then be a skew-Hermitian matrix (but the output must % not be, as the output is the Hessian in the embedding Euclidean space.) % % By default, the retraction is only a first-order approximation of the % exponential. To force the use of a second-order approximation, call % M.retr = M.retr_polar after creating M. This switches from a QR-based % computation to an SVD-based computation. % % By default, k = 1. % % See also: stiefelcomplexfactory rotationsgroup stiefelfactory % This file is part of Manopt: www.manopt.org. % Original author: Nicolas Boumal, June 18, 2019. % Contributors: % Change log: if ~exist('k', 'var') || isempty(k) k = 1; end if k == 1 M.name = @() sprintf('Unitary manifold U(%d)', n); elseif k > 1 M.name = @() sprintf('Product unitary manifold U(%d)^%d', n, k); else error('k must be an integer no less than 1.'); end M.dim = @() k*(n^2); M.inner = @(x, d1, d2) real(d1(:)'*d2(:)); M.norm = @(x, d) norm(d(:)); M.typicaldist = @() pi*sqrt(n*k); M.proj = @(X, H) multiskewh(multiprod(multihconj(X), H)); M.tangent = @(X, H) multiskewh(H); M.tangent2ambient_is_identity = false; M.tangent2ambient = @(X, U) multiprod(X, U); M.egrad2rgrad = M.proj; M.ehess2rhess = @ehess2rhess; function Rhess = ehess2rhess(X, Egrad, Ehess, H) % Reminder : H contains skew-Hermitian matrices. The actual % direction that the point X is moved along is X*H. Xt = multihconj(X); XtEgrad = multiprod(Xt, Egrad); symXtEgrad = multiherm(XtEgrad); XtEhess = multiprod(Xt, Ehess); Rhess = multiskewh( XtEhess - multiprod(H, symXtEgrad) ); end % This QR-based retraction is only a first-order approximation % of the exponential map, not a second-order one. M.retr_qr = @retraction_qr; function Y = retraction_qr(X, U, t) % It is necessary to call qr_unique rather than simply qr to ensure % this is a retraction, to avoid spurious column sign flips. XU = multiprod(X, U); if nargin < 3 Y = qr_unique(X + XU); else Y = qr_unique(X + t*XU); end end % A second-order retraction is implemented here. To force its use, % after creating the factory M, execute M.retr = M.retr_polar. M.retr_polar = @retraction_polar; function Y = retraction_polar(X, U, t) if nargin == 3 tU = t*U; else tU = U; end Y = X + multiprod(X, tU); for kk = 1 : k [Uk, ~, Vk] = svd(Y(:, :, kk)); Y(:, :, kk) = Uk*Vk'; end end % By default, use QR retraction M.retr = M.retr_qr; M.exp = @exponential; function Y = exponential(X, U, t) if nargin == 3 exptU = t*U; else exptU = U; end for kk = 1 : k exptU(:, :, kk) = expm(exptU(:, :, kk)); end Y = multiprod(X, exptU); end M.log = @logarithm; function U = logarithm(X, Y) U = multiprod(multihconj(X), Y); for kk = 1 : k U(:, :, kk) = logm(U(:, :, kk)); end % Ensure the tangent vector is in the Lie algebra. U = multiskewh(U); end M.hash = @(X) ['z' hashmd5([real(X(:)) ; imag(X(:))])]; M.rand = @() randunitary(n, k); M.randvec = @randomvec; function U = randomvec(X) %#ok U = randskewh(n, k); nrmU = sqrt(U(:)'*U(:)); U = U / nrmU; end M.lincomb = @matrixlincomb; M.zerovec = @(x) zeros(n, n, k); M.transp = @(x1, x2, d) d; M.isotransp = M.transp; % the transport is isometric M.pairmean = @pairmean; function Y = pairmean(X1, X2) V = M.log(X1, X2); Y = M.exp(X1, .5*V); end M.dist = @(x, y) M.norm(x, M.log(x, y)); sz = n*n*k; M.vec = @(x, u_mat) [real(u_mat(:)) ; imag(u_mat(:))]; M.mat = @(x, u_vec) reshape(u_vec(1:sz), [n, n, k]) ... + 1i*reshape(u_vec((sz+1):end), [n, n, k]); M.vecmatareisometries = @() true; end