<a href="https://colab.research.google.com/github/Fantiflex/Modular_Manifold_Muon/blob/main/Transformers_main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import os
import time
import pickle
import argparse

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms

# -------------------------------------------------------------------
# IMPORTS MUON (à adapter à ton projet)
# -------------------------------------------------------------------
# from ton_module_muon import manifold_muon, manifold_muon_general, hyperspherical_descent, ManifoldLBFGS

# -------------------------------------------------------------------
# DATA
# -------------------------------------------------------------------
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        (0.49139968, 0.48215827, 0.44653124),
        (0.24703233, 0.24348505, 0.26158768),
    ),
])

train_dataset = torchvision.datasets.CIFAR10(
    root="./data", train=True, transform=transform, download=True
)
test_dataset = torchvision.datasets.CIFAR10(
    root="./data", train=False, transform=transform, download=True
)

train_loader = DataLoader(dataset=train_dataset, batch_size=1024, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=1024, shuffle=False)

# -------------------------------------------------------------------
# MODELS
# -------------------------------------------------------------------
class MLP(nn.Module):
    def __init__(self, hidden_dim=512, num_classes=10):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(32 * 32 * 3, hidden_dim, bias=False)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.fc3 = nn.Linear(hidden_dim, num_classes, bias=False)

    def forward(self, x):
        x = x.view(-1, 32 * 32 * 3)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class SimpleCNN(nn.Module):
    """
    CNN simple pour CIFAR-10.
    On peut passer les convs et le fc sous Muon / L-BFGS (aplaties).
    """
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)

        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(128)

        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(256)

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(256, num_classes, bias=False)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = torch.relu(self.bn2(self.conv2(x)))
        x = torch.relu(self.bn3(self.conv3(x)))
        x = self.pool(x)              # [B, 256, 1, 1]
        x = x.view(x.size(0), -1)     # [B, 256]
        x = self.fc(x)                # [B, 10]
        return x


class SimpleViT(nn.Module):
    """
    Petit Vision Transformer pour CIFAR-10.

    - Image 32x32 -> patchs 4x4 -> 8x8=64 patchs.
    - Embedding de patchs (Linear 48 -> embed_dim).
    - Positional embedding appris.
    - 2 couches TransformerEncoder (d_model=128, 4 heads).
    - Average pooling sur les tokens -> Linear vers 10 classes.
    """
    def __init__(
        self,
        img_size=32,
        patch_size=4,
        in_chans=3,
        num_classes=10,
        embed_dim=128,
        depth=2,
        num_heads=4,
        mlp_ratio=4.0,
    ):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_chans = in_chans

        num_patches = (img_size // patch_size) ** 2
        patch_dim = in_chans * patch_size * patch_size

        # Matrice 2D -> parfait pour Muon/L-BFGS
        self.patch_embed = nn.Linear(patch_dim, embed_dim, bias=False)

        # Embedding de position
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=int(embed_dim * mlp_ratio),
            batch_first=True,
            activation="gelu",
            norm_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes, bias=False)

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size and W == self.img_size

        p = self.patch_size
        x = x.reshape(B, C, H // p, p, W // p, p)
        x = x.permute(0, 2, 4, 1, 3, 5).contiguous()
        x = x.view(B, -1, C * p * p)        # [B, num_patches, patch_dim]

        x = self.patch_embed(x)             # [B, num_patches, embed_dim]
        x = x + self.pos_embed              # [B, num_patches, embed_dim]

        x = self.encoder(x)                 # [B, num_patches, embed_dim]
        x = x.mean(dim=1)                   # [B, embed_dim]

        x = self.norm(x)
        x = self.head(x)                    # [B, 10]
        return x


def build_model(model_type: str) -> nn.Module:
    model_type = model_type.lower()
    if model_type == "mlp":
        return MLP()
    elif model_type == "cnn":
        return SimpleCNN()
    elif model_type == "vit":
        return SimpleViT()
    else:
        raise ValueError(f"Unknown model_type: {model_type}")

# -------------------------------------------------------------------
# GLOBAL STATE FOR L-BFGS
# -------------------------------------------------------------------
OPTS = {}  # map: param -> ManifoldLBFGS instance

# -------------------------------------------------------------------
# TRAIN
# -------------------------------------------------------------------
def train(epochs, initial_lr, update, wd, model_type):
    global OPTS
    OPTS = {}  # on reset pour chaque run

    model = build_model(model_type).cuda()
    criterion = nn.CrossEntropyLoss()

    # Cas AdamW vs "update manifold"
    if update == AdamW:
        optimizer = AdamW(model.parameters(), lr=initial_lr, weight_decay=wd)
    else:
        optimizer = None  # Muon / hyperspherical / L-BFGS

    # Est-ce qu'on est dans le cas L-BFGS globalisé ?
    use_lbfgs = (update is manifold_muon_general)

    steps = epochs * len(train_loader)
    step = 0

    # ----- Projection initiale des paramètres manifold -----
    if optimizer is None:
        with torch.no_grad():
            for p in model.parameters():
                # Paramètres 2D (Linear) : on peut les mettre sur Stiefel
                if p.ndim == 2:
                    if use_lbfgs:
                        opt_p = ManifoldLBFGS(eta=initial_lr, history=10)
                        OPTS[p] = opt_p
                        p.data = manifold_muon_general(
                            p.data,
                            torch.zeros_like(p.data),
                            eta=0.0,
                            opt=opt_p,
                        )
                    else:
                        p.data = update(p.data, torch.zeros_like(p.data), eta=0.0)

                # Paramètres 4D (conv) : on les aplati en matrice [out, in*kH*kW]
                elif p.ndim == 4:
                    shape = p.shape
                    W = p.data.view(shape[0], -1)
                    Z = torch.zeros_like(W)
                    if use_lbfgs:
                        opt_p = ManifoldLBFGS(eta=initial_lr, history=10)
                        OPTS[p] = opt_p
                        W_new = manifold_muon_general(W, Z, eta=0.0, opt=opt_p)
                    else:
                        W_new = update(W, Z, eta=0.0)
                    p.data.copy_(W_new.view(shape))
                # Le reste (bias, LayerNorm, pos_embed...) reste euclidien

    epoch_losses = []
    epoch_times = []

    for epoch in range(epochs):
        start_time = time.time()
        running_loss = 0.0
        model.train()

        for i, (images, labels) in enumerate(train_loader):
            images = images.cuda()
            labels = labels.cuda()

            # Forward
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward
            model.zero_grad()
            loss.backward()

            lr = initial_lr * (1 - step / steps)

            with torch.no_grad():
                if optimizer is None:
                    # --------- Cas Muon / hyperspherical / L-BFGS ---------
                    for p in model.parameters():
                        if p.grad is None:
                            continue

                        # Cas L-BFGS globalisé + param manifold
                        if use_lbfgs and p in OPTS and p.ndim in (2, 4):
                            opt_p = OPTS[p]

                            if p.ndim == 2:
                                W = p.data
                                G = p.grad
                            else:  # conv 4D
                                shape = p.shape
                                W = p.data.view(shape[0], -1)
                                G = p.grad.view(shape[0], -1)

                            # 1) On met à jour la courbure avec le gradient courant
                            #    (opt.update utilisera l'info _pending du step précédent)
                            opt_p.update(G)

                            # 2) On fait un step quasi-Newton sur la variété
                            W_new = manifold_muon_general(W, G, eta=lr, opt=opt_p)

                            if p.ndim == 2:
                                p.data.copy_(W_new)
                            else:
                                p.data.copy_(W_new.view(shape))

                        else:
                            # ----- Cas "update" stateless (manifold_muon / hyperspherical) -----
                            if p.ndim == 2:
                                p.data = update(p.data, p.grad, eta=lr)
                            elif p.ndim == 4:
                                shape = p.shape
                                W = p.data.view(shape[0], -1)
                                G = p.grad.view(shape[0], -1)
                                W_new = update(W, G, eta=lr)
                                p.data.copy_(W_new.view(shape))
                            else:
                                # Paramètres euclidiens (embeddings, LayerNorm, bias...)
                                p.data = p.data - lr * p.grad
                else:
                    # --------- Cas AdamW ---------
                    for param_group in optimizer.param_groups:
                        param_group["lr"] = lr
                    optimizer.step()

            step += 1
            running_loss += loss.item()

            if (i + 1) % 100 == 0:
                print(
                    f"Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(train_loader)}], "
                    f"Loss: {loss.item():.4f}"
                )

        end_time = time.time()
        epoch_loss = running_loss / len(train_loader)
        epoch_time = end_time - start_time
        epoch_losses.append(epoch_loss)
        epoch_times.append(epoch_time)
        print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}, Time: {epoch_time:.4f} seconds")

    return model, epoch_losses, epoch_times

# -------------------------------------------------------------------
# EVAL
# -------------------------------------------------------------------
def eval(model):
    model.eval()
    with torch.no_grad():
        accs = []
        for dataloader in [test_loader, train_loader]:
            correct = 0
            total = 0
            for images, labels in dataloader:
                images = images.cuda()
                labels = labels.cuda()
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            accs.append(100.0 * correct / total)

    print(
        f"Accuracy of the network on the {len(test_loader.dataset)} test images: "
        f"{accs[0]} %"
    )
    print(
        f"Accuracy of the network on the {len(train_loader.dataset)} train images: "
        f"{accs[1]} %"
    )
    return accs

# -------------------------------------------------------------------
# WEIGHT STATS
# -------------------------------------------------------------------
def weight_stats(model):
    singular_values = []
    norms = []
    for p in model.parameters():
        norms.append(p.norm().item())
        if p.ndim >= 2:
            mat = p.view(p.size(0), -1)
            u, s, v = torch.svd(mat)
            singular_values.append(s)
    return singular_values, norms

# -------------------------------------------------------------------
# MAIN
# -------------------------------------------------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train a model on CIFAR-10.")
    parser.add_argument("--model", type=str, default="mlp",
                        choices=["mlp", "cnn", "vit"],
                        help="Model type to use.")
    parser.add_argument("--epochs", type=int, default=5,
                        help="Number of epochs to train for.")
    parser.add_argument("--lr", type=float, default=0.1,
                        help="Initial learning rate.")
    parser.add_argument(
        "--update",
        type=str,
        default="manifold_muon_general",
        choices=["manifold_muon", "manifold_muon_general",
                 "hyperspherical_descent", "adam"],
        help="Update rule to use.",
    )
    parser.add_argument("--seed", type=int, default=42,
                        help="Seed for the random number generator.")
    parser.add_argument("--wd", type=float, default=0.0,
                        help="Weight decay for AdamW.")
    args = parser.parse_args([])  # mets None si tu lances en CLI

    # Determinism
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Dictionnaire des règles d'update
    update_rules = {
        "manifold_muon": manifold_muon,
        "manifold_muon_general": manifold_muon_general,
        "hyperspherical_descent": hyperspherical_descent,
        "adam": AdamW,
    }
    update = update_rules[args.update]

    print(f"Training {args.model.upper()} with: {args.update}")
    print(
        f"Epochs: {args.epochs} --- LR: {args.lr}"
        + (f" --- WD: {args.wd}" if args.update == "adam" else "")
    )

    model, epoch_losses, epoch_times = train(
        epochs=args.epochs,
        initial_lr=args.lr,
        update=update,
        wd=args.wd,
        model_type=args.model,
    )
    test_acc, train_acc = eval(model)
    singular_values, norms = weight_stats(model)

    results = {
        "model": args.model,
        "epochs": args.epochs,
        "lr": args.lr,
        "seed": args.seed,
        "wd": args.wd,
        "update": args.update,
        "epoch_losses": epoch_losses,
        "epoch_times": epoch_times,
        "test_acc": test_acc,
        "train_acc": train_acc,
        "singular_values": singular_values,
        "norms": norms,
    }

    filename = (
        f"model-{args.model}-update-{args.update}-lr-{args.lr}"
        f"-wd-{args.wd}-seed-{args.seed}.pkl"
    )
    os.makedirs("results", exist_ok=True)
    save_path = os.path.join("results", filename)
    print(f"Saving results to {save_path}")
    with open(save_path, "wb") as f:
        pickle.dump(results, f)
    print(f"Results saved to {save_path}")


Training VIT with: adam
Epochs: 20 --- LR: 0.001 --- WD: 0.0
Epoch 1, Loss: 1.9035, Time: 16.2980 seconds
Epoch 2, Loss: 1.5903, Time: 15.5355 seconds
Epoch 3, Loss: 1.4021, Time: 15.7451 seconds
Epoch 4, Loss: 1.2736, Time: 15.7018 seconds
Epoch 5, Loss: 1.1680, Time: 16.0391 seconds
Epoch 6, Loss: 1.0991, Time: 16.0083 seconds
Epoch 7, Loss: 1.0356, Time: 15.6703 seconds
Epoch 8, Loss: 0.9829, Time: 15.5816 seconds
Epoch 9, Loss: 0.9432, Time: 15.3213 seconds
Epoch 10, Loss: 0.9123, Time: 15.5982 seconds
Epoch 11, Loss: 0.8782, Time: 15.7600 seconds
Epoch 12, Loss: 0.8391, Time: 15.2527 seconds
Epoch 13, Loss: 0.8041, Time: 15.4981 seconds
Epoch 14, Loss: 0.7828, Time: 15.5278 seconds
Epoch 15, Loss: 0.7516, Time: 16.4243 seconds
Epoch 16, Loss: 0.7305, Time: 16.0905 seconds
Epoch 17, Loss: 0.7134, Time: 15.9614 seconds
Epoch 18, Loss: 0.6938, Time: 15.4912 seconds
Epoch 19, Loss: 0.6761, Time: 15.4057 seconds
Epoch 20, Loss: 0.6609, Time: 16.2186 seconds
Accuracy of the network on t

In [1]:
import argparse
import os
import pickle
import time

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

from google.colab import drive
drive.mount('/content/drive')
device = torch.device("cuda" if torch.cuda.is_available() else "gpu")

%run "/content/drive/MyDrive/Colab_Notebooks/EECS182_project/hyperspherical_descent.ipynb"
%run "/content/drive/MyDrive/Colab_Notebooks/EECS182_project/LGFBS_global.ipynb"
# after this, the functions defined inside those notebooks are available in the current notebook
from torch.optim import AdamW
from torch.utils.data import DataLoader

Mounted at /content/drive
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
