In [77]:
import torch
import geoopt
import numpy as np
import matplotlib.pyplot as plt

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Problem setup
dim_n = 10  # Rows
dim_p = 5   # Columns (must be <= dim_n)
A = torch.randn(dim_n, dim_p)  # Random matrix A
B = torch.randn(dim_n, dim_n)  # Target matrix B

In [78]:
X_init @ A

RuntimeError: mat1 and mat2 shapes cannot be multiplied (10x5 and 10x5)

In [79]:
# Define the Stiefel manifold
stiefel_manifold = geoopt.manifolds.Stiefel()

# Initialize a point on the Stiefel manifold
X_init = torch.randn(dim_n, dim_p)
X_init, _ = torch.qr(X_init)  # Orthogonalize to be on the manifold
X_sgd = geoopt.ManifoldParameter(X_init.clone(), manifold=stiefel_manifold)
X_adam = geoopt.ManifoldParameter(X_init.clone(), manifold=stiefel_manifold)

# Define the loss function
def quadratic_loss(X):
    return torch.norm(A @ X - B, p='fro') ** 2

# Optimizers
lr = 0.01  # Learning rate
sgd_optimizer = geoopt.optim.RiemannianSGD([X_sgd], lr=lr)
adam_optimizer = geoopt.optim.RiemannianAdam([X_adam], lr=lr)

# Training loop settings
max_iters = 5000
convergence_threshold = 1e-6
patience = 10

In [85]:
x = X_sgd
g = torch.randn(X_sgd.shape)

In [89]:
rg = g - x @ g.T.conj() @ x

rgp = x.T.conj() @ rg

rgp.T.conj() @ rgp - rg.T.conj() @ rg

tensor([[ -9.0651,  -7.0870,  -6.0973,   1.0166,  -3.1239],
        [ -7.0870,  -8.2470,  -8.4099,  -0.8544,  -4.6700],
        [ -6.0973,  -8.4099, -14.3863,  -6.2188,  -7.9235],
        [  1.0166,  -0.8544,  -6.2188,  -9.1306,  -4.9703],
        [ -3.1239,  -4.6700,  -7.9235,  -4.9703,  -5.2943]],
       grad_fn=<SubBackward0>)

In [3]:
# Training loop for both optimizers
def train(optimizer, X, name):
    loss_history = []
    best_loss = float('inf')
    no_improve_count = 0
    
    for i in range(max_iters):
        optimizer.zero_grad()
        loss = quadratic_loss(X)
        loss.backward()
        optimizer.step()
        loss_history.append(loss.item())
        
        # Check convergence criteria
        if loss.item() < best_loss - convergence_threshold:
            best_loss = loss.item()
            no_improve_count = 0
        else:
            no_improve_count += 1
        
        if no_improve_count >= patience:
            print(f"{name} converged at iteration {i+1} with loss {loss.item():.6f}")
            break
        
    return loss_history

# Run the optimization
loss_sgd = train(sgd_optimizer, X_sgd, "Riemannian SGD")
loss_adam = train(adam_optimizer, X_adam, "Riemannian Adam")

RuntimeError: mat1 and mat2 shapes cannot be multiplied (10x5 and 10x5)