# GPT-2 Research Paper | Part IV

## Training & Scaling: The Hidden Engineering Story

---

**Paper:** [Language Models are Unsupervised Multitask Learners](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)

**Authors:** Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, Ilya Sutskever (OpenAI, 2019)

---

## The Story Nobody Tells

Most GPT-2 discussions focus on what it can do. This notebook focuses on **how it was built**.

Training GPT-2 XL (1.5B parameters) in 2019 was a **massive engineering challenge**. The decisions made here shaped all future LLM development.

This notebook covers:

1. **WebText Dataset Creation** - The exact pipeline that made GPT-2 possible
2. **Byte-Level BPE Tokenization** - A technical deep dive
3. **Training Infrastructure** - Hardware, distributed training, optimizations
4. **Training Dynamics** - Loss curves, learning rates, batch sizes
5. **Scaling Laws** - The discovery that changed AI forever
6. **Compute Analysis** - FLOPs, GPU-hours, costs
7. **Reproducibility Guide** - How to train your own GPT-2

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch, Rectangle, Circle, FancyArrowPatch, Wedge, Arc
from matplotlib.patches import ConnectionPatch, Polygon, Ellipse, PathPatch
import matplotlib.patches as mpatches
import matplotlib.gridspec as gridspec
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.sankey import Sankey
import numpy as np
import math
from typing import List, Dict, Tuple
from dataclasses import dataclass

plt.style.use('default')
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['font.size'] = 10

np.random.seed(42)
torch.manual_seed(42)

print("Ready to explore GPT-2 training!")

---

## 1. WebText: Building the Dataset

### 1.1 The Problem with Existing Datasets

Before GPT-2, language models were trained on:

| Dataset | Size | Problem |
|---------|------|--------|
| BooksCorpus | ~1GB | Small, limited domains |
| Wikipedia | ~12GB | Encyclopedic style only |
| Common Crawl | ~300GB | Extremely noisy |

### 1.2 The WebText Innovation

From the paper:

> *"We created a new web scrape which emphasizes document quality. To do this, we only scraped web pages which have been curated/filtered by humans."*

The key insight: **Use Reddit as a quality filter.**

> *"We scraped all outbound links from Reddit, a social media platform, which received at least 3 karma. This can be thought of as a heuristic indicator for whether other users found the link interesting, educational, or just funny."*

### 1.3 The Complete Pipeline

```
Reddit API ‚Üí Extract links with 3+ karma ‚Üí Download pages ‚Üí 
Extract text ‚Üí Clean/Deduplicate ‚Üí Remove Wikipedia ‚Üí WebText
```

### 1.4 WebText Statistics

| Metric | Value |
|--------|-------|
| Total Documents | ~8 million |
| Total Size | ~40 GB |
| Token Count | ~10 billion BPE tokens |
| Avg Doc Length | ~500 words |

### 1.5 Why Wikipedia Was Excluded

From the paper:

> *"Wikipedia is excluded from WebText since it is a common data source for other datasets and could complicate the analysis due to overlapping training data."*

This was for **fair evaluation**, not because Wikipedia is bad data.

In [None]:
def visualize_webtext_pipeline():
    """
    Comprehensive visualization of the WebText creation pipeline.
    """
    fig = plt.figure(figsize=(20, 14))
    
    gs = gridspec.GridSpec(3, 2, height_ratios=[1.5, 1, 1], hspace=0.35, wspace=0.25)
    
    # === TOP: Complete Pipeline ===
    ax1 = fig.add_subplot(gs[0, :])
    ax1.set_xlim(0, 24)
    ax1.set_ylim(0, 10)
    ax1.axis('off')
    ax1.set_title('WebText Creation Pipeline: From Reddit to Training Data', 
                  fontsize=14, fontweight='bold', pad=15)
    
    # Pipeline stages
    stages = [
        (2, 'Reddit\nAPI', '~45M posts\nwith links', '#FF4500', '100%'),
        (6, 'Karma\nFilter', 'Keep only\n3+ karma', '#f39c12', '~60%'),
        (10, 'URL\nExtraction', 'Outbound\nlinks only', '#3498db', '~45%'),
        (14, 'Web\nScraping', 'Download\npages', '#27ae60', '~35%'),
        (18, 'Cleaning', 'Dedup, filter\nboilerplate', '#9b59b6', '~20%'),
        (22, 'WebText', '8M docs\n40GB', '#e74c3c', 'Final'),
    ]
    
    for i, (x, name, desc, color, pct) in enumerate(stages):
        # Main box
        rect = FancyBboxPatch((x-1.5, 4), 3, 4, boxstyle="round,pad=0.05",
                              facecolor=color, edgecolor='black', linewidth=2)
        ax1.add_patch(rect)
        ax1.text(x, 7, name, ha='center', va='center', 
                fontsize=11, fontweight='bold', color='white')
        ax1.text(x, 5, desc, ha='center', va='center', 
                fontsize=9, color='white')
        
        # Percentage below
        ax1.text(x, 3, pct, ha='center', fontsize=10, fontweight='bold', color=color)
        
        # Arrow to next
        if i < len(stages) - 1:
            ax1.annotate('', xy=(x+2, 6), xytext=(x+1.6, 6),
                        arrowprops=dict(arrowstyle='->', color='black', lw=2))
    
    # Funnel visualization
    ax1.text(12, 1.5, 'Data Funnel: Most content filtered out for quality', 
             ha='center', fontsize=11, color='gray', style='italic')
    
    # === MIDDLE LEFT: Karma Distribution ===
    ax2 = fig.add_subplot(gs[1, 0])
    
    # Simulated karma distribution (power law)
    karma_values = np.concatenate([
        np.random.exponential(2, 30000),  # Most posts low karma
        np.random.exponential(10, 5000) + 10,  # Some medium
        np.random.exponential(50, 500) + 100,  # Few high
    ])
    karma_values = np.clip(karma_values, 0, 500)
    
    ax2.hist(karma_values, bins=50, color='#3498db', edgecolor='black', alpha=0.7)
    ax2.axvline(x=3, color='#e74c3c', linewidth=3, linestyle='--', label='Threshold (3 karma)')
    
    # Shade included region
    ax2.axvspan(3, 500, alpha=0.2, color='#27ae60')
    
    ax2.set_xlabel('Karma', fontsize=11)
    ax2.set_ylabel('Number of Posts', fontsize=11)
    ax2.set_title('Karma Distribution & Filter Threshold', fontsize=12, fontweight='bold')
    ax2.legend(fontsize=10)
    ax2.set_xlim(0, 100)
    
    # === MIDDLE RIGHT: Content Types ===
    ax3 = fig.add_subplot(gs[1, 1])
    
    categories = ['News/Articles', 'Educational', 'Entertainment', 
                  'Technical', 'Discussion', 'Creative']
    sizes = [28, 22, 18, 14, 10, 8]  # Estimated percentages
    colors = ['#e74c3c', '#3498db', '#f39c12', '#27ae60', '#9b59b6', '#1abc9c']
    explode = (0.05, 0, 0, 0, 0, 0)
    
    wedges, texts, autotexts = ax3.pie(sizes, explode=explode, labels=categories, 
                                        colors=colors, autopct='%1.0f%%',
                                        shadow=True, startangle=90,
                                        textprops={'fontsize': 9})
    ax3.set_title('Estimated Content Distribution in WebText', fontsize=12, fontweight='bold')
    
    # === BOTTOM: Dataset Comparison ===
    ax4 = fig.add_subplot(gs[2, :])
    ax4.set_xlim(0, 20)
    ax4.set_ylim(0, 8)
    ax4.axis('off')
    ax4.set_title('WebText vs Other Datasets', fontsize=13, fontweight='bold', pad=10)
    
    datasets = [
        ('BooksCorpus\n(GPT-1)', 1, '#95a5a6', '~7K books', 'Limited domains'),
        ('Wikipedia', 12, '#3498db', '6M articles', 'One style'),
        ('WebText', 40, '#e74c3c', '8M pages', 'Diverse, quality'),
        ('Common Crawl', 300, '#f39c12', 'Petabytes', 'Very noisy'),
    ]
    
    max_size = 300
    bar_positions = [2, 6, 10, 15]
    
    for (name, size, color, count, note), x in zip(datasets, bar_positions):
        # Log scale for height
        height = 5 * np.log10(size + 1) / np.log10(max_size + 1)
        
        rect = FancyBboxPatch((x-1, 2), 2.5, height, boxstyle="round,pad=0.02",
                              facecolor=color, edgecolor='black', linewidth=2)
        ax4.add_patch(rect)
        
        ax4.text(x+0.25, 1.5, name, ha='center', fontsize=9, fontweight='bold')
        ax4.text(x+0.25, 2 + height + 0.3, f'{size} GB', ha='center', fontsize=10, fontweight='bold')
        ax4.text(x+0.25, 2 + height/2, count, ha='center', va='center', fontsize=8, color='white')
        ax4.text(x+0.25, 0.7, note, ha='center', fontsize=8, color='gray')
    
    # Highlight WebText
    ax4.annotate('Best quality-size\ntradeoff!', xy=(10, 5.5), xytext=(12.5, 6.5),
                fontsize=10, fontweight='bold', color='#e74c3c',
                arrowprops=dict(arrowstyle='->', color='#e74c3c', lw=2))
    
    plt.tight_layout()
    plt.show()

visualize_webtext_pipeline()

---

## 2. Byte-Level BPE: The Tokenization Revolution

### 2.1 The Problem with Previous Tokenizers

GPT-1 and earlier models used **word-level BPE** with preprocessing:

```python
# GPT-1 style
text = "Hello caf√©!"  
tokens = ["hello", "caf", "[UNK]", "!"]  # √© becomes UNK!
```

Problems:
- Unknown tokens (`[UNK]`) for rare characters
- Language-specific preprocessing required
- Can't handle code, emoji, or special characters

### 2.2 The Byte-Level BPE Solution

From the paper:

> *"A base vocabulary of 256 bytes means that any text can be encoded without any unknown tokens."*

The key insight: **Work at the byte level, not the character level.**

### 2.3 How It Works

1. **Start with 256 byte tokens** (0x00 to 0xFF)
2. **Count all byte pairs** in the training data
3. **Merge the most common pair** into a new token
4. **Repeat 50,000 times**
5. **Add 1 special token** (`<|endoftext|>`)
6. **Result: 50,257 tokens**

### 2.4 The Vocabulary Composition

| Component | Count | Description |
|-----------|-------|-------------|
| Base bytes | 256 | All possible byte values |
| BPE merges | 50,000 | Learned subword units |
| Special tokens | 1 | `<|endoftext|>` |
| **Total** | **50,257** | Complete vocabulary |

### 2.5 Preventing Suboptimal Merges

From the paper:

> *"BPE would merge very common words like dog to many versions like dog. dog! dog? ... We prevent BPE from merging across character categories."*

They added rules to prevent merges across:
- Letters and punctuation
- Letters and whitespace
- Numbers and letters

In [None]:
def visualize_byte_level_bpe_detailed():
    """
    Detailed visualization of byte-level BPE tokenization.
    """
    fig = plt.figure(figsize=(20, 14))
    
    gs = gridspec.GridSpec(3, 2, height_ratios=[1.2, 1, 1], hspace=0.35, wspace=0.25)
    
    # === TOP LEFT: The BPE Algorithm ===
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.set_xlim(0, 12)
    ax1.set_ylim(0, 12)
    ax1.axis('off')
    ax1.set_title('BPE Merge Algorithm', fontsize=13, fontweight='bold')
    
    # Show merge iterations
    text = "the cat sat"
    iterations = [
        ('Start', list(text), 256),
        ('Merge 1: t+h', ['th', 'e', ' ', 'c', 'a', 't', ' ', 's', 'a', 't'], 257),
        ('Merge 2: th+e', ['the', ' ', 'c', 'a', 't', ' ', 's', 'a', 't'], 258),
        ('Merge 3: a+t', ['the', ' ', 'c', 'at', ' ', 's', 'at'], 259),
        ('Merge 4: c+at', ['the', ' ', 'cat', ' ', 's', 'at'], 260),
    ]
    
    for i, (stage, tokens, vocab_size) in enumerate(iterations):
        y = 10.5 - i * 2
        
        # Stage label
        ax1.text(0.2, y, stage, fontsize=9, va='center', fontweight='bold')
        
        # Tokens
        x = 3.5
        colors = plt.cm.Set3(np.linspace(0, 1, len(tokens)))
        for j, (tok, color) in enumerate(zip(tokens, colors)):
            width = max(len(tok) * 0.3, 0.4)
            rect = FancyBboxPatch((x, y-0.35), width, 0.7, boxstyle="round,pad=0.02",
                                  facecolor=color, edgecolor='black', linewidth=1)
            ax1.add_patch(rect)
            display_tok = tok if tok != ' ' else '‚ê£'
            ax1.text(x + width/2, y, display_tok, ha='center', va='center', 
                    fontsize=8, fontweight='bold')
            x += width + 0.1
        
        # Vocab size
        ax1.text(11, y, f'V={vocab_size}', fontsize=9, va='center', color='gray')
    
    ax1.text(6, 0.3, '...continue for 50,000 merges', fontsize=10, ha='center', 
             style='italic', color='gray')
    
    # === TOP RIGHT: Byte vs Character Level ===
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.set_xlim(0, 12)
    ax2.set_ylim(0, 12)
    ax2.axis('off')
    ax2.set_title('Byte-Level vs Character-Level', fontsize=13, fontweight='bold')
    
    # Character level problem
    rect1 = FancyBboxPatch((0.5, 7), 11, 4, boxstyle="round,pad=0.05",
                           facecolor='#fadbd8', edgecolor='#e74c3c', linewidth=2)
    ax2.add_patch(rect1)
    ax2.text(6, 10.5, 'Character-Level (GPT-1 style)', fontsize=11, ha='center', 
             fontweight='bold', color='#c0392b')
    ax2.text(6, 9.5, 'Input: "Hello Êó•Êú¨Ë™û üòÄ"', fontsize=10, ha='center', family='monospace')
    ax2.text(6, 8.5, 'Output: [Hello] [Êó•] [Êú¨] [Ë™û] [UNK]', fontsize=10, ha='center', 
             family='monospace')
    ax2.text(6, 7.5, '‚Üí Emoji becomes [UNK]! Information lost.', fontsize=9, ha='center', 
             color='#c0392b')
    
    # Byte level solution
    rect2 = FancyBboxPatch((0.5, 2), 11, 4, boxstyle="round,pad=0.05",
                           facecolor='#d5f4e6', edgecolor='#27ae60', linewidth=2)
    ax2.add_patch(rect2)
    ax2.text(6, 5.5, 'Byte-Level (GPT-2)', fontsize=11, ha='center', 
             fontweight='bold', color='#1e8449')
    ax2.text(6, 4.5, 'Input: "Hello Êó•Êú¨Ë™û üòÄ"', fontsize=10, ha='center', family='monospace')
    ax2.text(6, 3.5, 'Output: [Hello] [ƒ†] [√¶][ƒ∫][¬•] ... [√∞≈Åƒ∫][ƒ¢]', fontsize=10, ha='center', 
             family='monospace')
    ax2.text(6, 2.5, '‚Üí Everything encoded! No information loss.', fontsize=9, ha='center', 
             color='#1e8449')
    
    ax2.text(6, 0.8, 'Key: 256 bytes can represent ANY text', fontsize=10, ha='center', 
             fontweight='bold', color='#2c3e50')
    
    # === MIDDLE: Vocabulary Composition ===
    ax3 = fig.add_subplot(gs[1, :])
    ax3.set_xlim(0, 20)
    ax3.set_ylim(0, 8)
    ax3.axis('off')
    ax3.set_title('GPT-2 Vocabulary Composition: 50,257 Tokens', fontsize=13, fontweight='bold', pad=10)
    
    # Visual breakdown (scaled bar)
    total_width = 16
    start_x = 2
    
    components = [
        ('Base Bytes (256)', 256, '#e74c3c', 'ASCII + extended'),
        ('BPE Merges (50,000)', 50000, '#3498db', 'Learned subwords'),
        ('<|endoftext|> (1)', 1, '#27ae60', 'Document separator'),
    ]
    
    # Draw as proportional blocks
    x = start_x
    total = 50257
    
    for name, count, color, desc in components:
        # Width proportional to count (log scale for visibility)
        width = max(total_width * np.log(count + 1) / np.log(total + 1), 0.8)
        if count == 1:
            width = 0.8  # Make special token visible
        
        rect = FancyBboxPatch((x, 3.5), width, 2.5, boxstyle="round,pad=0.02",
                              facecolor=color, edgecolor='black', linewidth=2)
        ax3.add_patch(rect)
        
        ax3.text(x + width/2, 5.5, name.split('(')[0].strip(), ha='center', 
                fontsize=9, fontweight='bold', color='white')
        ax3.text(x + width/2, 4.3, f'{count:,}', ha='center', fontsize=10, color='white')
        ax3.text(x + width/2, 2.8, desc, ha='center', fontsize=8, color='gray')
        
        x += width + 0.2
    
    # Formula
    ax3.text(10, 1, '256 + 50,000 + 1 = 50,257 total tokens', 
             ha='center', fontsize=12, fontweight='bold', family='monospace')
    
    # === BOTTOM: Merge Prevention Rules ===
    ax4 = fig.add_subplot(gs[2, :])
    ax4.set_xlim(0, 20)
    ax4.set_ylim(0, 6)
    ax4.axis('off')
    ax4.set_title('Merge Prevention Rules: Avoiding Suboptimal Tokenization', 
                  fontsize=13, fontweight='bold', pad=10)
    
    rules = [
        ('Letters + Punctuation', 'dog + ! ‚Üí dog!', 'Keeps "dog" reusable', '#e74c3c'),
        ('Letters + Whitespace', 'the + ‚ê£ ‚Üí the‚ê£', 'Preserves word boundaries', '#3498db'),
        ('Numbers + Letters', '123 + abc ‚Üí 123abc', 'Keeps numbers separate', '#27ae60'),
        ('Across Categories', 'Various merges', 'Prevented by regex rules', '#f39c12'),
    ]
    
    for i, (rule, example, benefit, color) in enumerate(rules):
        x = 0.5 + i * 5
        
        rect = FancyBboxPatch((x, 1), 4.5, 4, boxstyle="round,pad=0.05",
                              facecolor=color, edgecolor='black', linewidth=2, alpha=0.85)
        ax4.add_patch(rect)
        
        ax4.text(x + 2.25, 4.3, 'BLOCKED', ha='center', fontsize=9, 
                fontweight='bold', color='white')
        ax4.text(x + 2.25, 3.4, rule, ha='center', fontsize=9, color='white')
        ax4.text(x + 2.25, 2.5, example, ha='center', fontsize=8, 
                color='white', family='monospace')
        ax4.text(x + 2.25, 1.5, benefit, ha='center', fontsize=7, color='white')
    
    plt.tight_layout()
    plt.show()

visualize_byte_level_bpe_detailed()

In [None]:
class SimpleBPE:
    """
    Simplified BPE implementation for demonstration.
    
    This shows the core algorithm used in GPT-2 tokenization.
    """
    
    def __init__(self):
        self.vocab = {bytes([i]): i for i in range(256)}  # Base bytes
        self.merges = {}  # (token1, token2) -> new_token
        self.next_id = 256
    
    def get_pairs(self, tokens: List[bytes]) -> Dict[Tuple[bytes, bytes], int]:
        """Count all adjacent pairs."""
        pairs = {}
        for i in range(len(tokens) - 1):
            pair = (tokens[i], tokens[i + 1])
            pairs[pair] = pairs.get(pair, 0) + 1
        return pairs
    
    def merge_pair(self, tokens: List[bytes], pair: Tuple[bytes, bytes]) -> List[bytes]:
        """Merge all occurrences of a pair."""
        new_token = pair[0] + pair[1]
        new_tokens = []
        i = 0
        while i < len(tokens):
            if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) == pair:
                new_tokens.append(new_token)
                i += 2
            else:
                new_tokens.append(tokens[i])
                i += 1
        return new_tokens
    
    def train(self, text: str, num_merges: int = 100, verbose: bool = True) -> None:
        """Train BPE on text."""
        # Convert to bytes
        tokens = [bytes([b]) for b in text.encode('utf-8')]
        
        if verbose:
            print(f"Starting with {len(set(tokens))} unique byte tokens")
            print(f"Initial sequence length: {len(tokens)}")
        
        for i in range(num_merges):
            # Get pair counts
            pairs = self.get_pairs(tokens)
            if not pairs:
                break
            
            # Find most common pair
            best_pair = max(pairs, key=pairs.get)
            
            # Merge it
            new_token = best_pair[0] + best_pair[1]
            self.vocab[new_token] = self.next_id
            self.merges[best_pair] = new_token
            self.next_id += 1
            
            # Apply merge
            tokens = self.merge_pair(tokens, best_pair)
            
            if verbose and (i + 1) % 20 == 0:
                print(f"Merge {i+1}: {repr(best_pair[0])} + {repr(best_pair[1])} -> "
                      f"{repr(new_token)} (count: {pairs[best_pair]})")
        
        if verbose:
            print(f"\nFinal vocab size: {len(self.vocab)}")
            print(f"Final sequence length: {len(tokens)}")
            print(f"Compression ratio: {len(text.encode('utf-8')) / len(tokens):.2f}x")
    
    def encode(self, text: str) -> List[int]:
        """Encode text to token IDs."""
        tokens = [bytes([b]) for b in text.encode('utf-8')]
        
        # Apply merges in order
        for pair, new_token in self.merges.items():
            tokens = self.merge_pair(tokens, pair)
        
        return [self.vocab[t] for t in tokens]


# Demonstrate
print("=" * 70)
print("BPE TRAINING DEMONSTRATION")
print("=" * 70)

sample_text = """The cat sat on the mat. The cat was happy. The mat was soft.
The happy cat sat on the soft mat. Cats like mats. Mats like cats."""

bpe = SimpleBPE()
bpe.train(sample_text, num_merges=50, verbose=True)

# Test encoding
test = "The cat sat"
encoded = bpe.encode(test)
print(f"\nEncoded '{test}': {encoded}")

---

## 3. Training Infrastructure

### 3.1 Hardware Requirements (2019)

Training GPT-2 XL required significant compute:

| Component | Specification |
|-----------|---------------|
| GPUs | 8√ó V100 32GB (estimated) |
| Memory | 32GB per GPU |
| Interconnect | NVLink |
| Training Time | ~1 week for XL |

### 3.2 Training Configuration

From the paper and follow-up releases:

| Hyperparameter | Value |
|----------------|-------|
| Batch Size | 512 sequences |
| Sequence Length | 1024 tokens |
| Tokens per Batch | 512 √ó 1024 = 524,288 |
| Learning Rate | Variable by model size |
| Optimizer | Adam |
| Weight Decay | 0.01 |
| Warmup Steps | Model-dependent |

### 3.3 Memory Optimization Techniques

To fit large models in GPU memory:

1. **Gradient Checkpointing**: Trade compute for memory
2. **Mixed Precision (FP16)**: Halve memory usage
3. **Gradient Accumulation**: Simulate larger batches
4. **Model Parallelism**: Split across GPUs

### 3.4 The Batch Size Choice

From the paper:

> *"A larger batchsize of 512 is used."*

Why 512? This is 8√ó larger than GPT-1's batch size of 64:
- Better gradient estimates
- More stable training
- Better hardware utilization

In [None]:
def visualize_training_infrastructure():
    """
    Visualize the training infrastructure for GPT-2.
    """
    fig = plt.figure(figsize=(20, 14))
    
    gs = gridspec.GridSpec(2, 2, hspace=0.3, wspace=0.25)
    
    # === TOP LEFT: Hardware Setup ===
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.set_xlim(0, 12)
    ax1.set_ylim(0, 12)
    ax1.axis('off')
    ax1.set_title('Estimated Training Hardware (2019)', fontsize=13, fontweight='bold')
    
    # GPU cluster visualization
    for i in range(8):
        row = i // 4
        col = i % 4
        x = 1.5 + col * 2.5
        y = 7 - row * 3
        
        rect = FancyBboxPatch((x, y), 2, 2.5, boxstyle="round,pad=0.03",
                              facecolor='#27ae60', edgecolor='black', linewidth=2)
        ax1.add_patch(rect)
        ax1.text(x + 1, y + 1.8, f'GPU {i}', ha='center', fontsize=9, 
                fontweight='bold', color='white')
        ax1.text(x + 1, y + 1.2, 'V100', ha='center', fontsize=8, color='white')
        ax1.text(x + 1, y + 0.6, '32GB', ha='center', fontsize=8, color='white')
    
    # NVLink connections
    ax1.text(6, 10.5, '8√ó NVIDIA V100 32GB', ha='center', fontsize=12, fontweight='bold')
    ax1.text(6, 3, 'Connected via NVLink', ha='center', fontsize=10, color='gray')
    
    # Total memory
    rect_total = FancyBboxPatch((2, 0.5), 8, 1.5, boxstyle="round,pad=0.03",
                                 facecolor='#3498db', edgecolor='black', linewidth=2)
    ax1.add_patch(rect_total)
    ax1.text(6, 1.25, 'Total GPU Memory: 256 GB', ha='center', 
             fontsize=11, fontweight='bold', color='white')
    
    # === TOP RIGHT: Training Configuration ===
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.set_xlim(0, 12)
    ax2.set_ylim(0, 12)
    ax2.axis('off')
    ax2.set_title('Training Configuration', fontsize=13, fontweight='bold')
    
    configs = [
        ('Batch Size', '512 sequences', '#e74c3c'),
        ('Sequence Length', '1024 tokens', '#3498db'),
        ('Tokens/Batch', '524,288', '#27ae60'),
        ('Optimizer', 'Adam', '#f39c12'),
        ('Weight Decay', '0.01', '#9b59b6'),
        ('Precision', 'Mixed FP16/FP32', '#1abc9c'),
    ]
    
    for i, (param, value, color) in enumerate(configs):
        y = 10.5 - i * 1.7
        
        rect = FancyBboxPatch((0.5, y-0.5), 5, 1.2, boxstyle="round,pad=0.02",
                              facecolor=color, edgecolor='black', linewidth=1.5)
        ax2.add_patch(rect)
        ax2.text(3, y, param, ha='center', va='center', 
                fontsize=10, fontweight='bold', color='white')
        
        rect2 = FancyBboxPatch((6, y-0.5), 5, 1.2, boxstyle="round,pad=0.02",
                               facecolor='#f5f5f5', edgecolor=color, linewidth=2)
        ax2.add_patch(rect2)
        ax2.text(8.5, y, value, ha='center', va='center', fontsize=10)
    
    # === BOTTOM LEFT: Memory Breakdown ===
    ax3 = fig.add_subplot(gs[1, 0])
    
    # Memory usage breakdown for GPT-2 XL
    categories = ['Parameters\n(FP16)', 'Gradients\n(FP16)', 'Optimizer\nStates', 'Activations']
    sizes = [3.0, 3.0, 12.0, 8.0]  # GB, estimated
    colors = ['#e74c3c', '#3498db', '#f39c12', '#27ae60']
    
    bars = ax3.bar(categories, sizes, color=colors, edgecolor='black', linewidth=2)
    
    ax3.set_ylabel('Memory (GB)', fontsize=11)
    ax3.set_title('Memory Breakdown: GPT-2 XL Training (per GPU)', fontsize=12, fontweight='bold')
    ax3.set_ylim(0, 15)
    
    for bar, size in zip(bars, sizes):
        ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.3,
                f'{size:.1f} GB', ha='center', fontsize=10, fontweight='bold')
    
    ax3.axhline(y=32, color='gray', linestyle='--', linewidth=2)
    ax3.text(3.5, 32.5, 'V100 32GB limit', fontsize=9, color='gray')
    
    total = sum(sizes)
    ax3.text(1.5, 28, f'Total: ~{total:.0f} GB', fontsize=11, fontweight='bold')
    
    # === BOTTOM RIGHT: Training Timeline ===
    ax4 = fig.add_subplot(gs[1, 1])
    ax4.set_xlim(0, 12)
    ax4.set_ylim(0, 10)
    ax4.axis('off')
    ax4.set_title('Estimated Training Time by Model Size', fontsize=13, fontweight='bold')
    
    models = [
        ('GPT-2 Small', '117M', '~1 day', '#3498db'),
        ('GPT-2 Medium', '345M', '~2 days', '#27ae60'),
        ('GPT-2 Large', '774M', '~4 days', '#f39c12'),
        ('GPT-2 XL', '1.5B', '~7 days', '#e74c3c'),
    ]
    
    for i, (name, params, time, color) in enumerate(models):
        y = 8 - i * 2
        
        # Model box
        rect = FancyBboxPatch((0.5, y-0.5), 4, 1.3, boxstyle="round,pad=0.02",
                              facecolor=color, edgecolor='black', linewidth=2)
        ax4.add_patch(rect)
        ax4.text(2.5, y+0.2, name, ha='center', fontsize=10, fontweight='bold', color='white')
        ax4.text(2.5, y-0.3, params, ha='center', fontsize=9, color='white')
        
        # Time bar
        days = float(time.split('~')[1].split()[0])
        bar_width = days * 0.8
        rect2 = FancyBboxPatch((5, y-0.35), bar_width, 0.9, boxstyle="round,pad=0.02",
                               facecolor=color, edgecolor='black', linewidth=1, alpha=0.6)
        ax4.add_patch(rect2)
        ax4.text(5 + bar_width + 0.3, y, time, va='center', fontsize=10, fontweight='bold')
    
    ax4.text(6, 0.5, '(On 8√ó V100 GPUs with optimizations)', ha='center', 
             fontsize=10, color='gray', style='italic')
    
    plt.tight_layout()
    plt.show()

visualize_training_infrastructure()

---

## 4. Training Dynamics

### 4.1 Learning Rate Schedule

GPT-2 uses a **cosine learning rate schedule with warmup**:

1. **Warmup phase**: Linear increase from 0 to peak LR
2. **Decay phase**: Cosine decay to ~0

### 4.2 Learning Rates by Model Size

Larger models typically use smaller learning rates:

| Model | Peak Learning Rate |
|-------|-------------------|
| Small (117M) | 2.5e-4 |
| Medium (345M) | 2.5e-4 |
| Large (774M) | 1.5e-4 |
| XL (1.5B) | 1.0e-4 |

### 4.3 Loss Curves

What we expect during training:

1. **Rapid initial drop**: Model learns basic patterns
2. **Gradual improvement**: Learns finer patterns
3. **Plateau**: Approaches dataset entropy

### 4.4 Overfitting Considerations

With 10B tokens and 1.5B parameters:
- Ratio: ~6.7 tokens per parameter
- This is actually **undertrained** by modern standards
- Chinchilla (2022) suggests 20 tokens per parameter optimal

In [None]:
def visualize_training_dynamics():
    """
    Visualize training dynamics including learning rate and loss curves.
    """
    fig = plt.figure(figsize=(18, 12))
    
    gs = gridspec.GridSpec(2, 2, hspace=0.3, wspace=0.25)
    
    # === TOP LEFT: Learning Rate Schedule ===
    ax1 = fig.add_subplot(gs[0, 0])
    
    # Cosine schedule with warmup
    total_steps = 100000
    warmup_steps = 2000
    peak_lr = 2.5e-4
    
    steps = np.arange(total_steps)
    lr = np.zeros(total_steps)
    
    # Warmup
    lr[:warmup_steps] = peak_lr * steps[:warmup_steps] / warmup_steps
    
    # Cosine decay
    decay_steps = total_steps - warmup_steps
    lr[warmup_steps:] = peak_lr * 0.5 * (1 + np.cos(np.pi * np.arange(decay_steps) / decay_steps))
    
    ax1.plot(steps / 1000, lr * 1e4, color='#3498db', linewidth=2.5)
    ax1.fill_between(steps / 1000, 0, lr * 1e4, alpha=0.3, color='#3498db')
    
    ax1.axvline(x=warmup_steps/1000, color='#e74c3c', linestyle='--', linewidth=2)
    ax1.text(warmup_steps/1000 + 2, peak_lr * 1e4 * 0.9, 'Warmup\nends', 
             fontsize=9, color='#e74c3c')
    
    ax1.set_xlabel('Training Steps (thousands)', fontsize=11)
    ax1.set_ylabel('Learning Rate (√ó10‚Åª‚Å¥)', fontsize=11)
    ax1.set_title('Learning Rate Schedule: Cosine with Warmup', fontsize=12, fontweight='bold')
    ax1.grid(True, alpha=0.3)
    ax1.set_xlim(0, 100)
    
    # === TOP RIGHT: Simulated Loss Curve ===
    ax2 = fig.add_subplot(gs[0, 1])
    
    # Simulate loss curves for different model sizes
    np.random.seed(42)
    
    def simulate_loss(steps, final_loss, noise_scale=0.02):
        """Simulate a realistic loss curve."""
        initial_loss = np.log(50257)  # ~10.8 (random baseline)
        decay = np.exp(-steps / 20000)
        base_loss = final_loss + (initial_loss - final_loss) * decay
        noise = np.random.randn(len(steps)) * noise_scale * decay
        return base_loss + noise
    
    steps = np.linspace(0, 100000, 500)
    
    models_loss = [
        ('GPT-2 Small', 4.5, '#3498db'),
        ('GPT-2 Medium', 3.8, '#27ae60'),
        ('GPT-2 Large', 3.3, '#f39c12'),
        ('GPT-2 XL', 3.0, '#e74c3c'),
    ]
    
    for name, final_loss, color in models_loss:
        loss = simulate_loss(steps, final_loss)
        ax2.plot(steps / 1000, loss, linewidth=2, color=color, label=name, alpha=0.8)
    
    ax2.set_xlabel('Training Steps (thousands)', fontsize=11)
    ax2.set_ylabel('Training Loss (nats)', fontsize=11)
    ax2.set_title('Training Loss by Model Size (Simulated)', fontsize=12, fontweight='bold')
    ax2.legend(fontsize=10)
    ax2.grid(True, alpha=0.3)
    ax2.set_xlim(0, 100)
    ax2.set_ylim(2, 11)
    
    # Random baseline
    ax2.axhline(y=np.log(50257), color='gray', linestyle='--', linewidth=1.5)
    ax2.text(80, np.log(50257) + 0.2, 'Random baseline', fontsize=9, color='gray')
    
    # === BOTTOM LEFT: Tokens per Parameter ===
    ax3 = fig.add_subplot(gs[1, 0])
    
    models = ['Small', 'Medium', 'Large', 'XL']
    params = [117, 345, 774, 1500]  # Millions
    tokens = 10000  # Millions (10B)
    
    ratios = [tokens / p for p in params]
    colors = ['#3498db', '#27ae60', '#f39c12', '#e74c3c']
    
    bars = ax3.bar(models, ratios, color=colors, edgecolor='black', linewidth=2)
    
    ax3.axhline(y=20, color='#9b59b6', linestyle='--', linewidth=2, label='Chinchilla optimal (20:1)')
    
    ax3.set_ylabel('Tokens per Parameter', fontsize=11)
    ax3.set_title('Data-to-Parameter Ratio', fontsize=12, fontweight='bold')
    ax3.legend(fontsize=10)
    
    for bar, ratio in zip(bars, ratios):
        ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 2,
                f'{ratio:.1f}:1', ha='center', fontsize=10, fontweight='bold')
    
    ax3.text(1.5, 75, 'GPT-2 was undertrained!', fontsize=11, fontweight='bold', color='#e74c3c')
    ax3.text(1.5, 68, '(By modern standards)', fontsize=10, color='gray')
    
    # === BOTTOM RIGHT: Learning Rate by Model Size ===
    ax4 = fig.add_subplot(gs[1, 1])
    ax4.set_xlim(0, 12)
    ax4.set_ylim(0, 10)
    ax4.axis('off')
    ax4.set_title('Learning Rate Selection by Model Size', fontsize=13, fontweight='bold')
    
    lr_data = [
        ('GPT-2 Small', '117M', '2.5e-4', '#3498db'),
        ('GPT-2 Medium', '345M', '2.5e-4', '#27ae60'),
        ('GPT-2 Large', '774M', '1.5e-4', '#f39c12'),
        ('GPT-2 XL', '1.5B', '1.0e-4', '#e74c3c'),
    ]
    
    for i, (name, params, lr_val, color) in enumerate(lr_data):
        y = 8 - i * 2
        
        rect = FancyBboxPatch((0.5, y-0.5), 3, 1.3, boxstyle="round,pad=0.02",
                              facecolor=color, edgecolor='black', linewidth=2)
        ax4.add_patch(rect)
        ax4.text(2, y, name, ha='center', va='center', 
                fontsize=10, fontweight='bold', color='white')
        
        ax4.text(4.5, y, params, va='center', fontsize=10)
        
        # LR bar
        lr_float = float(lr_val)
        bar_width = lr_float / 2.5e-4 * 4  # Normalized
        rect2 = FancyBboxPatch((6, y-0.35), bar_width, 0.7, boxstyle="round,pad=0.02",
                               facecolor=color, edgecolor='black', linewidth=1, alpha=0.6)
        ax4.add_patch(rect2)
        ax4.text(6 + bar_width + 0.3, y, lr_val, va='center', fontsize=10, family='monospace')
    
    ax4.text(6, 0.5, 'Larger models ‚Üí Smaller learning rates', ha='center', 
             fontsize=11, fontweight='bold', color='#2c3e50')
    
    plt.tight_layout()
    plt.show()

visualize_training_dynamics()

---

## 5. Scaling Laws: The Discovery That Changed AI

### 5.1 What GPT-2 Revealed

The paper showed a remarkable finding:

> *"Performance on many NLP tasks continues to improve with increasing model capacity in a log-linear fashion."*

### 5.2 The Power Law Relationship

Loss scales as a power law with:

$$L(N) = \left(\frac{N_c}{N}\right)^{\alpha_N}$$

Where:
- $L$ = Loss
- $N$ = Number of parameters
- $N_c$ = Critical parameter count
- $\alpha_N$ ‚âà 0.076 (from later Kaplan et al. 2020)

### 5.3 The Three Scaling Dimensions

Later work (Kaplan et al. 2020) formalized three scaling dimensions:

1. **Parameters (N)**: Model size
2. **Data (D)**: Training tokens
3. **Compute (C)**: FLOPs

All follow power laws!

### 5.4 Why This Matters

If performance is predictable:
- Can estimate results before training
- Can plan compute budgets
- Can justify large investments
- No "plateau" means scaling continues to work

This insight directly led to GPT-3 and the modern LLM era.

In [None]:
def visualize_scaling_laws():
    """
    Comprehensive visualization of scaling laws.
    """
    fig = plt.figure(figsize=(20, 14))
    
    gs = gridspec.GridSpec(2, 2, hspace=0.3, wspace=0.25)
    
    # === TOP LEFT: Loss vs Parameters (Linear) ===
    ax1 = fig.add_subplot(gs[0, 0])
    
    # GPT-2 data
    params = np.array([117, 345, 774, 1500])  # Millions
    perplexity = np.array([37.50, 26.37, 22.05, 18.34])  # WikiText-103
    loss = np.log(perplexity)  # Convert to nats
    
    ax1.plot(params, loss, 'o-', color='#e74c3c', linewidth=2.5, markersize=12,
             markerfacecolor='white', markeredgewidth=2)
    
    for p, l in zip(params, loss):
        ax1.annotate(f'{np.exp(l):.1f}', xy=(p, l), xytext=(p+50, l+0.1),
                    fontsize=9, fontweight='bold')
    
    ax1.set_xlabel('Parameters (Millions)', fontsize=11)
    ax1.set_ylabel('Loss (nats)', fontsize=11)
    ax1.set_title('Loss vs Parameters (Linear Scale)', fontsize=12, fontweight='bold')
    ax1.grid(True, alpha=0.3)
    ax1.set_xlim(0, 1700)
    
    # === TOP RIGHT: Loss vs Parameters (Log-Log) ===
    ax2 = fig.add_subplot(gs[0, 1])
    
    ax2.loglog(params, loss, 'o', color='#e74c3c', markersize=12,
               markerfacecolor='white', markeredgewidth=2)
    
    # Power law fit
    log_params = np.log10(params)
    log_loss = np.log10(loss)
    coeffs = np.polyfit(log_params, log_loss, 1)
    
    fit_params = np.logspace(1.5, 6, 100)  # 30M to 1T
    fit_loss = 10 ** (coeffs[0] * np.log10(fit_params) + coeffs[1])
    
    ax2.loglog(fit_params, fit_loss, '--', color='#3498db', linewidth=2,
               label=f'Power law: L ‚àù N^{{{coeffs[0]:.3f}}}')
    
    # Extrapolation points
    future_models = [
        (175000, 'GPT-3', '#27ae60'),
        (540000, 'PaLM', '#9b59b6'),
    ]
    
    for size, name, color in future_models:
        pred_loss = 10 ** (coeffs[0] * np.log10(size) + coeffs[1])
        ax2.scatter([size], [pred_loss], color=color, s=150, marker='*', zorder=5)
        ax2.annotate(name, xy=(size, pred_loss), xytext=(size*1.5, pred_loss*1.1),
                    fontsize=9, color=color, fontweight='bold')
    
    ax2.set_xlabel('Parameters (Millions)', fontsize=11)
    ax2.set_ylabel('Loss (nats)', fontsize=11)
    ax2.set_title('Log-Log Scale: THE POWER LAW!', fontsize=12, fontweight='bold', color='#27ae60')
    ax2.legend(fontsize=10)
    ax2.grid(True, alpha=0.3)
    
    # === BOTTOM LEFT: Three Scaling Dimensions ===
    ax3 = fig.add_subplot(gs[1, 0])
    ax3.set_xlim(0, 12)
    ax3.set_ylim(0, 10)
    ax3.axis('off')
    ax3.set_title('The Three Scaling Dimensions', fontsize=13, fontweight='bold')
    
    dimensions = [
        ('N', 'Parameters', 'Model size\n(weights)', '#e74c3c', 'L ‚àù N^{-0.076}'),
        ('D', 'Data', 'Training tokens\n(examples)', '#3498db', 'L ‚àù D^{-0.095}'),
        ('C', 'Compute', 'FLOPs\n(operations)', '#27ae60', 'L ‚àù C^{-0.050}'),
    ]
    
    for i, (symbol, name, desc, color, formula) in enumerate(dimensions):
        x = 1 + i * 4
        
        # Main circle
        circle = plt.Circle((x + 1.5, 6), 1.3, facecolor=color, edgecolor='black', linewidth=2)
        ax3.add_patch(circle)
        ax3.text(x + 1.5, 6.3, symbol, ha='center', va='center', 
                fontsize=20, fontweight='bold', color='white')
        ax3.text(x + 1.5, 5.5, name, ha='center', va='center', 
                fontsize=10, color='white')
        
        # Description
        ax3.text(x + 1.5, 4, desc, ha='center', va='center', fontsize=9, color='gray')
        
        # Formula
        rect = FancyBboxPatch((x, 1.5), 3, 1, boxstyle="round,pad=0.02",
                              facecolor='#f5f5f5', edgecolor=color, linewidth=2)
        ax3.add_patch(rect)
        ax3.text(x + 1.5, 2, formula, ha='center', va='center', 
                fontsize=10, family='monospace')
    
    ax3.text(6, 0.3, 'All three follow power laws! (Kaplan et al. 2020)', 
             ha='center', fontsize=11, fontweight='bold', color='#2c3e50')
    
    # === BOTTOM RIGHT: Why It Matters ===
    ax4 = fig.add_subplot(gs[1, 1])
    ax4.set_xlim(0, 12)
    ax4.set_ylim(0, 10)
    ax4.axis('off')
    ax4.set_title('Why Scaling Laws Changed AI', fontsize=13, fontweight='bold')
    
    implications = [
        ('Predictable', 'Can forecast performance\nbefore training', '#e74c3c'),
        ('No Plateau', 'Scaling continues to work\n(no ceiling in sight)', '#3498db'),
        ('Investment\nJustification', '10√ó compute ‚Üí\nmeasurable improvement', '#27ae60'),
        ('Research\nDirection', 'Scale > Architecture\n(for now)', '#f39c12'),
    ]
    
    for i, (title, desc, color) in enumerate(implications):
        row = i // 2
        col = i % 2
        x = 0.5 + col * 6
        y = 6.5 - row * 4
        
        rect = FancyBboxPatch((x, y), 5, 3, boxstyle="round,pad=0.05",
                              facecolor=color, edgecolor='black', linewidth=2, alpha=0.85)
        ax4.add_patch(rect)
        
        ax4.text(x + 2.5, y + 2.2, title, ha='center', va='center', 
                fontsize=10, fontweight='bold', color='white')
        ax4.text(x + 2.5, y + 1, desc, ha='center', va='center', 
                fontsize=9, color='white')
    
    plt.tight_layout()
    plt.show()

visualize_scaling_laws()

---

## 6. Compute Analysis

### 6.1 FLOPs Calculation

For a transformer forward pass:

$$\text{FLOPs} \approx 6 \times N \times D$$

Where:
- N = Number of parameters
- D = Number of tokens processed
- 6 = Forward (2) + Backward (4) passes

### 6.2 GPT-2 XL Training Compute

| Metric | Value |
|--------|-------|
| Parameters | 1.5B |
| Training tokens | ~10B |
| Total FLOPs | ~9 √ó 10¬π‚Åπ |
| V100 peak | 125 TFLOPs |
| Utilization | ~30-50% |
| Effective | ~50 TFLOPs |
| GPU-hours | ~500-1000 |

### 6.3 Cost Estimation (2019 Prices)

| Resource | Cost |
|----------|------|
| V100 GPU (cloud) | ~$3/hour |
| 8 GPUs √ó 168 hours | ~$4,000 |
| With overhead | ~$10,000-50,000 |

### 6.4 Comparison to GPT-3

| Model | Parameters | Compute | Estimated Cost |
|-------|------------|---------|----------------|
| GPT-2 XL | 1.5B | 9√ó10¬π‚Åπ | ~$50K |
| GPT-3 | 175B | 3.1√ó10¬≤¬≥ | ~$5-12M |

GPT-3 required **~3,400√ó more compute**!

In [None]:
def compute_training_flops(params_millions: float, tokens_billions: float) -> float:
    """
    Estimate training FLOPs using the 6ND approximation.
    
    Args:
        params_millions: Model parameters in millions
        tokens_billions: Training tokens in billions
        
    Returns:
        Total FLOPs
    """
    N = params_millions * 1e6
    D = tokens_billions * 1e9
    return 6 * N * D


def flops_to_gpu_hours(flops: float, gpu_tflops: float = 50.0) -> float:
    """
    Convert FLOPs to GPU-hours.
    
    Args:
        flops: Total FLOPs
        gpu_tflops: Effective GPU throughput in TFLOPs
        
    Returns:
        GPU-hours required
    """
    flops_per_hour = gpu_tflops * 1e12 * 3600
    return flops / flops_per_hour


def print_compute_analysis():
    """Print detailed compute analysis."""
    print("=" * 80)
    print("GPT-2 COMPUTE ANALYSIS")
    print("=" * 80)
    
    models = [
        ('GPT-2 Small', 117, 10),
        ('GPT-2 Medium', 345, 10),
        ('GPT-2 Large', 774, 10),
        ('GPT-2 XL', 1500, 10),
    ]
    
    print(f"\n{'Model':<18} {'Params':<10} {'Tokens':<10} {'FLOPs':<15} {'GPU-hrs':<12} {'Cost (est)'}")
    print("-" * 80)
    
    for name, params, tokens in models:
        flops = compute_training_flops(params, tokens)
        gpu_hours = flops_to_gpu_hours(flops)
        cost = gpu_hours * 3  # $3/GPU-hour estimate
        
        print(f"{name:<18} {params}M{'':<6} {tokens}B{'':<6} {flops:.2e}   "
              f"{gpu_hours:,.0f}        ${cost:,.0f}")
    
    print("\n" + "=" * 80)
    print("COMPARISON TO LATER MODELS")
    print("=" * 80)
    
    comparison = [
        ('GPT-2 XL', 1500, 10, 50000),
        ('GPT-3', 175000, 300, 5000000),
        ('GPT-4 (est)', 1000000, 2000, 100000000),
    ]
    
    print(f"\n{'Model':<15} {'Params':<12} {'Tokens':<12} {'FLOPs':<18} {'Cost (est)'}")
    print("-" * 80)
    
    for name, params, tokens, cost in comparison:
        flops = compute_training_flops(params, tokens)
        print(f"{name:<15} {params/1000:.0f}B{'':<8} {tokens}B{'':<8} {flops:.2e}     ${cost:,}")
    
    print("\n" + "=" * 80)
    print("KEY FORMULAS")
    print("=" * 80)
    print("\nTraining FLOPs ‚âà 6 √ó N √ó D")
    print("  N = parameters, D = tokens")
    print("  6 = 2 (forward) + 4 (backward)")
    print("\nGPU-hours = FLOPs / (GPU_TFLOPS √ó 3600 √ó efficiency)")
    print("  Typical efficiency: 30-50%")

print_compute_analysis()

In [None]:
def visualize_compute_analysis():
    """
    Visualize compute requirements and costs.
    """
    fig = plt.figure(figsize=(18, 12))
    
    gs = gridspec.GridSpec(2, 2, hspace=0.3, wspace=0.25)
    
    # === TOP LEFT: FLOPs by Model Size ===
    ax1 = fig.add_subplot(gs[0, 0])
    
    models = ['Small\n117M', 'Medium\n345M', 'Large\n774M', 'XL\n1.5B']
    flops = [compute_training_flops(117, 10),
             compute_training_flops(345, 10),
             compute_training_flops(774, 10),
             compute_training_flops(1500, 10)]
    
    colors = plt.cm.Reds(np.linspace(0.3, 0.8, 4))
    bars = ax1.bar(models, [f/1e18 for f in flops], color=colors, edgecolor='black', linewidth=2)
    
    ax1.set_ylabel('FLOPs (√ó10¬π‚Å∏)', fontsize=11)
    ax1.set_title('Training Compute by Model Size', fontsize=12, fontweight='bold')
    
    for bar, f in zip(bars, flops):
        ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 2,
                f'{f:.1e}', ha='center', fontsize=9, rotation=45)
    
    # === TOP RIGHT: GPU Hours ===
    ax2 = fig.add_subplot(gs[0, 1])
    
    gpu_hours = [flops_to_gpu_hours(f) for f in flops]
    
    bars2 = ax2.bar(models, gpu_hours, color=colors, edgecolor='black', linewidth=2)
    
    ax2.set_ylabel('GPU-hours (V100 @ 50 TFLOPS)', fontsize=11)
    ax2.set_title('Training Time by Model Size', fontsize=12, fontweight='bold')
    
    for bar, h in zip(bars2, gpu_hours):
        days = h / 24 / 8  # Assuming 8 GPUs
        ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 20,
                f'{h:.0f}h\n({days:.1f}d @ 8 GPU)', ha='center', fontsize=8)
    
    # === BOTTOM LEFT: Cost Comparison ===
    ax3 = fig.add_subplot(gs[1, 0])
    
    model_costs = [
        ('GPT-2 XL', 50000, '#3498db'),
        ('GPT-3', 5000000, '#e74c3c'),
        ('GPT-4 (est)', 100000000, '#27ae60'),
    ]
    
    names = [m[0] for m in model_costs]
    costs = [m[1] for m in model_costs]
    colors_cost = [m[2] for m in model_costs]
    
    bars3 = ax3.barh(names, costs, color=colors_cost, edgecolor='black', linewidth=2)
    ax3.set_xscale('log')
    ax3.set_xlabel('Estimated Training Cost ($)', fontsize=11)
    ax3.set_title('Training Cost Comparison (Log Scale)', fontsize=12, fontweight='bold')
    
    for bar, cost in zip(bars3, costs):
        ax3.text(cost * 1.5, bar.get_y() + bar.get_height()/2,
                f'${cost:,.0f}', va='center', fontsize=10, fontweight='bold')
    
    # === BOTTOM RIGHT: Compute Formula ===
    ax4 = fig.add_subplot(gs[1, 1])
    ax4.set_xlim(0, 12)
    ax4.set_ylim(0, 10)
    ax4.axis('off')
    ax4.set_title('The Compute Formula', fontsize=13, fontweight='bold')
    
    # Main formula
    rect = FancyBboxPatch((1, 6), 10, 3, boxstyle="round,pad=0.05",
                          facecolor='#e8f6f3', edgecolor='#1abc9c', linewidth=2)
    ax4.add_patch(rect)
    ax4.text(6, 8.2, 'Training FLOPs Formula', fontsize=12, ha='center', fontweight='bold')
    ax4.text(6, 7, 'C ‚âà 6 √ó N √ó D', fontsize=16, ha='center', family='monospace', fontweight='bold')
    ax4.text(6, 6.3, 'C=compute, N=params, D=tokens', fontsize=9, ha='center', color='gray')
    
    # Breakdown
    breakdown = [
        ('6 =', 'Forward (2) + Backward (4)', '#e74c3c'),
        ('N =', 'Model parameters', '#3498db'),
        ('D =', 'Training tokens', '#27ae60'),
    ]
    
    for i, (symbol, desc, color) in enumerate(breakdown):
        y = 4.5 - i * 1.3
        
        rect = FancyBboxPatch((1.5, y-0.4), 1.5, 0.8, boxstyle="round,pad=0.02",
                              facecolor=color, edgecolor='black', linewidth=1.5)
        ax4.add_patch(rect)
        ax4.text(2.25, y, symbol, ha='center', va='center', 
                fontsize=10, fontweight='bold', color='white')
        ax4.text(3.5, y, desc, va='center', fontsize=10)
    
    # GPT-2 XL example
    ax4.text(6, 0.5, 'GPT-2 XL: 6 √ó 1.5B √ó 10B ‚âà 9√ó10¬π‚Åπ FLOPs', 
             ha='center', fontsize=10, family='monospace')
    
    plt.tight_layout()
    plt.show()

visualize_compute_analysis()

---

## 7. Reproducibility Guide

### 7.1 Can You Train GPT-2 Today?

**Yes!** With modern hardware, GPT-2 is very trainable:

| Model | Minimum GPU | Training Time |
|-------|-------------|---------------|
| Small (117M) | 1√ó RTX 3090 | ~12 hours |
| Medium (345M) | 1√ó A100 40GB | ~24 hours |
| Large (774M) | 2√ó A100 40GB | ~3 days |
| XL (1.5B) | 4√ó A100 80GB | ~5 days |

### 7.2 Open-Source Implementations

1. **nanoGPT** (Karpathy): Clean, minimal implementation
2. **Hugging Face**: Full-featured transformers library
3. **Megatron-LM**: NVIDIA's distributed training

### 7.3 Key Training Tips

1. **Use mixed precision (FP16/BF16)**: 2√ó memory savings
2. **Gradient checkpointing**: Trade compute for memory
3. **Flash Attention**: 2-4√ó speedup on attention
4. **Gradient accumulation**: Simulate larger batches
5. **Learning rate warmup**: Critical for stability

In [None]:
def visualize_reproducibility_guide():
    """
    Guide for reproducing GPT-2 training.
    """
    fig = plt.figure(figsize=(18, 14))
    
    gs = gridspec.GridSpec(2, 2, hspace=0.3, wspace=0.25)
    
    # === TOP LEFT: Hardware Requirements ===
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.set_xlim(0, 12)
    ax1.set_ylim(0, 12)
    ax1.axis('off')
    ax1.set_title('Hardware Requirements (2024)', fontsize=13, fontweight='bold')
    
    requirements = [
        ('GPT-2 Small', '117M', '1√ó RTX 3090', '12GB', '~12 hrs', '#3498db'),
        ('GPT-2 Medium', '345M', '1√ó A100 40GB', '24GB', '~24 hrs', '#27ae60'),
        ('GPT-2 Large', '774M', '2√ó A100 40GB', '48GB', '~3 days', '#f39c12'),
        ('GPT-2 XL', '1.5B', '4√ó A100 80GB', '160GB', '~5 days', '#e74c3c'),
    ]
    
    headers = ['Model', 'Params', 'GPU', 'VRAM', 'Time']
    
    # Header
    for j, header in enumerate(headers):
        x = 0.5 + j * 2.2
        rect = FancyBboxPatch((x, 10), 2, 1, boxstyle="round,pad=0.02",
                              facecolor='#2c3e50', edgecolor='black', linewidth=1)
        ax1.add_patch(rect)
        ax1.text(x + 1, 10.5, header, ha='center', va='center', 
                fontsize=9, fontweight='bold', color='white')
    
    # Rows
    for i, (model, params, gpu, vram, time, color) in enumerate(requirements):
        y = 8.5 - i * 2
        values = [model, params, gpu, vram, time]
        
        for j, val in enumerate(values):
            x = 0.5 + j * 2.2
            facecolor = color if j == 0 else '#f5f5f5'
            textcolor = 'white' if j == 0 else 'black'
            
            rect = FancyBboxPatch((x, y), 2, 1.5, boxstyle="round,pad=0.02",
                                  facecolor=facecolor, edgecolor='gray', linewidth=1)
            ax1.add_patch(rect)
            ax1.text(x + 1, y + 0.75, val, ha='center', va='center', 
                    fontsize=8, color=textcolor, fontweight='bold' if j == 0 else 'normal')
    
    # === TOP RIGHT: Training Tips ===
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.set_xlim(0, 12)
    ax2.set_ylim(0, 12)
    ax2.axis('off')
    ax2.set_title('Essential Training Tips', fontsize=13, fontweight='bold')
    
    tips = [
        ('Mixed Precision', 'FP16/BF16', '2√ó memory savings', '#e74c3c'),
        ('Flash Attention', 'Fused kernels', '2-4√ó speedup', '#3498db'),
        ('Grad Checkpoint', 'Recompute activs', '3√ó less memory', '#27ae60'),
        ('Grad Accumulation', 'Virtual batches', 'Simulate large batch', '#f39c12'),
        ('LR Warmup', 'Linear ramp', 'Training stability', '#9b59b6'),
    ]
    
    for i, (name, technique, benefit, color) in enumerate(tips):
        y = 10.5 - i * 2
        
        rect = FancyBboxPatch((0.5, y-0.6), 3, 1.4, boxstyle="round,pad=0.02",
                              facecolor=color, edgecolor='black', linewidth=2)
        ax2.add_patch(rect)
        ax2.text(2, y+0.2, name, ha='center', va='center', 
                fontsize=9, fontweight='bold', color='white')
        ax2.text(2, y-0.3, technique, ha='center', va='center', 
                fontsize=8, color='white')
        
        ax2.text(4, y, f'‚Üí {benefit}', va='center', fontsize=9)
    
    # === BOTTOM: Training Code Example ===
    ax3 = fig.add_subplot(gs[1, :])
    ax3.set_xlim(0, 20)
    ax3.set_ylim(0, 10)
    ax3.axis('off')
    ax3.set_title('Training Configuration Example (nanoGPT Style)', fontsize=13, fontweight='bold')
    
    # Code box
    rect = FancyBboxPatch((1, 1), 18, 8, boxstyle="round,pad=0.05",
                          facecolor='#1e1e1e', edgecolor='#3498db', linewidth=2)
    ax3.add_patch(rect)
    
    code = """# GPT-2 Small Training Configuration
config = {
    'n_layer': 12,
    'n_head': 12,
    'n_embd': 768,
    'vocab_size': 50257,
    'block_size': 1024,      # Context length
    'batch_size': 12,        # Per GPU
    'gradient_accumulation_steps': 40,  # Effective batch = 480
    'learning_rate': 6e-4,
    'warmup_iters': 2000,
    'lr_decay_iters': 600000,
    'weight_decay': 0.1,
    'dtype': 'bfloat16',     # Mixed precision
}"""
    
    lines = code.split('\n')
    for i, line in enumerate(lines):
        color = '#f39c12' if line.strip().startswith('#') else '#d4d4d4'
        if '=' in line and not line.strip().startswith('#'):
            parts = line.split('=')
            ax3.text(1.5, 8.3 - i * 0.5, parts[0], fontsize=9, 
                    family='monospace', color='#9cdcfe')
            ax3.text(1.5 + len(parts[0]) * 0.15, 8.3 - i * 0.5, '=' + parts[1], 
                    fontsize=9, family='monospace', color='#ce9178')
        else:
            ax3.text(1.5, 8.3 - i * 0.5, line, fontsize=9, 
                    family='monospace', color=color)
    
    plt.tight_layout()
    plt.show()

visualize_reproducibility_guide()

---

## Summary

### Key Training Insights

| Aspect | Detail |
|--------|--------|
| **Dataset** | WebText: 40GB from Reddit-filtered web pages |
| **Tokenizer** | Byte-level BPE with 50,257 tokens |
| **Training** | ~10B tokens, batch size 512, cosine LR |
| **Compute** | ~9√ó10¬π‚Åπ FLOPs for GPT-2 XL |
| **Scaling** | Power law: L ‚àù N^{-0.076} |

### The Legacy

GPT-2's training established:
1. **WebText methodology** ‚Üí CommonCrawl filtering
2. **Byte-level BPE** ‚Üí Universal tokenization
3. **Scaling laws** ‚Üí Predictable improvement
4. **Compute scaling** ‚Üí Billion-dollar training runs

---

### What's Next: Part V - Complete Implementation

We'll cover:
- Full GPT-2 implementation from scratch
- Loading pretrained weights
- Text generation
- Fine-tuning examples

---

## References

1. Radford et al. (2019). [Language Models are Unsupervised Multitask Learners](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)
2. Kaplan et al. (2020). [Scaling Laws for Neural Language Models](https://arxiv.org/abs/2001.08361)
3. Hoffmann et al. (2022). [Training Compute-Optimal Large Language Models](https://arxiv.org/abs/2203.15556) (Chinchilla)
4. Karpathy. [nanoGPT](https://github.com/karpathy/nanoGPT)