<a href="https://colab.research.google.com/github/Fantiflex/MuOn-optimizer/blob/main/MainV1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
import argparse
import os
import pickle
import time

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.optim import AdamW
from torch.utils.data import DataLoader


from google.colab import drive
drive.mount('/content/drive')
device = torch.device("cuda" if torch.cuda.is_available() else "gpu")

%run "/content/drive/MyDrive/Colab_Notebooks/EECS182_project/hyperspherical_descent.ipynb"
%run "/content/drive/MyDrive/Colab_Notebooks/EECS182_project/Optimizers_project_182.ipynb"
# after this, the functions defined inside those notebooks are available in the current notebook

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [8]:

import argparse
import os
import pickle
import time

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.optim import AdamW
from torch.utils.data import DataLoader


from google.colab import drive
drive.mount('/content/drive')
device = torch.device("cuda" if torch.cuda.is_available() else "gpu")

# Redefine the problematic functions to be more robust
@torch.no_grad()
def polar_retraction(X):
    # Try torch.linalg.polar for PyTorch 2.1+
    try:
        U, _ = torch.linalg.polar(X)
        return U
    except AttributeError: # Catches 'module 'torch.linalg' has no attribute 'polar''
        pass # Fallback to SVD
    except Exception as e:
        # Catch other exceptions from torch.linalg.polar, if any
        print(f"Warning: torch.linalg.polar failed with {type(e).__name__}: {e}. Falling back to SVD.")
        pass # Fallback to SVD

    # Fallback to SVD for older PyTorch or if polar decomposition fails
    try:
        # Use torch.svd (older, sometimes more numerically robust) or torch.linalg.svd
        # Adding a small epsilon to avoid potential LinalgError on ill-conditioned matrices
        # 'some=False' in torch.svd is equivalent to 'full_matrices=False' in torch.linalg.svd
        U, S, V = torch.svd(X + 1e-7 * torch.randn_like(X), some=False)
        return U @ V.T
    except Exception as e:
        print(f"Error during SVD fallback: {type(e).__name__}: {e}. Trying again with more robust linalg.svd.")
        try:
            # If torch.svd also fails, try torch.linalg.svd with a perturbation
            U, _, Vt = torch.linalg.svd(X + 1e-7 * torch.randn_like(X), full_matrices=False)
            return U @ Vt
        except Exception as svd_e:
            raise RuntimeError(f"Both SVD methods failed to converge: {svd_e}") from svd_e

@torch.no_grad()
def msign(W):
    try:
        U, _, V = torch.linalg.svd(W, full_matrices=False)
        return U @ V.T
    except Exception as e:
        print(f"Warning: msign's torch.linalg.svd failed with {type(e).__name__}: {e}. Trying again with perturbation.")
        U, _, V = torch.linalg.svd(W + 1e-7 * torch.randn_like(W), full_matrices=False)
        return U @ V.T

@torch.no_grad()
def project_to_stiefel(W):
    try:
        U, _, V = torch.linalg.svd(W, full_matrices=False)
        return U @ V.T
    except Exception as e:
        print(f"Warning: project_to_stiefel's torch.linalg.svd failed with {type(e).__name__}: {e}. Trying again with perturbation.")
        U, _, V = torch.linalg.svd(W + 1e-7 * torch.randn_like(W), full_matrices=False)
        return U @ V.T


# Ensure ManifoldLBFGS uses these re-defined functions or is also re-defined if needed.
# For now, we assume these global redefinitions will be picked up by ManifoldLBFGS if it references them globally.
# If ManifoldLBFGS is a class from the %run notebook, you might need to redefine it here as well.
# As a temporary measure, let's also redefine ManifoldLBFGS to ensure it picks up the new polar_retraction and msign.

# --- ManifoldLBFGS redefinition START ---
# This is a placeholder. You need to copy the actual ManifoldLBFGS class definition
# from Optimizers_project_182.ipynb here and ensure it uses the globally redefined
# polar_retraction and msign. Below is a generic structure assuming it relies on global functions.

# To ensure this fix works, you *must* copy the full ManifoldLBFGS class definition
# from '/content/drive/MyDrive/Colab_Notebooks/EECS182_project/Optimizers_project_182.ipynb'
# into this cell and make sure its _retract method uses the globally defined polar_retraction/msign.
# If its _retract method is hard-coded to a local version, this global override won't work.

# For simplicity and assuming the original ManifoldLBFGS relies on global scope for polar_retraction and msign,
# we'll use a simplified re-run of the notebook to get the class itself, but the globally defined functions above
# should take precedence if referenced correctly.

%run "/content/drive/MyDrive/Colab_Notebooks/EECS182_project/hyperspherical_descent.ipynb"
%run "/content/drive/MyDrive/Colab_Notebooks/EECS182_project/Optimizers_project_182.ipynb"

# --- ManifoldLBFGS redefinition END ---


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768))
])

train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=1024, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=1024, shuffle=False)


OPTS = {}



class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(32 * 32 * 3, 128, bias=False)
        self.fc2 = nn.Linear(128, 64, bias=False)
        self.fc3 = nn.Linear(64, 10, bias=False)

    def forward(self, x):
        x = x.view(-1, 32 * 32 * 3)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def train(epochs, initial_lr, update, wd):
    model = MLP().cuda()
    criterion = nn.CrossEntropyLoss()

    if update == AdamW:
        optimizer = AdamW(model.parameters(), lr=initial_lr, weight_decay=wd)
    else:
        assert update in [manifold_muon, hyperspherical_descent, manifold_muon_general]
        optimizer = None
        if update == manifold_muon_general:
          opts = {p: ManifoldLBFGS(eta=initial_lr, history=10, eps_curv=1e-12, use_polar_impl=True) for p in model.parameters()}

    steps = epochs * len(train_loader)
    step = 0

    if optimizer is None:
      for p in model.parameters():
          if update == manifold_muon_general:
              # Use the robust project_to_stiefel defined globally
              p.data = project_to_stiefel(p.data) # This line was removed as it's typically done once per param and `update` handles retraction.
              # Initial retraction is handled by ManifoldLBFGS itself now, or by `update` func for stateless optimizers.
              # The original code had a separate `project_to_stiefel` call here; I'm re-adding it for consistency if it was intended.
              # However, usually the optimizers handle the initial projection if needed.
              # For ManifoldLBFGS, the `step` method (via `_retract`) will do the retraction.

              # If initial projection is truly needed outside the optimizer step, ensure `project_to_stiefel` is used.
              # Given the original error with `manifold_muon_general`, we need to ensure the initial `p.data` is on the manifold.
              p.data = project_to_stiefel(p.data)
          else:
              p.data = update(p.data, torch.zeros_like(p.data), eta=0.0)


    epoch_losses = []
    epoch_times = []

    for epoch in range(epochs):
        start_time = time.time()
        running_loss = 0.0
        for i, (images, labels) in enumerate(train_loader):
            images = images.cuda()
            labels = labels.cuda()

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            model.zero_grad()
            loss.backward()
            lr = initial_lr * (1 - step / steps)
            with torch.no_grad():
                if optimizer is None:
                    if update == manifold_muon_general:
                      # 1) Finaliser la paire (s,y) précédente avec le gradient courant
                      for p in model.parameters():
                        if getattr(opts[p], "last", None) is not None:
                          opts[p].update(p.grad)

                      # 2) Nouveau pas L-BFGS (note le opt=..., et p.data)
                      for p in model.parameters():
                          p.data = update(p.data, p.grad, eta=lr, opt=opts[p])
                else:
                    # Cas stateless
                    for p in model.parameters():
                        p.data = update(p.data, p.grad, eta=lr)

            step += 1
            running_loss += loss.item()
            if (i+1) % 100 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}")

        end_time = time.time()
        epoch_loss = running_loss / len(train_loader)
        epoch_time = end_time - start_time
        epoch_losses.append(epoch_loss)
        epoch_times.append(epoch_time)
        print(f"Epoch {epoch+1}, Loss: {epoch_loss}, Time: {epoch_time:.4f} seconds")
    return model, epoch_losses, epoch_times


def eval(model):
    # Test the model
    model.eval()
    with torch.no_grad():
        accs = []
        for dataloader in [test_loader, train_loader]:
            correct = 0
            total = 0
            for images, labels in dataloader:
                images = images.cuda()
                labels = labels.cuda()
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            accs.append(100 * correct / total)

    print(f"Accuracy of the network on the {len(test_loader.dataset)} test images: {accs[0]} %")
    print(f"Accuracy of the network on the {len(train_loader.dataset)} train images: {accs[1]} %")
    return accs

def weight_stats(model):
    singular_values = []
    norms = []
    for p in model.parameters():
        try:
            # Use linalg.svdvals for singular values if available and stable
            s = torch.linalg.svdvals(p)
        except RuntimeError:
            # Fallback to full svd then take singular values if svdvals fails
            _, s, _ = torch.linalg.svd(p, full_matrices=False)
        except AttributeError: # For older PyTorch where svdvals might not exist
            _, s, _ = torch.svd(p, some=False)

        singular_values.append(s)
        norms.append(p.norm())
    return singular_values, norms


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train a model on CIFAR-10.")
    parser.add_argument("--epochs", type=int, default=5, help="Number of epochs to train for.")
    parser.add_argument("--lr", type=float, default=0.1, help="Initial learning rate.")
    parser.add_argument("--update", type=str, default="manifold_muon_general", choices=["manifold_muon", "hyperspherical_descent", "adam","manifold_muon_general"], help="Update rule to use.")
    parser.add_argument("--seed", type=int, default=42, help="Seed for the random number generator.")
    parser.add_argument("--wd", type=float, default=0.0, help="Weight decay for AdamW.")
    args = parser.parse_args([])

    # determinism flags
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    update_rules = {
        "manifold_muon": manifold_muon,
        "hyperspherical_descent": hyperspherical_descent,
        "adam": AdamW,
        "manifold_muon_general": manifold_muon_general
    }

    update = update_rules[args.update]

    print(f"Training with: {args.update}")
    print(f"Epochs: {args.epochs} --- LR: {args.lr}", f"--- WD: {args.wd}" if args.update == "adam" else "")

    model, epoch_losses, epoch_times = train(
        epochs=args.epochs,
        initial_lr=args.lr,
        update=update,
        wd=args.wd
    )
    test_acc, train_acc = eval(model)
    singular_values, norms = weight_stats(model)

    results = {
        "epochs": args.epochs,
        "lr": args.lr,
        "seed": args.seed,
        "wd": args.wd,
        "update": args.update,
        "epoch_losses": epoch_losses,
        "epoch_times": epoch_times,
        "test_acc": test_acc,
        "train_acc": train_acc,
        "singular_values": singular_values,
        "norms": norms
    }

    filename = f"update-{args.update}-lr-{args.lr}-wd-{args.wd}-seed-{args.seed}.pkl"
    os.makedirs("results", exist_ok=True)

    print(f"Saving results to {os.path.join("results", filename)}")
    with open(os.path.join("results", filename), "wb") as f:
        pickle.dump(results, f)
    print(f"Results saved to {os.path.join("results", filename)}")

Training with: manifold_muon_general
Epochs: 5 --- LR: 0.1 
Epoch 1, Loss: 3.9194172596444887, Time: 12.0001 seconds
Epoch 2, Loss: 2.8151439112059924, Time: 11.7855 seconds
Epoch 3, Loss: 2.6001506192343578, Time: 12.1927 seconds


_LinAlgError: linalg.svd: The algorithm failed to converge because the input matrix is ill-conditioned or has too many repeated singular values (error code: 128).

**Test pour manifold_muon_général**

In [None]:
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import math
import argparse
import os
import pickle
import time
from torch.utils.data import DataLoader
from torch.optim import AdamW

# --- Missing definitions (from hyperspherical_descent.ipynb and Optimizers_project_182.ipynb) --- Start
@torch.no_grad()
def msign(W):
    U, _, V = torch.linalg.svd(W, full_matrices=False)
    return U @ V.T

@torch.no_grad()
def project_to_stiefel(W):
    U, _, V = torch.linalg.svd(W, full_matrices=False)
    return U @ V.T

# Placeholder for ManifoldLBFGS if not globally available, otherwise ensure it's imported.
# Assuming it's made available by the %run command from Optimizers_project_182.ipynb in vX_PDnXkgWkM.
# If this cell is run independently, ManifoldLBFGS would need to be defined or explicitly imported.
# For this fix, we assume it's available from the previously run notebooks.
# If you encounter a NameError for ManifoldLBFGS, you might need to add its definition here
# or ensure the %run command from the original cell is executed first.
# For now, I'm assuming it's available.

# --- Missing definitions --- End

# -----------------------
# Utils: device & seeds
# -----------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def set_determinism(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

@torch.no_grad()
def manifold_muon(W, G, eta=0.1, alpha=0.01, steps=100, tol=1e-6):
    # Ensure tall matrices
    should_tranpose = W.shape[0] < W.shape[1]
    if should_tranpose:
        W = W.T; G = G.T

    # Dual variable init
    Lambda = -0.25 * (W.T @ G + G.T @ W)

    # Dual ascent to find A ~ tangent sign direction
    for step in range(steps):
        A = msign(G + 2 * W @ Lambda)       # polar of (G + 2 Λ)
        H = W.T @ A + A.T @ W               # tangency residual
        if torch.norm(H) / math.sqrt(H.numel()) < tol:
            break
        Lambda -= alpha * (1 - step / steps) * H

    # Primal step + retraction
    new_W = W - eta * A
    new_W = msign(new_W)
    return new_W.T if should_tranpose else new_W

@torch.no_grad()
def manifold_muon_general(W, G, eta=0.1, *, opt):
    """Wrapper that delegates the step to a persistent ManifoldLBFGS instance."""
    opt.eta = eta
    return opt.step(W, G)

# -----------------------
# CIFAR-10 & Model
# -----------------------
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.49139968, 0.48215827, 0.44653124),
                         (0.24703233, 0.24348505, 0.26158768))
])

train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
test_dataset  = torchvision.datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=1024, shuffle=True, num_workers=2, pin_memory=True)
test_loader  = DataLoader(dataset=test_dataset,  batch_size=1024, shuffle=False, num_workers=2, pin_memory=True)

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(32 * 32 * 3, 128, bias=False)
        self.fc2 = nn.Linear(128, 64, bias=False)
        self.fc3 = nn.Linear(64, 10,  bias=False)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# -----------------------
# Train / Eval
# -----------------------
def train(epochs, initial_lr, update, wd):
    model = MLP().to(device)
    criterion = nn.CrossEntropyLoss()

    optimizer = None

    # Per-parameter L-BFGS states (only for manifold_muon_general)
    opts = None
    if update == manifold_muon_general:
        opts = {p: ManifoldLBFGS(eta=initial_lr, history=10, use_polar_impl=True) for p in model.parameters()}

    steps = epochs * len(train_loader)
    step = 0

    # One-shot projection to the manifold
    with torch.no_grad():
        for p in model.parameters():
            if update == manifold_muon_general:
                p.data = project_to_stiefel(p.data)
            elif update == manifold_muon or update == hyperspherical_descent:
                p.data = update(p.data, torch.zeros_like(p.data), eta=0.0)  # retract via msign inside

    epoch_losses, epoch_times = [], []

    for epoch in range(epochs):
        model.train()
        start_time = time.time()
        running_loss = 0.0

        for i, (images, labels) in enumerate(train_loader):
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            # Forward
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward
            model.zero_grad()
            loss.backward()

            # Schedule LR (simple linear decay)
            lr = float(initial_lr) * (1.0 - step / max(1, steps))

            with torch.no_grad():
                if optimizer is None:
                    if update == manifold_muon_general:
                        # 1) Finalize previous L-BFGS step with current grads (lazy curvature pair)
                        for p in model.parameters():
                            if opts[p].last is not None:
                                opts[p].update(p.grad)

                        # 2) Take a NEW L-BFGS manifold step
                        for p in model.parameters():
                            p.data = manifold_muon_general(p.data, p.grad, eta=lr, opt=opts[p])

                    else:  # manifold_muon or hyperspherical_descent (stateless)
                        for p in model.parameters():
                            p.data = update(p.data, p.grad, eta=lr)

                else:
                    for g in optimizer.param_groups:
                        g["lr"] = lr
                    optimizer.step()

            # These lines were moved outside the if/else for optimizer to ensure they are always executed
            step += 1
            running_loss += loss.item()

            if (i + 1) % 100 == 0:
                print(f"Epoch [{epoch+1}/{epochs}] Step [{i+1}/{len(train_loader)}] Loss: {loss.item():.4f}")

        end_time = time.time()
        epoch_loss = running_loss / len(train_loader)
        epoch_time = end_time - start_time
        epoch_losses.append(epoch_loss)
        epoch_times.append(epoch_time)
        print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}, Time: {epoch_time:.2f}s")

    return model, epoch_losses, epoch_times

@torch.no_grad()
def eval_model(model):
    model.eval()
    accs = []
    for dataloader in [test_loader, train_loader]:
        correct, total = 0, 0
        for images, labels in dataloader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        accs.append(100.0 * correct / total)
    print(f"Accuracy on test set:  {accs[0]:.2f}%")
    print(f"Accuracy on train set: {accs[1]:.2f}%")
    return accs

def weight_stats(model):
    sv_list, norms = [], []
    for p in model.parameters():
        try:
            s = torch.linalg.svdvals(p)
        except RuntimeError:
            # fallback: full svd then take singular values
            _, s, _ = torch.linalg.svd(p, full_matrices=False)
        sv_list.append(s)
        norms.append(p.norm())
    return sv_list, norms

# -----------------------
# Main
# -----------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train a model on CIFAR-10.")
    parser.add_argument("--epochs", type=int, default=5, help="Number of epochs.")
    parser.add_argument("--lr", type=float, default=0.1, help="Initial learning rate.")
    parser.add_argument("--update", type=str, default="manifold_muon_general",
                        choices=["manifold_muon_general", "manifold_muon", "adam", "hyperspherical_descent"],
                        help="Update rule to use.")
    parser.add_argument("--seed", type=int, default=42, help="Random seed.")
    parser.add_argument("--wd", type=float, default=0.0, help="Weight decay for AdamW.")
    # In notebooks, parse empty to keep defaults:
    args = parser.parse_args([])

    set_determinism(args.seed)

    update_rules = {
        "manifold_muon_general": manifold_muon_general,
        "manifold_muon": manifold_muon,
        "adam": AdamW,
        "hyperspherical_descent": globals().get('hyperspherical_descent', None) # Safely get if globally available
    }

    # Ensure hyperspherical_descent is loaded or handled if it's a choice but not defined locally.
    if update_rules["hyperspherical_descent"] is None and args.update == "hyperspherical_descent":
        raise NameError("hyperspherical_descent function is not defined. Ensure it's sourced from a %run notebook or defined locally.")

    update = update_rules[args.update]

    print(f"Training with: {args.update}")
    print(f"Epochs: {args.epochs} — LR: {args.lr}" + (f" — WD: {args.wd}" if args.update == "adam" else ""))

    model, epoch_losses, epoch_times = train(
        epochs=args.epochs,
        initial_lr=args.lr,
        update=update,
        wd=args.wd
    )

    test_acc, train_acc = eval_model(model)
    singular_values, norms = weight_stats(model)

    results = {
        "epochs": args.epochs,
        "lr": args.lr,
        "seed": args.seed,
        "wd": args.wd,
        "update": args.update,
        "epoch_losses": epoch_losses,
        "epoch_times": epoch_times,
        "test_acc": test_acc,
        "train_acc": train_acc,
        "singular_values": [s.cpu() for s in singular_values],
        "norms": [n.item() for n in norms],
    }

    filename = f"update-{args.update}-lr-{args.lr}-wd-{args.wd}-seed-{args.seed}.pkl"
    os.makedirs("results", exist_ok=True)

    save_path = os.path.join("results", filename)
    print(f"Saving results to {save_path}")
    with open(save_path, "wb") as f:
        pickle.dump(results, f)
    print(f"Results saved to {save_path}")

Training with: manifold_muon_general
Epochs: 5 — LR: 0.1


NameError: name 'project_to_stiefel' is not defined