<a href="https://colab.research.google.com/github/Fantiflex/MuOn-optimizer/blob/main/CNN_Main_for_global_linear_layers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import argparse
import os
import pickle
import time

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.optim import AdamW
from torch.utils.data import DataLoader


from google.colab import drive
drive.mount('/content/drive', force_remount=True)
device = torch.device("cuda" if torch.cuda.is_available() else "gpu")

%run "/content/drive/MyDrive/Colab_Notebooks/EECS182_project/hyperspherical_descent.ipynb"
%run "/content/drive/MyDrive/Colab_Notebooks/EECS182_project/LGFBS_global.ipynb"
# after this, the functions defined inside those notebooks are available in the current notebook

Mounted at /content/drive
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [21]:

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768))
])

train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=1024, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=1024, shuffle=False)


OPTS = {}



class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # 3 x 32 x 32  -> 128 x 8 x 8
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True),

            nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),              # 64 x 16 x 16

            nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),              # 128 x 8 x 8
        )

        self.classifier = nn.Sequential(
            nn.Linear(128 * 8 * 8, 256, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(256, 256, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(256, 128, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(128, 10, bias=False),
        )

    def forward(self, x):
        # x arrive déjà en (B, 3, 32, 32) grâce au DataLoader
        x = self.features(x)
        x = x.view(x.size(0), -1)  # flatten
        x = self.classifier(x)
        return x


def is_linear_weight(p):
    # on considère comme "Stiefel" uniquement les matrices 2D (weights des Linear)
    return p.dim() == 2


def train(epochs, initial_lr, update, wd):
    model = CNN().cuda()
    criterion = nn.CrossEntropyLoss()

    if update == AdamW:
        optimizer = AdamW(model.parameters(), lr=initial_lr, weight_decay=wd)
        opts = None
    else:
        assert update in [manifold_muon, hyperspherical_descent, manifold_muon_general]
        optimizer = None
        opts = None
        if update == manifold_muon_general:
            # état L-BFGS seulement pour les matrices 2D (Linear)
            opts = {
                p: ManifoldLBFGS(eta=initial_lr, history=10, eps_curv=1e-12)
                for p in model.parameters()
                if is_linear_weight(p)
            }
            print("Nb de paramètres sur la manifold :", len(opts))

        if update == manifold_muon:
            nb_manifold = sum(1 for p in model.parameters() if is_linear_weight(p))
            print("Nb de paramètres sur la manifold (muon classique) :", nb_manifold)

    steps = epochs * len(train_loader)
    step = 0

    # --- Projection initiale sur la manifold : seulement pour les Linear ---
    if optimizer is None:
        for p in model.parameters():
            if is_linear_weight(p) and update in [manifold_muon, manifold_muon_general]:
                if update == manifold_muon_general:
                    p.data = manifold_muon_general(
                        p.data, torch.zeros_like(p.data), eta=0.0, opt=opts[p]
                    )
                else:  # muon classique
                    p.data = manifold_muon(
                        p.data, torch.zeros_like(p.data), eta=0.0
                    )
            # sinon: on laisse convs & biais tranquilles

    epoch_losses = []
    epoch_times = []

    for epoch in range(epochs):
        start_time = time.time()
        running_loss = 0.0
        for i, (images, labels) in enumerate(train_loader):
            images = images.cuda()
            labels = labels.cuda()

            outputs = model(images)
            loss = criterion(outputs, labels)

            model.zero_grad()
            loss.backward()
            lr = initial_lr * (1 - step / steps)

            with torch.no_grad():
                if optimizer is None:
                    # --- Cas muon / hyperspherical_descent ---
                    if update == manifold_muon_general:
                        # 1) update des paires (s,y) pour les Linear
                        for p in model.parameters():
                            if is_linear_weight(p) and getattr(opts.get(p, None), "last", None) is not None:
                                opts[p].update(p.grad)

                        # 2) step : muon général sur Linear, step euclidien pour le reste
                        for p in model.parameters():
                            if is_linear_weight(p):
                                p.data = manifold_muon_general(
                                    p.data, p.grad, eta=lr, opt=opts[p]
                                )
                            else:
                                p.data -= lr * p.grad

                    elif update == manifold_muon:
                        # muon classique seulement sur les Linear
                        for p in model.parameters():
                            if is_linear_weight(p):
                                p.data = manifold_muon(p.data, p.grad, eta=lr)
                            else:
                                p.data -= lr * p.grad

                    elif update == hyperspherical_descent:
                        # à toi de décider : tout ou seulement certains paramètres
                        for p in model.parameters():
                            p.data = hyperspherical_descent(p.data, p.grad, eta=lr)

                else:
                    # Cas AdamW
                    optimizer.step()

            step += 1
            running_loss += loss.item()
            if (i + 1) % 100 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}")

        end_time = time.time()
        epoch_loss = running_loss / len(train_loader)
        epoch_time = end_time - start_time
        epoch_losses.append(epoch_loss)
        epoch_times.append(epoch_time)
        print(f"Epoch {epoch+1}, Loss: {epoch_loss}, Time: {epoch_time:.4f} seconds")

    return model, epoch_losses, epoch_times


def eval(model):
    # Test the model
    model.eval()
    with torch.no_grad():
        accs = []
        for dataloader in [test_loader, train_loader]:
            correct = 0
            total = 0
            for images, labels in dataloader:
                images = images.cuda()
                labels = labels.cuda()
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            accs.append(100 * correct / total)

    print(f"Accuracy of the network on the {len(test_loader.dataset)} test images: {accs[0]} %")
    print(f"Accuracy of the network on the {len(train_loader.dataset)} train images: {accs[1]} %")
    return accs

def weight_stats(model):
    singular_values = []
    norms = []
    for p in model.parameters():
        u,s,v = torch.svd(p)
        singular_values.append(s)
        norms.append(p.norm())
    return singular_values, norms


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train a model on CIFAR-10.")
    parser.add_argument("--epochs", type=int, default=5, help="Number of epochs to train for.")
    parser.add_argument("--lr", type=float, default=0.5, help="Initial learning rate.")
    parser.add_argument("--update", type=str, default="manifold_muon_general", choices=["manifold_muon", "hyperspherical_descent", "adam","manifold_muon_general"], help="Update rule to use.")
    parser.add_argument("--seed", type=int, default=42, help="Seed for the random number generator.")
    parser.add_argument("--wd", type=float, default=0.0, help="Weight decay for AdamW.")
    args = parser.parse_args([])

    # determinism flags
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    update_rules = {
        "manifold_muon": manifold_muon,
        "hyperspherical_descent": hyperspherical_descent,
        "adam": AdamW,
        "manifold_muon_general": manifold_muon_general
    }

    update = update_rules[args.update]

    print(f"Training with: {args.update}")
    print(f"Epochs: {args.epochs} --- LR: {args.lr}", f"--- WD: {args.wd}" if args.update == "adam" else "")

    model, epoch_losses, epoch_times = train(
        epochs=args.epochs,
        initial_lr=args.lr,
        update=update,
        wd=args.wd
    )
    test_acc, train_acc = eval(model)
    singular_values, norms = weight_stats(model)

    results = {
        "epochs": args.epochs,
        "lr": args.lr,
        "seed": args.seed,
        "wd": args.wd,
        "update": args.update,
        "epoch_losses": epoch_losses,
        "epoch_times": epoch_times,
        "test_acc": test_acc,
        "train_acc": train_acc,
        "singular_values": singular_values,
        "norms": norms
    }

    filename = f"update-{args.update}-lr-{args.lr}-wd-{args.wd}-seed-{args.seed}.pkl"
    os.makedirs("results", exist_ok=True)

    print(f"Saving results to {os.path.join("results", filename)}")
    with open(os.path.join("results", filename), "wb") as f:
        pickle.dump(results, f)
    print(f"Results saved to {os.path.join("results", filename)}")

Training with: manifold_muon_general
Epochs: 5 --- LR: 0.5 
Nb de paramètres sur la manifold : 4
Epoch 1, Loss: 2.2883199769623426, Time: 20.6788 seconds
Epoch 2, Loss: 2.215462621377439, Time: 20.0864 seconds
Epoch 3, Loss: 1.9384604862758092, Time: 19.6220 seconds
Epoch 4, Loss: 1.6169386712872251, Time: 18.9903 seconds
Epoch 5, Loss: 1.427728957059432, Time: 19.4774 seconds
Accuracy of the network on the 10000 test images: 51.45 %
Accuracy of the network on the 50000 train images: 51.898 %
Saving results to results/update-manifold_muon_general-lr-0.5-wd-0.0-seed-42.pkl
Results saved to results/update-manifold_muon_general-lr-0.5-wd-0.0-seed-42.pkl


First, uninstall the current PyTorch version. This command will prompt you for confirmation, so make sure to type `y` and press Enter when asked.

In [None]:
!pip uninstall torch torchvision torchaudio -y

Found existing installation: torch 2.8.0+cu126
Uninstalling torch-2.8.0+cu126:
  Successfully uninstalled torch-2.8.0+cu126
Found existing installation: torchvision 0.23.0+cu126
Uninstalling torchvision-0.23.0+cu126:
  Successfully uninstalled torchvision-0.23.0+cu126
Found existing installation: torchaudio 2.8.0+cu126
Uninstalling torchaudio-2.8.0+cu126:
  Successfully uninstalled torchaudio-2.8.0+cu126


Next, install the desired PyTorch version. You can find the installation command for specific versions and CUDA compatibility on the official PyTorch website (`pytorch.org/get-started/locally/`).

For example, to install the latest stable version with CUDA 11.8 (common in Colab), you might use something like this:

In [None]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

Looking in indexes: https://download.pytorch.org/whl/cu118
Collecting torch
  Downloading https://download.pytorch.org/whl/cu118/torch-2.7.1%2Bcu118-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (28 kB)
Collecting torchvision
  Downloading https://download.pytorch.org/whl/cu118/torchvision-0.22.1%2Bcu118-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (6.1 kB)
Collecting torchaudio
  Downloading https://download.pytorch.org/whl/cu118/torchaudio-2.7.1%2Bcu118-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (6.6 kB)
Collecting nvidia-cuda-nvrtc-cu11==11.8.89 (from torch)
  Downloading https://download.pytorch.org/whl/cu118/nvidia_cuda_nvrtc_cu11-11.8.89-py3-none-manylinux1_x86_64.whl (23.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.2/23.2 MB[0m [31m117.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu11==11.8.89 (from torch)
  Downloading https://download.pytorch.org/whl/cu118/nvidia_cuda_runtime_cu11-11.8.89-py3-none-manylinux1_x86_64.wh

After running these commands, you should restart the runtime (`Runtime > Restart runtime` from the menu) for the changes to take effect. Then, you can verify the installed version:

In [None]:
import torch
print(torch.__version__)

2.7.1+cu118


Keep in mind that Colab environments are updated periodically, so a specific PyTorch version might become the default over time.