In [24]:
import torch
import geoopt
import numpy as np
import matplotlib.pyplot as plt

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Problem setup
dim_n = 10  # Rows
dim_p = 10   # Columns (must be <= dim_n)
A = torch.randn(dim_n, dim_p)  # Random matrix A
B = torch.randn(dim_n, dim_n)  # Target matrix B

In [36]:
# Define the Stiefel manifold
stiefel_manifold = geoopt.manifolds.Stiefel()

# Initialize a point on the Stiefel manifold
X_init = torch.normal(0, 1, size=(dim_n, dim_p), dtype=torch.float64)
X_init, _ = torch.qr(X_init)  # Orthogonalize to be on the manifold
X_sgd = geoopt.ManifoldParameter(X_init.clone(), manifold=stiefel_manifold)
X_adam = geoopt.ManifoldParameter(X_init.clone(), manifold=stiefel_manifold)

# Define the loss function
def quadratic_loss(X):
    return torch.norm(A @ X - B, p='fro') ** 2

# Optimizers
lr = 0.01  # Learning rate
sgd_optimizer = geoopt.optim.RiemannianSGD([X_sgd], lr=lr)
adam_optimizer = geoopt.optim.RiemannianAdam([X_adam], lr=lr)

# Training loop settings
max_iters = 5000
convergence_threshold = 1e-6
patience = 10

In [28]:
x = X_init
g = torch.randn(dim_n, dim_p)

rg = g - x @ g.T.conj() @ x

In [34]:
from scipy.stats import unitary_group, ortho_group

u = ortho_group.rvs(dim_n)
u @ u.T.conj()

array([[ 1.00000000e+00,  4.66942264e-17, -1.84249235e-16,
         5.81449071e-17,  1.57494012e-16,  5.54329758e-17,
         9.99333942e-17, -1.64029362e-16,  1.31237965e-16,
        -7.84588524e-17],
       [ 4.66942264e-17,  1.00000000e+00, -2.81497212e-16,
        -3.10855765e-16,  1.83392528e-16, -1.21977183e-16,
        -2.76596993e-17,  1.12274052e-16,  4.67952735e-18,
         1.63741571e-17],
       [-1.84249235e-16, -2.81497212e-16,  1.00000000e+00,
        -1.37028934e-16,  8.17325399e-18,  6.00121326e-17,
         7.94909630e-18, -1.32214116e-17,  6.64422336e-17,
        -5.62358495e-17],
       [ 5.81449071e-17, -3.10855765e-16, -1.37028934e-16,
         1.00000000e+00, -2.31993071e-17,  5.82334148e-18,
         3.29532778e-19,  3.01553302e-17,  5.98394286e-18,
        -2.67887525e-17],
       [ 1.57494012e-16,  1.83392528e-16,  8.17325399e-18,
        -2.31993071e-17,  1.00000000e+00, -5.76856649e-17,
        -1.03132824e-16,  1.32072350e-16, -1.40487223e-16,
        -1.

In [33]:
torch.linalg.matrix_exp(rg @ x.T.conj()) @ x - x @ torch.linalg.matrix_exp(x.T.conj() @ rg)

tensor([[ 8.9407e-08,  1.7881e-07,  6.8545e-07,  1.7881e-07, -2.3842e-07,
          2.0862e-07,  3.2037e-07,  1.0431e-06, -2.0862e-07, -5.9605e-08],
        [ 5.9605e-08,  1.1921e-07, -7.0781e-08, -1.5646e-07, -7.4506e-08,
          9.3132e-08, -2.0862e-07, -1.7881e-07,  2.3842e-07, -1.2293e-07],
        [-7.4506e-08,  1.3411e-07,  0.0000e+00,  0.0000e+00,  1.7881e-07,
         -6.4075e-07,  1.7136e-07,  2.9802e-08,  2.8312e-07, -1.1921e-07],
        [-2.3842e-07,  1.7881e-07, -2.3842e-07, -5.0664e-07, -2.5332e-07,
          2.6822e-07, -1.7881e-07, -6.3330e-08,  3.8743e-07, -1.4901e-07],
        [-3.7998e-07, -2.7381e-07, -2.9802e-08, -5.5879e-08, -2.3097e-07,
         -4.7684e-07,  2.9802e-08,  2.9802e-08,  1.2293e-07,  2.9802e-08],
        [ 2.3842e-07,  2.3842e-07, -4.4703e-08, -2.9802e-08, -8.9407e-08,
         -1.4901e-08,  1.4156e-07,  4.6194e-07,  4.7684e-07,  5.0664e-07],
        [ 2.1979e-07, -1.1921e-07,  4.4703e-08,  1.1921e-07, -2.3842e-07,
          2.3842e-07, -1.7881e-0

In [26]:
U, S, V = torch.linalg.svd(X_init, full_matrices=False)
A = U @ V


In [27]:
# Create a random matrix for the complementary space of shape (n, n-p)
A = torch.randn(dim_n, dim_n - dim_p, dtype=X_init.dtype, device=X_init.device)
# Project out the components in the span of W so that A is in the complement of W:
A = A - X_init @ (X_init.T @ A)
# Perform QR on A to get an orthonormal basis for the complement:
Q, _ = torch.linalg.qr(A)
# Concatenate W and Q to form W_tilde:
W_tilde = torch.cat([X_init, Q], dim=1)

In [None]:
x = x_ini

In [89]:
rg = g - x @ g.T.conj() @ x

rgp = x.T.conj() @ rg

rgp.T.conj() @ rgp - rg.T.conj() @ rg

tensor([[ -9.0651,  -7.0870,  -6.0973,   1.0166,  -3.1239],
        [ -7.0870,  -8.2470,  -8.4099,  -0.8544,  -4.6700],
        [ -6.0973,  -8.4099, -14.3863,  -6.2188,  -7.9235],
        [  1.0166,  -0.8544,  -6.2188,  -9.1306,  -4.9703],
        [ -3.1239,  -4.6700,  -7.9235,  -4.9703,  -5.2943]],
       grad_fn=<SubBackward0>)

In [3]:
# Training loop for both optimizers
def train(optimizer, X, name):
    loss_history = []
    best_loss = float('inf')
    no_improve_count = 0
    
    for i in range(max_iters):
        optimizer.zero_grad()
        loss = quadratic_loss(X)
        loss.backward()
        optimizer.step()
        loss_history.append(loss.item())
        
        # Check convergence criteria
        if loss.item() < best_loss - convergence_threshold:
            best_loss = loss.item()
            no_improve_count = 0
        else:
            no_improve_count += 1
        
        if no_improve_count >= patience:
            print(f"{name} converged at iteration {i+1} with loss {loss.item():.6f}")
            break
        
    return loss_history

# Run the optimization
loss_sgd = train(sgd_optimizer, X_sgd, "Riemannian SGD")
loss_adam = train(adam_optimizer, X_adam, "Riemannian Adam")

RuntimeError: mat1 and mat2 shapes cannot be multiplied (10x5 and 10x5)