# Muon Experiments Notebook

**EECS 182 Final Project - Running Experiments**

This notebook runs the core experiments for our project on spectral-norm constrained optimization.

---

## Experiments Overview

1. **Baseline Comparison**: SGD vs AdamW vs MuonSGD on SmallCNN
2. **Inner Solver Comparison**: Different solvers on ResNet-18
3. **Spectral Budget Sweep**: Effect of budget on training dynamics
4. **Width Transfer Experiment**: muP-style width scaling
5. **LR Stability Envelope**: Finding max stable learning rate

In [None]:
# Setup and imports
import sys
import os

# Add parent directory to path
sys.path.insert(0, '..')
os.chdir('..')  # Change to project root

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import json
import time

# Our modules
from muon import (
    MuonSGD, MuonAdamW, create_optimizer,
    SpectralClipSolver, FrankWolfeSolver, DualAscentSolver,
    QuasiNewtonDualSolver, ADMMSolver, get_inner_solver,
    compute_spectral_norms, estimate_sharpness, estimate_gradient_noise_scale,
    MetricsLogger
)
from models import get_model, MODEL_REGISTRY

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"Available models: {list(MODEL_REGISTRY.keys())}")

In [None]:
# Data loading utilities

def get_cifar10_loaders(batch_size=128, num_workers=2):
    """Create CIFAR-10 train/test loaders with standard augmentation."""
    
    transform_train = T.Compose([
        T.RandomCrop(32, padding=4),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    
    transform_test = T.Compose([
        T.ToTensor(),
        T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    
    train_set = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train
    )
    test_set = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test
    )
    
    train_loader = DataLoader(
        train_set, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True, drop_last=True
    )
    test_loader = DataLoader(
        test_set, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True
    )
    
    return train_loader, test_loader

# Load data
train_loader, test_loader = get_cifar10_loaders(batch_size=128)
print(f"Train batches: {len(train_loader)}, Test batches: {len(test_loader)}")

In [None]:
# Training utilities

def ce_loss_fn(model, x, y):
    """Cross-entropy loss for metrics API."""
    return F.cross_entropy(model(x), y)

@torch.no_grad()
def evaluate(model, loader, device):
    """Evaluate model and return (loss, accuracy)."""
    model.eval()
    total_loss, total_correct, total = 0.0, 0, 0
    
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        total_loss += F.cross_entropy(logits, y, reduction='sum').item()
        total_correct += (logits.argmax(1) == y).sum().item()
        total += y.size(0)
    
    return total_loss / total, total_correct / total


def train_epoch(model, loader, optimizer, device, scheduler=None):
    """Train for one epoch and return (loss, accuracy)."""
    model.train()
    total_loss, total_correct, total = 0.0, 0, 0
    
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        
        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * x.size(0)
        total_correct += (logits.argmax(1) == y).sum().item()
        total += y.size(0)
    
    if scheduler:
        scheduler.step()
    
    return total_loss / total, total_correct / total


def run_experiment(
    model_name='small_cnn',
    optimizer_type='muon_sgd',
    inner_solver_type='spectral_clip',
    spectral_budget=0.1,
    lr=0.1,
    epochs=10,
    width_mult=1.0,
    seed=42,
    verbose=True
):
    """Run a complete training experiment and return metrics history."""
    
    # Set seed
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    
    # Create model
    model = get_model(model_name, num_classes=10, width_mult=width_mult).to(device)
    num_params = sum(p.numel() for p in model.parameters())
    
    if verbose:
        print(f"Model: {model_name}, Params: {num_params:,}, Width: {width_mult}")
    
    # Create optimizer
    optimizer = create_optimizer(
        model,
        optimizer_type=optimizer_type,
        inner_solver_type=inner_solver_type,
        lr=lr,
        momentum=0.9,
        weight_decay=5e-4,
        spectral_budget=spectral_budget if inner_solver_type != 'none' else None
    )
    
    # LR scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    # Metrics logger
    logger = MetricsLogger()
    
    # Get batches for GNS estimation
    gns_batches = []
    for i, batch in enumerate(train_loader):
        gns_batches.append(batch)
        if len(gns_batches) >= 2:
            break
    
    # Training loop
    for epoch in range(1, epochs + 1):
        t0 = time.time()
        
        # Train
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, device)
        scheduler.step()
        
        # Evaluate
        val_loss, val_acc = evaluate(model, test_loader, device)
        
        # Compute metrics
        spec_norms = compute_spectral_norms(model, max_layers=8)
        max_spec = max(spec_norms.values()) if spec_norms else 0.0
        
        sharpness = estimate_sharpness(
            model, ce_loss_fn,
            gns_batches[0][0], gns_batches[0][1],
            epsilon=1e-3
        )
        
        gns = estimate_gradient_noise_scale(
            model, ce_loss_fn,
            gns_batches[0], gns_batches[1]
        )
        
        current_lr = optimizer.param_groups[0]['lr']
        epoch_time = time.time() - t0
        
        # Log
        metrics = {
            'epoch': epoch,
            'train_loss': train_loss,
            'train_acc': train_acc,
            'val_loss': val_loss,
            'val_acc': val_acc,
            'max_spectral_norm': max_spec,
            'sharpness': sharpness,
            'grad_noise_scale': gns,
            'lr': current_lr,
            'time': epoch_time
        }
        logger.log(metrics)
        
        if verbose:
            print(f"Epoch {epoch:3d} | Train: {train_loss:.4f}/{train_acc:.4f} | "
                  f"Val: {val_loss:.4f}/{val_acc:.4f} | σ_max: {max_spec:.3f}")
    
    return logger.history, model

---
## Experiment 1: Baseline Comparison (SmallCNN)

Compare SGD, AdamW, and MuonSGD on the small CNN to verify everything works.

In [None]:
# Experiment 1: Quick baseline comparison
EPOCHS = 10  # Quick run; increase for full experiments

results_baseline = {}

configs = [
    ('SGD', 'sgd', 'none', 0.1),
    ('MuonSGD + SpectralClip', 'muon_sgd', 'spectral_clip', 0.1),
    ('MuonSGD + DualAscent', 'muon_sgd', 'dual_ascent', 0.1),
]

for name, opt, solver, lr in configs:
    print(f"\n{'='*60}")
    print(f"Running: {name}")
    print('='*60)
    
    history, _ = run_experiment(
        model_name='small_cnn',
        optimizer_type=opt,
        inner_solver_type=solver,
        spectral_budget=0.1,
        lr=lr,
        epochs=EPOCHS,
        seed=42
    )
    
    results_baseline[name] = history

print("\nBaseline experiments complete!")

In [None]:
# Plot baseline comparison
fig, axes = plt.subplots(2, 3, figsize=(14, 8))

metrics_to_plot = [
    ('val_loss', 'Validation Loss'),
    ('val_acc', 'Validation Accuracy'),
    ('max_spectral_norm', 'Max Spectral Norm'),
    ('sharpness', 'Sharpness'),
    ('grad_noise_scale', 'Gradient Noise Scale'),
    ('train_loss', 'Training Loss'),
]

for ax, (metric, title) in zip(axes.flatten(), metrics_to_plot):
    for name, history in results_baseline.items():
        epochs = [h['epoch'] for h in history]
        values = [h[metric] for h in history]
        ax.plot(epochs, values, label=name, linewidth=2)
    
    ax.set_xlabel('Epoch')
    ax.set_ylabel(title)
    ax.set_title(title)
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.3)
    
    # Log scale for GNS
    if metric == 'grad_noise_scale':
        ax.set_yscale('log')

plt.tight_layout()
plt.savefig('logs/baseline_comparison.png', dpi=150)
plt.show()

print("\nFinal validation accuracies:")
for name, history in results_baseline.items():
    print(f"  {name}: {history[-1]['val_acc']:.4f}")

---
## Experiment 2: Inner Solver Comparison (ResNet-18)

Compare all inner solvers on ResNet-18 for a more realistic benchmark.

In [None]:
# Experiment 2: Inner solver comparison on ResNet-18
EPOCHS = 20  # Moderate run

results_solvers = {}

solver_configs = [
    ('SGD (baseline)', 'sgd', 'none'),
    ('SpectralClip', 'muon_sgd', 'spectral_clip'),
    ('DualAscent', 'muon_sgd', 'dual_ascent'),
    ('QuasiNewton', 'muon_sgd', 'quasi_newton'),
    ('FrankWolfe', 'muon_sgd', 'frank_wolfe'),
    ('ADMM', 'muon_sgd', 'admm'),
]

for name, opt, solver in solver_configs:
    print(f"\n{'='*60}")
    print(f"Running: {name}")
    print('='*60)
    
    try:
        history, _ = run_experiment(
            model_name='resnet18',
            optimizer_type=opt,
            inner_solver_type=solver,
            spectral_budget=0.1,
            lr=0.1,
            epochs=EPOCHS,
            seed=42
        )
        results_solvers[name] = history
    except Exception as e:
        print(f"Error with {name}: {e}")

print("\nSolver comparison complete!")

In [None]:
# Plot solver comparison
fig, axes = plt.subplots(2, 3, figsize=(14, 8))

for ax, (metric, title) in zip(axes.flatten(), metrics_to_plot):
    for name, history in results_solvers.items():
        epochs = [h['epoch'] for h in history]
        values = [h[metric] for h in history]
        ax.plot(epochs, values, label=name, linewidth=1.5)
    
    ax.set_xlabel('Epoch')
    ax.set_ylabel(title)
    ax.set_title(f'{title} (ResNet-18)')
    ax.legend(fontsize=7)
    ax.grid(True, alpha=0.3)
    
    if metric == 'grad_noise_scale':
        ax.set_yscale('log')

plt.tight_layout()
plt.savefig('logs/solver_comparison_resnet18.png', dpi=150)
plt.show()

print("\nFinal validation accuracies (ResNet-18):")
for name, history in sorted(results_solvers.items(), key=lambda x: -x[1][-1]['val_acc']):
    print(f"  {name:20s}: {history[-1]['val_acc']:.4f}")

---
## Experiment 3: Spectral Budget Sweep

How does the spectral budget affect training dynamics?

In [None]:
# Experiment 3: Spectral budget sweep
EPOCHS = 15

budgets = [0.01, 0.05, 0.1, 0.2, 0.5, 1.0]
results_budget = {}

for budget in budgets:
    print(f"\nRunning with spectral_budget = {budget}...")
    
    history, _ = run_experiment(
        model_name='small_cnn',
        optimizer_type='muon_sgd',
        inner_solver_type='spectral_clip',
        spectral_budget=budget,
        lr=0.1,
        epochs=EPOCHS,
        seed=42,
        verbose=False
    )
    
    results_budget[budget] = history
    print(f"  Final val_acc: {history[-1]['val_acc']:.4f}, max_spec: {history[-1]['max_spectral_norm']:.3f}")

print("\nBudget sweep complete!")

In [None]:
# Plot budget sweep results
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# Color map for budgets
colors = plt.cm.viridis(np.linspace(0, 1, len(budgets)))

# Val accuracy
for (budget, history), color in zip(results_budget.items(), colors):
    epochs = [h['epoch'] for h in history]
    values = [h['val_acc'] for h in history]
    axes[0].plot(epochs, values, label=f'η={budget}', color=color, linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Validation Accuracy')
axes[0].set_title('Accuracy vs Spectral Budget')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Spectral norm
for (budget, history), color in zip(results_budget.items(), colors):
    epochs = [h['epoch'] for h in history]
    values = [h['max_spectral_norm'] for h in history]
    axes[1].plot(epochs, values, label=f'η={budget}', color=color, linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Max Spectral Norm')
axes[1].set_title('Spectral Norm Trajectory')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Final accuracy vs budget
final_accs = [results_budget[b][-1]['val_acc'] for b in budgets]
axes[2].plot(budgets, final_accs, 'o-', markersize=10, linewidth=2)
axes[2].set_xlabel('Spectral Budget (η)')
axes[2].set_ylabel('Final Validation Accuracy')
axes[2].set_title('Accuracy vs Budget')
axes[2].set_xscale('log')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('logs/budget_sweep.png', dpi=150)
plt.show()

---
## Experiment 4: Width Transfer (muP-style)

Test if hyperparameters transfer across model widths when using spectral constraints.

In [None]:
# Experiment 4: Width transfer
EPOCHS = 20

width_mults = [0.5, 0.75, 1.0, 1.5, 2.0]
results_width_muon = {}
results_width_sgd = {}

print("Testing width transfer with MuonSGD...")
for width in width_mults:
    print(f"\n  Width = {width}x")
    history, _ = run_experiment(
        model_name='mlp',
        optimizer_type='muon_sgd',
        inner_solver_type='spectral_clip',
        spectral_budget=0.1,
        lr=0.1,
        epochs=EPOCHS,
        width_mult=width,
        seed=42,
        verbose=False
    )
    results_width_muon[width] = history
    print(f"    Final val_acc: {history[-1]['val_acc']:.4f}")

print("\nTesting width transfer with SGD (baseline)...")
for width in width_mults:
    print(f"\n  Width = {width}x")
    history, _ = run_experiment(
        model_name='mlp',
        optimizer_type='sgd',
        inner_solver_type='none',
        spectral_budget=0.1,
        lr=0.1,
        epochs=EPOCHS,
        width_mult=width,
        seed=42,
        verbose=False
    )
    results_width_sgd[width] = history
    print(f"    Final val_acc: {history[-1]['val_acc']:.4f}")

print("\nWidth transfer experiments complete!")

In [None]:
# Plot width transfer results
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Final accuracy vs width
muon_accs = [results_width_muon[w][-1]['val_acc'] for w in width_mults]
sgd_accs = [results_width_sgd[w][-1]['val_acc'] for w in width_mults]

axes[0].plot(width_mults, muon_accs, 'o-', label='MuonSGD', markersize=10, linewidth=2)
axes[0].plot(width_mults, sgd_accs, 's--', label='SGD', markersize=10, linewidth=2)
axes[0].set_xlabel('Width Multiplier')
axes[0].set_ylabel('Final Validation Accuracy')
axes[0].set_title('Width Transfer: Accuracy vs Width')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Variance across widths
muon_std = np.std(muon_accs)
sgd_std = np.std(sgd_accs)

axes[1].bar(['MuonSGD', 'SGD'], [muon_std, sgd_std], color=['tab:blue', 'tab:orange'])
axes[1].set_ylabel('Std Dev of Accuracy Across Widths')
axes[1].set_title('Transfer Stability (lower = better)')
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('logs/width_transfer.png', dpi=150)
plt.show()

print(f"\nTransfer stability (std dev across widths):")
print(f"  MuonSGD: {muon_std:.4f}")
print(f"  SGD:     {sgd_std:.4f}")

---
## Experiment 5: LR Stability Envelope

Find the maximum stable learning rate for different optimizers.

In [None]:
# Experiment 5: LR stability scan
from muon import lr_stability_scan

def model_factory():
    return get_model('small_cnn')

def sgd_factory(model, lr):
    return torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

def muon_factory(model, lr):
    return MuonSGD(
        model.parameters(), lr=lr, momentum=0.9,
        spectral_budget=0.1, inner_solver=SpectralClipSolver()
    )

print("Scanning LR envelope for SGD...")
sgd_envelope = lr_stability_scan(
    model_factory, sgd_factory, ce_loss_fn, train_loader,
    lr_range=(1e-3, 2.0), num_lrs=12, steps_per_lr=100,
    device=device
)

print("\nScanning LR envelope for MuonSGD...")
muon_envelope = lr_stability_scan(
    model_factory, muon_factory, ce_loss_fn, train_loader,
    lr_range=(1e-3, 2.0), num_lrs=12, steps_per_lr=100,
    device=device
)

print(f"\nMax stable LR:")
print(f"  SGD:     {sgd_envelope['max_stable_lr']:.4f}")
print(f"  MuonSGD: {muon_envelope['max_stable_lr']:.4f}")

In [None]:
# Plot LR envelope
fig, ax = plt.subplots(figsize=(8, 5))

# SGD
sgd_conv_lr = [lr for lr, c in zip(sgd_envelope['lr_values'], sgd_envelope['converged']) if c]
sgd_conv_loss = [loss for loss, c in zip(sgd_envelope['final_losses'], sgd_envelope['converged']) if c]
sgd_div_lr = [lr for lr, c in zip(sgd_envelope['lr_values'], sgd_envelope['converged']) if not c]

ax.scatter(sgd_conv_lr, sgd_conv_loss, c='blue', label='SGD (converged)', s=80, marker='o')
ax.scatter(sgd_div_lr, [max(sgd_conv_loss) if sgd_conv_loss else 10] * len(sgd_div_lr),
           c='blue', marker='x', s=80, label='SGD (diverged)')

# Muon
muon_conv_lr = [lr for lr, c in zip(muon_envelope['lr_values'], muon_envelope['converged']) if c]
muon_conv_loss = [loss for loss, c in zip(muon_envelope['final_losses'], muon_envelope['converged']) if c]
muon_div_lr = [lr for lr, c in zip(muon_envelope['lr_values'], muon_envelope['converged']) if not c]

ax.scatter(muon_conv_lr, muon_conv_loss, c='red', label='MuonSGD (converged)', s=80, marker='o')
ax.scatter(muon_div_lr, [max(muon_conv_loss) if muon_conv_loss else 10] * len(muon_div_lr),
           c='red', marker='x', s=80, label='MuonSGD (diverged)')

# Max stable LR lines
ax.axvline(sgd_envelope['max_stable_lr'], color='blue', linestyle='--', alpha=0.7,
           label=f"SGD max stable: {sgd_envelope['max_stable_lr']:.3f}")
ax.axvline(muon_envelope['max_stable_lr'], color='red', linestyle='--', alpha=0.7,
           label=f"Muon max stable: {muon_envelope['max_stable_lr']:.3f}")

ax.set_xscale('log')
ax.set_xlabel('Learning Rate')
ax.set_ylabel('Final Loss (after 100 steps)')
ax.set_title('Learning Rate Stability Envelope')
ax.legend(loc='upper left')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('logs/lr_envelope.png', dpi=150)
plt.show()

---
## Save All Results

In [None]:
# Save all results to JSON for later analysis
import json
import os

os.makedirs('logs', exist_ok=True)

all_results = {
    'baseline_comparison': {k: v for k, v in results_baseline.items()},
    'solver_comparison': {k: v for k, v in results_solvers.items()},
    'budget_sweep': {str(k): v for k, v in results_budget.items()},
    'width_transfer_muon': {str(k): v for k, v in results_width_muon.items()},
    'width_transfer_sgd': {str(k): v for k, v in results_width_sgd.items()},
    'lr_envelope_sgd': sgd_envelope,
    'lr_envelope_muon': muon_envelope,
}

with open('logs/all_experiment_results.json', 'w') as f:
    json.dump(all_results, f, indent=2)

print("All results saved to logs/all_experiment_results.json")

---
## Summary

Key findings from these experiments:

1. **Baseline Comparison**: MuonSGD with spectral constraints achieves competitive accuracy while controlling spectral norms

2. **Inner Solver Comparison**: Different solvers show trade-offs between computational cost and constraint satisfaction

3. **Spectral Budget**: There's an optimal budget range - too small constrains learning, too large doesn't help

4. **Width Transfer**: Spectral constraints improve hyperparameter transfer across widths

5. **LR Stability**: MuonSGD may extend the stable LR region compared to vanilla SGD

---

Continue to **03_analysis.ipynb** for deeper analysis of these results.