# GPT Research Paper | Part III

## Pre-training: Language Modeling, Training, and Generation

---

**Paper:** [Improving Language Understanding by Generative Pre-Training](https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf)

**Authors:** Alec Radford, Karthik Narasimhan, Tim Salimans, Ilya Sutskever (OpenAI, 2018)

---

In Part II, we built the complete GPT architecture. Now we dive into **how it's trained**:

1. The **language modeling objective** - what GPT learns
2. The **training procedure** - optimization details from the paper
3. **Text generation** - how to sample from the trained model

This is where the magic happens - turning 117M random parameters into a model that understands language.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch, Rectangle, Circle, FancyArrowPatch
import numpy as np
import math
from dataclasses import dataclass
from typing import Optional, Tuple, List
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR

torch.manual_seed(42)
np.random.seed(42)

print("PyTorch version:", torch.__version__)

---

## 1. The Language Modeling Objective

### 1.1 What the Paper Says

From Section 3.1 (Unsupervised pre-training):

> *"Given an unsupervised corpus of tokens $\mathcal{U} = \{u_1, ..., u_n\}$, we use a standard language modeling objective to maximize the following likelihood:"*

$$L_1(\mathcal{U}) = \sum_i \log P(u_i | u_{i-k}, ..., u_{i-1}; \Theta)$$

This is **the core of GPT** - predict each token given the previous tokens.

### 1.2 Breaking Down the Equation

| Symbol | Meaning |
|--------|--------|
| $\mathcal{U} = \{u_1, ..., u_n\}$ | The training corpus (sequence of tokens) |
| $u_i$ | The token at position $i$ |
| $k$ | Context window size (512 in GPT) |
| $u_{i-k}, ..., u_{i-1}$ | The $k$ tokens before position $i$ |
| $\Theta$ | Model parameters (~117M) |
| $P(u_i | ...)$ | Probability of token $u_i$ given context |

### 1.3 The Intuition

Consider the sentence: "The cat sat on the mat"

GPT learns to predict:
- P("cat" | "The") should be high
- P("sat" | "The cat") should be high
- P("on" | "The cat sat") should be high
- P("banana" | "The cat sat") should be low

By learning these patterns across billions of tokens, GPT develops an understanding of language.

In [None]:
def visualize_language_modeling():
    """Visualize the language modeling objective."""
    fig, ax = plt.subplots(figsize=(16, 8))
    ax.set_xlim(0, 16)
    ax.set_ylim(0, 8)
    ax.axis('off')
    
    # Title
    ax.text(8, 7.5, 'Language Modeling Objective', fontsize=16, fontweight='bold', ha='center')
    ax.text(8, 7, r'$L_1(\mathcal{U}) = \sum_i \log P(u_i | u_{i-k}, ..., u_{i-1}; \Theta)$', 
            fontsize=14, ha='center', style='italic')
    
    # Input sequence
    tokens = ['<s>', 'The', 'cat', 'sat', 'on', 'the', 'mat']
    colors = ['#3498db', '#3498db', '#3498db', '#3498db', '#3498db', '#3498db', '#3498db']
    
    ax.text(1, 5.8, 'Input tokens (context):', fontsize=11, fontweight='bold')
    for i, (tok, col) in enumerate(zip(tokens[:-1], colors[:-1])):
        x = 1.5 + i * 2
        rect = FancyBboxPatch((x-0.5, 5), 1.2, 0.6, boxstyle="round,pad=0.02",
                              facecolor=col, edgecolor='black', linewidth=1.5, alpha=0.8)
        ax.add_patch(rect)
        ax.text(x + 0.1, 5.3, tok, ha='center', va='center', fontsize=11, color='white', fontweight='bold')
    
    # Arrows down
    for i in range(6):
        x = 1.6 + i * 2
        ax.annotate('', xy=(x, 4.3), xytext=(x, 4.9),
                    arrowprops=dict(arrowstyle='->', color='black', lw=1.5))
    
    # GPT box
    rect_gpt = FancyBboxPatch((1, 3.5), 12, 0.7, boxstyle="round,pad=0.03",
                              facecolor='#2c3e50', edgecolor='black', linewidth=2)
    ax.add_patch(rect_gpt)
    ax.text(7, 3.85, 'GPT Model (12 layers, 768 dim, 12 heads)', 
            ha='center', va='center', fontsize=12, color='white', fontweight='bold')
    
    # Arrows down
    for i in range(6):
        x = 1.6 + i * 2
        ax.annotate('', xy=(x, 2.8), xytext=(x, 3.4),
                    arrowprops=dict(arrowstyle='->', color='black', lw=1.5))
    
    # Predictions
    ax.text(1, 2.3, 'Predict next token:', fontsize=11, fontweight='bold')
    predictions = ['The', 'cat', 'sat', 'on', 'the', 'mat']
    for i, pred in enumerate(predictions):
        x = 1.5 + i * 2
        rect = FancyBboxPatch((x-0.5, 1.5), 1.2, 0.6, boxstyle="round,pad=0.02",
                              facecolor='#e74c3c', edgecolor='black', linewidth=1.5, alpha=0.8)
        ax.add_patch(rect)
        ax.text(x + 0.1, 1.8, pred, ha='center', va='center', fontsize=11, color='white', fontweight='bold')
    
    # Loss computation
    ax.text(8, 0.8, 'Loss = -log P("The"|"<s>") - log P("cat"|"<s> The") - log P("sat"|"<s> The cat") - ...', 
            fontsize=10, ha='center', style='italic')
    ax.text(8, 0.3, 'Minimize this loss = Maximize probability of correct next tokens', 
            fontsize=11, ha='center', fontweight='bold', color='#27ae60')
    
    plt.tight_layout()
    plt.show()

visualize_language_modeling()

### 1.4 The Loss Function: Cross-Entropy

The paper's objective translates to **cross-entropy loss**:

$$\mathcal{L} = -\frac{1}{N}\sum_{i=1}^{N} \log P(u_i | u_1, ..., u_{i-1})$$

In PyTorch, this is `F.cross_entropy(logits, targets)`.

### 1.5 Teacher Forcing

During training, we use **teacher forcing**:
- Feed the **correct** previous tokens (not model predictions)
- This allows parallel computation of all positions
- The causal mask ensures position $i$ only sees positions $< i$

In [None]:
def demonstrate_teacher_forcing():
    """Show how teacher forcing enables parallel training."""
    
    print("Teacher Forcing in GPT Training")
    print("=" * 60)
    
    print("\nInput sequence:  ['<s>', 'The', 'cat', 'sat', 'on']")
    print("Target sequence: ['The', 'cat', 'sat', 'on', 'the']")
    print()
    print("At each position, GPT predicts the NEXT token:")
    print("-" * 60)
    print(f"{'Position':<10} {'Input (context)':<25} {'Target (predict)'}")
    print("-" * 60)
    
    contexts = [
        ("0", "'<s>'", "'The'"),
        ("1", "'<s>', 'The'", "'cat'"),
        ("2", "'<s>', 'The', 'cat'", "'sat'"),
        ("3", "'<s>', 'The', 'cat', 'sat'", "'on'"),
        ("4", "'<s>', 'The', 'cat', 'sat', 'on'", "'the'"),
    ]
    
    for pos, ctx, tgt in contexts:
        print(f"{pos:<10} {ctx:<25} {tgt}")
    
    print("\n" + "=" * 60)
    print("Key insight: ALL positions computed in PARALLEL (not sequential)")
    print("The causal mask ensures each position only sees past tokens.")

demonstrate_teacher_forcing()

---

## 2. Training Procedure

### 2.1 Dataset: BooksCorpus

From Section 4.1:

> *"For pre-training the model, we use the BooksCorpus dataset. It contains over 7,000 unique unpublished books from a variety of genres including Adventure, Fantasy, and Romance."*

Why BooksCorpus?
- **Long contiguous text**: Books have coherent narratives spanning many pages
- **Diverse topics**: Fiction covers many domains and writing styles
- **~1 billion words**: Enough data to train 117M parameters

### 2.2 Tokenization: BPE

From Section 4.1:

> *"We used a bytepair encoding (BPE) vocabulary with 40,000 merges."*

BPE (Byte Pair Encoding) is a subword tokenization method:
- Starts with character-level vocabulary
- Iteratively merges most frequent pairs
- Balances vocabulary size with handling rare words

### 2.3 Training Hyperparameters

From Section 4.1:

> *"We trained for 100 epochs on minibatches of 64 randomly sampled, contiguous sequences of 512 tokens. Since layernorm is used extensively throughout the model, a simple weight initialization of N(0, 0.02) was sufficient."*

> *"We used the Adam optimization scheme with a max learning rate of 2.5e-4. The learning rate was increased linearly from zero over the first 2000 updates and annealed to 0 using a cosine schedule."*

Let's extract all the details:

In [None]:
@dataclass
class GPTConfig:
    """Model configuration from paper."""
    vocab_size: int = 40478
    n_positions: int = 512
    n_embd: int = 768
    n_layer: int = 12
    n_head: int = 12
    n_inner: int = 3072
    embd_pdrop: float = 0.1
    attn_pdrop: float = 0.1
    resid_pdrop: float = 0.1


@dataclass 
class TrainingConfig:
    """
    Training configuration - all values from paper Section 4.1.
    """
    # === Batch and Sequence ===
    batch_size: int = 64              # "minibatches of 64"
    seq_length: int = 512             # "contiguous sequences of 512 tokens"
    
    # === Training Duration ===
    epochs: int = 100                 # "trained for 100 epochs"
    
    # === Optimizer ===
    optimizer: str = "Adam"           # "Adam optimization scheme"
    max_lr: float = 2.5e-4            # "max learning rate of 2.5e-4"
    
    # === Learning Rate Schedule ===
    warmup_steps: int = 2000          # "increased linearly from zero over the first 2000 updates"
    schedule: str = "cosine"          # "annealed to 0 using a cosine schedule"
    
    # === Regularization ===
    weight_decay: float = 0.01        # Standard for Adam/AdamW
    
    # === Initialization ===
    init_std: float = 0.02            # "weight initialization of N(0, 0.02)"


train_config = TrainingConfig()

print("GPT Training Configuration (from paper Section 4.1)")
print("=" * 60)
print(f"\n[Data]")
print(f"  Dataset:        BooksCorpus (~1B words)")
print(f"  Tokenization:   BPE with 40,000 merges")
print(f"  Batch size:     {train_config.batch_size}")
print(f"  Sequence length:{train_config.seq_length}")
print(f"\n[Training]")
print(f"  Epochs:         {train_config.epochs}")
print(f"  Optimizer:      {train_config.optimizer}")
print(f"  Max LR:         {train_config.max_lr}")
print(f"\n[LR Schedule]")
print(f"  Warmup steps:   {train_config.warmup_steps}")
print(f"  Schedule:       Linear warmup + {train_config.schedule} decay")
print(f"\n[Initialization]")
print(f"  Weight init:    N(0, {train_config.init_std})")

### 2.4 Learning Rate Schedule

The paper specifies a **warmup + cosine annealing** schedule:

1. **Warmup phase** (steps 0 to 2000): Linear increase from 0 to max_lr
2. **Cosine decay** (steps 2000 to end): Cosine annealing from max_lr to 0

$$\text{lr}(t) = \begin{cases} 
\text{max\_lr} \cdot \frac{t}{\text{warmup}} & \text{if } t < \text{warmup} \\
\text{max\_lr} \cdot \frac{1}{2}\left(1 + \cos\left(\pi \cdot \frac{t - \text{warmup}}{T - \text{warmup}}\right)\right) & \text{otherwise}
\end{cases}$$

In [None]:
def get_lr_scheduler(optimizer, warmup_steps: int, total_steps: int):
    """
    Create the learning rate scheduler from the paper:
    - Linear warmup for first 2000 steps
    - Cosine annealing to 0 after warmup
    """
    def lr_lambda(step):
        if step < warmup_steps:
            # Linear warmup
            return step / warmup_steps
        else:
            # Cosine annealing
            progress = (step - warmup_steps) / (total_steps - warmup_steps)
            return 0.5 * (1.0 + math.cos(math.pi * progress))
    
    return LambdaLR(optimizer, lr_lambda)


def visualize_lr_schedule():
    """Visualize the learning rate schedule from the paper."""
    
    # Simulate training
    total_steps = 100000  # Approximate steps for 100 epochs
    warmup_steps = 2000
    max_lr = 2.5e-4
    
    # Create dummy optimizer
    dummy_param = torch.nn.Parameter(torch.zeros(1))
    optimizer = AdamW([dummy_param], lr=max_lr)
    scheduler = get_lr_scheduler(optimizer, warmup_steps, total_steps)
    
    # Collect LR values
    lrs = []
    steps = list(range(0, total_steps, 100))
    for step in steps:
        scheduler.last_epoch = step - 1
        lr = scheduler.get_last_lr()[0]
        lrs.append(lr)
    
    # Plot
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Full schedule
    ax1 = axes[0]
    ax1.plot(steps, lrs, 'b-', linewidth=2)
    ax1.axvline(x=warmup_steps, color='r', linestyle='--', linewidth=1.5, label='End of warmup')
    ax1.axhline(y=max_lr, color='g', linestyle=':', linewidth=1.5, label=f'Max LR = {max_lr}')
    ax1.fill_between([0, warmup_steps], 0, max_lr, alpha=0.2, color='red', label='Warmup phase')
    ax1.set_xlabel('Training Step', fontsize=12)
    ax1.set_ylabel('Learning Rate', fontsize=12)
    ax1.set_title('GPT Learning Rate Schedule\n"Linear warmup + cosine annealing"', 
                  fontsize=13, fontweight='bold')
    ax1.legend(fontsize=10)
    ax1.grid(True, alpha=0.3)
    ax1.set_xlim(0, total_steps)
    ax1.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    
    # Zoom on warmup
    ax2 = axes[1]
    warmup_steps_plot = [s for s in steps if s <= 5000]
    warmup_lrs = lrs[:len(warmup_steps_plot)]
    ax2.plot(warmup_steps_plot, warmup_lrs, 'b-', linewidth=2)
    ax2.axvline(x=warmup_steps, color='r', linestyle='--', linewidth=1.5, label='End of warmup (step 2000)')
    ax2.axhline(y=max_lr, color='g', linestyle=':', linewidth=1.5)
    ax2.fill_between([0, warmup_steps], 0, max_lr, alpha=0.2, color='red')
    ax2.set_xlabel('Training Step', fontsize=12)
    ax2.set_ylabel('Learning Rate', fontsize=12)
    ax2.set_title('Zoom: Warmup Phase\n"Increased linearly from zero over first 2000 updates"', 
                  fontsize=13, fontweight='bold')
    ax2.legend(fontsize=10)
    ax2.grid(True, alpha=0.3)
    ax2.set_xlim(0, 5000)
    ax2.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    
    plt.tight_layout()
    plt.show()
    
    print("Why this schedule?")
    print("  - Warmup: Prevents early training instability (large gradients at start)")
    print("  - Cosine: Smooth decay allows fine-tuning towards end of training")
    print("  - Now standard in most transformer training")

visualize_lr_schedule()

### 2.5 The Adam Optimizer

From the paper:

> *"We used the Adam optimization scheme"*

Adam (Adaptive Moment Estimation) maintains:
- First moment estimate (momentum)
- Second moment estimate (adaptive learning rates)

The update rule:

$$m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t$$
$$v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2$$
$$\theta_t = \theta_{t-1} - \alpha \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$$

Standard hyperparameters:
- $\beta_1 = 0.9$
- $\beta_2 = 0.999$
- $\epsilon = 10^{-8}$

---

## 3. Implementation: Complete Training Loop

Let's implement the full training procedure. First, we need our model from Part II:

In [None]:
# === Model components from Part II ===

config = GPTConfig()

def gelu_approx(x):
    return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

class LayerNorm(nn.Module):
    def __init__(self, n_embd, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(n_embd))
        self.beta = nn.Parameter(torch.zeros(n_embd))
    
    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        return self.gamma * (x - mean) / torch.sqrt(var + self.eps) + self.beta

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = config.n_embd // config.n_head
        self.scale = 1.0 / math.sqrt(self.head_dim)
        
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
        
        mask = torch.tril(torch.ones(config.n_positions, config.n_positions))
        self.register_buffer('mask', mask.view(1, 1, config.n_positions, config.n_positions))
    
    def forward(self, x):
        B, T, C = x.shape
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_dropout(attn)
        
        out = (attn @ v).transpose(1, 2).contiguous().view(B, T, C)
        return self.resid_dropout(self.c_proj(out))

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, config.n_inner)
        self.c_proj = nn.Linear(config.n_inner, config.n_embd)
        self.dropout = nn.Dropout(config.resid_pdrop)
    
    def forward(self, x):
        return self.dropout(self.c_proj(gelu_approx(self.c_fc(x))))

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embd)
        self.mlp = MLP(config)
    
    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
        self.drop = nn.Dropout(config.embd_pdrop)
        self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
        self.ln_f = LayerNorm(config.n_embd)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.lm_head.weight = self.wte.weight  # Weight tying
        
        self.apply(self._init_weights)
        n_params = sum(p.numel() for p in self.parameters())
        print(f"GPT initialized: {n_params:,} parameters")
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def forward(self, input_ids, targets=None):
        B, T = input_ids.shape
        
        tok_emb = self.wte(input_ids)
        pos_emb = self.wpe(torch.arange(T, device=input_ids.device))
        x = self.drop(tok_emb + pos_emb)
        
        for block in self.blocks:
            x = block(x)
        
        x = self.ln_f(x)
        logits = self.lm_head(x)
        
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        
        return logits, loss


# Create model
model = GPT(config)

In [None]:
class Trainer:
    """
    GPT Trainer implementing the paper's training procedure.
    
    From Section 4.1:
    - Adam optimizer with max LR 2.5e-4
    - Linear warmup for 2000 steps
    - Cosine annealing to 0
    - Batch size 64, sequence length 512
    """
    
    def __init__(self, model, train_config):
        self.model = model
        self.config = train_config
        
        # Optimizer (from paper: "Adam optimization scheme")
        self.optimizer = AdamW(
            model.parameters(),
            lr=train_config.max_lr,
            weight_decay=train_config.weight_decay,
            betas=(0.9, 0.999)
        )
        
        self.step = 0
        self.losses = []
    
    def get_lr(self, step, total_steps):
        """Get learning rate for current step."""
        warmup = self.config.warmup_steps
        max_lr = self.config.max_lr
        
        if step < warmup:
            # Linear warmup
            return max_lr * step / warmup
        else:
            # Cosine annealing
            progress = (step - warmup) / (total_steps - warmup)
            return max_lr * 0.5 * (1.0 + math.cos(math.pi * progress))
    
    def train_step(self, input_ids, targets, total_steps):
        """Single training step."""
        # Update learning rate
        lr = self.get_lr(self.step, total_steps)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        
        # Forward pass
        self.model.train()
        logits, loss = self.model(input_ids, targets)
        
        # Backward pass
        self.optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping (common practice)
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        
        # Update weights
        self.optimizer.step()
        
        self.step += 1
        self.losses.append(loss.item())
        
        return loss.item(), lr


# Create trainer
trainer = Trainer(model, train_config)
print("\nTrainer initialized with paper's hyperparameters")

In [None]:
def demo_training(model, trainer, num_steps=100):
    """
    Demonstrate training on synthetic data.
    In practice, you'd use BooksCorpus.
    """
    print("Training Demo (synthetic data)")
    print("=" * 60)
    print(f"Running {num_steps} steps to demonstrate training loop...")
    print()
    
    total_steps = num_steps
    batch_size = 8  # Smaller for demo
    seq_len = 64    # Smaller for demo
    
    losses = []
    lrs = []
    
    for step in range(num_steps):
        # Generate synthetic batch
        input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
        # For language modeling, targets are inputs shifted by 1
        targets = torch.randint(0, config.vocab_size, (batch_size, seq_len))
        
        loss, lr = trainer.train_step(input_ids, targets, total_steps)
        losses.append(loss)
        lrs.append(lr)
        
        if (step + 1) % 20 == 0:
            print(f"Step {step+1:4d} | Loss: {loss:.4f} | LR: {lr:.2e}")
    
    # Plot
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    ax1 = axes[0]
    ax1.plot(losses, 'b-', linewidth=1.5, alpha=0.7)
    ax1.set_xlabel('Step', fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.set_title('Training Loss', fontsize=13, fontweight='bold')
    ax1.grid(True, alpha=0.3)
    
    ax2 = axes[1]
    ax2.plot(lrs, 'r-', linewidth=2)
    ax2.set_xlabel('Step', fontsize=12)
    ax2.set_ylabel('Learning Rate', fontsize=12)
    ax2.set_title('Learning Rate Schedule', fontsize=13, fontweight='bold')
    ax2.grid(True, alpha=0.3)
    ax2.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nFinal loss: {losses[-1]:.4f}")
    print(f"Expected random loss: {math.log(config.vocab_size):.4f}")
    print("\nNote: With real data and more steps, loss would decrease significantly.")

demo_training(model, trainer, num_steps=100)

---

## 4. Text Generation

### 4.1 Autoregressive Generation

Once trained, GPT generates text **autoregressively**:
1. Start with a prompt
2. Predict the next token
3. Append the predicted token to the sequence
4. Repeat until done

### 4.2 Sampling Strategies

The model outputs a probability distribution over the vocabulary. How do we select the next token?

| Strategy | Description | Pros | Cons |
|----------|-------------|------|------|
| **Greedy** | Always pick highest prob | Fast, deterministic | Repetitive, boring |
| **Temperature** | Scale logits before softmax | Controls randomness | Can be incoherent |
| **Top-k** | Sample from k highest probs | Balances diversity | Fixed k may not work for all distributions |
| **Top-p (nucleus)** | Sample from smallest set summing to p | Adaptive | Slightly more compute |

### 4.3 Temperature Scaling

$$P(x_i) = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)}$$

Where $T$ is temperature:
- $T = 1$: Original distribution
- $T < 1$: More peaked (more confident)
- $T > 1$: More uniform (more random)

In [None]:
def visualize_sampling_strategies():
    """Visualize different sampling strategies."""
    
    # Simulated logits for 10 tokens
    np.random.seed(42)
    logits = np.array([2.5, 1.8, 1.2, 0.8, 0.5, 0.3, 0.1, -0.2, -0.5, -1.0])
    token_names = ['the', 'a', 'cat', 'dog', 'sat', 'ran', 'big', 'on', 'in', 'of']
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Original probabilities
    probs = np.exp(logits) / np.exp(logits).sum()
    
    ax1 = axes[0, 0]
    bars = ax1.bar(token_names, probs, color='#3498db', edgecolor='black')
    bars[0].set_color('#e74c3c')  # Highest
    ax1.set_ylabel('Probability', fontsize=11)
    ax1.set_title('Original Distribution (T=1)\nGreedy selects "the" (highest)', fontsize=12, fontweight='bold')
    ax1.set_ylim(0, 0.5)
    for i, p in enumerate(probs):
        ax1.text(i, p + 0.01, f'{p:.2f}', ha='center', fontsize=9)
    
    # Low temperature
    T_low = 0.5
    probs_low = np.exp(logits / T_low) / np.exp(logits / T_low).sum()
    
    ax2 = axes[0, 1]
    bars = ax2.bar(token_names, probs_low, color='#3498db', edgecolor='black')
    bars[0].set_color('#e74c3c')
    ax2.set_ylabel('Probability', fontsize=11)
    ax2.set_title(f'Low Temperature (T={T_low})\nMore confident, less diverse', fontsize=12, fontweight='bold')
    ax2.set_ylim(0, 0.8)
    for i, p in enumerate(probs_low):
        if p > 0.01:
            ax2.text(i, p + 0.01, f'{p:.2f}', ha='center', fontsize=9)
    
    # High temperature
    T_high = 2.0
    probs_high = np.exp(logits / T_high) / np.exp(logits / T_high).sum()
    
    ax3 = axes[1, 0]
    bars = ax3.bar(token_names, probs_high, color='#3498db', edgecolor='black')
    ax3.set_ylabel('Probability', fontsize=11)
    ax3.set_title(f'High Temperature (T={T_high})\nMore uniform, more random', fontsize=12, fontweight='bold')
    ax3.set_ylim(0, 0.3)
    for i, p in enumerate(probs_high):
        ax3.text(i, p + 0.005, f'{p:.2f}', ha='center', fontsize=9)
    
    # Top-k sampling
    k = 3
    top_k_probs = probs.copy()
    top_k_probs[k:] = 0
    top_k_probs = top_k_probs / top_k_probs.sum()  # Renormalize
    
    ax4 = axes[1, 1]
    colors = ['#2ecc71' if i < k else '#ecf0f1' for i in range(len(probs))]
    bars = ax4.bar(token_names, top_k_probs, color=colors, edgecolor='black')
    ax4.set_ylabel('Probability', fontsize=11)
    ax4.set_title(f'Top-k Sampling (k={k})\nOnly sample from top {k} tokens', fontsize=12, fontweight='bold')
    ax4.set_ylim(0, 0.6)
    for i, p in enumerate(top_k_probs):
        if p > 0.01:
            ax4.text(i, p + 0.01, f'{p:.2f}', ha='center', fontsize=9)
    
    plt.tight_layout()
    plt.show()

visualize_sampling_strategies()

In [None]:
@torch.no_grad()
def generate(
    model,
    input_ids: torch.Tensor,
    max_new_tokens: int = 50,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    top_p: Optional[float] = None,
) -> torch.Tensor:
    """
    Autoregressive text generation.
    
    Args:
        model: The GPT model
        input_ids: Starting tokens, shape (batch, seq_len)
        max_new_tokens: How many tokens to generate
        temperature: Sampling temperature (1.0 = normal)
        top_k: If set, only sample from top k tokens
        top_p: If set, use nucleus sampling
    
    Returns:
        Generated token IDs including the input
    """
    model.eval()
    generated = input_ids.clone()
    
    for _ in range(max_new_tokens):
        # Crop to max sequence length if needed
        idx_cond = generated if generated.size(1) <= model.config.n_positions else generated[:, -model.config.n_positions:]
        
        # Forward pass
        logits, _ = model(idx_cond)
        
        # Get logits for the last position
        logits = logits[:, -1, :] / temperature
        
        # Apply top-k filtering
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = float('-inf')
        
        # Apply top-p (nucleus) filtering
        if top_p is not None:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
            
            # Remove tokens with cumulative probability above threshold
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            
            indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
            logits[indices_to_remove] = float('-inf')
        
        # Sample
        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        
        # Append
        generated = torch.cat([generated, next_token], dim=1)
    
    return generated


print("Generation function defined.")
print("\nSupported sampling methods:")
print("  - temperature: Controls randomness")
print("  - top_k: Sample from top k tokens")
print("  - top_p: Nucleus sampling")

In [None]:
def demo_generation(model):
    """
    Demonstrate generation with different settings.
    Note: Model is untrained, so outputs will be random.
    """
    print("Generation Demo (untrained model)")
    print("=" * 60)
    print("Note: Model is untrained, outputs will be random tokens.")
    print("With a trained model, this would produce coherent text.")
    print()
    
    # Start with a "prompt" (just random tokens for demo)
    prompt = torch.randint(0, 100, (1, 5))  # Use low token IDs
    
    settings = [
        ("Greedy (temperature=0.001)", {"temperature": 0.001}),
        ("Normal (temperature=1.0)", {"temperature": 1.0}),
        ("Creative (temperature=1.5)", {"temperature": 1.5}),
        ("Top-k (k=10)", {"top_k": 10}),
        ("Top-p (p=0.9)", {"top_p": 0.9}),
    ]
    
    for name, kwargs in settings:
        output = generate(model, prompt, max_new_tokens=10, **kwargs)
        generated_tokens = output[0, 5:].tolist()  # Only new tokens
        print(f"{name}:")
        print(f"  Token IDs: {generated_tokens}")
        print()

demo_generation(model)

### 4.4 Visualizing the Generation Process

In [None]:
def visualize_generation_process():
    """Visualize step-by-step autoregressive generation."""
    
    fig, ax = plt.subplots(figsize=(16, 10))
    ax.set_xlim(0, 16)
    ax.set_ylim(0, 10)
    ax.axis('off')
    
    ax.text(8, 9.5, 'Autoregressive Generation: Step by Step', 
            fontsize=16, fontweight='bold', ha='center')
    
    # Define the steps
    steps = [
        ("Step 1: Start with prompt", ['The', 'cat'], None, None),
        ("Step 2: Predict next token", ['The', 'cat'], 'sat', ['sat: 0.35', 'ran: 0.20', 'is: 0.15', '...']),
        ("Step 3: Append and continue", ['The', 'cat', 'sat'], 'on', ['on: 0.40', 'down: 0.25', '...']),
        ("Step 4: Continue...", ['The', 'cat', 'sat', 'on'], 'the', ['the: 0.50', 'a: 0.20', '...']),
        ("Step 5: Final", ['The', 'cat', 'sat', 'on', 'the', 'mat'], None, None),
    ]
    
    for i, (title, tokens, next_tok, probs) in enumerate(steps):
        y = 8 - i * 1.6
        
        # Step title
        ax.text(0.5, y + 0.3, title, fontsize=11, fontweight='bold')
        
        # Tokens
        for j, tok in enumerate(tokens):
            x = 1 + j * 1.3
            color = '#3498db' if tok != next_tok else '#2ecc71'
            rect = FancyBboxPatch((x-0.4, y-0.3), 1.0, 0.5, boxstyle="round,pad=0.02",
                                  facecolor=color, edgecolor='black', linewidth=1.5)
            ax.add_patch(rect)
            ax.text(x + 0.1, y - 0.05, tok, ha='center', va='center', fontsize=10, 
                   color='white', fontweight='bold')
        
        # Next token prediction
        if next_tok:
            x = 1 + len(tokens) * 1.3
            rect = FancyBboxPatch((x-0.4, y-0.3), 1.0, 0.5, boxstyle="round,pad=0.02",
                                  facecolor='#e74c3c', edgecolor='black', linewidth=1.5)
            ax.add_patch(rect)
            ax.text(x + 0.1, y - 0.05, next_tok, ha='center', va='center', fontsize=10, 
                   color='white', fontweight='bold')
            
            # Arrow
            ax.annotate('', xy=(x - 0.5, y - 0.05), xytext=(x - 0.8, y - 0.05),
                        arrowprops=dict(arrowstyle='->', color='#e74c3c', lw=2))
        
        # Probability distribution
        if probs:
            prob_text = '  '.join(probs)
            ax.text(10, y - 0.05, f'P(next): {prob_text}', fontsize=9, 
                   style='italic', color='gray')
    
    # Legend
    ax.text(1, 0.5, 'Legend:', fontsize=11, fontweight='bold')
    for i, (color, label) in enumerate([('#3498db', 'Context'), ('#e74c3c', 'Predicted'), ('#2ecc71', 'Just added')]):
        x = 3 + i * 3
        rect = Rectangle((x, 0.3), 0.5, 0.4, facecolor=color, edgecolor='black')
        ax.add_patch(rect)
        ax.text(x + 0.7, 0.5, label, fontsize=10, va='center')
    
    plt.tight_layout()
    plt.show()

visualize_generation_process()

---

## 5. Perplexity: Measuring Language Model Quality

### 5.1 What is Perplexity?

Perplexity is the standard metric for language models. It measures how "surprised" the model is by the test data:

$$\text{PPL} = \exp\left(-\frac{1}{N}\sum_{i=1}^{N} \log P(u_i | u_1, ..., u_{i-1})\right) = \exp(\mathcal{L})$$

Where $\mathcal{L}$ is the cross-entropy loss.

### 5.2 Intuition

- **Lower perplexity = Better model**
- Perplexity of $k$ means the model is "as confused as if choosing uniformly among $k$ options"
- Random baseline with vocabulary $V$: $\text{PPL} = V$ (40,478 for GPT)

### 5.3 GPT's Results

From the paper (Table 1):

| Model | BooksCorpus Perplexity |
|-------|----------------------|
| GPT | **18.4** |

This means GPT is "as confused as choosing among ~18 equally likely words" - remarkably good given 40,000+ vocabulary!

In [None]:
def compute_perplexity(model, input_ids, targets):
    """
    Compute perplexity.
    
    PPL = exp(cross_entropy_loss)
    """
    model.eval()
    with torch.no_grad():
        _, loss = model(input_ids, targets)
    perplexity = torch.exp(loss)
    return perplexity.item(), loss.item()


# Demo
test_input = torch.randint(0, config.vocab_size, (4, 128))
test_target = torch.randint(0, config.vocab_size, (4, 128))

ppl, loss = compute_perplexity(model, test_input, test_target)

print("Perplexity Demo")
print("=" * 50)
print(f"Loss:       {loss:.4f}")
print(f"Perplexity: {ppl:.2f}")
print(f"\nFor reference:")
print(f"  Random baseline PPL:  {config.vocab_size} (vocab size)")
print(f"  GPT paper result:     18.4 (on BooksCorpus)")
print(f"\nOur untrained model:    {ppl:.2f} (close to vocab size as expected)")

---

## 6. Summary

### 6.1 Key Training Details (All from Paper Section 4.1)

| Aspect | Value | Paper Quote |
|--------|-------|-------------|
| **Dataset** | BooksCorpus | "over 7,000 unique unpublished books" |
| **Tokenization** | BPE | "40,000 merges" |
| **Batch size** | 64 | "minibatches of 64" |
| **Sequence length** | 512 | "contiguous sequences of 512 tokens" |
| **Epochs** | 100 | "trained for 100 epochs" |
| **Optimizer** | Adam | "Adam optimization scheme" |
| **Max LR** | 2.5e-4 | "max learning rate of 2.5e-4" |
| **Warmup** | 2000 steps | "increased linearly... over first 2000 updates" |
| **Schedule** | Cosine | "annealed to 0 using a cosine schedule" |
| **Init** | N(0, 0.02) | "weight initialization of N(0, 0.02)" |

### 6.2 The Language Modeling Objective

$$L_1(\mathcal{U}) = \sum_i \log P(u_i | u_{i-k}, ..., u_{i-1}; \Theta)$$

- Predict each token given previous context
- Cross-entropy loss in practice
- Teacher forcing enables parallel training

### 6.3 Generation

- Autoregressive: predict one token, append, repeat
- Sampling strategies: greedy, temperature, top-k, top-p
- Temperature controls diversity vs. coherence tradeoff

### 6.4 What's Next

**Part IV**: Fine-tuning
- Task-specific input transformations
- The auxiliary loss trick
- Results on downstream tasks

---

## References

1. Radford et al. (2018). [Improving Language Understanding by Generative Pre-Training](https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf)
2. Kingma & Ba (2014). [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980)
3. Sennrich et al. (2016). [Neural Machine Translation of Rare Words with Subword Units](https://arxiv.org/abs/1508.07909) (BPE)
4. Holtzman et al. (2019). [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751) (Top-p sampling)