# EECS 182 Final: Manifold Muon - Quick Experiments (~90 min total)

**Run all cells in order. Total time: ~90 minutes on T4 GPU.**

Experiments:
1. Inner Solver Comparison (5 solvers, 20 epochs) - ~25 min
2. Multi-seed runs for error bars (2 best solvers, 3 seeds) - ~20 min  
3. LR Stability Envelope (shows widened stable LR) - ~15 min
4. Width Transfer (Category 1 key deliverable) - ~25 min

In [None]:
# Cell 1: Setup - Clone repo and install dependencies (~2 min)
!pip install -q torch torchvision matplotlib pandas numpy

# Clone from GitHub (your private repo)
# Option 1: If you're logged into Colab with GitHub access
!git clone https://github.com/aksrao-berkeley/cs182.git

# Option 2: With PAT (uncomment if needed)
# !git clone https://<PAT>@github.com/aksrao-berkeley/cs182.git

%cd cs182/muon_mve_project
!ls

In [None]:
# Cell 2: Verify GPU
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Cell 3: Create output directories
import os
import subprocess
import time
os.makedirs('logs/exp1_solver_comparison', exist_ok=True)
os.makedirs('logs/exp2_multiseed', exist_ok=True)
os.makedirs('logs/exp3_lr_sweep', exist_ok=True)
os.makedirs('logs/exp4_width_transfer', exist_ok=True)
os.makedirs('plots', exist_ok=True)
print("Directories created!")

---
## Experiment 1: Inner Solver Comparison (~25 min)
Compare all 5 inner solvers on SmallCNN, 20 epochs each

In [None]:
# Cell 4: Run Experiment 1 - Inner Solver Comparison
solvers = ['spectral_clip', 'dual_ascent', 'quasi_newton', 'frank_wolfe', 'admm']
exp1_start = time.time()

for solver in solvers:
    print(f"\n{'='*60}")
    print(f"Running: {solver}")
    print(f"{'='*60}")
    
    cmd = [
        'python', 'train.py',
        '--model', 'small_cnn',
        '--optimizer', 'muon_sgd',
        '--inner-solver', solver,
        '--spectral-budget', '0.1',
        '--lr', '0.05',
        '--epochs', '20',
        '--batch-size', '128',
        '--logdir', 'logs/exp1_solver_comparison',
        '--exp-name', f'solver_{solver}',
        '--seed', '42',
        '--log-interval', '200'
    ]
    subprocess.run(cmd)

exp1_time = time.time() - exp1_start
print(f"\nExperiment 1 complete! Time: {exp1_time/60:.1f} min")

In [None]:
# Cell 5: Generate Experiment 1 plots
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Load all solver results
solver_dfs = {}
for solver in solvers:
    path = f'logs/exp1_solver_comparison/solver_{solver}.csv'
    if os.path.exists(path):
        solver_dfs[solver] = pd.read_csv(path)

# Plot 1: Training Loss Comparison
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Loss curves
ax = axes[0, 0]
for solver, df in solver_dfs.items():
    ax.plot(df['epoch'], df['train_loss'], label=solver, linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Training Loss')
ax.set_title('Training Loss by Inner Solver')
ax.legend()
ax.grid(True, alpha=0.3)

# Validation Accuracy
ax = axes[0, 1]
for solver, df in solver_dfs.items():
    ax.plot(df['epoch'], df['val_acc']*100, label=solver, linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Validation Accuracy (%)')
ax.set_title('Validation Accuracy by Inner Solver')
ax.legend()
ax.grid(True, alpha=0.3)

# Spectral Norm
ax = axes[1, 0]
for solver, df in solver_dfs.items():
    ax.plot(df['epoch'], df['max_spectral_norm'], label=solver, linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Max Spectral Norm')
ax.set_title('Spectral Norm Trajectory')
ax.legend()
ax.grid(True, alpha=0.3)

# Sharpness
ax = axes[1, 1]
for solver, df in solver_dfs.items():
    ax.plot(df['epoch'], df['sharpness'], label=solver, linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Sharpness (SAM proxy)')
ax.set_title('Sharpness Trajectory')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('plots/exp1_solver_comparison.png', dpi=150, bbox_inches='tight')
plt.savefig('plots/exp1_solver_comparison.pdf', bbox_inches='tight')
plt.show()

# Summary table
print("\n" + "="*60)
print("EXPERIMENT 1 SUMMARY: Inner Solver Comparison")
print("="*60)
print(f"{'Solver':<15} {'Final Val Acc':<15} {'Final Loss':<15} {'Final σ_max':<15}")
print("-"*60)
for solver, df in solver_dfs.items():
    final = df.iloc[-1]
    print(f"{solver:<15} {final['val_acc']*100:.1f}%{'':<8} {final['val_loss']:.4f}{'':<8} {final['max_spectral_norm']:.4f}")

---
## Experiment 2: Multi-seed Runs for Error Bars (~20 min)
Run 3 seeds each for dual_ascent and spectral_clip (best performers)

In [None]:
# Cell 6: Run Experiment 2 - Multi-seed
exp2_start = time.time()
seeds = [0, 1, 2]
best_solvers = ['dual_ascent', 'spectral_clip']

for solver in best_solvers:
    for seed in seeds:
        print(f"\nRunning {solver} seed={seed}...")
        cmd = [
            'python', 'train.py',
            '--model', 'small_cnn',
            '--optimizer', 'muon_sgd',
            '--inner-solver', solver,
            '--spectral-budget', '0.1',
            '--lr', '0.05',
            '--epochs', '15',
            '--batch-size', '128',
            '--logdir', 'logs/exp2_multiseed',
            '--exp-name', f'{solver}_seed{seed}',
            '--seed', str(seed),
            '--log-interval', '300'
        ]
        subprocess.run(cmd)

exp2_time = time.time() - exp2_start
print(f"\nExperiment 2 complete! Time: {exp2_time/60:.1f} min")

In [None]:
# Cell 7: Generate Experiment 2 plots (with error bars)
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
colors = {'dual_ascent': 'blue', 'spectral_clip': 'orange'}

for solver in best_solvers:
    dfs = []
    for seed in seeds:
        path = f'logs/exp2_multiseed/{solver}_seed{seed}.csv'
        if os.path.exists(path):
            dfs.append(pd.read_csv(path))
    
    if len(dfs) > 0:
        # Stack and compute mean/std
        epochs = dfs[0]['epoch'].values
        val_accs = np.array([df['val_acc'].values for df in dfs])
        train_losses = np.array([df['train_loss'].values for df in dfs])
        
        mean_acc = val_accs.mean(axis=0) * 100
        std_acc = val_accs.std(axis=0) * 100
        mean_loss = train_losses.mean(axis=0)
        std_loss = train_losses.std(axis=0)
        
        # Validation Accuracy with error bars
        axes[0].plot(epochs, mean_acc, label=solver, color=colors[solver], linewidth=2)
        axes[0].fill_between(epochs, mean_acc-std_acc, mean_acc+std_acc, 
                            color=colors[solver], alpha=0.2)
        
        # Training Loss with error bars
        axes[1].plot(epochs, mean_loss, label=solver, color=colors[solver], linewidth=2)
        axes[1].fill_between(epochs, mean_loss-std_loss, mean_loss+std_loss,
                            color=colors[solver], alpha=0.2)

axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Validation Accuracy (%)')
axes[0].set_title('Val Accuracy with Error Bars (n=3 seeds)')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Training Loss')
axes[1].set_title('Training Loss with Error Bars (n=3 seeds)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('plots/exp2_error_bars.png', dpi=150, bbox_inches='tight')
plt.savefig('plots/exp2_error_bars.pdf', bbox_inches='tight')
plt.show()

---
## Experiment 3: LR Stability Envelope (~15 min)
Show that spectral constraints widen the stable learning rate region

In [None]:
# Cell 8: Run Experiment 3 - LR Sweep
exp3_start = time.time()
lr_values = [0.001, 0.005, 0.01, 0.05, 0.1, 0.2, 0.5, 1.0]
configs = [
    ('muon_sgd', 'spectral_clip', 'MuonSGD+SpectralClip'),
    ('sgd', 'none', 'Vanilla SGD'),
]

for opt, solver, name in configs:
    for lr in lr_values:
        print(f"Running {name} lr={lr}...")
        cmd = [
            'python', 'train.py',
            '--model', 'small_cnn',
            '--optimizer', opt,
            '--inner-solver', solver,
            '--spectral-budget', '0.1',
            '--lr', str(lr),
            '--epochs', '5',  # Short runs to detect divergence
            '--batch-size', '128',
            '--logdir', 'logs/exp3_lr_sweep',
            '--exp-name', f'{opt}_{solver}_lr{lr}',
            '--seed', '42',
            '--log-interval', '500',
            '--lr-schedule', 'none'  # No schedule for stability scan
        ]
        try:
            subprocess.run(cmd, timeout=120)  # 2 min timeout per run
        except subprocess.TimeoutExpired:
            print(f"  Timeout (likely diverged) for lr={lr}")

exp3_time = time.time() - exp3_start
print(f"\nExperiment 3 complete! Time: {exp3_time/60:.1f} min")

In [None]:
# Cell 9: Generate Experiment 3 plots - LR Envelope
fig, ax = plt.subplots(figsize=(8, 5))

results = {}
for opt, solver, name in configs:
    final_losses = []
    converged = []
    for lr in lr_values:
        path = f'logs/exp3_lr_sweep/{opt}_{solver}_lr{lr}.csv'
        if os.path.exists(path):
            df = pd.read_csv(path)
            final_loss = df['val_loss'].iloc[-1]
            # Consider diverged if loss > 10 or NaN
            is_converged = final_loss < 10 and not np.isnan(final_loss)
        else:
            final_loss = float('inf')
            is_converged = False
        final_losses.append(final_loss if is_converged else np.nan)
        converged.append(is_converged)
    results[name] = {'losses': final_losses, 'converged': converged}

markers = {'MuonSGD+SpectralClip': 'o', 'Vanilla SGD': 's'}
colors = {'MuonSGD+SpectralClip': 'blue', 'Vanilla SGD': 'red'}

for name, data in results.items():
    # Plot converged points
    conv_lrs = [lr for lr, c in zip(lr_values, data['converged']) if c]
    conv_losses = [l for l, c in zip(data['losses'], data['converged']) if c]
    ax.scatter(conv_lrs, conv_losses, marker=markers[name], s=80, 
              label=f"{name} (converged)", color=colors[name], zorder=3)
    
    # Plot diverged points at bottom
    div_lrs = [lr for lr, c in zip(lr_values, data['converged']) if not c]
    if div_lrs:
        ax.scatter(div_lrs, [0.5]*len(div_lrs), marker='x', s=80,
                  color=colors[name], alpha=0.5, label=f"{name} (diverged)")

ax.set_xscale('log')
ax.set_xlabel('Learning Rate')
ax.set_ylabel('Final Validation Loss (5 epochs)')
ax.set_title('LR Stability Envelope: Spectral Constraints Widen Stable Region')
ax.legend(loc='upper left')
ax.grid(True, alpha=0.3)
ax.set_ylim([0, 5])

plt.tight_layout()
plt.savefig('plots/exp3_lr_envelope.png', dpi=150, bbox_inches='tight')
plt.savefig('plots/exp3_lr_envelope.pdf', bbox_inches='tight')
plt.show()

# Print max stable LR for each
print("\nMax Stable Learning Rates:")
for name, data in results.items():
    max_stable = max([lr for lr, c in zip(lr_values, data['converged']) if c], default=0)
    print(f"  {name}: {max_stable}")

---
## Experiment 4: Width Transfer (Category 1 Key!) (~25 min)
Test if hyperparameters (LR, spectral budget) transfer across network widths

In [None]:
# Cell 10: Run Experiment 4 - Width Transfer
exp4_start = time.time()
widths = [0.5, 1.0, 2.0]
configs = [
    ('muon_sgd', 'spectral_clip'),
    ('sgd', 'none'),
]

# Fixed hyperparameters tuned at width=1.0
fixed_lr = 0.05

for opt, solver in configs:
    for width in widths:
        print(f"\nRunning {opt}/{solver} width={width}...")
        cmd = [
            'python', 'train.py',
            '--model', 'small_cnn',
            '--optimizer', opt,
            '--inner-solver', solver,
            '--spectral-budget', '0.1',
            '--lr', str(fixed_lr),
            '--width-mult', str(width),
            '--epochs', '15',
            '--batch-size', '128',
            '--logdir', 'logs/exp4_width_transfer',
            '--exp-name', f'{opt}_{solver}_width{width}',
            '--seed', '42',
            '--log-interval', '300'
        ]
        subprocess.run(cmd)

exp4_time = time.time() - exp4_start
print(f"\nExperiment 4 complete! Time: {exp4_time/60:.1f} min")

In [None]:
# Cell 11: Generate Experiment 4 plots - Width Transfer
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

width_results = {}
for opt, solver in configs:
    name = f"{opt}/{solver}"
    width_results[name] = {'widths': [], 'val_accs': [], 'train_losses': []}
    for width in widths:
        path = f'logs/exp4_width_transfer/{opt}_{solver}_width{width}.csv'
        if os.path.exists(path):
            df = pd.read_csv(path)
            width_results[name]['widths'].append(width)
            width_results[name]['val_accs'].append(df['val_acc'].iloc[-1] * 100)
            width_results[name]['train_losses'].append(df['train_loss'].iloc[-1])

# Plot 1: Val Accuracy vs Width
for name, data in width_results.items():
    style = 'o-' if 'muon' in name else 's--'
    color = 'blue' if 'muon' in name else 'red'
    axes[0].plot(data['widths'], data['val_accs'], style, label=name, 
                markersize=10, linewidth=2, color=color)

axes[0].set_xlabel('Width Multiplier')
axes[0].set_ylabel('Final Validation Accuracy (%)')
axes[0].set_title('Hyperparameter Transfer: Accuracy vs Width')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot 2: Training Loss vs Width
for name, data in width_results.items():
    style = 'o-' if 'muon' in name else 's--'
    color = 'blue' if 'muon' in name else 'red'
    axes[1].plot(data['widths'], data['train_losses'], style, label=name,
                markersize=10, linewidth=2, color=color)

axes[1].set_xlabel('Width Multiplier')
axes[1].set_ylabel('Final Training Loss')
axes[1].set_title('Hyperparameter Transfer: Loss vs Width')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Plot 3: Overlay training curves for different widths (MuonSGD only)
for width in widths:
    path = f'logs/exp4_width_transfer/muon_sgd_spectral_clip_width{width}.csv'
    if os.path.exists(path):
        df = pd.read_csv(path)
        axes[2].plot(df['epoch'], df['val_acc']*100, label=f'width={width}', linewidth=2)

axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Validation Accuracy (%)')
axes[2].set_title('MuonSGD Training Curves Across Widths\n(Same LR, Same Spectral Budget)')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('plots/exp4_width_transfer.png', dpi=150, bbox_inches='tight')
plt.savefig('plots/exp4_width_transfer.pdf', bbox_inches='tight')
plt.show()

# Print transfer quality
print("\n" + "="*60)
print("WIDTH TRANSFER ANALYSIS")
print("="*60)
print("(Good transfer = similar accuracy across widths with fixed hyperparams)\n")
for name, data in width_results.items():
    if len(data['val_accs']) > 0:
        spread = max(data['val_accs']) - min(data['val_accs'])
        print(f"{name}:")
        for w, acc in zip(data['widths'], data['val_accs']):
            print(f"  width={w}: {acc:.1f}%")
        print(f"  Spread: {spread:.1f}% (lower = better transfer)\n")

---
## Summary & Final Outputs

In [None]:
# Cell 12: Generate summary table for paper
print("\n" + "="*70)
print("FINAL SUMMARY TABLE FOR PAPER")
print("="*70)

# Experiment 1 results
print("\nTable 1: Inner Solver Comparison (SmallCNN, CIFAR-10, 20 epochs)")
print("-"*70)
print(f"{'Solver':<18} {'Val Acc (%)':<12} {'Train Loss':<12} {'σ_max':<10} {'Time (s)':<10}")
print("-"*70)

for solver in solvers:
    path = f'logs/exp1_solver_comparison/solver_{solver}.csv'
    if os.path.exists(path):
        df = pd.read_csv(path)
        final = df.iloc[-1]
        total_time = df['time'].sum()
        print(f"{solver:<18} {final['val_acc']*100:<12.1f} {final['train_loss']:<12.4f} {final['max_spectral_norm']:<10.4f} {total_time:<10.0f}")

# Copy plots to easy download location
print("\n" + "="*70)
print("Generated plots saved to: ./plots/")
print("="*70)
!ls -la plots/

In [None]:
# Cell 13: Download plots (for Colab)
from google.colab import files
import shutil

# Create zip of all plots
shutil.make_archive('experiment_plots', 'zip', 'plots')
files.download('experiment_plots.zip')

# Also download logs
shutil.make_archive('experiment_logs', 'zip', 'logs')
files.download('experiment_logs.zip')

print("\nDownload complete! Use these files in your LaTeX document.")

In [None]:
# Cell 14: Total time summary
total_time = exp1_time + exp2_time + exp3_time + exp4_time
print(f"\n{'='*50}")
print("TOTAL EXPERIMENT TIME")
print(f"{'='*50}")
print(f"Experiment 1 (Solver Comparison): {exp1_time/60:.1f} min")
print(f"Experiment 2 (Multi-seed):        {exp2_time/60:.1f} min")
print(f"Experiment 3 (LR Envelope):       {exp3_time/60:.1f} min")
print(f"Experiment 4 (Width Transfer):    {exp4_time/60:.1f} min")
print(f"-"*50)
print(f"TOTAL:                            {total_time/60:.1f} min")