# GPT-2 Research Paper | Part V

## Complete Implementation: From Scratch to Generation

---

**Paper:** [Language Models are Unsupervised Multitask Learners](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)

**Authors:** Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, Ilya Sutskever (OpenAI, 2019)

---

## The Final Piece: Working Code

This notebook brings everything together into a **complete, working implementation**.

We'll cover:

1. **Full GPT-2 Architecture** - Every component explained and implemented
2. **Weight Loading** - Load pretrained weights from Hugging Face
3. **Tokenization** - Using tiktoken for byte-level BPE
4. **Text Generation** - Multiple decoding strategies
5. **Zero-Shot Prompting** - Implementing tasks from the paper
6. **Fine-Tuning** - Adapt GPT-2 for custom tasks
7. **Performance Optimization** - Making it fast

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch, Rectangle, Circle
import matplotlib.patches as mpatches
import matplotlib.gridspec as gridspec
import numpy as np
import math
import json
import os
from dataclasses import dataclass
from typing import Optional, Tuple, List, Dict, Union
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

plt.style.use('default')
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['font.size'] = 10

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

---

## 1. Complete GPT-2 Configuration

### 1.1 The Config Class

First, let's define a configuration class that holds all hyperparameters.

In [None]:
@dataclass
class GPT2Config:
    """
    GPT-2 Model Configuration.
    
    Contains all hyperparameters needed to build a GPT-2 model.
    Default values correspond to GPT-2 Small (124M parameters).
    
    Attributes:
        vocab_size: Size of vocabulary (50257 for GPT-2)
        n_positions: Maximum sequence length (1024 for GPT-2)
        n_embd: Embedding dimension
        n_layer: Number of transformer blocks
        n_head: Number of attention heads
        n_inner: FFN inner dimension (default: 4 * n_embd)
        activation: Activation function ('gelu' or 'gelu_new')
        resid_pdrop: Dropout probability for residual connections
        embd_pdrop: Dropout probability for embeddings
        attn_pdrop: Dropout probability for attention
        layer_norm_eps: Epsilon for layer normalization
    """
    # Model architecture
    vocab_size: int = 50257
    n_positions: int = 1024
    n_embd: int = 768
    n_layer: int = 12
    n_head: int = 12
    n_inner: Optional[int] = None  # Defaults to 4 * n_embd
    activation: str = 'gelu_new'
    
    # Regularization
    resid_pdrop: float = 0.1
    embd_pdrop: float = 0.1
    attn_pdrop: float = 0.1
    layer_norm_eps: float = 1e-5
    
    # Initialization
    initializer_range: float = 0.02
    
    def __post_init__(self):
        if self.n_inner is None:
            self.n_inner = 4 * self.n_embd
    
    @property
    def head_dim(self) -> int:
        """Dimension of each attention head."""
        return self.n_embd // self.n_head
    
    @classmethod
    def from_pretrained(cls, model_name: str) -> 'GPT2Config':
        """
        Create config for a pretrained model.
        
        Args:
            model_name: One of 'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'
        """
        configs = {
            'gpt2': dict(n_layer=12, n_head=12, n_embd=768),
            'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024),
            'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280),
            'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600),
        }
        if model_name not in configs:
            raise ValueError(f"Unknown model: {model_name}. Choose from {list(configs.keys())}")
        return cls(**configs[model_name])


# Show all configurations
print("GPT-2 Model Configurations:")
print("=" * 70)
for name in ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']:
    config = GPT2Config.from_pretrained(name)
    params = config.vocab_size * config.n_embd  # Embeddings
    params += config.n_positions * config.n_embd  # Position embeddings
    # Per layer: attention (4 * n_embd^2) + MLP (8 * n_embd^2) + LN (4 * n_embd)
    params += config.n_layer * (12 * config.n_embd ** 2 + 13 * config.n_embd)
    params += 2 * config.n_embd  # Final LN
    print(f"{name:<15} layers={config.n_layer:<3} heads={config.n_head:<3} "
          f"d_model={config.n_embd:<5} params≈{params/1e6:.0f}M")

---

## 2. Building Blocks

### 2.1 GELU Activation

GPT-2 uses the GELU (Gaussian Error Linear Unit) activation function.

In [None]:
def gelu_new(x: torch.Tensor) -> torch.Tensor:
    """
    GELU activation function (new/fast approximation).
    
    This is the version used in GPT-2 (slightly different from original GELU).
    
    Formula:
        GELU(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
    
    Args:
        x: Input tensor
        
    Returns:
        Activated tensor
    """
    return 0.5 * x * (1.0 + torch.tanh(
        math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
    ))


# Visualize GELU
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

x = torch.linspace(-4, 4, 1000)

# GELU vs ReLU
axes[0].plot(x.numpy(), gelu_new(x).numpy(), 'b-', linewidth=2.5, label='GELU')
axes[0].plot(x.numpy(), F.relu(x).numpy(), 'r--', linewidth=2, label='ReLU')
axes[0].axhline(y=0, color='gray', linestyle='-', alpha=0.3)
axes[0].axvline(x=0, color='gray', linestyle='-', alpha=0.3)
axes[0].set_xlabel('x', fontsize=11)
axes[0].set_ylabel('f(x)', fontsize=11)
axes[0].set_title('GELU vs ReLU', fontsize=12, fontweight='bold')
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3)
axes[0].set_xlim(-4, 4)
axes[0].set_ylim(-1, 4)

# Gradient
x_grad = x.clone().requires_grad_(True)
y = gelu_new(x_grad)
grad = torch.autograd.grad(y.sum(), x_grad)[0]

x_relu = x.clone().requires_grad_(True)
y_relu = F.relu(x_relu)
grad_relu = torch.autograd.grad(y_relu.sum(), x_relu)[0]

axes[1].plot(x.numpy(), grad.detach().numpy(), 'b-', linewidth=2.5, label='GELU gradient')
axes[1].plot(x.numpy(), grad_relu.detach().numpy(), 'r--', linewidth=2, label='ReLU gradient')
axes[1].axhline(y=0, color='gray', linestyle='-', alpha=0.3)
axes[1].axhline(y=1, color='gray', linestyle='--', alpha=0.3)
axes[1].set_xlabel('x', fontsize=11)
axes[1].set_ylabel('df/dx', fontsize=11)
axes[1].set_title('Gradient Comparison', fontsize=12, fontweight='bold')
axes[1].legend(fontsize=10)
axes[1].grid(True, alpha=0.3)
axes[1].set_xlim(-4, 4)
axes[1].set_ylim(-0.5, 1.5)

plt.tight_layout()
plt.show()

print("\nKey difference: GELU has smooth gradients everywhere (no hard cutoff at 0)")

### 2.2 Layer Normalization

In [None]:
class LayerNorm(nn.Module):
    """
    Layer Normalization with learnable parameters.
    
    Normalizes the last dimension of the input tensor.
    
    Formula:
        y = (x - mean) / sqrt(var + eps) * weight + bias
    
    Args:
        normalized_shape: Size of the last dimension
        eps: Small constant for numerical stability
    """
    
    def __init__(self, normalized_shape: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Compute mean and variance along last dimension
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        
        # Normalize
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        
        # Scale and shift
        return self.weight * x_norm + self.bias


# Test LayerNorm
ln = LayerNorm(768)
x = torch.randn(2, 10, 768)
y = ln(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")
print(f"Output mean (should be ~0): {y.mean().item():.6f}")
print(f"Output std (should be ~1): {y.std().item():.6f}")

### 2.3 Multi-Head Causal Self-Attention

In [None]:
class CausalSelfAttention(nn.Module):
    """
    Multi-Head Causal Self-Attention.
    
    This implements the attention mechanism used in GPT-2 with:
    - Combined Q/K/V projection (efficient single matmul)
    - Causal masking (tokens can only attend to previous tokens)
    - Multi-head attention
    - Scaled dot-product attention
    
    Args:
        config: GPT2Config object
    """
    
    def __init__(self, config: GPT2Config):
        super().__init__()
        
        assert config.n_embd % config.n_head == 0, \
            f"n_embd ({config.n_embd}) must be divisible by n_head ({config.n_head})"
        
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = config.head_dim
        
        # Combined Q, K, V projection: [n_embd] -> [3 * n_embd]
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        
        # Output projection: [n_embd] -> [n_embd]
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        
        # Dropout
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
        
        # Causal mask (lower triangular)
        # Register as buffer so it's moved with the model but not trained
        self.register_buffer(
            'bias',
            torch.tril(torch.ones(config.n_positions, config.n_positions))
                .view(1, 1, config.n_positions, config.n_positions)
        )
    
    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        use_cache: bool = False,
        past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        """
        Forward pass.
        
        Args:
            x: Input tensor [batch, seq_len, n_embd]
            attention_mask: Optional attention mask
            use_cache: Whether to return key/value for caching
            past_key_value: Cached key/value from previous forward pass
            
        Returns:
            output: Attention output [batch, seq_len, n_embd]
            present_key_value: Cached key/value (if use_cache=True)
        """
        B, T, C = x.shape  # batch, sequence length, embedding dim
        
        # Calculate Q, K, V in one matmul
        qkv = self.c_attn(x)  # [B, T, 3*C]
        q, k, v = qkv.split(self.n_embd, dim=2)  # Each: [B, T, C]
        
        # Reshape for multi-head attention
        # [B, T, C] -> [B, T, n_head, head_dim] -> [B, n_head, T, head_dim]
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        
        # Handle cached key/values for efficient generation
        if past_key_value is not None:
            past_k, past_v = past_key_value
            k = torch.cat([past_k, k], dim=2)
            v = torch.cat([past_v, v], dim=2)
        
        present_key_value = (k, v) if use_cache else None
        
        # Attention scores: [B, n_head, T, T_kv]
        # Scale by 1/sqrt(head_dim) for stable gradients
        scale = 1.0 / math.sqrt(self.head_dim)
        attn = torch.matmul(q, k.transpose(-2, -1)) * scale
        
        # Apply causal mask
        T_kv = k.shape[2]
        causal_mask = self.bias[:, :, T_kv - T:T_kv, :T_kv]
        attn = attn.masked_fill(causal_mask == 0, float('-inf'))
        
        # Apply additional attention mask if provided
        if attention_mask is not None:
            attn = attn + attention_mask
        
        # Softmax and dropout
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_dropout(attn)
        
        # Apply attention to values: [B, n_head, T, head_dim]
        out = torch.matmul(attn, v)
        
        # Reshape back: [B, n_head, T, head_dim] -> [B, T, C]
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        
        # Output projection and dropout
        out = self.resid_dropout(self.c_proj(out))
        
        return out, present_key_value


# Test attention
config = GPT2Config()
attn = CausalSelfAttention(config)
x = torch.randn(2, 10, 768)
out, _ = attn(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")
print(f"Causal mask shape: {attn.bias.shape}")

### 2.4 Feed-Forward Network (MLP)

In [None]:
class MLP(nn.Module):
    """
    Position-wise Feed-Forward Network.
    
    Two linear transformations with GELU activation:
        FFN(x) = Dropout(Linear₂(GELU(Linear₁(x))))
    
    The inner dimension is typically 4× the embedding dimension.
    
    Args:
        config: GPT2Config object
    """
    
    def __init__(self, config: GPT2Config):
        super().__init__()
        
        # Expand: n_embd -> n_inner (4 * n_embd)
        self.c_fc = nn.Linear(config.n_embd, config.n_inner)
        
        # Contract: n_inner -> n_embd
        self.c_proj = nn.Linear(config.n_inner, config.n_embd)
        
        # Dropout
        self.dropout = nn.Dropout(config.resid_pdrop)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.c_fc(x)
        x = gelu_new(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x


# Test MLP
mlp = MLP(config)
x = torch.randn(2, 10, 768)
out = mlp(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")
print(f"Inner dimension: {config.n_inner}")

### 2.5 Transformer Block

In [None]:
class Block(nn.Module):
    """
    GPT-2 Transformer Block with Pre-LayerNorm.
    
    Architecture:
        x = x + Attention(LayerNorm(x))  # Pre-LN for attention
        x = x + MLP(LayerNorm(x))        # Pre-LN for MLP
    
    This is different from the original Transformer (Post-LN):
        x = LayerNorm(x + Attention(x))  # Post-LN
    
    Pre-LN provides better gradient flow for deep networks.
    
    Args:
        config: GPT2Config object
    """
    
    def __init__(self, config: GPT2Config):
        super().__init__()
        
        # Pre-LayerNorm for attention
        self.ln_1 = LayerNorm(config.n_embd, eps=config.layer_norm_eps)
        self.attn = CausalSelfAttention(config)
        
        # Pre-LayerNorm for MLP
        self.ln_2 = LayerNorm(config.n_embd, eps=config.layer_norm_eps)
        self.mlp = MLP(config)
    
    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        use_cache: bool = False,
        past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        """
        Forward pass.
        
        Args:
            x: Input tensor [batch, seq_len, n_embd]
            attention_mask: Optional attention mask
            use_cache: Whether to cache key/values
            past_key_value: Cached key/values from previous step
            
        Returns:
            output: Block output [batch, seq_len, n_embd]
            present_key_value: Cached key/values (if use_cache=True)
        """
        # Pre-LN Attention: x = x + Attn(LN(x))
        attn_out, present_key_value = self.attn(
            self.ln_1(x),
            attention_mask=attention_mask,
            use_cache=use_cache,
            past_key_value=past_key_value,
        )
        x = x + attn_out
        
        # Pre-LN MLP: x = x + MLP(LN(x))
        x = x + self.mlp(self.ln_2(x))
        
        return x, present_key_value


# Test Block
block = Block(config)
x = torch.randn(2, 10, 768)
out, _ = block(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")

---

## 3. Complete GPT-2 Model

In [None]:
class GPT2Model(nn.Module):
    """
    Complete GPT-2 Language Model.
    
    Architecture:
        1. Token Embedding + Position Embedding
        2. Dropout
        3. N × Transformer Blocks (Pre-LN)
        4. Final LayerNorm
        5. Language Model Head (tied with token embedding)
    
    Args:
        config: GPT2Config object
    """
    
    def __init__(self, config: GPT2Config):
        super().__init__()
        self.config = config
        
        # Token embeddings
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        
        # Position embeddings (learned, not sinusoidal)
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
        
        # Dropout after embeddings
        self.drop = nn.Dropout(config.embd_pdrop)
        
        # Transformer blocks
        self.h = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
        
        # Final layer norm (GPT-2 specific)
        self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_eps)
        
        # Language model head (weight tied with wte)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        
        # Weight tying
        self.lm_head.weight = self.wte.weight
        
        # Initialize weights
        self.apply(self._init_weights)
        
        # Apply special scaled initialization to residual projections
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(
                    p, 
                    mean=0.0, 
                    std=config.initializer_range / math.sqrt(2 * config.n_layer)
                )
        
        # Report number of parameters
        n_params = sum(p.numel() for p in self.parameters())
        print(f"GPT-2 initialized with {n_params:,} parameters")
    
    def _init_weights(self, module: nn.Module):
        """Initialize weights."""
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
    
    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        use_cache: bool = False,
        past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Forward pass.
        
        Args:
            input_ids: Token IDs [batch, seq_len]
            attention_mask: Attention mask [batch, seq_len]
            labels: Labels for language modeling loss [batch, seq_len]
            use_cache: Whether to return cached key/values
            past_key_values: Cached key/values from previous forward
            
        Returns:
            Dictionary containing:
                - logits: Vocabulary logits [batch, seq_len, vocab_size]
                - loss: Language modeling loss (if labels provided)
                - past_key_values: Cached key/values (if use_cache=True)
        """
        B, T = input_ids.shape
        
        # Determine position indices
        if past_key_values is not None:
            past_length = past_key_values[0][0].shape[2]
        else:
            past_length = 0
        
        position_ids = torch.arange(
            past_length, past_length + T, 
            device=input_ids.device
        ).unsqueeze(0)
        
        # Get embeddings
        token_emb = self.wte(input_ids)  # [B, T, n_embd]
        pos_emb = self.wpe(position_ids)  # [1, T, n_embd]
        
        # Combine embeddings
        hidden_states = self.drop(token_emb + pos_emb)
        
        # Process attention mask
        if attention_mask is not None:
            # Convert [B, T] mask to [B, 1, 1, T] for broadcasting
            attention_mask = attention_mask.view(B, 1, 1, -1)
            attention_mask = (1.0 - attention_mask) * -10000.0
        
        # Forward through transformer blocks
        presents = [] if use_cache else None
        
        for i, block in enumerate(self.h):
            past_kv = past_key_values[i] if past_key_values is not None else None
            
            hidden_states, present_kv = block(
                hidden_states,
                attention_mask=attention_mask,
                use_cache=use_cache,
                past_key_value=past_kv,
            )
            
            if use_cache:
                presents.append(present_kv)
        
        # Final layer norm
        hidden_states = self.ln_f(hidden_states)
        
        # Language model head
        logits = self.lm_head(hidden_states)  # [B, T, vocab_size]
        
        # Calculate loss if labels provided
        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            
            loss = F.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )
        
        return {
            'logits': logits,
            'loss': loss,
            'past_key_values': presents,
        }
    
    @torch.no_grad()
    def generate(
        self,
        input_ids: torch.Tensor,
        max_new_tokens: int = 50,
        temperature: float = 1.0,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        do_sample: bool = True,
    ) -> torch.Tensor:
        """
        Generate text autoregressively.
        
        Args:
            input_ids: Starting token IDs [batch, seq_len]
            max_new_tokens: Maximum tokens to generate
            temperature: Sampling temperature (higher = more random)
            top_k: Keep only top k tokens with highest probability
            top_p: Keep tokens with cumulative probability >= top_p
            do_sample: Whether to sample (True) or take argmax (False)
            
        Returns:
            Generated token IDs [batch, seq_len + max_new_tokens]
        """
        self.eval()
        
        for _ in range(max_new_tokens):
            # Crop to max context length if needed
            idx_cond = input_ids if input_ids.shape[1] <= self.config.n_positions \
                       else input_ids[:, -self.config.n_positions:]
            
            # Forward pass
            outputs = self(idx_cond)
            logits = outputs['logits'][:, -1, :]  # [B, vocab_size]
            
            # Apply temperature
            logits = logits / temperature
            
            # Apply top-k filtering
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = float('-inf')
            
            # Apply top-p (nucleus) filtering
            if top_p is not None:
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                
                # Remove tokens with cumulative probability above threshold
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0
                
                indices_to_remove = sorted_indices_to_remove.scatter(
                    1, sorted_indices, sorted_indices_to_remove
                )
                logits[indices_to_remove] = float('-inf')
            
            # Sample or argmax
            probs = F.softmax(logits, dim=-1)
            
            if do_sample:
                next_token = torch.multinomial(probs, num_samples=1)
            else:
                next_token = torch.argmax(probs, dim=-1, keepdim=True)
            
            # Append to sequence
            input_ids = torch.cat([input_ids, next_token], dim=1)
        
        return input_ids


# Create model
print("Creating GPT-2 Small:")
model = GPT2Model(GPT2Config())

# Test forward pass
x = torch.randint(0, 50257, (2, 10))
outputs = model(x, labels=x)

print(f"\nInput shape: {x.shape}")
print(f"Logits shape: {outputs['logits'].shape}")
print(f"Loss: {outputs['loss'].item():.4f}")
print(f"Expected loss (random): {math.log(50257):.4f}")

---

## 4. Loading Pretrained Weights

Let's load the official GPT-2 weights from Hugging Face.

In [None]:
def load_pretrained_gpt2(model_name: str = 'gpt2') -> GPT2Model:
    """
    Load pretrained GPT-2 weights from Hugging Face.
    
    Args:
        model_name: One of 'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'
        
    Returns:
        GPT2Model with pretrained weights
    """
    from transformers import GPT2LMHeadModel
    
    print(f"Loading {model_name} from Hugging Face...")
    
    # Load HF model
    hf_model = GPT2LMHeadModel.from_pretrained(model_name)
    hf_state = hf_model.state_dict()
    
    # Create our model
    config = GPT2Config.from_pretrained(model_name)
    model = GPT2Model(config)
    
    # Map HF keys to our keys
    key_mapping = {
        'transformer.wte.weight': 'wte.weight',
        'transformer.wpe.weight': 'wpe.weight',
        'transformer.ln_f.weight': 'ln_f.weight',
        'transformer.ln_f.bias': 'ln_f.bias',
        'lm_head.weight': 'lm_head.weight',
    }
    
    # Add block mappings
    for i in range(config.n_layer):
        prefix_hf = f'transformer.h.{i}'
        prefix_ours = f'h.{i}'
        
        key_mapping.update({
            f'{prefix_hf}.ln_1.weight': f'{prefix_ours}.ln_1.weight',
            f'{prefix_hf}.ln_1.bias': f'{prefix_ours}.ln_1.bias',
            f'{prefix_hf}.attn.c_attn.weight': f'{prefix_ours}.attn.c_attn.weight',
            f'{prefix_hf}.attn.c_attn.bias': f'{prefix_ours}.attn.c_attn.bias',
            f'{prefix_hf}.attn.c_proj.weight': f'{prefix_ours}.attn.c_proj.weight',
            f'{prefix_hf}.attn.c_proj.bias': f'{prefix_ours}.attn.c_proj.bias',
            f'{prefix_hf}.ln_2.weight': f'{prefix_ours}.ln_2.weight',
            f'{prefix_hf}.ln_2.bias': f'{prefix_ours}.ln_2.bias',
            f'{prefix_hf}.mlp.c_fc.weight': f'{prefix_ours}.mlp.c_fc.weight',
            f'{prefix_hf}.mlp.c_fc.bias': f'{prefix_ours}.mlp.c_fc.bias',
            f'{prefix_hf}.mlp.c_proj.weight': f'{prefix_ours}.mlp.c_proj.weight',
            f'{prefix_hf}.mlp.c_proj.bias': f'{prefix_ours}.mlp.c_proj.bias',
        })
    
    # Copy weights
    new_state = {}
    for hf_key, our_key in key_mapping.items():
        if hf_key in hf_state:
            weight = hf_state[hf_key]
            
            # HF uses Conv1D, we use Linear (need to transpose)
            if 'c_attn.weight' in hf_key or 'c_proj.weight' in hf_key or \
               'c_fc.weight' in hf_key:
                weight = weight.t()
            
            new_state[our_key] = weight
    
    # Load into model
    model.load_state_dict(new_state, strict=False)
    
    print(f"Loaded {len(new_state)} weight tensors")
    return model


# Load pretrained model
try:
    model = load_pretrained_gpt2('gpt2')
    model = model.to(device)
    model.eval()
    print("\nModel loaded successfully!")
except Exception as e:
    print(f"Could not load pretrained weights: {e}")
    print("Using randomly initialized model instead.")
    model = GPT2Model(GPT2Config()).to(device)

---

## 5. Tokenization

GPT-2 uses byte-level BPE tokenization. Let's use tiktoken for this.

In [None]:
try:
    import tiktoken
    enc = tiktoken.get_encoding('gpt2')
    print("Using tiktoken for tokenization")
except ImportError:
    from transformers import GPT2Tokenizer
    enc = GPT2Tokenizer.from_pretrained('gpt2')
    print("Using Hugging Face tokenizer")


def encode(text: str) -> torch.Tensor:
    """Encode text to token IDs."""
    if hasattr(enc, 'encode'):
        tokens = enc.encode(text)
    else:
        tokens = enc(text)['input_ids']
    return torch.tensor(tokens, dtype=torch.long).unsqueeze(0)


def decode(tokens: torch.Tensor) -> str:
    """Decode token IDs to text."""
    if hasattr(enc, 'decode'):
        return enc.decode(tokens.squeeze().tolist())
    else:
        return enc.decode(tokens.squeeze().tolist())


# Test tokenization
test_text = "Hello, world! How are you today?"
tokens = encode(test_text)
decoded = decode(tokens)

print(f"Original: '{test_text}'")
print(f"Tokens: {tokens.tolist()}")
print(f"Decoded: '{decoded}'")
print(f"Match: {test_text == decoded}")

---

## 6. Text Generation

### 6.1 Different Decoding Strategies

In [None]:
def generate_text(
    model: GPT2Model,
    prompt: str,
    max_new_tokens: int = 50,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    top_p: Optional[float] = None,
    do_sample: bool = True,
) -> str:
    """
    Generate text from a prompt.
    
    Args:
        model: GPT-2 model
        prompt: Starting text
        max_new_tokens: Number of tokens to generate
        temperature: Sampling temperature
        top_k: Top-k filtering
        top_p: Nucleus sampling
        do_sample: Whether to sample
        
    Returns:
        Generated text
    """
    input_ids = encode(prompt).to(device)
    
    output_ids = model.generate(
        input_ids,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        do_sample=do_sample,
    )
    
    return decode(output_ids)


# Test generation with different strategies
prompt = "The future of artificial intelligence is"

print("=" * 70)
print("TEXT GENERATION WITH DIFFERENT STRATEGIES")
print("=" * 70)
print(f"\nPrompt: '{prompt}'\n")

strategies = [
    ('Greedy (argmax)', dict(do_sample=False)),
    ('Temperature 0.7', dict(temperature=0.7)),
    ('Top-k = 50', dict(top_k=50)),
    ('Top-p = 0.9', dict(top_p=0.9)),
    ('Combined', dict(temperature=0.8, top_k=50, top_p=0.95)),
]

for name, kwargs in strategies:
    print(f"\n[{name}]")
    text = generate_text(model, prompt, max_new_tokens=40, **kwargs)
    print(text)
    print("-" * 70)

In [None]:
def visualize_generation_strategies():
    """
    Visualize different text generation strategies.
    """
    fig = plt.figure(figsize=(18, 12))
    
    gs = gridspec.GridSpec(2, 2, hspace=0.3, wspace=0.25)
    
    # Simulated probability distribution
    np.random.seed(42)
    vocab_size = 20
    raw_logits = np.random.randn(vocab_size) * 2
    raw_logits[0] = 5  # One token much more likely
    raw_logits[1] = 3
    raw_logits[2] = 2.5
    
    def softmax(x):
        exp_x = np.exp(x - np.max(x))
        return exp_x / exp_x.sum()
    
    probs = softmax(raw_logits)
    tokens = [f'T{i}' for i in range(vocab_size)]
    
    # === TOP LEFT: Temperature ===
    ax1 = fig.add_subplot(gs[0, 0])
    
    temps = [0.5, 1.0, 1.5]
    colors = ['#e74c3c', '#3498db', '#27ae60']
    width = 0.25
    
    for i, (temp, color) in enumerate(zip(temps, colors)):
        temp_probs = softmax(raw_logits / temp)
        x = np.arange(vocab_size) + i * width
        ax1.bar(x, temp_probs, width, label=f'T={temp}', color=color, alpha=0.8)
    
    ax1.set_xlabel('Token', fontsize=11)
    ax1.set_ylabel('Probability', fontsize=11)
    ax1.set_title('Temperature: Controls Distribution Sharpness', fontsize=12, fontweight='bold')
    ax1.legend(fontsize=10)
    ax1.set_xticks(np.arange(vocab_size) + width)
    ax1.set_xticklabels(tokens, fontsize=8)
    
    # === TOP RIGHT: Top-k ===
    ax2 = fig.add_subplot(gs[0, 1])
    
    sorted_idx = np.argsort(probs)[::-1]
    sorted_probs = probs[sorted_idx]
    sorted_tokens = [tokens[i] for i in sorted_idx]
    
    k = 5
    colors_topk = ['#27ae60' if i < k else '#95a5a6' for i in range(vocab_size)]
    ax2.bar(range(vocab_size), sorted_probs, color=colors_topk)
    ax2.axvline(x=k-0.5, color='#e74c3c', linestyle='--', linewidth=2)
    ax2.text(k, max(sorted_probs)*0.8, f'k={k}', fontsize=11, color='#e74c3c', fontweight='bold')
    
    ax2.set_xlabel('Token (sorted by probability)', fontsize=11)
    ax2.set_ylabel('Probability', fontsize=11)
    ax2.set_title('Top-k: Keep Only k Most Likely Tokens', fontsize=12, fontweight='bold')
    ax2.set_xticks(range(vocab_size))
    ax2.set_xticklabels(sorted_tokens, fontsize=8, rotation=45)
    
    # === BOTTOM LEFT: Top-p (Nucleus) ===
    ax3 = fig.add_subplot(gs[1, 0])
    
    cumsum = np.cumsum(sorted_probs)
    p = 0.9
    cutoff = np.searchsorted(cumsum, p)
    
    colors_topp = ['#27ae60' if i <= cutoff else '#95a5a6' for i in range(vocab_size)]
    ax3.bar(range(vocab_size), sorted_probs, color=colors_topp)
    ax3.plot(range(vocab_size), cumsum, 'r-', linewidth=2, marker='o', markersize=4)
    ax3.axhline(y=p, color='#e74c3c', linestyle='--', linewidth=2)
    ax3.text(vocab_size-3, p+0.02, f'p={p}', fontsize=11, color='#e74c3c', fontweight='bold')
    
    ax3.set_xlabel('Token (sorted by probability)', fontsize=11)
    ax3.set_ylabel('Probability / Cumulative', fontsize=11)
    ax3.set_title('Top-p (Nucleus): Keep Until Cumulative ≥ p', fontsize=12, fontweight='bold')
    ax3.set_xticks(range(vocab_size))
    ax3.set_xticklabels(sorted_tokens, fontsize=8, rotation=45)
    
    # === BOTTOM RIGHT: Strategy Comparison ===
    ax4 = fig.add_subplot(gs[1, 1])
    ax4.set_xlim(0, 12)
    ax4.set_ylim(0, 10)
    ax4.axis('off')
    ax4.set_title('Strategy Comparison', fontsize=13, fontweight='bold')
    
    strategies_info = [
        ('Greedy', 'Always pick highest prob', 'Deterministic, repetitive', '#e74c3c'),
        ('Temperature', 'Scale logits before softmax', 'Controls randomness', '#3498db'),
        ('Top-k', 'Sample from k best tokens', 'Excludes unlikely tokens', '#27ae60'),
        ('Top-p', 'Sample until cumsum ≥ p', 'Adaptive vocabulary size', '#f39c12'),
    ]
    
    for i, (name, desc, note, color) in enumerate(strategies_info):
        y = 8 - i * 2
        
        rect = FancyBboxPatch((0.5, y-0.5), 3, 1.3, boxstyle="round,pad=0.02",
                              facecolor=color, edgecolor='black', linewidth=2)
        ax4.add_patch(rect)
        ax4.text(2, y, name, ha='center', va='center', 
                fontsize=10, fontweight='bold', color='white')
        
        ax4.text(4, y+0.2, desc, va='center', fontsize=9)
        ax4.text(4, y-0.3, f'→ {note}', va='center', fontsize=8, color='gray')
    
    plt.tight_layout()
    plt.show()

visualize_generation_strategies()

---

## 7. Zero-Shot Task Prompting

Let's implement the zero-shot tasks from the paper.

In [None]:
class ZeroShotTasks:
    """
    Zero-shot task implementations using GPT-2.
    """
    
    def __init__(self, model: GPT2Model):
        self.model = model
    
    def summarize(self, article: str, max_tokens: int = 50) -> str:
        """
        Summarize an article using the TL;DR prompt.
        
        From the paper: "To induce summarization behavior we add 
        the text TL;DR: after the article"
        """
        prompt = f"{article}\n\nTL;DR:"
        return generate_text(
            self.model, prompt, max_new_tokens=max_tokens,
            temperature=0.7, top_p=0.9
        )
    
    def translate(self, text: str, target_lang: str = "French") -> str:
        """
        Translate text to another language.
        """
        prompt = f"Translate English to {target_lang}:\n\n{text} =>"
        return generate_text(
            self.model, prompt, max_new_tokens=30,
            temperature=0.3, top_k=50
        )
    
    def answer_question(self, question: str, context: str = None) -> str:
        """
        Answer a question, optionally with context.
        """
        if context:
            prompt = f"Context: {context}\n\nQ: {question}\nA:"
        else:
            prompt = f"Q: {question}\nA:"
        
        return generate_text(
            self.model, prompt, max_new_tokens=30,
            temperature=0.5, top_k=40
        )
    
    def complete_story(self, beginning: str, max_tokens: int = 100) -> str:
        """
        Complete a story given a beginning.
        """
        return generate_text(
            self.model, beginning, max_new_tokens=max_tokens,
            temperature=0.9, top_p=0.95
        )
    
    def sentiment(self, review: str) -> str:
        """
        Analyze sentiment of a review.
        """
        prompt = f"Review: {review}\nSentiment (positive/negative):"
        return generate_text(
            self.model, prompt, max_new_tokens=5,
            temperature=0.3, top_k=10
        )


# Test zero-shot tasks
tasks = ZeroShotTasks(model)

print("=" * 70)
print("ZERO-SHOT TASK DEMONSTRATIONS")
print("=" * 70)

# Summarization
article = """Artificial intelligence has made remarkable progress in recent years. 
Deep learning models can now understand language, recognize images, and even generate 
creative content. Companies around the world are investing billions in AI research 
and development. However, concerns about AI safety, job displacement, and ethical 
implications continue to be debated by researchers and policymakers."""

print("\n[SUMMARIZATION]")
print(f"Article: {article[:100]}...")
summary = tasks.summarize(article)
print(f"Summary: {summary}")

# Question Answering
print("\n[QUESTION ANSWERING]")
question = "What is the capital of France?"
answer = tasks.answer_question(question)
print(f"Q: {question}")
print(f"A: {answer}")

# Story Completion
print("\n[STORY COMPLETION]")
beginning = "Once upon a time, in a distant galaxy, there lived a robot who dreamed of"
story = tasks.complete_story(beginning, max_tokens=50)
print(story)

---

## 8. Fine-Tuning Example

Let's show how to fine-tune GPT-2 on a custom dataset.

In [None]:
class TextDataset(Dataset):
    """
    Simple text dataset for fine-tuning.
    """
    
    def __init__(self, texts: List[str], max_length: int = 256):
        self.examples = []
        
        for text in texts:
            tokens = encode(text).squeeze()
            
            # Split into chunks of max_length
            for i in range(0, len(tokens) - 1, max_length):
                chunk = tokens[i:i + max_length + 1]
                if len(chunk) > 1:
                    self.examples.append(chunk)
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        tokens = self.examples[idx]
        return {
            'input_ids': tokens[:-1],
            'labels': tokens[1:],
        }


def collate_fn(batch):
    """Collate function for DataLoader."""
    max_len = max(len(x['input_ids']) for x in batch)
    
    input_ids = torch.zeros(len(batch), max_len, dtype=torch.long)
    labels = torch.full((len(batch), max_len), -100, dtype=torch.long)
    
    for i, x in enumerate(batch):
        length = len(x['input_ids'])
        input_ids[i, :length] = x['input_ids']
        labels[i, :length] = x['labels']
    
    return {'input_ids': input_ids, 'labels': labels}


def train_epoch(model, dataloader, optimizer, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    
    for batch in tqdm(dataloader, desc='Training'):
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        
        outputs = model(input_ids, labels=labels)
        loss = outputs['loss']
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)


# Example: Fine-tune on Shakespeare
shakespeare_texts = [
    "To be, or not to be, that is the question: Whether 'tis nobler in the mind to suffer",
    "The slings and arrows of outrageous fortune, Or to take arms against a sea of troubles",
    "All the world's a stage, And all the men and women merely players",
    "They have their exits and their entrances, And one man in his time plays many parts",
    "What's in a name? That which we call a rose By any other name would smell as sweet",
    "Romeo, Romeo, wherefore art thou Romeo? Deny thy father and refuse thy name",
    "But soft, what light through yonder window breaks? It is the east, and Juliet is the sun",
    "Good night, good night! Parting is such sweet sorrow, That I shall say good night till it be morrow",
] * 10  # Repeat for more data

print("=" * 70)
print("FINE-TUNING DEMONSTRATION")
print("=" * 70)

# Create dataset and dataloader
dataset = TextDataset(shakespeare_texts, max_length=64)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

print(f"\nDataset size: {len(dataset)} examples")
print(f"Batch size: 4")

# Setup optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)

# Train for a few epochs (demonstration only)
n_epochs = 2
print(f"\nTraining for {n_epochs} epochs...")

for epoch in range(n_epochs):
    loss = train_epoch(model, dataloader, optimizer, device)
    print(f"Epoch {epoch + 1}/{n_epochs}, Loss: {loss:.4f}")

# Test generation after fine-tuning
print("\n[Generation after fine-tuning]")
prompt = "To be, or not to be,"
output = generate_text(model, prompt, max_new_tokens=30, temperature=0.8)
print(output)

---

## 9. Model Architecture Visualization

In [None]:
def visualize_complete_architecture():
    """
    Complete visualization of GPT-2 architecture.
    """
    fig = plt.figure(figsize=(20, 16))
    ax = fig.add_subplot(111)
    ax.set_xlim(0, 24)
    ax.set_ylim(0, 22)
    ax.axis('off')
    
    ax.text(12, 21.5, 'GPT-2 Complete Architecture', fontsize=18, fontweight='bold', ha='center')
    
    # === INPUT ===
    rect = FancyBboxPatch((9, 19.5), 6, 1, boxstyle="round,pad=0.03",
                          facecolor='#3498db', edgecolor='black', linewidth=2)
    ax.add_patch(rect)
    ax.text(12, 20, 'Input Token IDs [B, T]', ha='center', va='center', 
            fontsize=11, color='white', fontweight='bold')
    
    ax.annotate('', xy=(12, 18.5), xytext=(12, 19.4),
               arrowprops=dict(arrowstyle='->', color='black', lw=2))
    
    # === EMBEDDINGS ===
    # Token embedding
    rect = FancyBboxPatch((5, 17), 5, 1.2, boxstyle="round,pad=0.03",
                          facecolor='#9b59b6', edgecolor='black', linewidth=2)
    ax.add_patch(rect)
    ax.text(7.5, 17.6, 'Token Embedding', ha='center', va='center', 
            fontsize=10, color='white', fontweight='bold')
    ax.text(7.5, 17.2, 'wte: [50257, 768]', ha='center', fontsize=8, color='white')
    
    # Position embedding
    rect = FancyBboxPatch((14, 17), 5, 1.2, boxstyle="round,pad=0.03",
                          facecolor='#9b59b6', edgecolor='black', linewidth=2)
    ax.add_patch(rect)
    ax.text(16.5, 17.6, 'Position Embedding', ha='center', va='center', 
            fontsize=10, color='white', fontweight='bold')
    ax.text(16.5, 17.2, 'wpe: [1024, 768]', ha='center', fontsize=8, color='white')
    
    # Add
    circle = plt.Circle((12, 15.5), 0.4, facecolor='#27ae60', edgecolor='black', linewidth=2)
    ax.add_patch(circle)
    ax.text(12, 15.5, '+', ha='center', va='center', fontsize=16, color='white', fontweight='bold')
    
    ax.annotate('', xy=(12, 15.1), xytext=(7.5, 16.9),
               arrowprops=dict(arrowstyle='->', color='black', lw=1.5))
    ax.annotate('', xy=(12, 15.1), xytext=(16.5, 16.9),
               arrowprops=dict(arrowstyle='->', color='black', lw=1.5))
    
    # Dropout
    ax.text(13, 15.5, 'Dropout', fontsize=9, va='center')
    
    # === TRANSFORMER BLOCK ===
    block_rect = FancyBboxPatch((4, 5), 16, 9), 
    rect = FancyBboxPatch((4, 5), 16, 9, boxstyle="round,pad=0.1",
                          facecolor='#f5f5f5', edgecolor='#2c3e50', linewidth=3)
    ax.add_patch(rect)
    ax.text(12, 13.5, 'Transformer Block × 12', ha='center', fontsize=12, fontweight='bold')
    
    # Pre-LN 1
    rect = FancyBboxPatch((5.5, 11.5), 3, 1, boxstyle="round,pad=0.02",
                          facecolor='#f39c12', edgecolor='black', linewidth=1.5)
    ax.add_patch(rect)
    ax.text(7, 12, 'LayerNorm', ha='center', va='center', fontsize=9, color='white', fontweight='bold')
    
    # Multi-Head Attention
    rect = FancyBboxPatch((9.5, 11), 5, 2, boxstyle="round,pad=0.02",
                          facecolor='#e74c3c', edgecolor='black', linewidth=1.5)
    ax.add_patch(rect)
    ax.text(12, 12.3, 'Multi-Head', ha='center', va='center', fontsize=9, color='white', fontweight='bold')
    ax.text(12, 11.7, 'Self-Attention', ha='center', va='center', fontsize=9, color='white')
    ax.text(12, 11.2, '(12 heads)', ha='center', va='center', fontsize=8, color='white')
    
    ax.annotate('', xy=(9.4, 12), xytext=(8.6, 12),
               arrowprops=dict(arrowstyle='->', color='black', lw=1.5))
    
    # Residual 1
    ax.plot([5, 5], [12, 9.5], 'g-', linewidth=3)
    ax.plot([5, 16], [9.5, 9.5], 'g-', linewidth=3)
    
    circle = plt.Circle((16, 9.5), 0.3, facecolor='#27ae60', edgecolor='black', linewidth=1.5)
    ax.add_patch(circle)
    ax.text(16, 9.5, '+', ha='center', va='center', fontsize=12, color='white', fontweight='bold')
    
    ax.annotate('', xy=(15.7, 9.5), xytext=(14.6, 11),
               arrowprops=dict(arrowstyle='->', color='black', lw=1.5))
    
    # Pre-LN 2
    ax.plot([16, 16], [9.2, 8.3], 'k-', linewidth=2)
    ax.plot([16, 7], [8.3, 8.3], 'k-', linewidth=2)
    
    rect = FancyBboxPatch((5.5, 7.5), 3, 1, boxstyle="round,pad=0.02",
                          facecolor='#f39c12', edgecolor='black', linewidth=1.5)
    ax.add_patch(rect)
    ax.text(7, 8, 'LayerNorm', ha='center', va='center', fontsize=9, color='white', fontweight='bold')
    
    # MLP
    rect = FancyBboxPatch((9.5, 6.5), 5, 2.5, boxstyle="round,pad=0.02",
                          facecolor='#3498db', edgecolor='black', linewidth=1.5)
    ax.add_patch(rect)
    ax.text(12, 8.2, 'MLP', ha='center', va='center', fontsize=10, color='white', fontweight='bold')
    ax.text(12, 7.6, 'Linear: 768→3072', ha='center', va='center', fontsize=8, color='white')
    ax.text(12, 7.1, 'GELU', ha='center', va='center', fontsize=8, color='white')
    ax.text(12, 6.6, 'Linear: 3072→768', ha='center', va='center', fontsize=8, color='white')
    
    ax.annotate('', xy=(9.4, 7.75), xytext=(8.6, 7.75),
               arrowprops=dict(arrowstyle='->', color='black', lw=1.5))
    
    # Residual 2
    ax.plot([5, 5], [8, 5.5], 'g-', linewidth=3)
    ax.plot([5, 16], [5.5, 5.5], 'g-', linewidth=3)
    
    circle = plt.Circle((16, 5.5), 0.3, facecolor='#27ae60', edgecolor='black', linewidth=1.5)
    ax.add_patch(circle)
    ax.text(16, 5.5, '+', ha='center', va='center', fontsize=12, color='white', fontweight='bold')
    
    ax.annotate('', xy=(15.7, 5.5), xytext=(14.6, 6.4),
               arrowprops=dict(arrowstyle='->', color='black', lw=1.5))
    
    # === FINAL LAYER NORM ===
    ax.annotate('', xy=(12, 3.7), xytext=(12, 4.9),
               arrowprops=dict(arrowstyle='->', color='black', lw=2))
    
    rect = FancyBboxPatch((9, 2.5), 6, 1, boxstyle="round,pad=0.03",
                          facecolor='#f39c12', edgecolor='black', linewidth=2)
    ax.add_patch(rect)
    ax.text(12, 3, 'Final LayerNorm', ha='center', va='center', 
            fontsize=10, color='white', fontweight='bold')
    
    # === LM HEAD ===
    ax.annotate('', xy=(12, 1.3), xytext=(12, 2.4),
               arrowprops=dict(arrowstyle='->', color='black', lw=2))
    
    rect = FancyBboxPatch((9, 0.2), 6, 1, boxstyle="round,pad=0.03",
                          facecolor='#2c3e50', edgecolor='black', linewidth=2)
    ax.add_patch(rect)
    ax.text(12, 0.7, 'LM Head → Logits [B, T, 50257]', ha='center', va='center', 
            fontsize=10, color='white', fontweight='bold')
    
    # Weight tying note
    ax.annotate('', xy=(7, 1), xytext=(7, 17),
               arrowprops=dict(arrowstyle='<->', color='#9b59b6', lw=2, ls='--'))
    ax.text(3, 9, 'Weight\nTying', ha='center', fontsize=9, color='#9b59b6', fontweight='bold')
    
    # Legend
    ax.text(20, 12, 'Key:', fontsize=10, fontweight='bold')
    
    legend_items = [
        ('Pre-LN', '#f39c12'),
        ('Attention', '#e74c3c'),
        ('MLP', '#3498db'),
        ('Residual', '#27ae60'),
    ]
    
    for i, (name, color) in enumerate(legend_items):
        y = 11 - i * 1.2
        rect = FancyBboxPatch((19, y-0.3), 2, 0.6, boxstyle="round,pad=0.02",
                              facecolor=color, edgecolor='black', linewidth=1)
        ax.add_patch(rect)
        ax.text(20, y, name, ha='center', va='center', fontsize=8, color='white')
    
    plt.tight_layout()
    plt.show()

visualize_complete_architecture()

---

## 10. Summary

### What We Built

| Component | Description |
|-----------|-------------|
| **GPT2Config** | Configuration dataclass with all hyperparameters |
| **LayerNorm** | Custom layer normalization |
| **CausalSelfAttention** | Multi-head attention with causal masking |
| **MLP** | Position-wise feed-forward network |
| **Block** | Transformer block with Pre-LN |
| **GPT2Model** | Complete model with generation |

### Key Features

1. **Pre-LayerNorm**: Better gradient flow than original Transformer
2. **Scaled Initialization**: 1/√(2N) for residual layers
3. **KV-Cache**: Efficient autoregressive generation
4. **Multiple Decoding Strategies**: Greedy, temperature, top-k, top-p
5. **Weight Tying**: Token embedding = LM head

### Usage Patterns

```python
# Load pretrained model
model = load_pretrained_gpt2('gpt2')

# Generate text
text = generate_text(model, "The future of AI", max_new_tokens=50)

# Zero-shot tasks
tasks = ZeroShotTasks(model)
summary = tasks.summarize(article)
answer = tasks.answer_question(question)

# Fine-tune
dataset = TextDataset(texts)
train_epoch(model, dataloader, optimizer, device)
```

---

## Complete GPT-2 Series

This concludes our 5-part deep dive into GPT-2:

1. **Part I: Genesis** - The vision and WebText dataset
2. **Part II: Architecture** - Pre-LN, scaled init, all modifications
3. **Part III: Zero-Shot** - Task prompting and emergence
4. **Part IV: Training** - Compute, scaling laws, infrastructure
5. **Part V: Implementation** - Complete working code

---

## References

1. Radford et al. (2019). [Language Models are Unsupervised Multitask Learners](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)
2. Karpathy. [nanoGPT](https://github.com/karpathy/nanoGPT)
3. Hugging Face. [Transformers Library](https://github.com/huggingface/transformers)
4. OpenAI. [GPT-2 GitHub](https://github.com/openai/gpt-2)