# GPT-2 FineWeb-Edu Training Visualizations

This notebook provides comprehensive visualizations for training metrics including:
- Training and validation loss curves
- Perplexity over time
- Learning rate schedule
- Token processing statistics
- Training progress tracking

## Data Sources
- WandB API (if available)
- Local checkpoint files
- Synthetic data for demonstration


In [None]:
# Install required packages
# !pip install wandb matplotlib seaborn numpy pandas torch transformers

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import wandb
import os
from pathlib import Path
from datetime import datetime

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 10

print("Libraries imported successfully!")


## 1. Load Training Data from WandB

Load training metrics from WandB API. If WandB is not available, we'll generate synthetic data based on expected training behavior.


In [None]:
def load_wandb_data(project_name="gpt-fineweb-demo", run_id=None, use_api=True):
    """
    Load training data from WandB.
    If WandB is not available, generate synthetic data based on expected metrics.
    """
    if use_api:
        try:
            api = wandb.Api()
            if run_id:
                run = api.run(f"{project_name}/{run_id}")
            else:
                # Get the latest run
                runs = api.runs(project_name)
                if len(runs) == 0:
                    raise ValueError(f"No runs found in project {project_name}")
                run = runs[0]
            
            # Get history
            history = run.history()
            return history, run
        except Exception as e:
            print(f"WandB API error: {e}")
            print("Generating synthetic data instead...")
            use_api = False
    
    if not use_api:
        # Generate synthetic training data based on expected behavior
        # Final stats: Train loss 4.19 (PPL 66), Val loss 6.16 (PPL 478) at 200M tokens
        max_tokens = 200_000_000
        log_interval = 20
        tokens_per_step = 16 * 256  # batch_size * context_length
        
        steps = np.arange(0, max_tokens // tokens_per_step, log_interval)
        tokens_seen = steps * tokens_per_step
        
        # Training loss: starts at ~10.8, decays to 4.19
        train_loss = 10.8 * np.exp(-tokens_seen / 50_000_000) + 4.19 * (1 - np.exp(-tokens_seen / 50_000_000))
        train_loss += np.random.normal(0, 0.1, len(train_loss))  # Add noise
        
        # Validation loss: higher, starts at ~10, decays to 6.16
        val_steps = np.arange(0, len(steps), max(1, len(steps) // 10))  # Less frequent
        val_tokens = val_steps * tokens_per_step * log_interval
        val_loss = 10.0 * np.exp(-val_tokens / 80_000_000) + 6.16 * (1 - np.exp(-val_tokens / 80_000_000))
        val_loss += np.random.normal(0, 0.2, len(val_loss))
        
        # Learning rate: warmup then cosine decay
        warmup_tokens = 2_000_000
        peak_lr = 3e-4
        final_lr = 6e-5
        
        lr = np.zeros_like(tokens_seen)
        for i, tok in enumerate(tokens_seen):
            if tok <= warmup_tokens:
                lr[i] = peak_lr * (tok / warmup_tokens)
            else:
                progress = (tok - warmup_tokens) / (max_tokens - warmup_tokens)
                cosine = 0.5 * (1.0 + np.cos(np.pi * min(1.0, progress)))
                lr[i] = final_lr + (peak_lr - final_lr) * cosine
        
        # Token processing speed: starts slow, stabilizes
        tokens_per_sec = 5000 + 2000 * (1 - np.exp(-tokens_seen / 20_000_000))
        tokens_per_sec += np.random.normal(0, 200, len(tokens_per_sec))
        
        # Create DataFrame
        data = pd.DataFrame({
            'step': steps,
            'tokens_seen': tokens_seen,
            'loss': train_loss,
            'lr': lr,
            'tokens_per_sec': tokens_per_sec,
        })
        
        val_data = pd.DataFrame({
            'step': val_steps * log_interval,
            'tokens_seen': val_tokens,
            'val_loss': val_loss,
        })
        
        return data, val_data, None
    
    return None, None, None

# Load data
train_data, val_data, wandb_run = load_wandb_data(use_api=False)  # Set to True if WandB is available

print(f"Loaded {len(train_data)} training data points")
if val_data is not None:
    print(f"Loaded {len(val_data)} validation data points")
print("\nFirst few rows:")
print(train_data.head())


## 2. Training and Validation Loss Curves


In [None]:
def plot_loss_curves(train_data, val_data=None, tokens_or_steps='tokens'):
    """
    Plot training and validation loss curves.
    
    Args:
        train_data: DataFrame with training metrics
        val_data: DataFrame with validation metrics (optional)
        tokens_or_steps: 'tokens' or 'steps' for x-axis
    """
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    x_train = train_data['tokens_seen'] if tokens_or_steps == 'tokens' else train_data['step']
    x_label = 'Tokens Seen' if tokens_or_steps == 'tokens' else 'Training Steps'
    
    # Plot 1: Linear scale
    ax1 = axes[0]
    ax1.plot(x_train, train_data['loss'], label='Training Loss', color='#2E86AB', linewidth=2, alpha=0.8)
    if val_data is not None:
        x_val = val_data['tokens_seen'] if tokens_or_steps == 'tokens' else val_data['step']
        ax1.plot(x_val, val_data['val_loss'], label='Validation Loss', color='#A23B72', 
                linewidth=2, alpha=0.8, marker='o', markersize=4)
    
    ax1.set_xlabel(x_label, fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
    ax1.legend(fontsize=11)
    ax1.grid(True, alpha=0.3)
    
    # Format x-axis for tokens
    if tokens_or_steps == 'tokens':
        ax1.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1e6:.1f}M'))
    
    # Plot 2: Log scale (better for viewing decay)
    ax2 = axes[1]
    ax2.semilogy(x_train, train_data['loss'], label='Training Loss', color='#2E86AB', linewidth=2, alpha=0.8)
    if val_data is not None:
        ax2.semilogy(x_val, val_data['val_loss'], label='Validation Loss', color='#A23B72', 
                    linewidth=2, alpha=0.8, marker='o', markersize=4)
    
    ax2.set_xlabel(x_label, fontsize=12)
    ax2.set_ylabel('Loss (Log Scale)', fontsize=12)
    ax2.set_title('Training and Validation Loss (Log Scale)', fontsize=14, fontweight='bold')
    ax2.legend(fontsize=11)
    ax2.grid(True, alpha=0.3, which='both')
    
    if tokens_or_steps == 'tokens':
        ax2.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1e6:.1f}M'))
    
    plt.tight_layout()
    return fig

# Plot loss curves
fig = plot_loss_curves(train_data, val_data, tokens_or_steps='tokens')
plt.show()


In [None]:
def plot_perplexity(train_data, val_data=None, tokens_or_steps='tokens'):
    """
    Plot perplexity over training.
    Perplexity = exp(loss)
    """
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    x_train = train_data['tokens_seen'] if tokens_or_steps == 'tokens' else train_data['step']
    x_label = 'Tokens Seen' if tokens_or_steps == 'tokens' else 'Training Steps'
    
    # Calculate perplexity
    train_ppl = np.exp(train_data['loss'])
    if val_data is not None:
        val_ppl = np.exp(val_data['val_loss'])
        x_val = val_data['tokens_seen'] if tokens_or_steps == 'tokens' else val_data['step']
    
    # Plot 1: Linear scale
    ax1 = axes[0]
    ax1.plot(x_train, train_ppl, label='Train PPL', color='#06A77D', linewidth=2, alpha=0.8)
    if val_data is not None:
        ax1.plot(x_val, val_ppl, label='Val PPL', color='#F18F01', 
                linewidth=2, alpha=0.8, marker='o', markersize=4)
    
    ax1.set_xlabel(x_label, fontsize=12)
    ax1.set_ylabel('Perplexity', fontsize=12)
    ax1.set_title('Training and Validation Perplexity', fontsize=14, fontweight='bold')
    ax1.legend(fontsize=11)
    ax1.grid(True, alpha=0.3)
    
    # Add final values as text
    final_train_ppl = train_ppl.iloc[-1]
    ax1.axhline(y=final_train_ppl, color='#06A77D', linestyle='--', alpha=0.5, linewidth=1)
    ax1.text(x_train.iloc[-1] * 0.7, final_train_ppl * 1.1, 
             f'Final Train PPL: {final_train_ppl:.1f}', fontsize=10, color='#06A77D')
    
    if val_data is not None:
        final_val_ppl = val_ppl.iloc[-1]
        ax1.axhline(y=final_val_ppl, color='#F18F01', linestyle='--', alpha=0.5, linewidth=1)
        ax1.text(x_val.iloc[-1] * 0.7, final_val_ppl * 1.1, 
                 f'Final Val PPL: {final_val_ppl:.1f}', fontsize=10, color='#F18F01')
    
    if tokens_or_steps == 'tokens':
        ax1.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1e6:.1f}M'))
    
    # Plot 2: Log scale
    ax2 = axes[1]
    ax2.semilogy(x_train, train_ppl, label='Train PPL', color='#06A77D', linewidth=2, alpha=0.8)
    if val_data is not None:
        ax2.semilogy(x_val, val_ppl, label='Val PPL', color='#F18F01', 
                    linewidth=2, alpha=0.8, marker='o', markersize=4)
    
    ax2.set_xlabel(x_label, fontsize=12)
    ax2.set_ylabel('Perplexity (Log Scale)', fontsize=12)
    ax2.set_title('Training and Validation Perplexity (Log Scale)', fontsize=14, fontweight='bold')
    ax2.legend(fontsize=11)
    ax2.grid(True, alpha=0.3, which='both')
    
    if tokens_or_steps == 'tokens':
        ax2.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1e6:.1f}M'))
    
    plt.tight_layout()
    return fig

# Plot perplexity
fig = plot_perplexity(train_data, val_data, tokens_or_steps='tokens')
plt.show()

# Print final metrics
print(f"\n=== Final Training Metrics ===")
print(f"Final Train Loss: {train_data['loss'].iloc[-1]:.4f}")
print(f"Final Train PPL: {np.exp(train_data['loss'].iloc[-1]):.2f}")
if val_data is not None:
    print(f"Final Val Loss: {val_data['val_loss'].iloc[-1]:.4f}")
    print(f"Final Val PPL: {np.exp(val_data['val_loss'].iloc[-1]):.2f}")


## 4. Learning Rate Schedule Visualization


In [None]:
def plot_learning_rate_schedule(train_data, tokens_or_steps='tokens'):
    """
    Plot learning rate schedule over training.
    """
    fig, ax = plt.subplots(figsize=(12, 6))
    
    x = train_data['tokens_seen'] if tokens_or_steps == 'tokens' else train_data['step']
    x_label = 'Tokens Seen' if tokens_or_steps == 'tokens' else 'Training Steps'
    
    ax.plot(x, train_data['lr'], color='#C73E1D', linewidth=2, alpha=0.8)
    ax.set_xlabel(x_label, fontsize=12)
    ax.set_ylabel('Learning Rate', fontsize=12)
    ax.set_title('Learning Rate Schedule (Warmup + Cosine Decay)', fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3)
    ax.set_yscale('log')  # Log scale for LR
    
    # Annotate key points
    peak_lr_idx = train_data['lr'].idxmax()
    peak_lr = train_data['lr'].loc[peak_lr_idx]
    peak_x = x.loc[peak_lr_idx]
    
    ax.plot(peak_x, peak_lr, 'ro', markersize=8, label=f'Peak LR: {peak_lr:.2e}')
    ax.axvline(x=peak_x, color='r', linestyle='--', alpha=0.5, linewidth=1)
    
    final_lr = train_data['lr'].iloc[-1]
    final_x = x.iloc[-1]
    ax.plot(final_x, final_lr, 'go', markersize=8, label=f'Final LR: {final_lr:.2e}')
    
    ax.legend(fontsize=11)
    
    if tokens_or_steps == 'tokens':
        ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1e6:.1f}M'))
    
    plt.tight_layout()
    return fig

# Plot LR schedule
fig = plot_learning_rate_schedule(train_data, tokens_or_steps='tokens')
plt.show()

# Print LR stats
print(f"\n=== Learning Rate Statistics ===")
print(f"Peak LR: {train_data['lr'].max():.2e}")
print(f"Final LR: {train_data['lr'].iloc[-1]:.2e}")
print(f"LR Ratio (final/peak): {train_data['lr'].iloc[-1] / train_data['lr'].max():.3f}")


## 5. Token Processing Speed


In [None]:
def plot_token_processing_speed(train_data, tokens_or_steps='tokens'):
    """
    Plot token processing speed over training.
    """
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    x = train_data['tokens_seen'] if tokens_or_steps == 'tokens' else train_data['step']
    x_label = 'Tokens Seen' if tokens_or_steps == 'tokens' else 'Training Steps'
    
    # Plot 1: Speed over time
    ax1 = axes[0]
    ax1.plot(x, train_data['tokens_per_sec'], color='#6A4C93', linewidth=2, alpha=0.8)
    ax1.set_xlabel(x_label, fontsize=12)
    ax1.set_ylabel('Tokens per Second', fontsize=12)
    ax1.set_title('Token Processing Speed', fontsize=14, fontweight='bold')
    ax1.grid(True, alpha=0.3)
    
    # Add average line
    avg_speed = train_data['tokens_per_sec'].mean()
    ax1.axhline(y=avg_speed, color='r', linestyle='--', alpha=0.7, linewidth=2, 
                label=f'Average: {avg_speed:,.0f} tok/s')
    ax1.legend(fontsize=11)
    
    if tokens_or_steps == 'tokens':
        ax1.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1e6:.1f}M'))
    
    # Plot 2: Speed distribution
    ax2 = axes[1]
    ax2.hist(train_data['tokens_per_sec'], bins=50, color='#6A4C93', alpha=0.7, edgecolor='black')
    ax2.axvline(x=avg_speed, color='r', linestyle='--', linewidth=2, label=f'Mean: {avg_speed:,.0f}')
    ax2.axvline(x=train_data['tokens_per_sec'].median(), color='g', linestyle='--', linewidth=2, 
                label=f'Median: {train_data["tokens_per_sec"].median():,.0f}')
    ax2.set_xlabel('Tokens per Second', fontsize=12)
    ax2.set_ylabel('Frequency', fontsize=12)
    ax2.set_title('Token Processing Speed Distribution', fontsize=14, fontweight='bold')
    ax2.legend(fontsize=11)
    ax2.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    return fig

# Plot token processing speed
if 'tokens_per_sec' in train_data.columns:
    fig = plot_token_processing_speed(train_data, tokens_or_steps='tokens')
    plt.show()
    
    print(f"\n=== Token Processing Statistics ===")
    print(f"Average Speed: {train_data['tokens_per_sec'].mean():,.0f} tokens/sec")
    print(f"Median Speed: {train_data['tokens_per_sec'].median():,.0f} tokens/sec")
    print(f"Max Speed: {train_data['tokens_per_sec'].max():,.0f} tokens/sec")
    print(f"Min Speed: {train_data['tokens_per_sec'].min():,.0f} tokens/sec")
else:
    print("Token processing speed data not available in the dataset.")


## 6. Comprehensive Training Dashboard

Combine all metrics into a single comprehensive dashboard.


In [None]:
def create_training_dashboard(train_data, val_data=None, tokens_or_steps='tokens'):
    """
    Create a comprehensive training dashboard with all key metrics.
    """
    fig = plt.figure(figsize=(18, 12))
    gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)
    
    x = train_data['tokens_seen'] if tokens_or_steps == 'tokens' else train_data['step']
    x_label = 'Tokens Seen (M)' if tokens_or_steps == 'tokens' else 'Training Steps'
    
    # 1. Loss curves (top left, spans 2 columns)
    ax1 = fig.add_subplot(gs[0, :2])
    ax1.plot(x, train_data['loss'], label='Training Loss', color='#2E86AB', linewidth=2)
    if val_data is not None:
        x_val = val_data['tokens_seen'] if tokens_or_steps == 'tokens' else val_data['step']
        ax1.plot(x_val, val_data['val_loss'], label='Validation Loss', color='#A23B72', 
                linewidth=2, marker='o', markersize=4)
    ax1.set_xlabel(x_label, fontsize=11)
    ax1.set_ylabel('Loss', fontsize=11)
    ax1.set_title('Training and Validation Loss', fontsize=12, fontweight='bold')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    if tokens_or_steps == 'tokens':
        ax1.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1e6:.1f}'))
    
    # 2. Perplexity (top right)
    ax2 = fig.add_subplot(gs[0, 2])
    train_ppl = np.exp(train_data['loss'])
    ax2.plot(x, train_ppl, label='Train PPL', color='#06A77D', linewidth=2)
    if val_data is not None:
        val_ppl = np.exp(val_data['val_loss'])
        ax2.plot(x_val, val_ppl, label='Val PPL', color='#F18F01', linewidth=2, marker='o', markersize=4)
    ax2.set_xlabel(x_label, fontsize=11)
    ax2.set_ylabel('Perplexity', fontsize=11)
    ax2.set_title('Perplexity', fontsize=12, fontweight='bold')
    ax2.legend(fontsize=9)
    ax2.grid(True, alpha=0.3)
    if tokens_or_steps == 'tokens':
        ax2.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1e6:.1f}'))
    
    # 3. Learning rate (middle left)
    ax3 = fig.add_subplot(gs[1, 0])
    ax3.plot(x, train_data['lr'], color='#C73E1D', linewidth=2)
    ax3.set_xlabel(x_label, fontsize=11)
    ax3.set_ylabel('Learning Rate', fontsize=11)
    ax3.set_title('Learning Rate Schedule', fontsize=12, fontweight='bold')
    ax3.set_yscale('log')
    ax3.grid(True, alpha=0.3)
    if tokens_or_steps == 'tokens':
        ax3.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1e6:.1f}'))
    
    # 4. Token processing speed (middle center)
    if 'tokens_per_sec' in train_data.columns:
        ax4 = fig.add_subplot(gs[1, 1])
        ax4.plot(x, train_data['tokens_per_sec'], color='#6A4C93', linewidth=2)
        ax4.axhline(y=train_data['tokens_per_sec'].mean(), color='r', linestyle='--', 
                   label=f'Avg: {train_data["tokens_per_sec"].mean():,.0f}')
        ax4.set_xlabel(x_label, fontsize=11)
        ax4.set_ylabel('Tokens/sec', fontsize=11)
        ax4.set_title('Processing Speed', fontsize=12, fontweight='bold')
        ax4.legend(fontsize=9)
        ax4.grid(True, alpha=0.3)
        if tokens_or_steps == 'tokens':
            ax4.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1e6:.1f}'))
    
    # 5. Loss vs Learning Rate (middle right)
    ax5 = fig.add_subplot(gs[1, 2])
    scatter = ax5.scatter(train_data['lr'], train_data['loss'], 
                         c=x, cmap='viridis', alpha=0.6, s=20)
    ax5.set_xlabel('Learning Rate', fontsize=11)
    ax5.set_ylabel('Loss', fontsize=11)
    ax5.set_title('Loss vs Learning Rate', fontsize=12, fontweight='bold')
    ax5.set_xscale('log')
    ax5.grid(True, alpha=0.3)
    plt.colorbar(scatter, ax=ax5, label=x_label.replace(' (M)', ''))
    
    # 6. Metrics summary table (bottom, spans all columns)
    ax6 = fig.add_subplot(gs[2, :])
    ax6.axis('off')
    
    # Create summary statistics
    summary_data = {
        'Metric': [
            'Final Train Loss',
            'Final Train PPL',
            'Final Val Loss' if val_data is not None else 'N/A',
            'Final Val PPL' if val_data is not None else 'N/A',
            'Peak Learning Rate',
            'Final Learning Rate',
            'Avg Processing Speed' if 'tokens_per_sec' in train_data.columns else 'N/A',
            'Total Tokens Processed',
        ],
        'Value': [
            f"{train_data['loss'].iloc[-1]:.4f}",
            f"{np.exp(train_data['loss'].iloc[-1]):.2f}",
            f"{val_data['val_loss'].iloc[-1]:.4f}" if val_data is not None else 'N/A',
            f"{np.exp(val_data['val_loss'].iloc[-1]):.2f}" if val_data is not None else 'N/A',
            f"{train_data['lr'].max():.2e}",
            f"{train_data['lr'].iloc[-1]:.2e}",
            f"{train_data['tokens_per_sec'].mean():,.0f} tok/s" if 'tokens_per_sec' in train_data.columns else 'N/A',
            f"{train_data['tokens_seen'].iloc[-1]:,}",
        ]
    }
    
    summary_df = pd.DataFrame(summary_data)
    table = ax6.table(cellText=summary_df.values, colLabels=summary_df.columns,
                     cellLoc='left', loc='center', bbox=[0, 0, 1, 1])
    table.auto_set_font_size(False)
    table.set_fontsize(11)
    table.scale(1, 2)
    ax6.set_title('Training Summary Statistics', fontsize=14, fontweight='bold', pad=20)
    
    plt.suptitle('GPT-2 FineWeb-Edu Training Dashboard', fontsize=16, fontweight='bold', y=0.995)
    return fig

# Create comprehensive dashboard
fig = create_training_dashboard(train_data, val_data, tokens_or_steps='tokens')
plt.show()
