# MoT Training Optimization Benchmark

This notebook tests different optimization strategies for the Mixture of Thoughts model.

**Goal:** Identify bottlenecks and optimize training speed to match nanoGPT baseline (~11.79 ms/iter)

## Setup

In [None]:
import torch
import torch.nn as nn
import time
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import sys

# Add parent to path
sys.path.insert(0, str(Path.cwd().parent))

from mot.core.model import MixtureOfThoughtsTransformer, MoTConfig

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")
print(f"torch.compile available: {hasattr(torch, 'compile')}")

## Create Test Data

In [None]:
def create_dummy_data(seq_length=128, vocab_size=100, num_samples=100):
    """Create dummy data for testing"""
    data = torch.randint(0, vocab_size, (num_samples, seq_length))
    return data

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create test data
data = create_dummy_data(seq_length=128, vocab_size=100, num_samples=100)
print(f"Test data shape: {data.shape}")

## Create Model

In [None]:
# Model configuration
config = MoTConfig(
    vocab_size=100,
    hidden_size=256,
    num_hidden_layers=4,
    num_thoughts=8,
    max_position_embeddings=128
)

print("Model configuration:")
print(f"  Vocab size: {config.vocab_size}")
print(f"  Hidden size: {config.hidden_size}")
print(f"  Layers: {config.num_hidden_layers}")
print(f"  Thoughts: {config.num_thoughts}")
print(f"  Sequence length: {config.max_position_embeddings}")

# Create model to check size
test_model = MixtureOfThoughtsTransformer(config)
params = sum(p.numel() for p in test_model.parameters())
print(f"\nModel parameters: {params:,}")
del test_model

## Benchmark Functions

In [None]:
def benchmark_training(model, data, device, num_iters=10, use_amp=False, desc=""):
    """Benchmark training step speed"""
    print(f"\n{'='*70}")
    print(f"Testing: {desc}")
    print(f"{'='*70}")
    
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    times = []
    
    # Gradient scaler for AMP
    scaler = torch.cuda.amp.GradScaler() if use_amp and device.type == 'cuda' else None
    
    # Warmup
    print("Warming up...")
    for i in range(5):
        batch = data[i:i+1].to(device)
        optimizer.zero_grad()
        
        if scaler is not None:
            with torch.cuda.amp.autocast():
                outputs = model(batch)
                loss = outputs['logits'].mean()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(batch)
            loss = outputs['logits'].mean()
            loss.backward()
            optimizer.step()
    
    # Actual benchmark
    print(f"Running {num_iters} iterations...")
    for i in range(num_iters):
        batch = data[i:i+1].to(device)
        
        start = time.perf_counter()
        
        optimizer.zero_grad()
        
        if scaler is not None:
            with torch.cuda.amp.autocast():
                outputs = model(batch)
                loss = outputs['logits'].mean()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(batch)
            loss = outputs['logits'].mean()
            loss.backward()
            optimizer.step()
        
        if device.type == 'cuda':
            torch.cuda.synchronize()
        
        end = time.perf_counter()
        times.append((end - start) * 1000)  # Convert to ms
    
    mean_time = np.mean(times)
    std_time = np.std(times)
    
    print(f"\nResults:")
    print(f"  Mean: {mean_time:.2f} ms")
    print(f"  Std:  {std_time:.2f} ms")
    print(f"  Min:  {np.min(times):.2f} ms")
    print(f"  Max:  {np.max(times):.2f} ms")
    
    return mean_time, std_time, times

## Test 1: Baseline (No Optimizations)

In [None]:
model_baseline = MixtureOfThoughtsTransformer(config).to(device)

baseline_mean, baseline_std, baseline_times = benchmark_training(
    model_baseline, 
    data, 
    device, 
    num_iters=10, 
    use_amp=False,
    desc="Baseline (no optimizations)"
)

del model_baseline
if device.type == 'cuda':
    torch.cuda.empty_cache()

## Test 2: With AMP (Automatic Mixed Precision)

In [None]:
model_amp = MixtureOfThoughtsTransformer(config).to(device)

amp_mean, amp_std, amp_times = benchmark_training(
    model_amp,
    data,
    device,
    num_iters=10,
    use_amp=True,
    desc="With AMP (FP16 mixed precision)"
)

speedup_amp = baseline_mean / amp_mean
print(f"\nSpeedup vs baseline: {speedup_amp:.2f}x")

del model_amp
if device.type == 'cuda':
    torch.cuda.empty_cache()

## Test 3: With torch.compile + AMP

In [None]:
if hasattr(torch, 'compile') and device.type == 'cuda':
    print("Compiling model... (this may take 1-2 minutes on first run)")
    
    model_compiled = MixtureOfThoughtsTransformer(config).to(device)
    model_compiled = torch.compile(model_compiled)
    
    compiled_mean, compiled_std, compiled_times = benchmark_training(
        model_compiled,
        data,
        device,
        num_iters=10,
        use_amp=True,
        desc="With torch.compile + AMP"
    )
    
    speedup_compiled = baseline_mean / compiled_mean
    print(f"\nSpeedup vs baseline: {speedup_compiled:.2f}x")
    
    del model_compiled
    torch.cuda.empty_cache()
else:
    print("torch.compile not available or not on CUDA")
    compiled_mean, compiled_std, compiled_times = None, None, None

## Visualization

In [None]:
# Prepare data for plotting
methods = ['Baseline', 'AMP']
means = [baseline_mean, amp_mean]
stds = [baseline_std, amp_std]

if compiled_mean is not None:
    methods.append('Compile + AMP')
    means.append(compiled_mean)
    stds.append(compiled_std)

# Create bar chart
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Plot 1: Mean times
colors = ['#ff7f0e', '#2ca02c', '#1f77b4']
bars = ax1.bar(methods, means, yerr=stds, capsize=5, color=colors[:len(methods)])
ax1.set_ylabel('Time per iteration (ms)', fontsize=12)
ax1.set_title('Training Speed Comparison', fontsize=14, fontweight='bold')
ax1.grid(axis='y', alpha=0.3)

# Add value labels on bars
for bar, mean in zip(bars, means):
    height = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2., height,
            f'{mean:.2f} ms',
            ha='center', va='bottom', fontsize=10, fontweight='bold')

# Plot 2: Speedup
speedups = [1.0]  # Baseline
speedups.append(baseline_mean / amp_mean)
if compiled_mean is not None:
    speedups.append(baseline_mean / compiled_mean)

bars2 = ax2.bar(methods, speedups, color=colors[:len(methods)])
ax2.set_ylabel('Speedup (vs baseline)', fontsize=12)
ax2.set_title('Relative Performance', fontsize=14, fontweight='bold')
ax2.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5)
ax2.grid(axis='y', alpha=0.3)

# Add value labels
for bar, speedup in zip(bars2, speedups):
    height = bar.get_height()
    ax2.text(bar.get_x() + bar.get_width()/2., height,
            f'{speedup:.2f}x',
            ha='center', va='bottom', fontsize=10, fontweight='bold')

plt.tight_layout()
plt.savefig('optimization_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nChart saved as 'optimization_comparison.png'")

## Comparison with nanoGPT Baseline

In [None]:
# nanoGPT baseline from testing
nanogpt_time = 11.79  # ms/iter

print("\n" + "="*70)
print("Comparison with nanoGPT Baseline")
print("="*70)

print(f"\nnanoGPT (baseline):     {nanogpt_time:.2f} ms/iter")
print(f"MoT Baseline:           {baseline_mean:.2f} ms/iter ({baseline_mean/nanogpt_time:.1f}x slower)")
print(f"MoT with AMP:           {amp_mean:.2f} ms/iter ({amp_mean/nanogpt_time:.1f}x slower)")

if compiled_mean is not None:
    print(f"MoT with Compile+AMP:   {compiled_mean:.2f} ms/iter ({compiled_mean/nanogpt_time:.1f}x slower)")

print("\n" + "="*70)
print("Key Insights")
print("="*70)
print("\n1. Model computation optimizations (AMP, compile) provide speedup")
print("2. But still much slower than nanoGPT baseline")
print("3. Remaining bottleneck: DataLoader overhead in full training loop")
print("\nSolution: Adopt nanoGPT's data loading approach:")
print("  - Use numpy memmap for direct data access")
print("  - Eliminate DataLoader multiprocessing overhead")
print("  - Simplify training loop to minimize Python overhead")

## Results Summary

In [None]:
print("\n" + "="*70)
print("OPTIMIZATION SUMMARY")
print("="*70)

print(f"\nModel: {params:,} parameters")
print(f"Device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

print("\nResults:")
print(f"  Baseline:           {baseline_mean:7.2f} ms  (1.00x)")
print(f"  + AMP:              {amp_mean:7.2f} ms  ({baseline_mean/amp_mean:.2f}x speedup)")
if compiled_mean is not None:
    print(f"  + Compile + AMP:    {compiled_mean:7.2f} ms  ({baseline_mean/compiled_mean:.2f}x speedup)")

print("\nRecommendations:")
print("  ✅ Enable AMP for ~{:.0f}% speedup".format((1 - amp_mean/baseline_mean) * 100))
if compiled_mean is not None:
    print("  ✅ Enable torch.compile for additional ~{:.0f}% speedup".format((1 - compiled_mean/amp_mean) * 100))
print("  ⚠️  Set num_workers=0 to eliminate multiprocessing overhead")
print("  🎯 Ultimate goal: Match nanoGPT's {:.2f} ms/iter".format(nanogpt_time))

print("\n" + "="*70)

## Next Steps

Based on these results:

1. **Immediate fixes** (already applied):
   - Set `num_workers: 0` in config
   - Set `use_compile: true` in config

2. **Test full training**:
   ```bash
   python scripts/train_with_config.py configs/training/small.yaml --batch-size 128
   ```

3. **If still slow, adopt nanoGPT's approach**:
   - Replace DataLoader with numpy memmap
   - Simplify training loop
   - Minimize Python overhead