<a href="https://colab.research.google.com/github/Fantiflex/Modular_Manifold_Muon/blob/main/global_LBFGS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Setup**



In [None]:
import math
import torch


**L-BFGS optimizer class setup**

In [None]:
class ManifoldLBFGS:
  """
    Riemannian L-BFGS sur Stiefel avec:
      - métrique euclidienne (produit de Frobenius)
      - rétraction polaire (msign/polar)
      - transport par projection
    Utilisation typique:
        opt = ManifoldLBFGS(eta=0.1, history=10)
        # itération k
        W = opt.step(W, Gk)           # prend un pas, renvoie W_{k+1}
        ...
        Gkp1 = ...                    # gradient au nouveau point (fourni par toi)
        opt.update(Gkp1)              # met à jour l'historique (s_k, y_k)
    """
  def __init__(self, eta=0.1, history=10,eps_curv=1e-12, c0=1e-4, c1=1.0, c2=None):
    '''
    W : parameters
    G : gradient
    opt : instance optimiseur
    '''
    self.eta = eta
    self.history = history
    self.c0, self.c1 = c0, c1
    self.m=history
    self.c2 = (1.0/(2*history+3)) if c2 is None else c2
    self.pairs = []                 # liste (s_i, y_i) dans le tangent courant
    self._pending = None            # stocke (W_old, G_old_riem, eta_vec) entre step() et update()







  # two-loop uses stored (s_i, y_i) in current tangent; H0 = gamma*I
  def two_loop(self, G, pairs, gamma, eps=1e-16):
    q = G.clone()
    alphas = []
    for s, y in reversed(pairs):
        ys = (y.flatten() @ s.flatten()).item()
        a  = (s.flatten() @ q.flatten()).item() / (ys + eps)
        alphas.append(a)
        q = q - a * y
    r = gamma * q
    for (s, y), a in zip(pairs, reversed(alphas)):
        ys = (y.flatten() @ s.flatten()).item()
        b  = (y.flatten() @ r.flatten()).item() / (ys + eps)
        r = r + (a - b) * s
    return -r

  # une étape RL-BFGS "Mannel"
  def step(self, W, G):               # <--- plus de closure ici
        Griem = tangent_proj(W, G)
        gamma = self._gamma_from_memory(default=1.0)
        P = self.two_loop(Griem, self.pairs, gamma)

        # Pas simple sans line search ; option : normaliser la longueur
        step_vec = (self.eta * P) if P.norm() == 0 else (self.eta * P / (P.norm() + 1e-16))
        W_new = retract_stiefel_shape_preserving(W,step_vec)


        # garder info pour construire (s,y) au prochain update()
        self._pending = (W.detach(), Griem.detach(), step_vec.detach(), W_new.detach())
        return W_new







  @torch.no_grad()
  def update(self, G_new):            # appeler APRES avoir recomputé le gradient au nouveau W
      if self._pending is None:
          return
      W_old, G_old_riem, step_vec, W_new = self._pending
      # transport projeté
      s = transport_proj(W_old, W_new, step_vec)
      G_new_riem = tangent_proj(W_new, G_new)
      y = G_new_riem - transport_proj(W_old, W_new, G_old_riem)

      omega = min(self.c0, self.c1 * (G_old_riem.norm().item() ** self.c2))
      sy = (s.flatten() @ y.flatten()).item()
      if sy >= omega * max(s.norm().pow(2).item(), y.norm().pow(2).item()):
          self.pairs.append((s.detach(), y.detach()))
          if len(self.pairs) > self.history:
              self.pairs.pop(0)

      # transporter toute la mémoire dans le nouveau tangent
      self.pairs = [(transport_proj(W_old, W_new, si),
                      transport_proj(W_old, W_new, yi)) for (si, yi) in self.pairs]
      self._pending = None

  def _gamma_from_memory(self, default=1.0):
      if not self.pairs: return default
      s, y = self.pairs[-1]
      sty = (s.flatten() @ y.flatten()).item()
      yy  = (y.flatten() @ y.flatten()).item()
      if sty > 0:   g = sty / (yy + 1e-16)
      elif yy > 0:  g = (s.norm() / (y.norm() + 1e-16)).item()
      else:         g = default
      omega = min(self.c0, self.c1 * 1.0 ** self.c2)
      return max(omega, min(g, 1.0/omega))

**Useful geometrical riemaniann functions**

In [None]:
def _as_matrix(X):
    """
    View X as a 2D matrix (rows, cols) by collapsing all dims except the first.
    Returns (X_mat, orig_shape) so we can reshape back afterwards.
    """
    orig_shape = X.shape
    if X.ndim == 2:
        return X, orig_shape
    X_mat = X.reshape(X.shape[0], -1)
    return X_mat, orig_shape


def _from_matrix(X_mat, orig_shape):
    """
    Reshape a 2D matrix back to the original tensor shape.
    Assumes first dimension is unchanged.
    """
    return X_mat.reshape(orig_shape)


def _sym(X):
    """Return the symmetric part of a square matrix: (X + X^T)/2."""
    return 0.5 * (X + X.T)

def tangent_proj(W, Z):
    """
    Project an ambient matrix/tensor Z onto the tangent space of the Stiefel manifold at W.
    Tangent space condition at W: W^T Δ + Δ^T W = 0  (skew-symmetry).
    Projection formula: Proj(Z) = Z - W * Sym(W^T Z).

    Ici W et Z peuvent être 2D (Linear) ou des tenseurs (Conv2d).
    On les reshape en matrice (n, p), on applique la formule, puis on remet en forme.
    """
    orig_shape = W.shape

    # Aplatit tout sauf la 1ère dimension : (out_c, in_c, kH, kW) -> (out_c, in_c*kH*kW)
    W_mat = W.reshape(W.shape[0], -1)
    Z_mat = Z.reshape(Z.shape[0], -1)

    n, p = W_mat.shape
    if n >= p:
        # tall : formule standard
        WTZ = W_mat.T @ Z_mat
        P_mat = Z_mat - W_mat @ (0.5 * (WTZ + WTZ.T))
    else:
        # wide : travaille dans l'espace transposé (tall), puis retranspose
        Wt, Zt = W_mat.T, Z_mat.T            # shapes: p x n
        WtTZt = Wt.T @ Zt                    # n x n
        Pt = Zt - Wt @ (0.5 * (WtTZt + WtTZt.T))  # p x n
        P_mat = Pt.T                         # n x p

    # On revient à la forme originale du poids
    return P_mat.reshape(orig_shape)


def polar_retraction(X):
    """
    Retract an ambient matrix back to the Stiefel manifold using the polar factor.
    This returns the matrix with orthonormal columns closest to X in Frobenius norm.
    Prefer torch.linalg.polar when available; fall back to SVD otherwise.
    """
    try:
        # torch >= 2.1: exact polar factorization, numerically stable
        U, _ = torch.linalg.polar(X)
    except Exception:
        # Fallback: X = U Σ V^T  => polar(X) = U V^T
        U, _, Vt = torch.linalg.svd(X, full_matrices=False)
        U = U @ Vt
    return U

def transport_proj(W_old, W_new, Xi):
    """
    Vector transport from T_{W_old} to T_{W_new} by simple re-projection.
    This is NOT isometric but is widely used in practice: T(Xi) = Proj_{T_{W_new}}(Xi).
    Keeps code simple, robust, and compatible with polar retraction.
    """
    return tangent_proj(W_new, Xi)

def frob_inner(X, Y):
    """
    Frobenius inner product <X, Y> = trace(X^T Y). Works for same-shaped matrices.
    Used to compute curvature scalars (sᵀy), scaling, and two-loop recursion scalars.
    """
    return torch.tensordot(X, Y, dims=([0,1],[0,1]))


def retract_qr(X, Xi):  # R_X(Xi)
    Y = X + Xi
    Q, R = torch.linalg.qr(Y, mode='reduced')
    d = torch.sign(torch.diag(R))
    D = torch.diag_embed(d)
    return Q @ D

def project_stiefel_keep_shape_qr(W):
    """
    Project W onto the Stiefel manifold using a QR-based projection,
    while preserving the original tensor shape.

    For 2D weights, this is the usual QR projection.
    For conv weights (e.g. out_c x in_c x kH x kW), we reshape to (m, n),
    do the projection, then reshape back.
    """
    W_mat, orig_shape = _as_matrix(W)  # (m, n)
    m, n = W_mat.shape

    if m >= n:
        # tall case
        Q, R = torch.linalg.qr(W_mat, mode='reduced')   # Q: m x n
        d = torch.sign(torch.diag(R))
        D = torch.diag_embed(d)
        W_proj = Q @ D                                  # m x n
    else:
        # wide case: work in the transposed space, then transpose back
        Qt, Rt = torch.linalg.qr(W_mat.T, mode='reduced')  # Qt: n x m
        d = torch.sign(torch.diag(Rt))
        D = torch.diag_embed(d)
        W_proj = (Qt @ D).T                                # m x n

    return _from_matrix(W_proj, orig_shape)


# _retract ensures that W stays a Stiefel
@torch.no_grad()
def retract_stiefel_shape_preserving(X, Xi):
    Y = X + Xi
    U, S, Vh = torch.linalg.svd(Y, full_matrices=False)
    return U @ Vh  # same shape as X


**Original version from Jeremy Bernstein's paper**

In [None]:
@torch.no_grad() #context manager in PyTorch that disables gradient calculation within its scope. This is particularly useful during inference, validation, or when performing operations where you do not intend to update model parameters and therefore do not need to compute gradients.

def manifold_muon(W, G, eta=0.1, alpha=0.01, steps=100, tol=1e-6):
  '''
  W is a 2D pytorch matric, current weights
  G is a 2D pytorch matric, loss gradient
  eta is the step size for the primal problem
  alpha is the step size for the dual problem
  steps is the number of steps for the dual problem
  tol is the tolerance for the stopping criterion
  '''
  # Ensure that W and G are both tall matrices (more stable for Stiefel)
  should_tranpose = W.shape[0] < W.shape[1]
  if should_tranpose:
      W = W.T
      G = G.T


  # Initialize the dual variable
  Lambda = -0.25 * (W.T @ G + G.T @ W)


  # Ascend on the dual problem to find the update direction A
  for step in range(steps):
      # Update the candidate direction A
      A = msign(G + 2 * W @ Lambda)
      # Measure deviation of A from the tangent space:
      H = W.T @ A + A.T @ W
      # Check the stopping criterion
      if torch.norm(H) / math.sqrt(H.numel()) < tol:
          break
      # Update the dual variable
      Lambda -= alpha * (1 - step / steps) * H
  # Descend on the primal problem
  new_W = W - eta * A
  # Retract to the manifold
  new_W = msign(new_W)
  # Restore the shape of the solution and return
  return new_W.T if should_tranpose else new_W

**Manifold MuOn adaptable for other optimizer instances**

In [None]:
@torch.no_grad()
def manifold_muon_general( W, G, eta=0.1, alpha=0.01, steps=100, tol=1e-6, *, opt = ManifoldLBFGS(eta=0.1, history=10)):
    """
    Drop-in wrapper that replaces the Muon direction with a Riemannian L-BFGS step.

    Args:
        W (torch.Tensor): current point (n x p), ideally on Stiefel (W^T W ≈ I).
        G (torch.Tensor): raw gradient d(loss)/dW, same shape as W.
        eta (float): step size; overrides opt.eta for this call.
        alpha, steps, tol: kept for signature compatibility (ignored here).
        opt (ManifoldLBFGS): persistent optimizer instance.

    Returns:
        torch.Tensor: new W on the Stiefel manifold (same shape/orientation as input).

    Notes:
        - This function ONLY performs the STEP. After you recompute the gradient
          at the returned W, call opt.update(G_new) once to feed (s_k, y_k) to L-BFGS.
        - Removed redundant tall/wide handling from this wrapper; ManifoldLBFGS.step()
          now handles it internally and returns W in the original orientation.
    """
    assert opt is not None and W.shape == G.shape, "Pass your ManifoldLBFGS instance via opt=..."

    # INIT / projection only
    if eta == 0.0 or torch.all(G == 0):
        W_new = project_stiefel_keep_shape_qr(W)            # <— keeps (out×in) shape
        assert W_new.shape == W.shape
        return W_new

    # Ensure W and G have the same shape
    assert W.shape == G.shape, f"W and G must have the same shape, but got W.shape={W.shape} and G.shape={G.shape}"

    # Set per-step step size (if you want to schedule eta externally, set opt.eta there)
    opt.eta = eta
    # Cas "init": juste projeter W sur Stiefel si eta==0 ou G==0
    if eta == 0.0 or torch.all(G == 0):
        return retract_qr(W, torch.zeros_like(W))  # projection via QR

    # Take one quasi-Newton step on the manifold (retraction included inside .step)
    # ManifoldLBFGS.step handles internal transposition and returns W in the original orientation.
    new_W = opt.step(W, G)

    # The new_W returned by opt.step is already in the original orientation.
    return new_W