<a href="https://colab.research.google.com/github/ashegde/notebooks/blob/main/sampling_the_stiefel_manifold_with_newton_schulz.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""
In this notebook, we sample uniform random m-by-k matrices X from the Stiefel manifold V_{k,m} (X'X = I_{k-by-k}).
Note, this implies that (a) the k m-dimensional columns of X are orthonormal and (b) m >= k -- i.e.,
nrows >= ncols.

The function below implements the following method based on Theorem 2.2.1 of

Chikuse, Y. (2003). Statistics on special manifolds (Vol. 174). Springer Science & Business Media.

Part (iii) of this theorem states that if X is uniformly distributed on V_{k,m}, it can be expressed as X = Z(Z'Z)^{-0.5} where Z
is an m-by-k matrix with entries independent and identically distributed as N(0,1).

If we take the SVD of Z: Z = USV', we can rewrite X as X = UV'. Thus, we can sample X
by generating Z and then use Newton-Schulz iteration to drive the singular values to unity.
Specifically, we use the Newton-Schultz iteration suggested in,

Bernstein, J., & Newhouse, L. (2024). Modular Duality in Deep Learning. arXiv preprint arXiv:2410.21265.
"""

In [None]:
from typing import Callable

import numpy as np
import matplotlib.pyplot as plt

In [None]:

def newton_schulz(A: np.array, n_steps: int = 15, eps: float = 1e-9) -> np.array:
  """
  Rectangular Newton-Schulz Iteration

  iterate: X = 1.5 X - 0.5 XX'X
  """
  # normalize to ensure singular values of A are
  # contained in between 0 and \sqrt{3}.
  X = A / (eps + np.linalg.norm(A, ord = 'fro'))

  for _ in range(n_steps):
      X = 1.5 * X - 0.5 * X @ X.T @ X
  return X

def rstiefel(gain: float = 1.0, n_steps: int = 15) -> Callable[[np.array], np.array]:

    def initializer(a: np.array) -> np.array:
        if a.ndim < 2:
            raise ValueError("Only arrays with 2 or more dimensions are supported")

        if a.size == 0:
            # do nothing
            return a

        # flattened dims
        nrows = a.shape[0]
        ncols = a.size // nrows
        flattened = np.random.normal(size=(nrows, ncols), scale=1.0, loc=0.0)
        if nrows < ncols: # ensures m >= k
            flattened = flattened.T

        # it seems that linalg.qr only works on cpu
        x = newton_schulz(flattened, n_steps)

        if nrows < ncols:
            x = x.T

        return gain * x

    return initializer

In [None]:
init_fn = rstiefel()

tmp = np.zeros((100, 10))
w = init_fn(tmp)

# w is semi-orthogonal
fig, axs = plt.subplots(1, 2, layout='constrained', figsize=(8, 4))
axs[0].imshow(w.T@w)
axs[0].set_title(f'w.T @ w')
axs[1].imshow(w@w.T)
axs[1].set_title(f'w @ w.T')

plt.show()