<a href="https://colab.research.google.com/github/Fantiflex/MuOn-optimizer/blob/main/Optimizers_project_182.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Setup**



In [1]:
import math
import torch


**L-BFGS optimizer class setup**

In [3]:
class ManifoldLBFGS:
  """
    Riemannian L-BFGS sur Stiefel avec:
      - métrique euclidienne (produit de Frobenius)
      - rétraction polaire (msign/polar)
      - transport par projection
    Utilisation typique:
        opt = ManifoldLBFGS(eta=0.1, history=10)
        # itération k
        W = opt.step(W, Gk)           # prend un pas, renvoie W_{k+1}
        ...
        Gkp1 = ...                    # gradient au nouveau point (fourni par toi)
        opt.update(Gkp1)              # met à jour l'historique (s_k, y_k)
    """
  def __init__(self, eta=0.1, history=10, eps_curv=1e-12, use_polar_impl=True):
    '''
    eta is the step size for the primal problem
    history is the memory length for L-BGFS
    eps_curv avoids unfortunate updates
    '''
    self.eta = eta
    self.m = history
    self.eps_curv = eps_curv
    self.S, self.Y, self.RHO = [], [], []
    self.last = None
    self.use_polar_impl = use_polar_impl

  # _retract ensures that W stays a Stiefel
  def _retract(self, X):
    return polar_retraction(X) if self.use_polar_impl else msign(X)




  def two_loops(self, q):
    '''
    L-BFGS two-loop recursion to apply the approximate inverse Hessian H_k to vector q (current gradient already projected to the tangent).
    Returns H_k q (descent direction).
    Uses stored (s_i, y_i) with ρ_i = 1/<s_i, y_i>.
    '''
    if len(self.S) == 0: # No curvature info yet -> fallback to steepest descent
          return -q
    else :
      alpha = []
      for s, y, rho in reversed(list(zip(self.S, self.Y, self.RHO))):
        # ===== Backward loop=====
        # Remove components of q along recent curvature directions y_i.
        # This emulates multiplying by the right-hand factors in the (inverse) BFGS formula,
        # while only storing vectors. Each step uses a scalar α_i = ρ_i <s_i, q>.
        a = rho * frob_inner(s, q)   # α_i = ρ_i * <s_i, q>   (scalar)
        alpha.append(a)
        q = q - a * y                # q ← q - α_i y_i  (matrix update)

      # ===== Initial scaling: H0 = γ I =====
      # γ scales the identity so that H0 y_last ≈ s_last, i.e., matches recent local curvature.
      y_last = self.Y[-1]
      s_last = self.S[-1]
      sy = frob_inner(s_last, y_last)  # <s_{m-1}, y_{m-1}>  (positive by safeguard)
      yy = frob_inner(y_last, y_last)  # <y_{m-1}, y_{m-1}>
      gamma = (sy / yy) if yy > 0 else 1.0  # robust fallback if yy ≈ 0
      r = gamma * q                    # r ≈ H0 q

      for (s, y, rho), a in zip(zip(self.S, self.Y, self.RHO), reversed(alpha)):
        # ===== Forward loop=====
        # Rebuild the action of H_k using s_i and the stored α_i.
        # β_i = ρ_i <y_i, r>, then r ← r + s_i (α_i - β_i).
        # This sequence applies the low-rank BFGS corrections in the correct order.
        beta = rho * frob_inner(y, r)  # β_i = ρ_i * <y_i, r>  (scalar)
        r = r + s * (a - beta)         # r ← r + s_i (α_i - β_i) (matrix)

      #r ≈ H_k g_k the final descent direction
      return r


  def step(self, W, G):
    """
      Take a single quasi-Newton step from W using raw gradient G (same shape as W).
      Pipeline:
        1) If W is wide (n < p), transpose to operate in tall shape for Stiefel stability.
        2) Project raw gradient onto tangent space: g = Proj_TW(G).
        3) Build quasi-Newton direction d = -H_k g via two-loop recursion.
        4) Extra projection for numerical safety: d = Proj_TW(d).
        5) Take ambient step and retract: W_new = Retr(W - eta * d).
      Returns:
          W_new with the same shape/orientation as the input W (orthonormal columns).
      """
    # Ensure W and G have the same shape
    assert W.shape == G.shape, f"W and G must have the same shape, but got W.shape={W.shape} and G.shape={G.shape}"

    should_transpose = W.shape[0] < W.shape[1]
    # The print statements below are for debugging and can be removed in a production environment.
    # print(f"DEBUG (ManifoldLBFGS.update): should_transpose: {should_transpose}")
    # print(f"DEBUG (ManifoldLBFGS.update): W_new shape from last step: {W.shape}")
    # print(f"DEBUG (ManifoldLBFGS.update): G_new shape (raw from p.grad) before transpose logic: {G.shape}")
    if should_transpose:
        W = W.T
        G = G.T
    # print(f"DEBUG (ManifoldLBFGS.update): Calling tangent_proj with W_new shape: {W.shape}, G_new shape: {G.shape}")
    # Riemannian gradient (tangent at W)
    #print('1 tour step')
    g = tangent_proj(W, G)

    # Quasi-Newton direction via L-BFGS memory
    d = self.two_loops(g)

    # Re-project in case of small numerical drift
    d = tangent_proj(W, d)

    # Ambient move + retraction back to the Stiefel
    W_new = W - self.eta * d
    W_new = self._retract(W_new)

    # Cache objects needed to form (s_k, y_k) once G_new is available
    self.last = {
        "W": W,
        "W_new": W_new,
        "g": g,
        "d": d,
        "should_transpose": should_transpose
    }
    return W_new.T if should_transpose else W_new






  def update(self, G_new):
      """
      Update the L-BFGS history once the new gradient at W_new is known.
      Pipeline:
        1) Read cache from the last call to step().
        2) Build new Riemannian gradient: g_new = Proj_{T_{W_new}}(G_new).
        3) Transport previous step and gradient to the new tangent space:
              s_k = Transport(eta * d)           ∈ T_{W_new}
              y_k = g_new - Transport(g)          ∈ T_{W_new}
        4) Curvature test: require <s_k, y_k> > eps_curv to keep H_k PD.
        5) Push (s_k, y_k, 1/<s_k, y_k>) into rolling buffers; drop oldest if needed.
      Returns:
          True if the pair was accepted; False if rejected by curvature safeguard.
      """
      assert self.last is not None

      W = self.last["W"]



      # Ensure W and G have the same shape
      #print('1 début de update')

      W_new  = self.last["W_new"]
      d      = self.last["d"]
      g      = self.last["g"]
      should_transpose = self.last["should_transpose"]
      #print(f"DEBUG (ManifoldLBFGS.update): should_transpose: {should_transpose}")
      #print(f"DEBUG (ManifoldLBFGS.update): W_new shape from last step: {W_new.shape}")
      #print(f"DEBUG (ManifoldLBFGS.update): G_new shape (raw from p.grad) before transpose logic: {G_new.shape}")

      if should_transpose:
          G_new = G_new.T

      # print(f"DEBUG (ManifoldLBFGS.update): Calling tangent_proj with W_new shape: {W_new.shape}, G_new shape: {G.shape}")
      #assert W.shape == G_new.shape, f"W and G must have the same shape, but got W.shape={W.shape} and G.shape={G_new.shape}"
      # New Riemannian gradient
      g_new = tangent_proj(W_new, G_new)

      # Form (s_k, y_k) in the new tangent space
      s = transport_by_projection(W, W_new, -self.eta * d)   # displacement
      y = g_new - transport_by_projection(W, W_new, g)      # grad change


      # Curvature condition <s, y> > 0 for positive definite inverse Hessian
      sy = frob_inner(s, y)
      if not torch.isfinite(sy) or sy <= self.eps_curv:
          # Reject bad curvature to keep the inverse Hessian approximation well-conditioned
          self.last = None
          return False

      # Maintain limited memory (FIFO)
      if len(self.S) == self.m:
          self.S.pop(0); self.Y.pop(0); self.RHO.pop(0)
      self.S.append(s.detach().clone())
      self.Y.append(y.detach().clone())
      self.RHO.append(1.0 / sy)

      self.last = None
      return True

**Useful geometrical riemaniann functions**

In [None]:
def _sym(X):
    """Return the symmetric part of a square matrix: (X + X^T)/2."""
    return 0.5 * (X + X.T)

def tangent_proj(W, Z):
  """
  Project an ambient matrix Z onto the tangent space of the Stiefel manifold at W.
  Tangent space condition at W: W^T Δ + Δ^T W = 0  (skew-symmetry).
  Projection formula: Proj(Z) = Z - W * Sym(W^T Z).
  Shapes: W ∈ R^{n×p} with orthonormal columns; Z ∈ R^{n×p}; return ∈ R^{n×p}.
  """
  n, p = W.shape
  if n >= p:
      # tall : formule standard
      #print(f"DEBUG: In tangent_proj (n >= p branch):")
      #print(f"DEBUG: W shape: {W.shape}")
      #print(f"DEBUG: Z shape: {Z.shape}")
      return Z - W @ (0.5 * ((W.T @ Z) + (W.T @ Z).T))
  else:
      # wide : travaille dans l'espace transposé (tall), puis retranspose
      Wt, Zt = W.T, Z.T            # shapes: p x n
      Pt = Zt - Wt @ (0.5 * ((Wt.T @ Zt) + (Wt.T @ Zt).T))  # p x n
      return Pt.T                   # n x p


def polar_retraction(X):
    """
    Retract an ambient matrix back to the Stiefel manifold using the polar factor.
    This returns the matrix with orthonormal columns closest to X in Frobenius norm.
    Prefer torch.linalg.polar when available; fall back to SVD otherwise.
    """
    try:
        # torch >= 2.1: exact polar factorization, numerically stable
        U, _ = torch.linalg.polar(X)
    except Exception:
        # Fallback: X = U Σ V^T  => polar(X) = U V^T
        U, _, Vt = torch.linalg.svd(X, full_matrices=False)
        U = U @ Vt
    return U

def transport_by_projection(W_old, W_new, Xi):
    """
    Vector transport from T_{W_old} to T_{W_new} by simple re-projection.
    This is NOT isometric but is widely used in practice: T(Xi) = Proj_{T_{W_new}}(Xi).
    Keeps code simple, robust, and compatible with polar retraction.
    """
    return tangent_proj(W_new, Xi)

def frob_inner(X, Y):
    """
    Frobenius inner product <X, Y> = trace(X^T Y). Works for same-shaped matrices.
    Used to compute curvature scalars (sᵀy), scaling, and two-loop recursion scalars.
    """
    return torch.tensordot(X, Y, dims=([0,1],[0,1]))

**Original version from Jeremy Bernstein's paper**

In [None]:
@torch.no_grad() #context manager in PyTorch that disables gradient calculation within its scope. This is particularly useful during inference, validation, or when performing operations where you do not intend to update model parameters and therefore do not need to compute gradients.

def manifold_muon(W, G, eta=0.1, alpha=0.01, steps=100, tol=1e-6):
  '''
  W is a 2D pytorch matric, current weights
  G is a 2D pytorch matric, loss gradient
  eta is the step size for the primal problem
  alpha is the step size for the dual problem
  steps is the number of steps for the dual problem
  tol is the tolerance for the stopping criterion
  '''
  # Ensure that W and G are both tall matrices (more stable for Stiefel)
  should_tranpose = W.shape[0] < W.shape[1]
  if should_tranpose:
      W = W.T
      G = G.T


  # Initialize the dual variable
  Lambda = -0.25 * (W.T @ G + G.T @ W)


  # Ascend on the dual problem to find the update direction A
  for step in range(steps):
      # Update the candidate direction A
      A = msign(G + 2 * W @ Lambda)
      # Measure deviation of A from the tangent space:
      H = W.T @ A + A.T @ W
      # Check the stopping criterion
      if torch.norm(H) / math.sqrt(H.numel()) < tol:
          break
      # Update the dual variable
      Lambda -= alpha * (1 - step / steps) * H
  # Descend on the primal problem
  new_W = W - eta * A
  # Retract to the manifold
  new_W = msign(new_W)
  # Restore the shape of the solution and return
  return new_W.T if should_tranpose else new_W

**Manifold MuOn adaptable for other optimizer instances**

In [None]:
@torch.no_grad()
def manifold_muon_general( W, G, eta=0.1, alpha=0.01, steps=100, tol=1e-6, *, opt = ManifoldLBFGS(eta=0.1, history=10)):
    """
    Drop-in wrapper that replaces the Muon direction with a Riemannian L-BFGS step.

    Args:
        W (torch.Tensor): current point (n x p), ideally on Stiefel (W^T W ≈ I).
        G (torch.Tensor): raw gradient d(loss)/dW, same shape as W.
        eta (float): step size; overrides opt.eta for this call.
        alpha, steps, tol: kept for signature compatibility (ignored here).
        opt (ManifoldLBFGS): persistent optimizer instance.

    Returns:
        torch.Tensor: new W on the Stiefel manifold (same shape/orientation as input).

    Notes:
        - This function ONLY performs the STEP. After you recompute the gradient
          at the returned W, call opt.update(G_new) once to feed (s_k, y_k) to L-BFGS.
        - Removed redundant tall/wide handling from this wrapper; ManifoldLBFGS.step()
          now handles it internally and returns W in the original orientation.
    """
    assert opt is not None, "Pass your ManifoldLBFGS instance via opt=..."
    # Ensure W and G have the same shape
    assert W.shape == G.shape, f"W and G must have the same shape, but got W.shape={W.shape} and G.shape={G.shape}"

    # Set per-step step size (if you want to schedule eta externally, set opt.eta there)
    opt.eta = eta

    # Take one quasi-Newton step on the manifold (retraction included inside .step)
    # ManifoldLBFGS.step handles internal transposition and returns W in the original orientation.
    new_W = opt.step(W, G)

    # The new_W returned by opt.step is already in the original orientation.
    return new_W