# GPT-2 from Scratch

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/deep-learning-abc/blob/main/gpt2.ipynb)

This notebook implements **GPT-2 (Generative Pre-trained Transformer)** from scratch.

Key differences from BERT:
1. **Decoder-only architecture**: Uses masked self-attention (causal mask) so tokens can only attend to past tokens.
2. **Objective**: Causal Language Modeling (predict next token).
3. **Layer Normalization**: Pre-norm formulation (LayerNorm is applied *before* the sub-layer).

In [None]:
!pip install torch matplotlib

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Causal Self-Attention

The core of GPT is the **causal mask**, which ensures that position $i$ can only attend to positions $j \le i$.

In [None]:
class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads, max_len):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Merged Q, K, V projection for efficiency
        self.c_attn = nn.Linear(d_model, 3 * d_model)
        self.c_proj = nn.Linear(d_model, d_model)
        
        # Causal mask: Lower triangular matrix of ones
        self.register_buffer("bias", torch.tril(torch.ones(max_len, max_len))
                                     .view(1, 1, max_len, max_len))
        
    def forward(self, x):
        batch, seq_len, d_model = x.shape
        
        # Calculate Q, K, V
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.d_model, dim=2)
        
        # Reshape for multi-head attention
        q = q.view(batch, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        k = k.view(batch, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        v = v.view(batch, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # Scaled dot-product attention
        scores = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.d_k))
        
        # Apply CAUSAL mask (mask future tokens)
        mask = self.bias[:, :, :seq_len, :seq_len]
        scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Softmax and output
        att = F.softmax(scores, dim=-1)
        y = att @ v
        
        # Reassemble heads
        y = y.transpose(1, 2).contiguous().view(batch, seq_len, d_model)
        return self.c_proj(y)

## 2. GPT-2 Block (Pre-Norm)

GPT-2 uses **Pre-LayerNorm**: `x = x + Sublayer(LayerNorm(x))`.
This is more stable than Post-LayerNorm (used in BERT) for deep networks.

In [None]:
class GPTBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, max_len):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_heads, max_len)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model)
        )
        
    def forward(self, x):
        # Pre-Norm: Norm -> Attention -> Resid
        x = x + self.attn(self.ln1(x))
        # Pre-Norm: Norm -> MLP -> Resid
        x = x + self.mlp(self.ln2(x))
        return x

## 3. Full GPT-2 Model

In [None]:
class GPT2(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, d_ff, n_layers, max_len):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.blocks = nn.ModuleList([
            GPTBlock(d_model, n_heads, d_ff, max_len) for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        
        # Weight tying: embedding weights == output weights
        self.token_emb.weight = self.head.weight
        
        self.max_len = max_len

    def forward(self, idx):
        batch, seq_len = idx.shape
        assert seq_len <= self.max_len, f"Seq len {seq_len} > max {self.max_len}"
        
        pos_idx = torch.arange(seq_len, device=idx.device)
        
        # Token + Pos embeddings
        x = self.token_emb(idx) + self.pos_emb(pos_idx)
        
        # Transformer blocks
        for block in self.blocks:
            x = block(x)
            
        x = self.ln_f(x)
        logits = self.head(x)
        return logits

    def generate(self, idx, max_new_tokens, temperature=1.0):
        """Generate new tokens by repeatedly predicting the next token."""
        for _ in range(max_new_tokens):
            # Crop to max_len if needed
            idx_cond = idx if idx.size(1) <= self.max_len else idx[:, -self.max_len:]
            
            # Forward pass
            logits = self(idx_cond)
            
            # Get last token prediction
            logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            
            # Sample
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
            
        return idx

# Initialize GPT-2 (small config for demo)
model = GPT2(
    vocab_size=1000,
    d_model=128,
    n_heads=4,
    d_ff=256,
    n_layers=4,
    max_len=64
).to(device)

print(f"GPT-2 Initialized with {sum(p.numel() for p in model.parameters()):,} parameters")

## 4. Test Generation

In [None]:
# Generate random text (untrained model)
start_tokens = torch.zeros((1, 1), dtype=torch.long, device=device)  # Start with token 0
generated = model.generate(start_tokens, max_new_tokens=20)

print(f"Generated sequence (indices): {generated[0].tolist()}")
print("\n(Note: The model is random, so output is meaningless noise)")