# GPT-2 FineWeb-Edu Model Comparison & Analysis

This notebook provides tools for comparing multiple training runs, checkpoints, and model configurations:
- Multi-run loss comparison
- Checkpoint comparison
- Hyperparameter impact analysis
- Model performance benchmarks
- A/B testing visualizations


In [None]:
# Install required packages
# !pip install wandb matplotlib seaborn numpy pandas torch transformers

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import wandb
import os
from pathlib import Path
from datetime import datetime

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (14, 8)
plt.rcParams['font.size'] = 10

print("Libraries imported successfully!")


## 1. Load Multiple Training Runs

Load data from multiple WandB runs or generate synthetic runs for comparison.


In [None]:
def generate_comparison_runs(n_runs=3, use_wandb=False, project_name="gpt-fineweb-demo", run_ids=None):
    """
    Generate or load multiple training runs for comparison.
    
    Args:
        n_runs: Number of runs to generate/load
        use_wandb: Whether to use WandB API
        project_name: WandB project name
        run_ids: List of WandB run IDs (if None, uses latest runs)
    """
    runs_data = []
    
    if use_wandb:
        try:
            api = wandb.Api()
            if run_ids:
                runs = [api.run(f"{project_name}/{rid}") for rid in run_ids]
            else:
                runs = list(api.runs(project_name))[:n_runs]
            
            for run in runs:
                history = run.history()
                history['run_name'] = run.name
                history['run_id'] = run.id
                runs_data.append(history)
            return runs_data
        except Exception as e:
            print(f"WandB API error: {e}")
            print("Generating synthetic runs instead...")
            use_wandb = False
    
    if not use_wandb:
        # Generate synthetic runs with variations
        max_tokens = 200_000_000
        log_interval = 20
        tokens_per_step = 16 * 256
        
        steps = np.arange(0, max_tokens // tokens_per_step, log_interval)
        tokens_seen = steps * tokens_per_step
        
        # Run variations: different learning rates, batch sizes, etc.
        configs = [
            {'name': 'Baseline (LR=3e-4)', 'lr_peak': 3e-4, 'lr_final': 6e-5, 'batch': 16, 'decay': 50e6},
            {'name': 'High LR (LR=5e-4)', 'lr_peak': 5e-4, 'lr_final': 8e-5, 'batch': 16, 'decay': 50e6},
            {'name': 'Low LR (LR=1e-4)', 'lr_peak': 1e-4, 'lr_final': 2e-5, 'batch': 16, 'decay': 50e6},
        ][:n_runs]
        
        for config in configs:
            # Training loss with config-dependent decay
            base_loss = 10.8 * np.exp(-tokens_seen / config['decay']) + 4.19 * (1 - np.exp(-tokens_seen / config['decay']))
            # Add variation based on LR (higher LR might converge faster but less stable)
            if config['lr_peak'] > 3e-4:
                base_loss += 0.2 * np.exp(-tokens_seen / 30e6)  # Slightly higher early loss
            elif config['lr_peak'] < 3e-4:
                base_loss += 0.3 * np.exp(-tokens_seen / 70e6)  # Slower convergence
            base_loss += np.random.normal(0, 0.1, len(base_loss))
            
            # Learning rate schedule
            warmup_tokens = 2_000_000
            lr = np.zeros_like(tokens_seen)
            for i, tok in enumerate(tokens_seen):
                if tok <= warmup_tokens:
                    lr[i] = config['lr_peak'] * (tok / warmup_tokens)
                else:
                    progress = (tok - warmup_tokens) / (max_tokens - warmup_tokens)
                    cosine = 0.5 * (1.0 + np.cos(np.pi * min(1.0, progress)))
                    lr[i] = config['lr_final'] + (config['lr_peak'] - config['lr_final']) * cosine
            
            # Token processing speed (slight variation)
            base_speed = 6000 + np.random.normal(0, 500)
            tokens_per_sec = base_speed + 1000 * (1 - np.exp(-tokens_seen / 20_000_000))
            tokens_per_sec += np.random.normal(0, 200, len(tokens_per_sec))
            
            data = pd.DataFrame({
                'step': steps,
                'tokens_seen': tokens_seen,
                'loss': base_loss,
                'lr': lr,
                'tokens_per_sec': tokens_per_sec,
                'run_name': config['name'],
                'run_id': f"run_{len(runs_data)}",
            })
            
            runs_data.append(data)
    
    return runs_data

# Generate comparison runs
runs_data = generate_comparison_runs(n_runs=3, use_wandb=False)

print(f"Loaded {len(runs_data)} runs for comparison:")
for i, run in enumerate(runs_data):
    print(f"  {i+1}. {run['run_name'].iloc[0]} ({len(run)} data points)")


## 2. Multi-Run Loss Comparison


In [None]:
def plot_multi_run_comparison(runs_data, metric='loss', tokens_or_steps='tokens'):
    """
    Plot comparison of multiple runs for a given metric.
    
    Args:
        runs_data: List of DataFrames, one per run
        metric: Metric to compare ('loss', 'lr', 'tokens_per_sec', etc.)
        tokens_or_steps: 'tokens' or 'steps' for x-axis
    """
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    colors = plt.cm.tab10(np.linspace(0, 1, len(runs_data)))
    
    for i, run_data in enumerate(runs_data):
        run_name = run_data['run_name'].iloc[0]
        x = run_data['tokens_seen'] if tokens_or_steps == 'tokens' else run_data['step']
        
        # Plot 1: Linear scale
        axes[0].plot(x, run_data[metric], label=run_name, color=colors[i], 
                    linewidth=2, alpha=0.8)
        
        # Plot 2: Log scale (if loss)
        if metric == 'loss':
            axes[1].semilogy(x, run_data[metric], label=run_name, color=colors[i], 
                           linewidth=2, alpha=0.8)
        else:
            axes[1].plot(x, run_data[metric], label=run_name, color=colors[i], 
                        linewidth=2, alpha=0.8)
    
    x_label = 'Tokens Seen' if tokens_or_steps == 'tokens' else 'Training Steps'
    y_label = metric.replace('_', ' ').title()
    
    axes[0].set_xlabel(x_label, fontsize=12)
    axes[0].set_ylabel(y_label, fontsize=12)
    axes[0].set_title(f'{y_label} Comparison (Linear Scale)', fontsize=14, fontweight='bold')
    axes[0].legend(fontsize=10)
    axes[0].grid(True, alpha=0.3)
    if tokens_or_steps == 'tokens':
        axes[0].xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1e6:.1f}M'))
    
    axes[1].set_xlabel(x_label, fontsize=12)
    axes[1].set_ylabel(y_label + (' (Log Scale)' if metric == 'loss' else ''), fontsize=12)
    axes[1].set_title(f'{y_label} Comparison ({"Log" if metric == "loss" else "Linear"} Scale)', 
                     fontsize=14, fontweight='bold')
    axes[1].legend(fontsize=10)
    axes[1].grid(True, alpha=0.3, which='both' if metric == 'loss' else 'major')
    if tokens_or_steps == 'tokens':
        axes[1].xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1e6:.1f}M'))
    
    plt.tight_layout()
    return fig

# Plot loss comparison
fig = plot_multi_run_comparison(runs_data, metric='loss', tokens_or_steps='tokens')
plt.show()


## 3. Final Metrics Comparison Table


In [None]:
def create_comparison_table(runs_data):
    """
    Create a comparison table of final metrics for all runs.
    """
    comparison_data = []
    
    for run_data in runs_data:
        run_name = run_data['run_name'].iloc[0]
        final_loss = run_data['loss'].iloc[-1]
        final_ppl = np.exp(final_loss)
        final_lr = run_data['lr'].iloc[-1]
        peak_lr = run_data['lr'].max()
        
        if 'tokens_per_sec' in run_data.columns:
            avg_speed = run_data['tokens_per_sec'].mean()
        else:
            avg_speed = None
        
        total_tokens = run_data['tokens_seen'].iloc[-1]
        
        comparison_data.append({
            'Run Name': run_name,
            'Final Loss': f"{final_loss:.4f}",
            'Final PPL': f"{final_ppl:.2f}",
            'Peak LR': f"{peak_lr:.2e}",
            'Final LR': f"{final_lr:.2e}",
            'Avg Speed (tok/s)': f"{avg_speed:,.0f}" if avg_speed else 'N/A',
            'Total Tokens': f"{total_tokens:,}",
        })
    
    comparison_df = pd.DataFrame(comparison_data)
    
    # Create visualization
    fig, ax = plt.subplots(figsize=(14, len(comparison_data) * 0.8 + 1))
    ax.axis('off')
    
    table = ax.table(cellText=comparison_df.values, colLabels=comparison_df.columns,
                    cellLoc='left', loc='center', bbox=[0, 0, 1, 1])
    table.auto_set_font_size(False)
    table.set_fontsize(11)
    table.scale(1, 2.5)
    
    # Color code best metrics
    for i in range(1, len(comparison_data) + 1):
        # Highlight best (lowest) loss
        losses = [float(r['Final Loss']) for r in comparison_data]
        best_loss_idx = np.argmin(losses)
        if i == best_loss_idx + 1:
            table[(i, 1)].set_facecolor('#90EE90')  # Light green
            table[(i, 2)].set_facecolor('#90EE90')
        
        # Highlight best (lowest) PPL
        ppls = [float(r['Final PPL']) for r in comparison_data]
        best_ppl_idx = np.argmin(ppls)
        if i == best_ppl_idx + 1:
            table[(i, 3)].set_facecolor('#90EE90')
    
    ax.set_title('Run Comparison - Final Metrics', fontsize=14, fontweight='bold', pad=20)
    
    plt.tight_layout()
    return fig, comparison_df

# Create comparison table
fig, comparison_df = create_comparison_table(runs_data)
plt.show()

print("\n=== Comparison Summary ===")
print(comparison_df.to_string(index=False))


## 4. Perplexity Comparison with Confidence Intervals


In [None]:
def plot_perplexity_comparison(runs_data, tokens_or_steps='tokens', show_confidence=True):
    """
    Plot perplexity comparison with optional confidence intervals.
    """
    fig, ax = plt.subplots(figsize=(14, 8))
    
    colors = plt.cm.tab10(np.linspace(0, 1, len(runs_data)))
    
    # Calculate perplexity for each run
    all_ppls = []
    all_x = []
    all_names = []
    
    for i, run_data in enumerate(runs_data):
        run_name = run_data['run_name'].iloc[0]
        x = run_data['tokens_seen'] if tokens_or_steps == 'tokens' else run_data['step']
        ppl = np.exp(run_data['loss'])
        
        all_ppls.append(ppl.values)
        all_x.append(x.values)
        all_names.append(run_name)
        
        ax.plot(x, ppl, label=run_name, color=colors[i], linewidth=2.5, alpha=0.8)
    
    # Add final PPL annotations
    for i, (x, ppl, name) in enumerate(zip(all_x, all_ppls, all_names)):
        final_ppl = ppl[-1]
        final_x = x[-1]
        ax.plot(final_x, final_ppl, 'o', color=colors[i], markersize=10, markeredgecolor='black', 
               markeredgewidth=1)
        ax.annotate(f'{final_ppl:.1f}', xy=(final_x, final_ppl), 
                   xytext=(10, 10), textcoords='offset points', fontsize=9,
                   bbox=dict(boxstyle='round,pad=0.3', facecolor=colors[i], alpha=0.3))
    
    x_label = 'Tokens Seen' if tokens_or_steps == 'tokens' else 'Training Steps'
    ax.set_xlabel(x_label, fontsize=12)
    ax.set_ylabel('Perplexity', fontsize=12)
    ax.set_title('Perplexity Comparison Across Runs', fontsize=14, fontweight='bold')
    ax.legend(fontsize=11, loc='upper right')
    ax.grid(True, alpha=0.3)
    ax.set_yscale('log')
    
    if tokens_or_steps == 'tokens':
        ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1e6:.1f}M'))
    
    plt.tight_layout()
    return fig

# Plot perplexity comparison
fig = plot_perplexity_comparison(runs_data, tokens_or_steps='tokens')
plt.show()


## 5. Learning Rate Schedule Comparison


In [None]:
def plot_lr_schedule_comparison(runs_data, tokens_or_steps='tokens'):
    """
    Compare learning rate schedules across runs.
    """
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    colors = plt.cm.tab10(np.linspace(0, 1, len(runs_data)))
    
    for i, run_data in enumerate(runs_data):
        run_name = run_data['run_name'].iloc[0]
        x = run_data['tokens_seen'] if tokens_or_steps == 'tokens' else run_data['step']
        
        # Plot 1: Linear scale
        axes[0].plot(x, run_data['lr'], label=run_name, color=colors[i], linewidth=2, alpha=0.8)
        
        # Plot 2: Log scale
        axes[1].semilogy(x, run_data['lr'], label=run_name, color=colors[i], linewidth=2, alpha=0.8)
    
    x_label = 'Tokens Seen' if tokens_or_steps == 'tokens' else 'Training Steps'
    
    axes[0].set_xlabel(x_label, fontsize=12)
    axes[0].set_ylabel('Learning Rate', fontsize=12)
    axes[0].set_title('Learning Rate Schedule Comparison (Linear Scale)', fontsize=14, fontweight='bold')
    axes[0].legend(fontsize=10)
    axes[0].grid(True, alpha=0.3)
    if tokens_or_steps == 'tokens':
        axes[0].xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1e6:.1f}M'))
    
    axes[1].set_xlabel(x_label, fontsize=12)
    axes[1].set_ylabel('Learning Rate (Log Scale)', fontsize=12)
    axes[1].set_title('Learning Rate Schedule Comparison (Log Scale)', fontsize=14, fontweight='bold')
    axes[1].legend(fontsize=10)
    axes[1].grid(True, alpha=0.3, which='both')
    if tokens_or_steps == 'tokens':
        axes[1].xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1e6:.1f}M'))
    
    plt.tight_layout()
    return fig

# Plot LR schedule comparison
fig = plot_lr_schedule_comparison(runs_data, tokens_or_steps='tokens')
plt.show()


## 6. Convergence Speed Analysis

Compare how quickly different runs converge to their final performance.


In [None]:
def plot_convergence_analysis(runs_data, target_loss=5.0, tokens_or_steps='tokens'):
    """
    Analyze convergence speed: when each run reaches a target loss.
    """
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    colors = plt.cm.tab10(np.linspace(0, 1, len(runs_data)))
    
    convergence_tokens = []
    convergence_steps = []
    run_names = []
    
    for i, run_data in enumerate(runs_data):
        run_name = run_data['run_name'].iloc[0]
        x = run_data['tokens_seen'] if tokens_or_steps == 'tokens' else run_data['step']
        losses = run_data['loss'].values
        
        # Find when loss drops below target
        below_target = np.where(losses <= target_loss)[0]
        if len(below_target) > 0:
            conv_idx = below_target[0]
            conv_tokens = x.iloc[conv_idx] if hasattr(x, 'iloc') else x[conv_idx]
            conv_step = run_data['step'].iloc[conv_idx] if 'step' in run_data.columns else conv_idx
        else:
            conv_tokens = None
            conv_step = None
        
        convergence_tokens.append(conv_tokens)
        convergence_steps.append(conv_step)
        run_names.append(run_name)
        
        # Plot 1: Loss curves with target line
        axes[0].plot(x, losses, label=run_name, color=colors[i], linewidth=2, alpha=0.8)
        if conv_tokens is not None:
            axes[0].plot(conv_tokens, target_loss, 'o', color=colors[i], markersize=10, 
                        markeredgecolor='black', markeredgewidth=1)
            axes[0].annotate(f'{conv_tokens/1e6:.1f}M', xy=(conv_tokens, target_loss),
                           xytext=(10, 10), textcoords='offset points', fontsize=9)
    
    # Add target line
    axes[0].axhline(y=target_loss, color='red', linestyle='--', linewidth=2, 
                   label=f'Target Loss: {target_loss}', alpha=0.7)
    
    x_label = 'Tokens Seen' if tokens_or_steps == 'tokens' else 'Training Steps'
    axes[0].set_xlabel(x_label, fontsize=12)
    axes[0].set_ylabel('Loss', fontsize=12)
    axes[0].set_title(f'Convergence to Target Loss = {target_loss}', fontsize=14, fontweight='bold')
    axes[0].legend(fontsize=10)
    axes[0].grid(True, alpha=0.3)
    if tokens_or_steps == 'tokens':
        axes[0].xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1e6:.1f}M'))
    
    # Plot 2: Bar chart of convergence tokens
    valid_convs = [(name, tok) for name, tok in zip(run_names, convergence_tokens) if tok is not None]
    if valid_convs:
        names, tokens = zip(*valid_convs)
        bars = axes[1].barh(names, [t/1e6 for t in tokens], color=colors[:len(names)], alpha=0.7)
        axes[1].set_xlabel('Tokens to Converge (Millions)', fontsize=12)
        axes[1].set_title('Convergence Speed Comparison', fontsize=14, fontweight='bold')
        axes[1].grid(True, alpha=0.3, axis='x')
        
        # Add value labels on bars
        for i, (bar, tok) in enumerate(zip(bars, tokens)):
            axes[1].text(tok/1e6, i, f' {tok/1e6:.1f}M', va='center', fontsize=10)
    
    plt.tight_layout()
    return fig

# Plot convergence analysis
fig = plot_convergence_analysis(runs_data, target_loss=5.0, tokens_or_steps='tokens')
plt.show()


## 7. Side-by-Side Dashboard Comparison

Create a comprehensive side-by-side comparison of all runs.


In [None]:
def create_comparison_dashboard(runs_data, tokens_or_steps='tokens'):
    """
    Create a comprehensive comparison dashboard for all runs.
    """
    n_runs = len(runs_data)
    fig = plt.figure(figsize=(20, 5 * n_runs))
    gs = fig.add_gridspec(n_runs, 4, hspace=0.4, wspace=0.3)
    
    colors = plt.cm.tab10(np.linspace(0, 1, n_runs))
    
    for i, run_data in enumerate(runs_data):
        run_name = run_data['run_name'].iloc[0]
        x = run_data['tokens_seen'] if tokens_or_steps == 'tokens' else run_data['step']
        x_label = 'Tokens Seen (M)' if tokens_or_steps == 'tokens' else 'Steps'
        
        # Loss curve
        ax1 = fig.add_subplot(gs[i, 0])
        ax1.plot(x, run_data['loss'], color=colors[i], linewidth=2)
        ax1.set_xlabel(x_label, fontsize=10)
        ax1.set_ylabel('Loss', fontsize=10)
        ax1.set_title(f'{run_name}\nLoss', fontsize=11, fontweight='bold')
        ax1.grid(True, alpha=0.3)
        if tokens_or_steps == 'tokens':
            ax1.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1e6:.1f}'))
        
        # Perplexity
        ax2 = fig.add_subplot(gs[i, 1])
        ppl = np.exp(run_data['loss'])
        ax2.semilogy(x, ppl, color=colors[i], linewidth=2)
        ax2.set_xlabel(x_label, fontsize=10)
        ax2.set_ylabel('PPL', fontsize=10)
        ax2.set_title(f'{run_name}\nPerplexity', fontsize=11, fontweight='bold')
        ax2.grid(True, alpha=0.3, which='both')
        if tokens_or_steps == 'tokens':
            ax2.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1e6:.1f}'))
        
        # Learning rate
        ax3 = fig.add_subplot(gs[i, 2])
        ax3.semilogy(x, run_data['lr'], color=colors[i], linewidth=2)
        ax3.set_xlabel(x_label, fontsize=10)
        ax3.set_ylabel('LR', fontsize=10)
        ax3.set_title(f'{run_name}\nLearning Rate', fontsize=11, fontweight='bold')
        ax3.grid(True, alpha=0.3, which='both')
        if tokens_or_steps == 'tokens':
            ax3.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1e6:.1f}'))
        
        # Metrics summary
        ax4 = fig.add_subplot(gs[i, 3])
        ax4.axis('off')
        
        final_loss = run_data['loss'].iloc[-1]
        final_ppl = np.exp(final_loss)
        peak_lr = run_data['lr'].max()
        final_lr = run_data['lr'].iloc[-1]
        
        summary_text = (
            f"Final Loss: {final_loss:.4f}\n"
            f"Final PPL: {final_ppl:.2f}\n"
            f"Peak LR: {peak_lr:.2e}\n"
            f"Final LR: {final_lr:.2e}"
        )
        
        if 'tokens_per_sec' in run_data.columns:
            avg_speed = run_data['tokens_per_sec'].mean()
            summary_text += f"\nAvg Speed: {avg_speed:,.0f} tok/s"
        
        ax4.text(0.1, 0.5, summary_text, fontsize=10, verticalalignment='center',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3))
        ax4.set_title(f'{run_name}\nSummary', fontsize=11, fontweight='bold')
    
    plt.suptitle('Multi-Run Comparison Dashboard', fontsize=16, fontweight='bold', y=0.995)
    return fig

# Create comparison dashboard
fig = create_comparison_dashboard(runs_data, tokens_or_steps='tokens')
plt.show()
