# ðŸš€ Muon MVE - Quick Start (Colab Ready)

**EECS 182 Final Project: Spectral-Norm Constrained Optimization**

This notebook provides a quick, self-contained demo that runs in Google Colab.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/)

---

In [None]:
# Step 1: Setup (run this cell first!)
import sys
import os

# If running in Colab, clone the repo
if 'google.colab' in sys.modules:
    print("Running in Colab - setting up environment...")
    # Uncomment and modify if you have a GitHub repo:
    # !git clone https://github.com/yourusername/muon_mve_project.git
    # %cd muon_mve_project
    # !pip install -q -r requirements.txt
    print("Note: Upload the project files or clone from your repo.")
else:
    # Local setup
    os.chdir('..')  # Assumes running from notebooks/ directory

# Check GPU
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Step 2: Import modules
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm

# Import our modules
try:
    from muon import MuonSGD, SpectralClipSolver, compute_spectral_norms
    from models import get_model
    print("âœ“ Successfully imported muon modules")
except ImportError as e:
    print(f"Import error: {e}")
    print("Make sure you're in the project directory or have uploaded the files.")

In [None]:
# Step 3: Load CIFAR-10
transform_train = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
transform_test = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])

train_set = torchvision.datasets.CIFAR10('./data', train=True, download=True, transform=transform_train)
test_set = torchvision.datasets.CIFAR10('./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2, drop_last=True)
test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)

print(f"âœ“ Loaded CIFAR-10: {len(train_set)} train, {len(test_set)} test")

In [None]:
# Step 4: Quick Demo - Compare SGD vs MuonSGD

EPOCHS = 5  # Quick demo; increase for real experiments

def train_and_evaluate(model, optimizer, scheduler, epochs):
    """Train model and return history."""
    history = {'train_loss': [], 'val_acc': [], 'max_spec': []}
    
    for epoch in range(1, epochs + 1):
        # Train
        model.train()
        total_loss = 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            loss = F.cross_entropy(model(x), y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
        
        # Evaluate
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                correct += (model(x).argmax(1) == y).sum().item()
                total += y.size(0)
        
        val_acc = correct / total
        spec_norms = compute_spectral_norms(model)
        max_spec = max(spec_norms.values()) if spec_norms else 0
        
        history['train_loss'].append(total_loss / len(train_loader))
        history['val_acc'].append(val_acc)
        history['max_spec'].append(max_spec)
        
        print(f"  Epoch {epoch}: loss={history['train_loss'][-1]:.4f}, "
              f"acc={val_acc:.4f}, Ïƒ_max={max_spec:.3f}")
    
    return history

# Train with SGD
print("\n" + "="*50)
print("Training with SGD (baseline)")
print("="*50)
torch.manual_seed(42)
model_sgd = get_model('small_cnn').to(device)
opt_sgd = torch.optim.SGD(model_sgd.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
sched_sgd = torch.optim.lr_scheduler.CosineAnnealingLR(opt_sgd, T_max=EPOCHS)
history_sgd = train_and_evaluate(model_sgd, opt_sgd, sched_sgd, EPOCHS)

# Train with MuonSGD
print("\n" + "="*50)
print("Training with MuonSGD (spectral constraint)")
print("="*50)
torch.manual_seed(42)
model_muon = get_model('small_cnn').to(device)
opt_muon = MuonSGD(
    model_muon.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4,
    spectral_budget=0.1, inner_solver=SpectralClipSolver()
)
sched_muon = torch.optim.lr_scheduler.CosineAnnealingLR(opt_muon, T_max=EPOCHS)
history_muon = train_and_evaluate(model_muon, opt_muon, sched_muon, EPOCHS)

In [None]:
# Step 5: Visualize Results
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

epochs = range(1, EPOCHS + 1)

# Loss
axes[0].plot(epochs, history_sgd['train_loss'], 'b-', label='SGD', linewidth=2)
axes[0].plot(epochs, history_muon['train_loss'], 'r-', label='MuonSGD', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Training Loss')
axes[0].set_title('Training Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(epochs, history_sgd['val_acc'], 'b-', label='SGD', linewidth=2)
axes[1].plot(epochs, history_muon['val_acc'], 'r-', label='MuonSGD', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Validation Accuracy')
axes[1].set_title('Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Spectral Norm
axes[2].plot(epochs, history_sgd['max_spec'], 'b-', label='SGD', linewidth=2)
axes[2].plot(epochs, history_muon['max_spec'], 'r-', label='MuonSGD', linewidth=2)
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Max Spectral Norm')
axes[2].set_title('Spectral Norm Control')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\n" + "="*50)
print("Summary")
print("="*50)
print(f"SGD     - Final Acc: {history_sgd['val_acc'][-1]:.4f}, Final Ïƒ_max: {history_sgd['max_spec'][-1]:.3f}")
print(f"MuonSGD - Final Acc: {history_muon['val_acc'][-1]:.4f}, Final Ïƒ_max: {history_muon['max_spec'][-1]:.3f}")

---
## ðŸŽ¯ Key Observations

1. **Spectral Norm Control**: MuonSGD keeps the spectral norms bounded, while SGD allows them to grow.

2. **Comparable Accuracy**: Despite the constraint, MuonSGD achieves similar accuracy to SGD.

3. **Stability**: The spectral constraint provides a Lipschitz bound on layer-wise transformations.

---

## ðŸ“š Next Steps

1. **More epochs**: Run for 50-100 epochs for full convergence
2. **Different architectures**: Try `resnet18`, `tiny_vit`, or `mlp_mixer`
3. **Different solvers**: Compare `spectral_clip`, `dual_ascent`, `frank_wolfe`, etc.
4. **Width transfer**: Test if hyperparameters transfer across model widths

See the other notebooks for comprehensive experiments!

In [None]:
# Bonus: Try different inner solvers
from muon import get_inner_solver, SOLVER_REGISTRY

print("Available inner solvers:")
for name in SOLVER_REGISTRY.keys():
    print(f"  - {name}")

# Example: Try Frank-Wolfe solver
print("\nTrying Frank-Wolfe solver...")
fw_solver = get_inner_solver('frank_wolfe', max_iters=5)
print(f"Created: {type(fw_solver).__name__}")