# EECS 182 Manifold Muon Experiments

**Category 1: Optimizers & Hyperparameter Transfer**

This notebook runs the three core experiments for the final project:
1. **Experiment 1**: Inner Solver Comparison (5 solvers on SmallCNN)
2. **Experiment 2**: Multi-seed baseline comparisons with error bars (ResNet-18)
3. **Experiment 3**: Width transfer experiments (muP-style)

---

## Cell 1: Environment Setup

Choose one of the following methods to get the code:
- **Option A**: Clone from GitHub (if you've pushed your code)
- **Option B**: Upload the `muon_mve_project` folder manually

In [None]:
# Cell 1: Environment Setup
import os
import sys

# Detect environment
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Running in Google Colab")
    
    # OPTION A: Clone from GitHub (uncomment if using)
    # !git clone https://github.com/achiii800/cs182.git
    # PROJECT_ROOT = '/content/cs182/muon_mve_project'
    
    # OPTION B: Upload manually
    # After running this, upload muon_mve_project.zip, then:
    # !unzip muon_mve_project.zip -d /content/
    # PROJECT_ROOT = '/content/muon_mve_project'
    
    # For now, assume we'll use mounted Drive or manual upload
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Set this to your actual path in Drive
    PROJECT_ROOT = '/content/drive/MyDrive/cs182/muon_mve_project'
    
else:
    print("Running locally")
    PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))
    if not os.path.exists(os.path.join(PROJECT_ROOT, 'train.py')):
        # Try parent directory
        PROJECT_ROOT = os.path.dirname(PROJECT_ROOT)

print(f"Project root: {PROJECT_ROOT}")
os.chdir(PROJECT_ROOT)
sys.path.insert(0, PROJECT_ROOT)

# Verify structure
expected_files = ['train.py', 'requirements.txt', 'muon/__init__.py', 'models/__init__.py']
for f in expected_files:
    path = os.path.join(PROJECT_ROOT, f)
    status = '✓' if os.path.exists(path) else '✗ MISSING'
    print(f"  {status} {f}")

In [None]:
# Cell 2: Install Dependencies
!pip install torch torchvision numpy matplotlib scipy pandas tqdm einops -q

# Verify GPU
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Cell 3: Verify Module Imports
import os
os.chdir(PROJECT_ROOT)

try:
    from muon import (
        MuonSGD, create_optimizer,
        SpectralClipSolver, DualAscentSolver, FrankWolfeSolver,
        QuasiNewtonDualSolver, ADMMSolver,
        compute_spectral_norms, estimate_sharpness, estimate_gradient_noise_scale
    )
    from models import get_model, SmallConvNet, ResNet18CIFAR, WidthScalableMLP
    print("✓ All modules imported successfully!")
    print(f"  Available models: small_cnn, resnet18, tiny_vit, mlp_mixer, mlp")
    print(f"  Available solvers: spectral_clip, dual_ascent, frank_wolfe, quasi_newton, admm")
except Exception as e:
    print(f"✗ Import error: {e}")
    print("  Make sure PROJECT_ROOT is set correctly in Cell 1")

---
## Experiment 1: Inner Solver Comparison

Compare all 5 inner solvers on SmallCNN for quick iteration (20 epochs).

**Solvers**: SpectralClip, DualAscent, QuasiNewton, FrankWolfe, ADMM

In [None]:
# Cell 4: Experiment 1 - Inner Solver Comparison
import os
import subprocess
import sys

os.chdir(PROJECT_ROOT)

# Create log directory
log_dir = os.path.join(PROJECT_ROOT, 'logs', 'exp1_solver_comparison')
os.makedirs(log_dir, exist_ok=True)

solvers = ['spectral_clip', 'dual_ascent', 'quasi_newton', 'frank_wolfe', 'admm']

print("="*60)
print("EXPERIMENT 1: Inner Solver Comparison on SmallCNN")
print("="*60)

for solver_name in solvers:
    print(f"\n{'='*60}")
    print(f"Training with solver: {solver_name}")
    print(f"{'='*60}")
    
    # Use the CORRECT script name and argument format
    cmd = [
        sys.executable, 'train.py',
        '--model', 'small_cnn',
        '--optimizer', 'muon_sgd',
        '--inner-solver', solver_name,  # Note: dashes, not underscores!
        '--spectral-budget', '1.0',
        '--lr', '0.01',
        '--epochs', '20',
        '--logdir', log_dir,
        '--exp-name', f'solver_{solver_name}',
        '--num-workers', '2',
    ]
    
    print(f"Command: {' '.join(cmd)}\n")
    
    result = subprocess.run(cmd, capture_output=False)
    
    if result.returncode != 0:
        print(f"WARNING: Training failed with exit code {result.returncode}")
    else:
        print(f"✓ {solver_name} completed successfully")
    
    # Check log file
    log_path = os.path.join(log_dir, f'solver_{solver_name}.csv')
    if os.path.exists(log_path):
        with open(log_path, 'r') as f:
            lines = f.readlines()
            print(f"  Log file: {len(lines)} lines (including header)")
            if len(lines) > 1:
                last_line = lines[-1].strip().split(',')
                print(f"  Final epoch: train_loss={last_line[1]}, val_acc={last_line[4]}")

print("\n" + "="*60)
print("EXPERIMENT 1 COMPLETE")
print(f"Logs saved to: {log_dir}")
print("="*60)

In [None]:
# Cell 5: Plot Experiment 1 Results
import pandas as pd
import matplotlib.pyplot as plt
import os

os.chdir(PROJECT_ROOT)
log_dir = os.path.join(PROJECT_ROOT, 'logs', 'exp1_solver_comparison')

solvers = ['spectral_clip', 'dual_ascent', 'quasi_newton', 'frank_wolfe', 'admm']
solver_colors = {
    'spectral_clip': '#1f77b4',
    'dual_ascent': '#ff7f0e',
    'quasi_newton': '#2ca02c',
    'frank_wolfe': '#d62728',
    'admm': '#9467bd'
}

# Load all data
data = {}
for solver in solvers:
    path = os.path.join(log_dir, f'solver_{solver}.csv')
    if os.path.exists(path):
        try:
            df = pd.read_csv(path)
            if len(df) > 0:
                data[solver] = df
                print(f"✓ Loaded {solver}: {len(df)} epochs")
        except Exception as e:
            print(f"✗ Error loading {solver}: {e}")
    else:
        print(f"✗ Missing: {path}")

if len(data) == 0:
    print("\n⚠️ No data to plot. Run Experiment 1 first.")
else:
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # Plot 1: Training Loss
    ax = axes[0, 0]
    for solver, df in data.items():
        ax.plot(df['epoch'], df['train_loss'], label=solver, color=solver_colors[solver], linewidth=2)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Training Loss')
    ax.set_title('Training Loss by Inner Solver')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot 2: Validation Accuracy
    ax = axes[0, 1]
    for solver, df in data.items():
        ax.plot(df['epoch'], df['val_acc'] * 100, label=solver, color=solver_colors[solver], linewidth=2)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Validation Accuracy (%)')
    ax.set_title('Validation Accuracy by Inner Solver')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot 3: Max Spectral Norm
    ax = axes[1, 0]
    for solver, df in data.items():
        if 'max_spectral_norm' in df.columns:
            ax.plot(df['epoch'], df['max_spectral_norm'], label=solver, color=solver_colors[solver], linewidth=2)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Max Spectral Norm (σ_max)')
    ax.set_title('Spectral Norm Control')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot 4: Sharpness
    ax = axes[1, 1]
    for solver, df in data.items():
        if 'sharpness' in df.columns:
            ax.plot(df['epoch'], df['sharpness'], label=solver, color=solver_colors[solver], linewidth=2)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Sharpness (SAM proxy)')
    ax.set_title('Loss Landscape Sharpness')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save figure
    fig_path = os.path.join(log_dir, 'exp1_results.png')
    plt.savefig(fig_path, dpi=150, bbox_inches='tight')
    print(f"\n✓ Figure saved to: {fig_path}")
    
    plt.show()
    
    # Summary table
    print("\n" + "="*60)
    print("SUMMARY TABLE")
    print("="*60)
    print(f"{'Solver':<15} {'Final Val Acc':<15} {'Final Loss':<12} {'Max σ_max':<12}")
    print("-"*60)
    for solver, df in sorted(data.items(), key=lambda x: -x[1]['val_acc'].iloc[-1]):
        val_acc = df['val_acc'].iloc[-1] * 100
        loss = df['train_loss'].iloc[-1]
        sigma = df['max_spectral_norm'].iloc[-1] if 'max_spectral_norm' in df.columns else 0
        print(f"{solver:<15} {val_acc:>12.2f}% {loss:>12.4f} {sigma:>12.4f}")

---
## Experiment 2: Multi-Seed Baseline Comparison

Compare SGD, AdamW, MuonSGD (with DualAscent) on ResNet-18 with 3 seeds each.

**Goal**: Get meaningful error bars for the paper.

In [None]:
# Cell 6: Experiment 2 - Multi-Seed Baseline Comparison
import os
import subprocess
import sys

os.chdir(PROJECT_ROOT)

log_dir = os.path.join(PROJECT_ROOT, 'logs', 'exp2_baselines')
os.makedirs(log_dir, exist_ok=True)

# Configurations: (optimizer, inner_solver, spectral_budget)
configs = [
    ('sgd', 'none', None),
    ('adamw', 'none', None),
    ('muon_sgd', 'dual_ascent', 0.1),
    ('muon_sgd', 'spectral_clip', 0.1),
]

seeds = [0, 1, 2]
epochs = 50  # Can reduce to 30 for faster iteration

print("="*70)
print("EXPERIMENT 2: Multi-Seed Baseline Comparison on ResNet-18")
print(f"Seeds: {seeds}, Epochs: {epochs}")
print("="*70)

for opt, solver, budget in configs:
    for seed in seeds:
        exp_name = f'{opt}_{solver}_seed{seed}'
        print(f"\nRunning: {exp_name}")
        
        cmd = [
            sys.executable, 'train.py',
            '--model', 'resnet18',
            '--optimizer', opt,
            '--inner-solver', solver,
            '--lr', '0.1',
            '--epochs', str(epochs),
            '--seed', str(seed),
            '--logdir', log_dir,
            '--exp-name', exp_name,
            '--num-workers', '2',
        ]
        
        if budget is not None:
            cmd.extend(['--spectral-budget', str(budget)])
        
        result = subprocess.run(cmd, capture_output=False)
        
        if result.returncode != 0:
            print(f"  ⚠️ Failed with code {result.returncode}")
        else:
            print(f"  ✓ Completed")

print("\n" + "="*70)
print("EXPERIMENT 2 COMPLETE")
print(f"Logs saved to: {log_dir}")
print("="*70)

In [None]:
# Cell 7: Plot Experiment 2 with Error Bars
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from collections import defaultdict

os.chdir(PROJECT_ROOT)
log_dir = os.path.join(PROJECT_ROOT, 'logs', 'exp2_baselines')

configs = [
    ('sgd', 'none'),
    ('adamw', 'none'),
    ('muon_sgd', 'dual_ascent'),
    ('muon_sgd', 'spectral_clip'),
]
seeds = [0, 1, 2]

config_colors = {
    'sgd_none': '#1f77b4',
    'adamw_none': '#ff7f0e',
    'muon_sgd_dual_ascent': '#2ca02c',
    'muon_sgd_spectral_clip': '#d62728',
}

config_labels = {
    'sgd_none': 'SGD',
    'adamw_none': 'AdamW',
    'muon_sgd_dual_ascent': 'MuonSGD (DualAscent)',
    'muon_sgd_spectral_clip': 'MuonSGD (SpectralClip)',
}

# Aggregate data across seeds
aggregated = defaultdict(list)

for opt, solver in configs:
    key = f'{opt}_{solver}'
    for seed in seeds:
        path = os.path.join(log_dir, f'{key}_seed{seed}.csv')
        if os.path.exists(path):
            try:
                df = pd.read_csv(path)
                if len(df) > 0:
                    aggregated[key].append(df)
            except Exception as e:
                print(f"Error loading {path}: {e}")

if len(aggregated) == 0:
    print("⚠️ No data found. Run Experiment 2 first.")
else:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    for key, dfs in aggregated.items():
        if len(dfs) == 0:
            continue
            
        # Stack and compute mean/std
        min_len = min(len(df) for df in dfs)
        epochs = dfs[0]['epoch'].values[:min_len]
        
        train_losses = np.array([df['train_loss'].values[:min_len] for df in dfs])
        val_accs = np.array([df['val_acc'].values[:min_len] * 100 for df in dfs])
        
        mean_loss = train_losses.mean(axis=0)
        std_loss = train_losses.std(axis=0)
        mean_acc = val_accs.mean(axis=0)
        std_acc = val_accs.std(axis=0)
        
        color = config_colors.get(key, 'gray')
        label = config_labels.get(key, key)
        
        # Training Loss
        ax = axes[0]
        ax.plot(epochs, mean_loss, label=label, color=color, linewidth=2)
        ax.fill_between(epochs, mean_loss - std_loss, mean_loss + std_loss, color=color, alpha=0.2)
        
        # Validation Accuracy
        ax = axes[1]
        ax.plot(epochs, mean_acc, label=label, color=color, linewidth=2)
        ax.fill_between(epochs, mean_acc - std_acc, mean_acc + std_acc, color=color, alpha=0.2)
        
        print(f"{label}: Final Val Acc = {mean_acc[-1]:.2f}% ± {std_acc[-1]:.2f}%")
    
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Training Loss')
    axes[0].set_title('Training Loss (mean ± std over 3 seeds)')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Validation Accuracy (%)')
    axes[1].set_title('Validation Accuracy (mean ± std over 3 seeds)')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    fig_path = os.path.join(log_dir, 'exp2_results_with_errorbars.png')
    plt.savefig(fig_path, dpi=150, bbox_inches='tight')
    print(f"\n✓ Figure saved to: {fig_path}")
    
    plt.show()

---
## Experiment 3: Width Transfer (muP-style)

Test whether hyperparameters (LR, spectral budget) transfer across network widths.

**Hypothesis**: MuonSGD's spectral constraints should enable better width transfer than vanilla SGD.

In [None]:
# Cell 8: Experiment 3 - Width Transfer
import os
import subprocess
import sys

os.chdir(PROJECT_ROOT)

log_dir = os.path.join(PROJECT_ROOT, 'logs', 'exp3_width_transfer')
os.makedirs(log_dir, exist_ok=True)

# Width multipliers
widths = [0.5, 0.75, 1.0, 1.5, 2.0]
optimizers = [
    ('sgd', 'none', None),
    ('muon_sgd', 'dual_ascent', 0.1),
]
seeds = [0, 1, 2]
epochs = 30

print("="*70)
print("EXPERIMENT 3: Width Transfer on MLP")
print(f"Widths: {widths}")
print(f"Seeds: {seeds}, Epochs: {epochs}")
print("="*70)

for opt, solver, budget in optimizers:
    for width in widths:
        for seed in seeds:
            exp_name = f'{opt}_{solver}_width{width}_seed{seed}'
            print(f"Running: {exp_name}")
            
            cmd = [
                sys.executable, 'train.py',
                '--model', 'mlp',
                '--width-mult', str(width),
                '--optimizer', opt,
                '--inner-solver', solver,
                '--lr', '0.01',
                '--epochs', str(epochs),
                '--seed', str(seed),
                '--logdir', log_dir,
                '--exp-name', exp_name,
                '--num-workers', '2',
            ]
            
            if budget is not None:
                cmd.extend(['--spectral-budget', str(budget)])
            
            result = subprocess.run(cmd, capture_output=True, text=True)
            
            if result.returncode != 0:
                print(f"  ⚠️ Failed: {result.stderr[:100] if result.stderr else 'unknown error'}")
            else:
                print(f"  ✓ Completed")

print("\n" + "="*70)
print("EXPERIMENT 3 COMPLETE")
print(f"Logs saved to: {log_dir}")
print("="*70)

In [None]:
# Cell 9: Plot Width Transfer Results
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from collections import defaultdict

os.chdir(PROJECT_ROOT)
log_dir = os.path.join(PROJECT_ROOT, 'logs', 'exp3_width_transfer')

widths = [0.5, 0.75, 1.0, 1.5, 2.0]
optimizers = [('sgd', 'none'), ('muon_sgd', 'dual_ascent')]
seeds = [0, 1, 2]

# Collect final accuracies
results = {}

for opt, solver in optimizers:
    key = f'{opt}_{solver}'
    results[key] = {'widths': [], 'mean_acc': [], 'std_acc': []}
    
    for width in widths:
        accs = []
        for seed in seeds:
            path = os.path.join(log_dir, f'{key}_width{width}_seed{seed}.csv')
            if os.path.exists(path):
                try:
                    df = pd.read_csv(path)
                    if len(df) > 0:
                        accs.append(df['val_acc'].iloc[-1] * 100)
                except:
                    pass
        
        if len(accs) > 0:
            results[key]['widths'].append(width)
            results[key]['mean_acc'].append(np.mean(accs))
            results[key]['std_acc'].append(np.std(accs))

if all(len(v['widths']) == 0 for v in results.values()):
    print("⚠️ No data found. Run Experiment 3 first.")
else:
    fig, ax = plt.subplots(figsize=(10, 6))
    
    colors = {'sgd_none': '#1f77b4', 'muon_sgd_dual_ascent': '#2ca02c'}
    labels = {'sgd_none': 'SGD', 'muon_sgd_dual_ascent': 'MuonSGD (DualAscent)'}
    
    for key, data in results.items():
        if len(data['widths']) > 0:
            ax.errorbar(
                data['widths'], data['mean_acc'],
                yerr=data['std_acc'],
                label=labels.get(key, key),
                color=colors.get(key, 'gray'),
                marker='o', markersize=8, linewidth=2, capsize=5
            )
    
    ax.set_xlabel('Width Multiplier', fontsize=12)
    ax.set_ylabel('Final Validation Accuracy (%)', fontsize=12)
    ax.set_title('Width Transfer: MuonSGD vs SGD on MLP', fontsize=14)
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)
    
    # Reference line at width=1.0
    ax.axvline(x=1.0, color='gray', linestyle='--', alpha=0.5, label='Reference width')
    
    plt.tight_layout()
    
    fig_path = os.path.join(log_dir, 'exp3_width_transfer.png')
    plt.savefig(fig_path, dpi=150, bbox_inches='tight')
    print(f"✓ Figure saved to: {fig_path}")
    
    plt.show()
    
    # Summary
    print("\n" + "="*60)
    print("WIDTH TRANSFER SUMMARY")
    print("="*60)
    for key, data in results.items():
        if len(data['widths']) > 0:
            print(f"\n{labels.get(key, key)}:")
            for w, m, s in zip(data['widths'], data['mean_acc'], data['std_acc']):
                print(f"  Width {w}: {m:.2f}% ± {s:.2f}%")

---
## Quick Single-Run Testing

Use this cell for quick tests before running full experiments.

In [None]:
# Cell 10: Quick Test Run
import os
import subprocess
import sys

os.chdir(PROJECT_ROOT)

# Quick test: 5 epochs on SmallCNN
cmd = [
    sys.executable, 'train.py',
    '--model', 'small_cnn',
    '--optimizer', 'muon_sgd',
    '--inner-solver', 'dual_ascent',
    '--spectral-budget', '0.1',
    '--lr', '0.01',
    '--epochs', '5',
    '--logdir', 'logs/quick_test',
    '--exp-name', 'quick_test',
    '--num-workers', '0',  # No multiprocessing for debugging
]

print("Running quick test...")
print(f"Command: {' '.join(cmd)}\n")

result = subprocess.run(cmd)

if result.returncode == 0:
    print("\n✓ Quick test passed!")
    
    # Show results
    log_path = os.path.join(PROJECT_ROOT, 'logs/quick_test/quick_test.csv')
    if os.path.exists(log_path):
        import pandas as pd
        df = pd.read_csv(log_path)
        print(f"\nResults ({len(df)} epochs):")
        print(df[['epoch', 'train_loss', 'val_acc', 'max_spectral_norm']].to_string())
else:
    print(f"\n✗ Test failed with code {result.returncode}")

---
## Generate Publication Figures

Run this after all experiments complete to generate NeurIPS-quality figures.

In [None]:
# Cell 11: Generate All Publication Figures
import os
import subprocess
import sys

os.chdir(PROJECT_ROOT)

# Check if generate_paper_figures.py exists
gen_script = os.path.join(PROJECT_ROOT, 'scripts', 'generate_paper_figures.py')

if os.path.exists(gen_script):
    print("Generating publication figures...")
    result = subprocess.run([sys.executable, gen_script])
    if result.returncode == 0:
        print("✓ Figures generated successfully!")
    else:
        print(f"⚠️ Figure generation failed with code {result.returncode}")
else:
    print("Note: scripts/generate_paper_figures.py not found.")
    print("Use the individual plotting cells above instead.")