# EECS 182 Manifold Muon Experiments - FIXED

**Category 1: Optimizers & Hyperparameter Transfer**

---

In [None]:
# Cell 1: Clone Repository and Setup
import os
import sys

# Clone the repo (run this once)
if not os.path.exists('/content/cs182'):
    !git clone https://github.com/achiii800/cs182.git /content/cs182
else:
    print("Repository already exists")

PROJECT_ROOT = '/content/cs182/muon_mve_project'
os.chdir(PROJECT_ROOT)
sys.path.insert(0, PROJECT_ROOT)

print(f"Working directory: {os.getcwd()}")
print(f"Files: {os.listdir('.')}")

In [None]:
# Cell 2: Install Dependencies
!pip install torch torchvision numpy matplotlib scipy pandas tqdm einops -q

import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Cell 3: Pre-download CIFAR-10 Dataset
# This ensures data is available before training
import torchvision
import torchvision.transforms as T

print("Downloading CIFAR-10 dataset...")
data_dir = os.path.join(PROJECT_ROOT, 'data')
os.makedirs(data_dir, exist_ok=True)

# Download train set
train_set = torchvision.datasets.CIFAR10(
    root=data_dir, train=True, download=True,
    transform=T.ToTensor()
)
# Download test set
test_set = torchvision.datasets.CIFAR10(
    root=data_dir, train=False, download=True,
    transform=T.ToTensor()
)

print(f"✓ Train set: {len(train_set)} samples")
print(f"✓ Test set: {len(test_set)} samples")
print(f"✓ Data saved to: {data_dir}")

In [None]:
# Cell 4: DEBUG - Test imports and find the actual error
import os
import sys
os.chdir(PROJECT_ROOT)

print("Testing imports...")
try:
    from muon import (
        MuonSGD, create_optimizer,
        SpectralClipSolver, DualAscentSolver,
        compute_spectral_norms, estimate_sharpness, estimate_gradient_noise_scale,
        MetricsLogger
    )
    print("✓ muon imports OK")
except Exception as e:
    print(f"✗ muon import error: {e}")
    import traceback
    traceback.print_exc()

try:
    from models import get_model, SmallConvNet
    print("✓ models imports OK")
except Exception as e:
    print(f"✗ models import error: {e}")
    import traceback
    traceback.print_exc()

# Test creating a model
try:
    model = get_model('small_cnn', num_classes=10)
    print(f"✓ Model created: {sum(p.numel() for p in model.parameters()):,} params")
except Exception as e:
    print(f"✗ Model creation error: {e}")

# Test creating optimizer
try:
    opt = create_optimizer(model, 'muon_sgd', 'dual_ascent', lr=0.01, spectral_budget=0.1)
    print(f"✓ Optimizer created: {type(opt).__name__}")
except Exception as e:
    print(f"✗ Optimizer creation error: {e}")
    import traceback
    traceback.print_exc()

In [None]:
# Cell 5: DEBUG - Run train.py with visible error output
import subprocess
import sys
import os

os.chdir(PROJECT_ROOT)

# Run with error capture
cmd = [
    sys.executable, 'train.py',
    '--model', 'small_cnn',
    '--optimizer', 'muon_sgd',
    '--inner-solver', 'dual_ascent',
    '--spectral-budget', '0.1',
    '--lr', '0.01',
    '--epochs', '2',  # Just 2 epochs for testing
    '--logdir', 'logs/debug',
    '--exp-name', 'debug_test',
    '--num-workers', '0',  # IMPORTANT: 0 workers for Colab!
]

print("Running debug test...")
print(f"Command: {' '.join(cmd)}\n")
print("="*60)

result = subprocess.run(cmd, capture_output=True, text=True)

print("STDOUT:")
print(result.stdout)

if result.stderr:
    print("\nSTDERR:")
    print(result.stderr)

print(f"\nExit code: {result.returncode}")

if result.returncode == 0:
    print("\n✓ SUCCESS!")
    # Check log file
    log_path = os.path.join(PROJECT_ROOT, 'logs/debug/debug_test.csv')
    if os.path.exists(log_path):
        with open(log_path) as f:
            print(f"\nLog contents:\n{f.read()}")
else:
    print("\n✗ FAILED - see error above")

In [None]:
# Cell 6: INLINE TRAINING (bypass subprocess entirely)
# If Cell 5 still fails, this runs training directly in the notebook

import os
import sys
import time
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T

os.chdir(PROJECT_ROOT)
sys.path.insert(0, PROJECT_ROOT)

from muon import create_optimizer, compute_spectral_norms, estimate_sharpness, estimate_gradient_noise_scale
from models import get_model

# Config
CONFIG = {
    'model': 'small_cnn',
    'optimizer': 'muon_sgd',
    'inner_solver': 'dual_ascent',
    'spectral_budget': 0.1,
    'lr': 0.01,
    'epochs': 5,
    'batch_size': 128,
    'seed': 42,
}

print(f"Config: {CONFIG}")

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

torch.manual_seed(CONFIG['seed'])

# Data
transform_train = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
transform_test = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])

data_dir = os.path.join(PROJECT_ROOT, 'data')
train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform_train)
test_set = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform_test)

# IMPORTANT: num_workers=0 for Colab!
train_loader = DataLoader(train_set, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=0, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=0, pin_memory=True)

print(f"Train batches: {len(train_loader)}, Test batches: {len(test_loader)}")

# Model
model = get_model(CONFIG['model'], num_classes=10).to(device)
print(f"Model params: {sum(p.numel() for p in model.parameters()):,}")

# Optimizer
optimizer = create_optimizer(
    model,
    optimizer_type=CONFIG['optimizer'],
    inner_solver_type=CONFIG['inner_solver'],
    lr=CONFIG['lr'],
    spectral_budget=CONFIG['spectral_budget'],
)
print(f"Optimizer: {type(optimizer).__name__}")

# Loss function for metrics
def ce_loss_fn(model, x, y):
    return F.cross_entropy(model(x), y)

# Training loop
results = []

print("\nStarting training...")
print("-" * 80)

for epoch in range(1, CONFIG['epochs'] + 1):
    epoch_start = time.time()
    
    # Train
    model.train()
    running_loss = 0.0
    running_correct = 0
    total = 0
    
    for batch_idx, (x, y) in enumerate(train_loader):
        x, y = x.to(device), y.to(device)
        
        optimizer.zero_grad(set_to_none=True)
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * x.size(0)
        running_correct += (logits.argmax(1) == y).sum().item()
        total += y.size(0)
    
    train_loss = running_loss / total
    train_acc = running_correct / total
    
    # Eval
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            val_loss += F.cross_entropy(logits, y, reduction='sum').item()
            val_correct += (logits.argmax(1) == y).sum().item()
            val_total += y.size(0)
    
    val_loss /= val_total
    val_acc = val_correct / val_total
    
    # Metrics
    spec_norms = compute_spectral_norms(model, max_layers=8)
    max_spec = max(spec_norms.values()) if spec_norms else 0.0
    
    epoch_time = time.time() - epoch_start
    
    results.append({
        'epoch': epoch,
        'train_loss': train_loss,
        'train_acc': train_acc,
        'val_loss': val_loss,
        'val_acc': val_acc,
        'max_spectral_norm': max_spec,
        'time': epoch_time,
    })
    
    print(f"Epoch {epoch:2d} | Train: {train_loss:.4f}/{train_acc:.4f} | Val: {val_loss:.4f}/{val_acc:.4f} | σ_max: {max_spec:.3f} | {epoch_time:.1f}s")

print("-" * 80)
print(f"Final val accuracy: {val_acc*100:.2f}%")

# Save results
import pandas as pd
df = pd.DataFrame(results)
os.makedirs('logs/inline_test', exist_ok=True)
df.to_csv('logs/inline_test/results.csv', index=False)
print(f"\n✓ Results saved to logs/inline_test/results.csv")

---
## Experiment 1: Inner Solver Comparison (Fixed)

Now that we know the inline training works, let's run the full experiment.

In [None]:
# Cell 7: Experiment 1 - Inner Solver Comparison (FIXED)
import os
import sys
import time
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T

os.chdir(PROJECT_ROOT)
sys.path.insert(0, PROJECT_ROOT)

from muon import create_optimizer, compute_spectral_norms, estimate_sharpness, estimate_gradient_noise_scale
from models import get_model

# Config
EPOCHS = 20
BATCH_SIZE = 128
LR = 0.01
SPECTRAL_BUDGET = 1.0
SEED = 42

solvers = ['spectral_clip', 'dual_ascent', 'quasi_newton', 'frank_wolfe', 'admm']

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# Data (load once)
transform_train = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
transform_test = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])

data_dir = os.path.join(PROJECT_ROOT, 'data')
train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform_train)
test_set = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

# Results storage
all_results = {}
log_dir = os.path.join(PROJECT_ROOT, 'logs', 'exp1_solver_comparison')
os.makedirs(log_dir, exist_ok=True)

print("="*60)
print("EXPERIMENT 1: Inner Solver Comparison")
print(f"Epochs: {EPOCHS}, LR: {LR}, Budget: {SPECTRAL_BUDGET}")
print("="*60)

def ce_loss_fn(model, x, y):
    return F.cross_entropy(model(x), y)

for solver_name in solvers:
    print(f"\n{'='*60}")
    print(f"Solver: {solver_name}")
    print(f"{'='*60}")
    
    torch.manual_seed(SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(SEED)
    
    # Fresh model
    model = get_model('small_cnn', num_classes=10).to(device)
    
    # Optimizer with this solver
    optimizer = create_optimizer(
        model,
        optimizer_type='muon_sgd',
        inner_solver_type=solver_name,
        lr=LR,
        spectral_budget=SPECTRAL_BUDGET,
    )
    
    results = []
    
    for epoch in range(1, EPOCHS + 1):
        epoch_start = time.time()
        
        # Train
        model.train()
        running_loss = 0.0
        running_correct = 0
        total = 0
        
        # Get batches for GNS
        gns_batches = []
        
        for batch_idx, (x, y) in enumerate(train_loader):
            x, y = x.to(device), y.to(device)
            
            if len(gns_batches) < 2:
                gns_batches.append((x.clone(), y.clone()))
            
            optimizer.zero_grad(set_to_none=True)
            logits = model(x)
            loss = F.cross_entropy(logits, y)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * x.size(0)
            running_correct += (logits.argmax(1) == y).sum().item()
            total += y.size(0)
        
        train_loss = running_loss / total
        train_acc = running_correct / total
        
        # Eval
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                logits = model(x)
                val_loss += F.cross_entropy(logits, y, reduction='sum').item()
                val_correct += (logits.argmax(1) == y).sum().item()
                val_total += y.size(0)
        
        val_loss /= val_total
        val_acc = val_correct / val_total
        
        # Metrics
        spec_norms = compute_spectral_norms(model, max_layers=8)
        max_spec = max(spec_norms.values()) if spec_norms else 0.0
        
        # Sharpness (every epoch)
        try:
            sharpness = estimate_sharpness(model, ce_loss_fn, gns_batches[0][0], gns_batches[0][1], epsilon=1e-3)
        except:
            sharpness = 0.0
        
        # GNS (every epoch)
        try:
            if len(gns_batches) >= 2:
                gns = estimate_gradient_noise_scale(model, ce_loss_fn, gns_batches[0], gns_batches[1])
            else:
                gns = 0.0
        except:
            gns = 0.0
        
        epoch_time = time.time() - epoch_start
        
        results.append({
            'epoch': epoch,
            'train_loss': train_loss,
            'train_acc': train_acc,
            'val_loss': val_loss,
            'val_acc': val_acc,
            'max_spectral_norm': max_spec,
            'sharpness': sharpness,
            'gns': gns,
        })
        
        print(f"Epoch {epoch:2d} | Train: {train_loss:.4f}/{train_acc:.4f} | Val: {val_loss:.4f}/{val_acc:.4f} | σ: {max_spec:.3f} | {epoch_time:.1f}s")
    
    # Save
    df = pd.DataFrame(results)
    csv_path = os.path.join(log_dir, f'solver_{solver_name}.csv')
    df.to_csv(csv_path, index=False)
    all_results[solver_name] = df
    print(f"✓ Saved to {csv_path}")

print("\n" + "="*60)
print("EXPERIMENT 1 COMPLETE!")
print("="*60)

In [None]:
# Cell 8: Plot Experiment 1 Results
import pandas as pd
import matplotlib.pyplot as plt
import os

log_dir = os.path.join(PROJECT_ROOT, 'logs', 'exp1_solver_comparison')

solvers = ['spectral_clip', 'dual_ascent', 'quasi_newton', 'frank_wolfe', 'admm']
colors = {
    'spectral_clip': '#1f77b4',
    'dual_ascent': '#ff7f0e',
    'quasi_newton': '#2ca02c',
    'frank_wolfe': '#d62728',
    'admm': '#9467bd'
}

# Load data
data = {}
for solver in solvers:
    path = os.path.join(log_dir, f'solver_{solver}.csv')
    if os.path.exists(path):
        df = pd.read_csv(path)
        if len(df) > 0:
            data[solver] = df
            print(f"✓ {solver}: {len(df)} epochs, final val_acc={df['val_acc'].iloc[-1]*100:.2f}%")

if len(data) == 0:
    print("No data found!")
else:
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    for solver, df in data.items():
        c = colors[solver]
        axes[0,0].plot(df['epoch'], df['train_loss'], label=solver, color=c, lw=2)
        axes[0,1].plot(df['epoch'], df['val_acc']*100, label=solver, color=c, lw=2)
        axes[1,0].plot(df['epoch'], df['max_spectral_norm'], label=solver, color=c, lw=2)
        axes[1,1].plot(df['epoch'], df['sharpness'], label=solver, color=c, lw=2)
    
    axes[0,0].set_xlabel('Epoch'); axes[0,0].set_ylabel('Train Loss'); axes[0,0].set_title('Training Loss')
    axes[0,1].set_xlabel('Epoch'); axes[0,1].set_ylabel('Val Acc (%)'); axes[0,1].set_title('Validation Accuracy')
    axes[1,0].set_xlabel('Epoch'); axes[1,0].set_ylabel('σ_max'); axes[1,0].set_title('Max Spectral Norm')
    axes[1,1].set_xlabel('Epoch'); axes[1,1].set_ylabel('Sharpness'); axes[1,1].set_title('Sharpness (SAM proxy)')
    
    for ax in axes.flat:
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    fig_path = os.path.join(log_dir, 'exp1_results.png')
    plt.savefig(fig_path, dpi=150, bbox_inches='tight')
    print(f"\n✓ Figure saved to {fig_path}")
    plt.show()
    
    # Summary table
    print("\n" + "="*60)
    print("SUMMARY")
    print("="*60)
    print(f"{'Solver':<15} {'Val Acc':<12} {'Train Loss':<12} {'σ_max':<10}")
    print("-"*60)
    for solver, df in sorted(data.items(), key=lambda x: -x[1]['val_acc'].iloc[-1]):
        print(f"{solver:<15} {df['val_acc'].iloc[-1]*100:>10.2f}% {df['train_loss'].iloc[-1]:>12.4f} {df['max_spectral_norm'].iloc[-1]:>10.4f}")

---
## Experiment 2: Multi-Seed Baselines

In [None]:
# Cell 9: Experiment 2 - Multi-Seed Baseline Comparison on ResNet-18
import os
import sys
import time
import pandas as pd

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T

os.chdir(PROJECT_ROOT)
from muon import create_optimizer, compute_spectral_norms
from models import get_model

# Config
EPOCHS = 50  # Full training
BATCH_SIZE = 128
SEEDS = [0, 1, 2]

configs = [
    # (name, optimizer, inner_solver, lr, spectral_budget)
    ('sgd', 'sgd', 'none', 0.1, None),
    ('adamw', 'adamw', 'none', 0.001, None),
    ('muon_dual', 'muon_sgd', 'dual_ascent', 0.1, 0.1),
    ('muon_clip', 'muon_sgd', 'spectral_clip', 0.1, 0.1),
]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data
transform_train = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
transform_test = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])

data_dir = os.path.join(PROJECT_ROOT, 'data')
train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform_train)
test_set = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

log_dir = os.path.join(PROJECT_ROOT, 'logs', 'exp2_baselines')
os.makedirs(log_dir, exist_ok=True)

print("="*70)
print("EXPERIMENT 2: Multi-Seed Baselines on ResNet-18")
print(f"Epochs: {EPOCHS}, Seeds: {SEEDS}")
print("="*70)

for config_name, opt_type, solver_type, lr, budget in configs:
    for seed in SEEDS:
        exp_name = f"{config_name}_seed{seed}"
        print(f"\nRunning: {exp_name}")
        
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
        
        model = get_model('resnet18', num_classes=10).to(device)
        optimizer = create_optimizer(
            model, optimizer_type=opt_type, inner_solver_type=solver_type,
            lr=lr, spectral_budget=budget,
        )
        
        # Cosine LR schedule
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
        
        results = []
        
        for epoch in range(1, EPOCHS + 1):
            epoch_start = time.time()
            
            # Train
            model.train()
            running_loss, running_correct, total = 0.0, 0, 0
            
            for x, y in train_loader:
                x, y = x.to(device), y.to(device)
                optimizer.zero_grad(set_to_none=True)
                logits = model(x)
                loss = F.cross_entropy(logits, y)
                loss.backward()
                optimizer.step()
                running_loss += loss.item() * x.size(0)
                running_correct += (logits.argmax(1) == y).sum().item()
                total += y.size(0)
            
            scheduler.step()
            train_loss = running_loss / total
            train_acc = running_correct / total
            
            # Eval
            model.eval()
            val_loss, val_correct, val_total = 0.0, 0, 0
            with torch.no_grad():
                for x, y in test_loader:
                    x, y = x.to(device), y.to(device)
                    logits = model(x)
                    val_loss += F.cross_entropy(logits, y, reduction='sum').item()
                    val_correct += (logits.argmax(1) == y).sum().item()
                    val_total += y.size(0)
            
            val_loss /= val_total
            val_acc = val_correct / val_total
            
            spec_norms = compute_spectral_norms(model, max_layers=8)
            max_spec = max(spec_norms.values()) if spec_norms else 0.0
            
            results.append({
                'epoch': epoch,
                'train_loss': train_loss, 'train_acc': train_acc,
                'val_loss': val_loss, 'val_acc': val_acc,
                'max_spectral_norm': max_spec,
            })
            
            if epoch % 10 == 0:
                print(f"  Epoch {epoch:2d}: val_acc={val_acc*100:.2f}%")
        
        df = pd.DataFrame(results)
        df.to_csv(os.path.join(log_dir, f'{exp_name}.csv'), index=False)
        print(f"  ✓ Final: {val_acc*100:.2f}%")

print("\n" + "="*70)
print("EXPERIMENT 2 COMPLETE!")
print("="*70)

In [None]:
# Cell 10: Plot Experiment 2 with Error Bars
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

log_dir = os.path.join(PROJECT_ROOT, 'logs', 'exp2_baselines')

configs = ['sgd', 'adamw', 'muon_dual', 'muon_clip']
seeds = [0, 1, 2]

colors = {'sgd': '#1f77b4', 'adamw': '#ff7f0e', 'muon_dual': '#2ca02c', 'muon_clip': '#d62728'}
labels = {'sgd': 'SGD', 'adamw': 'AdamW', 'muon_dual': 'MuonSGD (DualAscent)', 'muon_clip': 'MuonSGD (SpectralClip)'}

aggregated = defaultdict(list)

for cfg in configs:
    for seed in seeds:
        path = os.path.join(log_dir, f'{cfg}_seed{seed}.csv')
        if os.path.exists(path):
            df = pd.read_csv(path)
            if len(df) > 0:
                aggregated[cfg].append(df)

if all(len(v) == 0 for v in aggregated.values()):
    print("No data found. Run Experiment 2 first.")
else:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    for cfg, dfs in aggregated.items():
        if len(dfs) == 0:
            continue
        
        min_len = min(len(df) for df in dfs)
        epochs = dfs[0]['epoch'].values[:min_len]
        
        losses = np.array([df['train_loss'].values[:min_len] for df in dfs])
        accs = np.array([df['val_acc'].values[:min_len] * 100 for df in dfs])
        
        mean_loss, std_loss = losses.mean(0), losses.std(0)
        mean_acc, std_acc = accs.mean(0), accs.std(0)
        
        c = colors[cfg]
        lbl = labels[cfg]
        
        axes[0].plot(epochs, mean_loss, label=lbl, color=c, lw=2)
        axes[0].fill_between(epochs, mean_loss-std_loss, mean_loss+std_loss, color=c, alpha=0.2)
        
        axes[1].plot(epochs, mean_acc, label=lbl, color=c, lw=2)
        axes[1].fill_between(epochs, mean_acc-std_acc, mean_acc+std_acc, color=c, alpha=0.2)
        
        print(f"{lbl}: {mean_acc[-1]:.2f}% ± {std_acc[-1]:.2f}%")
    
    axes[0].set_xlabel('Epoch'); axes[0].set_ylabel('Train Loss')
    axes[0].set_title('Training Loss (mean ± std)')
    axes[0].legend(); axes[0].grid(True, alpha=0.3)
    
    axes[1].set_xlabel('Epoch'); axes[1].set_ylabel('Val Acc (%)')
    axes[1].set_title('Validation Accuracy (mean ± std)')
    axes[1].legend(); axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(log_dir, 'exp2_results.png'), dpi=150, bbox_inches='tight')
    plt.show()

---
## Experiment 3: Width Transfer

In [None]:
# Cell 11: Experiment 3 - Width Transfer on MLP
import os
import sys
import time
import pandas as pd
import numpy as np

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T

os.chdir(PROJECT_ROOT)
from muon import create_optimizer, compute_spectral_norms
from models import get_model

# Config
EPOCHS = 30
BATCH_SIZE = 128
WIDTHS = [0.5, 0.75, 1.0, 1.5, 2.0]
SEEDS = [0, 1, 2]
LR = 0.01

configs = [
    ('sgd', 'sgd', 'none', None),
    ('muon_dual', 'muon_sgd', 'dual_ascent', 0.1),
]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data
transform = T.Compose([T.ToTensor(), T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])
data_dir = os.path.join(PROJECT_ROOT, 'data')
train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

log_dir = os.path.join(PROJECT_ROOT, 'logs', 'exp3_width_transfer')
os.makedirs(log_dir, exist_ok=True)

print("="*70)
print("EXPERIMENT 3: Width Transfer on MLP")
print(f"Widths: {WIDTHS}, Seeds: {SEEDS}, Epochs: {EPOCHS}")
print("="*70)

for cfg_name, opt_type, solver_type, budget in configs:
    for width in WIDTHS:
        for seed in SEEDS:
            exp_name = f"{cfg_name}_w{width}_s{seed}"
            print(f"Running: {exp_name}")
            
            torch.manual_seed(seed)
            
            model = get_model('mlp', num_classes=10, width_mult=width).to(device)
            optimizer = create_optimizer(
                model, optimizer_type=opt_type, inner_solver_type=solver_type,
                lr=LR, spectral_budget=budget,
            )
            
            results = []
            
            for epoch in range(1, EPOCHS + 1):
                model.train()
                for x, y in train_loader:
                    x, y = x.to(device), y.to(device)
                    optimizer.zero_grad()
                    loss = F.cross_entropy(model(x), y)
                    loss.backward()
                    optimizer.step()
                
                model.eval()
                correct, total = 0, 0
                with torch.no_grad():
                    for x, y in test_loader:
                        x, y = x.to(device), y.to(device)
                        correct += (model(x).argmax(1) == y).sum().item()
                        total += y.size(0)
                val_acc = correct / total
                results.append({'epoch': epoch, 'val_acc': val_acc})
            
            df = pd.DataFrame(results)
            df.to_csv(os.path.join(log_dir, f'{exp_name}.csv'), index=False)
            print(f"  Final: {val_acc*100:.2f}%")

print("\n" + "="*70)
print("EXPERIMENT 3 COMPLETE!")
print("="*70)

In [None]:
# Cell 12: Plot Width Transfer Results
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

log_dir = os.path.join(PROJECT_ROOT, 'logs', 'exp3_width_transfer')

configs = ['sgd', 'muon_dual']
widths = [0.5, 0.75, 1.0, 1.5, 2.0]
seeds = [0, 1, 2]

colors = {'sgd': '#1f77b4', 'muon_dual': '#2ca02c'}
labels = {'sgd': 'SGD', 'muon_dual': 'MuonSGD (DualAscent)'}

results = {cfg: {'widths': [], 'mean': [], 'std': []} for cfg in configs}

for cfg in configs:
    for w in widths:
        accs = []
        for s in seeds:
            path = os.path.join(log_dir, f'{cfg}_w{w}_s{s}.csv')
            if os.path.exists(path):
                df = pd.read_csv(path)
                if len(df) > 0:
                    accs.append(df['val_acc'].iloc[-1] * 100)
        
        if accs:
            results[cfg]['widths'].append(w)
            results[cfg]['mean'].append(np.mean(accs))
            results[cfg]['std'].append(np.std(accs))

fig, ax = plt.subplots(figsize=(10, 6))

for cfg in configs:
    r = results[cfg]
    if r['widths']:
        ax.errorbar(r['widths'], r['mean'], yerr=r['std'],
                   label=labels[cfg], color=colors[cfg],
                   marker='o', markersize=8, lw=2, capsize=5)

ax.axvline(x=1.0, color='gray', linestyle='--', alpha=0.5)
ax.set_xlabel('Width Multiplier', fontsize=12)
ax.set_ylabel('Final Val Accuracy (%)', fontsize=12)
ax.set_title('Width Transfer: MuonSGD vs SGD on MLP', fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(log_dir, 'exp3_width_transfer.png'), dpi=150, bbox_inches='tight')
plt.show()

# Summary
print("\nWidth Transfer Summary:")
for cfg in configs:
    r = results[cfg]
    if r['widths']:
        print(f"\n{labels[cfg]}:")
        for w, m, s in zip(r['widths'], r['mean'], r['std']):
            print(f"  Width {w}: {m:.2f}% ± {s:.2f}%")