# Topic 9: The Transformer Architecture

## Learning Objectives

By the end of this notebook, you will:
- Understand the **complete transformer architecture** from 30,000 feet to implementation
- Learn **why** each component exists and what problem it solves
- Build an encoder-decoder transformer **from scratch**
- Understand encoder-only (BERT), decoder-only (GPT), and encoder-decoder (T5) variants
- Visualize information flow through the transformer
- Connect transformers to modern LLMs and understand their evolution
- Appreciate why transformers revolutionized AI

## The Big Picture: Why Transformers Changed Everything

### Before Transformers: The RNN/LSTM Era (Pre-2017)

**The dominant architecture**:
```
Input sequence → RNN/LSTM → Hidden states → Output
```

**Critical limitations**:
1. **Sequential processing**: Must process tokens one-by-one (slow, can't parallelize)
2. **Vanishing gradients**: Hard to learn long-range dependencies
3. **Memory bottleneck**: All information compressed into fixed hidden state
4. **Training time**: Takes weeks to train large models

**Example failure**:
```
Input: "The cat, which was very fluffy and loved to play with yarn, sat on the mat."
Problem: By the time LSTM reaches "sat", it may have forgotten "cat" (the subject)
```

### The Transformer Revolution (2017)

**"Attention is All You Need"** (Vaswani et al., 2017)

**Key insight**: Replace recurrence with **self-attention**

**Revolutionary properties**:
1. **Parallel processing**: All tokens processed simultaneously
2. **Direct connections**: Any token can attend to any other token
3. **Scalability**: Can train 100B+ parameter models
4. **Transfer learning**: Pre-train once, fine-tune for many tasks

**Impact timeline**:
- **2017**: Original Transformer (machine translation)
- **2018**: BERT (masked language modeling) + GPT (autoregressive LM)
- **2019**: GPT-2, T5, RoBERTa
- **2020**: GPT-3 (175B params), Vision Transformers
- **2021**: Multimodal transformers (CLIP, DALL-E)
- **2022**: ChatGPT, LLaMA, Stable Diffusion
- **2023-2025**: GPT-4, Claude, LLaMA 2/3, Gemini, 100K+ context models

**Why transformers cannot be skipped**: They are the foundation of modern AI. Every major breakthrough since 2017—GPT, BERT, CLIP, AlphaFold, Stable Diffusion, ChatGPT, Claude—is built on transformers.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import math

torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 6)

## The Transformer Architecture: Overview

### The Complete Picture

```
┌─────────────────────────────────────────────────────────────┐
│                    TRANSFORMER                               │
│                                                              │
│  ┌──────────────────────┐      ┌──────────────────────┐   │
│  │      ENCODER         │      │      DECODER         │   │
│  │                      │      │                      │   │
│  │  Input Embedding     │      │  Output Embedding    │   │
│  │         +            │      │         +            │   │
│  │  Positional Encoding │      │  Positional Encoding │   │
│  │         ↓            │      │         ↓            │   │
│  │  ┌──────────────┐   │      │  ┌──────────────┐   │   │
│  │  │ Encoder Layer│ ×N│      │  │ Decoder Layer│ ×N│   │
│  │  │              │   │──────→  │              │   │   │
│  │  │ • Self-Attn  │   │ Context │ • Self-Attn  │   │   │
│  │  │ • Feed-Fwd   │   │      │  │ • Cross-Attn │   │   │
│  │  └──────────────┘   │      │  │ • Feed-Fwd   │   │   │
│  │         ↓            │      │  └──────────────┘   │   │
│  │  Encoder Output      │      │         ↓            │   │
│  └──────────────────────┘      │  Linear + Softmax    │   │
│                                 │         ↓            │   │
│                                 │  Output Probabilities│   │
│                                 └──────────────────────┘   │
└─────────────────────────────────────────────────────────────┘
```

### Key Components (We'll build each from scratch)

**1. Input/Output Embeddings**:
- Convert tokens to vectors
- Why? Neural networks need numeric inputs

**2. Positional Encoding**:
- Inject position information
- Why? Attention is permutation-invariant

**3. Encoder Layer**:
- **Self-attention**: Understand relationships within input
- **Feed-forward**: Process each position independently
- Why? Build rich representations of input

**4. Decoder Layer**:
- **Masked self-attention**: Generate output autoregressively
- **Cross-attention**: Attend to encoder output
- **Feed-forward**: Process each position
- Why? Generate output while looking at input

**5. Residual Connections & Layer Norm**:
- Skip connections around each sub-layer
- Why? Enable deep networks, stabilize training

### Three Architectural Variants

**Encoder-only (BERT)**:
- Use only encoder stack
- Bidirectional context
- Best for: Classification, NER, sentence embeddings

**Decoder-only (GPT)**:
- Use only decoder stack (without cross-attention)
- Causal (left-to-right) generation
- Best for: Text generation, language modeling

**Encoder-decoder (T5, BART)**:
- Full transformer with both stacks
- Best for: Translation, summarization, seq2seq tasks

## Building Block 1: Multi-Head Attention (Recap)

We've already covered attention in detail. Here's a compact implementation for reference.

In [None]:
class MultiHeadAttention(nn.Module):
    """
    Multi-head attention mechanism.
    
    Key insight: Multiple attention patterns in parallel.
    """
    def __init__(self, d_model, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def split_heads(self, x):
        batch_size, seq_len, d_model = x.size()
        return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
    
    def combine_heads(self, x):
        batch_size, num_heads, seq_len, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
    
    def forward(self, query, key, value, mask=None):
        """
        Args:
            query: (batch, seq_len_q, d_model)
            key: (batch, seq_len_k, d_model)
            value: (batch, seq_len_v, d_model)
            mask: Optional attention mask
        
        Returns:
            output: (batch, seq_len_q, d_model)
            attention_weights: (batch, num_heads, seq_len_q, seq_len_k)
        """
        Q = self.split_heads(self.W_Q(query))
        K = self.split_heads(self.W_K(key))
        V = self.split_heads(self.W_V(value))
        
        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        attn_output = torch.matmul(attention_weights, V)
        attn_output = self.combine_heads(attn_output)
        
        output = self.W_O(attn_output)
        
        return output, attention_weights

print("Multi-Head Attention: ✓")

## Building Block 2: Position-wise Feed-Forward Network

### What is it?

A simple 2-layer fully-connected network applied to **each position independently**:

$$FFN(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2$$

Or equivalently: Two `Conv1d` with kernel_size=1

### Why do we need it?

**Attention aggregates information, but doesn't transform it deeply!**

**Problem**: Attention is a weighted sum (linear operation)
- It mixes information from different positions
- But doesn't apply non-linear transformations

**Solution**: Feed-forward network adds:
1. **Non-linearity**: ReLU/GELU introduces non-linear transformations
2. **Depth**: Two layers with expansion (typically 4x) add capacity
3. **Position-wise processing**: Each position transformed independently

**Why independent processing?**
- Attention mixed information across positions
- FFN processes each position's representation
- Alternating attention + FFN = mixing + processing

**Standard configuration**:
- Hidden dimension = 4 × d_model (expansion factor)
- Activation: ReLU (original) or GELU (modern)
- Why 4x? Empirically found to work well (more capacity without too many params)

In [None]:
class PositionWiseFeedForward(nn.Module):
    """
    Position-wise feed-forward network.
    
    Applied to each position independently and identically.
    
    FFN(x) = max(0, xW1 + b1)W2 + b2
    """
    def __init__(self, d_model, d_ff, dropout=0.1):
        """
        Args:
            d_model: Model dimension (e.g., 512)
            d_ff: Hidden dimension (typically 4 × d_model = 2048)
            dropout: Dropout probability
        """
        super(PositionWiseFeedForward, self).__init__()
        
        # Two linear layers with expansion
        # Why 4x expansion? Adds capacity while keeping model manageable
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
        # Modern transformers often use GELU instead of ReLU
        # GELU is smoother and works better in practice
        self.activation = nn.GELU()
    
    def forward(self, x):
        """
        Args:
            x: (batch, seq_len, d_model)
        
        Returns:
            output: (batch, seq_len, d_model)
        """
        # x: (batch, seq_len, d_model)
        # -> (batch, seq_len, d_ff) [expand]
        # -> (batch, seq_len, d_model) [project back]
        
        x = self.linear1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.linear2(x)
        
        return x

# Test FFN
d_model = 512
d_ff = 2048  # 4x expansion

ffn = PositionWiseFeedForward(d_model, d_ff)
x = torch.randn(2, 10, d_model)
output = ffn(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nFFN parameters: {sum(p.numel() for p in ffn.parameters()):,}")
print(f"  W1: {d_model} × {d_ff} = {d_model * d_ff:,}")
print(f"  W2: {d_ff} × {d_model} = {d_ff * d_model:,}")
print(f"\nNote: FFN has most of the transformer's parameters!")

## Building Block 3: Residual Connections & Layer Normalization

### Why Residual Connections?

**Problem**: Deep networks suffer from degradation
- Gradients vanish or explode
- Deeper doesn't always mean better

**Solution**: Skip connections (ResNet-style)
```python
output = x + SubLayer(x)
```

**Why this works**:
1. **Gradient flow**: Gradients can flow directly through skip connection
2. **Identity mapping**: Network can learn to do nothing if needed
3. **Easier optimization**: Each layer only learns residual (delta)

### Why Layer Normalization?

**Batch Norm vs Layer Norm**:
- **Batch Norm**: Normalize across batch dimension (doesn't work well for sequences)
- **Layer Norm**: Normalize across feature dimension (works for variable-length sequences)

**Layer Norm formula**:
$$\text{LayerNorm}(x) = \gamma \frac{x - \mu}{\sigma} + \beta$$

Where $\mu$ and $\sigma$ are computed **per sample, per layer**

**Why it's crucial**:
1. **Stable training**: Prevents activation explosion/vanishing
2. **Faster convergence**: Normalizes gradient flow
3. **Higher learning rates**: Can train with larger learning rates

### Pre-LN vs Post-LN

**Post-LN (Original Transformer)**:
```python
output = LayerNorm(x + SubLayer(x))
```

**Pre-LN (Modern, more stable)**:
```python
output = x + SubLayer(LayerNorm(x))
```

**Why Pre-LN is better**:
- More stable training for very deep networks
- Gradients flow better
- Used in GPT-2, GPT-3, and most modern models

In [None]:
class SublayerConnection(nn.Module):
    """
    Residual connection followed by layer norm.
    
    Implements: LayerNorm(x + Sublayer(x))
    """
    def __init__(self, d_model, dropout=0.1):
        super(SublayerConnection, self).__init__()
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, sublayer):
        """
        Args:
            x: Input (batch, seq_len, d_model)
            sublayer: Function (attention or FFN)
        
        Returns:
            output: (batch, seq_len, d_model)
        """
        # Apply sublayer, dropout, add residual, then normalize
        return self.norm(x + self.dropout(sublayer(x)))

# Demonstrate residual connection benefit
x = torch.randn(2, 10, 512)
sublayer = lambda x: torch.zeros_like(x)  # Sublayer that outputs zeros

sublayer_conn = SublayerConnection(d_model=512)
output = sublayer_conn(x, sublayer)

print(f"Input mean: {x.mean().item():.4f}")
print(f"Sublayer output mean: {sublayer(x).mean().item():.4f}")
print(f"With residual connection mean: {output.mean().item():.4f}")
print("\nEven if sublayer outputs zeros, information flows through residual!")
print("This is why deep transformers train successfully.")

## The Encoder Layer: Putting It Together

### Structure

Each encoder layer has two sub-layers:
1. **Multi-head self-attention**: Understand relationships in input
2. **Position-wise feed-forward**: Process each position

Each sub-layer has:
- Residual connection
- Layer normalization
- Dropout

```
EncoderLayer:
  x -> [Self-Attention] -> Add & Norm -> 
       [Feed-Forward]   -> Add & Norm -> output
```

### Why This Design?

**Two-stage processing**:
1. **Attention**: Mix information across positions (global context)
2. **FFN**: Transform each position independently (local processing)

**Alternating global + local** = powerful hierarchical representations

**Why multiple layers?**
- Layer 1: Low-level patterns (syntax, basic relationships)
- Layer 2-6: Mid-level patterns (phrases, dependencies)
- Layer 7-12: High-level patterns (semantics, discourse)

Standard: 6-12 layers (BERT-base: 12, GPT-3: 96!)

In [None]:
class EncoderLayer(nn.Module):
    """
    Single encoder layer.
    
    Architecture:
      1. Multi-head self-attention
      2. Add & Norm
      3. Position-wise FFN
      4. Add & Norm
    """
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        
        # Sub-layer 1: Self-attention
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        
        # Sub-layer 2: Feed-forward
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)
        
        # Sublayer connections (residual + norm)
        self.sublayer1 = SublayerConnection(d_model, dropout)
        self.sublayer2 = SublayerConnection(d_model, dropout)
    
    def forward(self, x, mask=None):
        """
        Args:
            x: Input (batch, seq_len, d_model)
            mask: Optional attention mask
        
        Returns:
            output: (batch, seq_len, d_model)
        """
        # Self-attention sub-layer
        # Why self-attention? Input attends to itself
        x = self.sublayer1(x, lambda x: self.self_attn(x, x, x, mask)[0])
        
        # Feed-forward sub-layer
        # Why FFN? Add depth and non-linearity
        x = self.sublayer2(x, self.feed_forward)
        
        return x

# Test encoder layer
d_model = 512
num_heads = 8
d_ff = 2048

encoder_layer = EncoderLayer(d_model, num_heads, d_ff)
x = torch.randn(2, 10, d_model)
output = encoder_layer(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nEncoder layer parameters: {sum(p.numel() for p in encoder_layer.parameters()):,}")
print("\nThis single layer can be stacked 6-12 times to form the encoder!")

## The Decoder Layer: More Complex

### Structure

Each decoder layer has **three** sub-layers:
1. **Masked self-attention**: Generate output autoregressively
2. **Cross-attention**: Attend to encoder output
3. **Position-wise feed-forward**: Process each position

```
DecoderLayer:
  x -> [Masked Self-Attention] -> Add & Norm -> 
       [Cross-Attention to Encoder] -> Add & Norm -> 
       [Feed-Forward] -> Add & Norm -> output
```

### Why Three Sub-layers?

**1. Masked Self-Attention**:
- **Purpose**: Understand relationships in output generated so far
- **Why masked?**: Can't look at future tokens (we're generating them!)
- **Causal masking**: Position i can only attend to positions ≤ i

**2. Cross-Attention**:
- **Purpose**: Look at encoder output (input sequence)
- **Query**: From decoder (what do I need?)
- **Key, Value**: From encoder (what information is available?)
- **Why crucial**: This is how decoder "reads" the input!

**3. Feed-Forward**:
- Same as encoder
- Process each position independently

### Information Flow

```
Decoder position i:
  1. Look at previous decoder outputs (positions 0...i)
  2. Look at ALL encoder outputs (entire input)
  3. Process combined information
  -> Predict next token
```

This is the magic of sequence-to-sequence learning!

In [None]:
class DecoderLayer(nn.Module):
    """
    Single decoder layer.
    
    Architecture:
      1. Masked multi-head self-attention
      2. Add & Norm
      3. Multi-head cross-attention to encoder
      4. Add & Norm
      5. Position-wise FFN
      6. Add & Norm
    """
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()
        
        # Sub-layer 1: Masked self-attention
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        
        # Sub-layer 2: Cross-attention to encoder
        self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
        
        # Sub-layer 3: Feed-forward
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)
        
        # Sublayer connections
        self.sublayer1 = SublayerConnection(d_model, dropout)
        self.sublayer2 = SublayerConnection(d_model, dropout)
        self.sublayer3 = SublayerConnection(d_model, dropout)
    
    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        """
        Args:
            x: Decoder input (batch, tgt_len, d_model)
            encoder_output: Encoder output (batch, src_len, d_model)
            src_mask: Source attention mask
            tgt_mask: Target (causal) mask
        
        Returns:
            output: (batch, tgt_len, d_model)
        """
        # Sub-layer 1: Masked self-attention
        # Why masked? Can't attend to future tokens during generation
        x = self.sublayer1(x, lambda x: self.self_attn(x, x, x, tgt_mask)[0])
        
        # Sub-layer 2: Cross-attention to encoder
        # Query from decoder, Key/Value from encoder
        # This is where decoder "reads" the input!
        x = self.sublayer2(x, lambda x: self.cross_attn(x, encoder_output, encoder_output, src_mask)[0])
        
        # Sub-layer 3: Feed-forward
        x = self.sublayer3(x, self.feed_forward)
        
        return x

# Test decoder layer
decoder_layer = DecoderLayer(d_model=512, num_heads=8, d_ff=2048)

# Decoder input (target sequence so far)
tgt = torch.randn(2, 8, 512)

# Encoder output (source sequence)
encoder_out = torch.randn(2, 10, 512)

# Create causal mask for target
tgt_len = tgt.size(1)
tgt_mask = torch.tril(torch.ones(1, 1, tgt_len, tgt_len))

output = decoder_layer(tgt, encoder_out, tgt_mask=tgt_mask)

print(f"Decoder input shape: {tgt.shape}")
print(f"Encoder output shape: {encoder_out.shape}")
print(f"Decoder output shape: {output.shape}")
print(f"\nDecoder layer parameters: {sum(p.numel() for p in decoder_layer.parameters()):,}")
print("\nNote: Decoder is more complex than encoder (3 sub-layers vs 2)")

## Complete Encoder-Decoder Transformer

Now let's put everything together into the full transformer architecture!

In [None]:
class PositionalEncoding(nn.Module):
    """Sinusoidal positional encoding."""
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


class Transformer(nn.Module):
    """
    Complete Transformer model (Encoder-Decoder).
    
    This is the full architecture from 'Attention is All You Need'.
    """
    def __init__(self, 
                 src_vocab_size,
                 tgt_vocab_size,
                 d_model=512,
                 num_heads=8,
                 num_encoder_layers=6,
                 num_decoder_layers=6,
                 d_ff=2048,
                 dropout=0.1,
                 max_len=5000):
        """
        Args:
            src_vocab_size: Source vocabulary size
            tgt_vocab_size: Target vocabulary size
            d_model: Model dimension (512 in original paper)
            num_heads: Number of attention heads (8 in original)
            num_encoder_layers: Encoder stack depth (6 in original)
            num_decoder_layers: Decoder stack depth (6 in original)
            d_ff: Feed-forward hidden dimension (2048 in original)
            dropout: Dropout probability
            max_len: Maximum sequence length
        """
        super(Transformer, self).__init__()
        
        self.d_model = d_model
        
        # Embeddings
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        
        # Positional encoding
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        
        # Encoder stack
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_encoder_layers)
        ])
        
        # Decoder stack
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_decoder_layers)
        ])
        
        # Output projection
        self.output_projection = nn.Linear(d_model, tgt_vocab_size)
        
        self.dropout = nn.Dropout(dropout)
        
        # Initialize parameters
        self._init_parameters()
    
    def _init_parameters(self):
        """Initialize parameters with Xavier uniform."""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def encode(self, src, src_mask=None):
        """
        Encode source sequence.
        
        Args:
            src: Source token indices (batch, src_len)
            src_mask: Source attention mask
        
        Returns:
            encoder_output: (batch, src_len, d_model)
        """
        # Embed and add positional encoding
        # Why sqrt(d_model)? Scales embedding to same magnitude as PE
        x = self.src_embedding(src) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        x = self.dropout(x)
        
        # Pass through encoder layers
        for layer in self.encoder_layers:
            x = layer(x, src_mask)
        
        return x
    
    def decode(self, tgt, encoder_output, src_mask=None, tgt_mask=None):
        """
        Decode target sequence.
        
        Args:
            tgt: Target token indices (batch, tgt_len)
            encoder_output: Encoder output (batch, src_len, d_model)
            src_mask: Source attention mask
            tgt_mask: Target (causal) mask
        
        Returns:
            decoder_output: (batch, tgt_len, d_model)
        """
        # Embed and add positional encoding
        x = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        x = self.dropout(x)
        
        # Pass through decoder layers
        for layer in self.decoder_layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        
        return x
    
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        """
        Full forward pass.
        
        Args:
            src: Source sequence (batch, src_len)
            tgt: Target sequence (batch, tgt_len)
            src_mask: Source attention mask
            tgt_mask: Target (causal) mask
        
        Returns:
            logits: (batch, tgt_len, tgt_vocab_size)
        """
        # Encode
        encoder_output = self.encode(src, src_mask)
        
        # Decode
        decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
        
        # Project to vocabulary
        logits = self.output_projection(decoder_output)
        
        return logits

print("Complete Transformer: ✓")

### Test the Complete Transformer

In [None]:
# Create transformer
src_vocab_size = 10000
tgt_vocab_size = 10000

model = Transformer(
    src_vocab_size=src_vocab_size,
    tgt_vocab_size=tgt_vocab_size,
    d_model=512,
    num_heads=8,
    num_encoder_layers=6,
    num_decoder_layers=6,
    d_ff=2048,
    dropout=0.1
)

# Create dummy input
batch_size = 2
src_len = 10
tgt_len = 8

src = torch.randint(0, src_vocab_size, (batch_size, src_len))
tgt = torch.randint(0, tgt_vocab_size, (batch_size, tgt_len))

# Create causal mask for target
tgt_mask = torch.tril(torch.ones(1, 1, tgt_len, tgt_len))

# Forward pass
logits = model(src, tgt, tgt_mask=tgt_mask)

print("="*60)
print("COMPLETE TRANSFORMER TEST")
print("="*60)
print(f"\nSource sequence shape: {src.shape}")
print(f"Target sequence shape: {tgt.shape}")
print(f"Output logits shape: {logits.shape}")

print(f"\n" + "="*60)
print("MODEL STATISTICS")
print("="*60)

total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")

# Break down by component
embed_params = sum(p.numel() for p in model.src_embedding.parameters()) + \
               sum(p.numel() for p in model.tgt_embedding.parameters())
encoder_params = sum(p.numel() for p in model.encoder_layers.parameters())
decoder_params = sum(p.numel() for p in model.decoder_layers.parameters())
output_params = sum(p.numel() for p in model.output_projection.parameters())

print(f"\nParameter breakdown:")
print(f"  Embeddings: {embed_params:,} ({100*embed_params/total_params:.1f}%)")
print(f"  Encoder: {encoder_params:,} ({100*encoder_params/total_params:.1f}%)")
print(f"  Decoder: {decoder_params:,} ({100*decoder_params/total_params:.1f}%)")
print(f"  Output projection: {output_params:,} ({100*output_params/total_params:.1f}%)")

print(f"\n" + "="*60)
print("ARCHITECTURE SUMMARY")
print("="*60)
print(f"Model dimension (d_model): 512")
print(f"Number of attention heads: 8")
print(f"Encoder layers: 6")
print(f"Decoder layers: 6")
print(f"Feed-forward dimension: 2048")
print(f"\nThis is the exact architecture from 'Attention is All You Need'!")

## Architectural Variants

### 1. Encoder-Only (BERT)

**Use only the encoder stack**

```python
# BERT architecture
x = embeddings + positional_encoding
for layer in encoder_layers:
    x = layer(x)  # Self-attention + FFN
output = x  # Use for classification, NER, etc.
```

**Why encoder-only?**
- **Bidirectional context**: Can attend to both past and future
- **Better representations**: Sees full context for each token
- **Best for**: Classification, NER, question answering

**Examples**: BERT, RoBERTa, DeBERTa

### 2. Decoder-Only (GPT)

**Use only the decoder stack (remove cross-attention)**

```python
# GPT architecture
x = embeddings + positional_encoding
for layer in decoder_layers:
    x = masked_self_attention(x)  # Causal masking
    x = ffn(x)
logits = output_projection(x)
```

**Why decoder-only?**
- **Autoregressive generation**: Natural for text generation
- **Simpler**: No encoder, no cross-attention
- **Scalable**: Easier to scale to 100B+ parameters
- **Best for**: Text generation, language modeling

**Examples**: GPT-2, GPT-3, GPT-4, LLaMA, Mistral

**Why GPT-style dominates in 2023-2025**:
- Pre-train on next-token prediction (simple, scalable)
- Can be fine-tuned for any task
- Emergent abilities at scale

### 3. Encoder-Decoder (T5)

**Full transformer with both stacks**

**Why encoder-decoder?**
- **Best for seq2seq**: Translation, summarization
- **Flexible**: Can handle different input/output lengths
- **Best for**: Tasks with clear input → output structure

**Examples**: T5, BART, mT5

In [None]:
# Quick comparison
print("TRANSFORMER ARCHITECTURAL VARIANTS\n")
print("=" * 70)

variants = [
    {
        'name': 'Encoder-Only (BERT)',
        'structure': 'Embeddings → Encoder × N → Task head',
        'attention': 'Bidirectional self-attention',
        'best_for': 'Classification, NER, embeddings',
        'examples': 'BERT, RoBERTa, DeBERTa'
    },
    {
        'name': 'Decoder-Only (GPT)',
        'structure': 'Embeddings → Decoder × N → LM head',
        'attention': 'Causal (masked) self-attention',
        'best_for': 'Text generation, language modeling',
        'examples': 'GPT-2/3/4, LLaMA, Mistral'
    },
    {
        'name': 'Encoder-Decoder (T5)',
        'structure': 'Encoder × N → Decoder × N',
        'attention': 'Bi + Cross + Causal attention',
        'best_for': 'Translation, summarization, seq2seq',
        'examples': 'T5, BART, mT5'
    }
]

for v in variants:
    print(f"\n{v['name']}")
    print("-" * 70)
    print(f"  Structure:  {v['structure']}")
    print(f"  Attention:  {v['attention']}")
    print(f"  Best for:   {v['best_for']}")
    print(f"  Examples:   {v['examples']}")

print("\n" + "=" * 70)
print("\nTrend in 2023-2025: Decoder-only models dominate!")
print("Why? Scalability + emergent abilities + simpler architecture")

## Mini Exercises

### Exercise 1: Count Attention Operations

For a sequence of length N with d_model dimension:
- How many attention operations in encoder-only with 12 layers?
- How many in decoder-only?
- How many in encoder-decoder (6 + 6 layers)?

In [None]:
# YOUR CODE HERE


# SOLUTION
def show_solution_1():
    N = 512  # Sequence length
    
    print("Attention operation complexity: O(N²·d)\n")
    
    # Encoder-only (BERT)
    bert_layers = 12
    bert_attn_per_layer = 1  # Self-attention only
    bert_total = bert_layers * bert_attn_per_layer
    
    # Decoder-only (GPT)
    gpt_layers = 12
    gpt_attn_per_layer = 1  # Masked self-attention only
    gpt_total = gpt_layers * gpt_attn_per_layer
    
    # Encoder-decoder (T5)
    enc_layers = 6
    dec_layers = 6
    t5_total = (enc_layers * 1) + (dec_layers * 2)  # Decoder has self + cross
    
    print(f"For sequence length N={N}:\n")
    print(f"Encoder-only (BERT, 12 layers):")
    print(f"  {bert_total} attention operations")
    print(f"  All are bidirectional self-attention\n")
    
    print(f"Decoder-only (GPT, 12 layers):")
    print(f"  {gpt_total} attention operations")
    print(f"  All are causal self-attention\n")
    
    print(f"Encoder-decoder (T5, 6+6 layers):")
    print(f"  Encoder: {enc_layers} self-attention")
    print(f"  Decoder: {dec_layers} self-attention + {dec_layers} cross-attention")
    print(f"  Total: {t5_total} attention operations\n")
    
    print("Why this matters:")
    print("- Attention is O(N²) - the bottleneck for long sequences")
    print("- Modern optimizations: Flash Attention, sparse attention")
    print("- This is why context length matters for cost!")

# Uncomment to see solution:
# show_solution_1()

### Exercise 2: Implement a Simple GPT-style Decoder

Build a decoder-only model (GPT-style) by removing the cross-attention from decoder layers.

In [None]:
# YOUR CODE HERE


# SOLUTION
def show_solution_2():
    class GPTBlock(nn.Module):
        """GPT-style decoder block (no cross-attention)."""
        def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
            super(GPTBlock, self).__init__()
            
            self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
            self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)
            
            self.sublayer1 = SublayerConnection(d_model, dropout)
            self.sublayer2 = SublayerConnection(d_model, dropout)
        
        def forward(self, x, mask=None):
            # Causal self-attention
            x = self.sublayer1(x, lambda x: self.self_attn(x, x, x, mask)[0])
            # Feed-forward
            x = self.sublayer2(x, self.feed_forward)
            return x
    
    class GPT(nn.Module):
        """Simple GPT-style language model."""
        def __init__(self, vocab_size, d_model=768, num_heads=12, 
                     num_layers=12, d_ff=3072, max_len=1024, dropout=0.1):
            super(GPT, self).__init__()
            
            self.embedding = nn.Embedding(vocab_size, d_model)
            self.pos_encoding = PositionalEncoding(d_model, max_len)
            
            self.blocks = nn.ModuleList([
                GPTBlock(d_model, num_heads, d_ff, dropout)
                for _ in range(num_layers)
            ])
            
            self.ln_f = nn.LayerNorm(d_model)  # Final layer norm
            self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
            
            # Tie weights (share embeddings and output projection)
            self.lm_head.weight = self.embedding.weight
        
        def forward(self, x):
            seq_len = x.size(1)
            
            # Create causal mask
            mask = torch.tril(torch.ones(1, 1, seq_len, seq_len, device=x.device))
            
            # Embed and add position
            x = self.embedding(x)
            x = self.pos_encoding(x)
            
            # Pass through blocks
            for block in self.blocks:
                x = block(x, mask)
            
            # Final norm and project to vocab
            x = self.ln_f(x)
            logits = self.lm_head(x)
            
            return logits
    
    # Test GPT
    vocab_size = 50257  # GPT-2 vocab size
    gpt = GPT(vocab_size, d_model=768, num_heads=12, num_layers=12)
    
    x = torch.randint(0, vocab_size, (2, 64))  # Batch of 2, sequence length 64
    logits = gpt(x)
    
    print("GPT-style Decoder:")
    print(f"Input: {x.shape}")
    print(f"Output logits: {logits.shape}")
    print(f"\nTotal parameters: {sum(p.numel() for p in gpt.parameters()):,}")
    print("\nKey differences from full transformer:")
    print("✓ No encoder (decoder-only)")
    print("✓ No cross-attention (only self-attention)")
    print("✓ Causal masking (can't see future)")
    print("✓ Tied embeddings (same weights for input/output)")

# Uncomment to see solution:
# show_solution_2()

### Exercise 3: Visualize Attention Patterns

Extract and visualize attention patterns from different layers to see what the model learns.

In [None]:
# This would require running a trained model and extracting attention weights
# Left as an advanced exercise for you to explore with pre-trained models!

print("Attention visualization is a powerful debugging tool!")
print("\nWith a trained transformer, you can:")
print("1. See which tokens attend to each other")
print("2. Understand what each head learns")
print("3. Debug model behavior")
print("4. Gain insights into language structure")
print("\nTools: BertViz, exBERT, Transformer Interpret")

## Key Takeaways

### Core Concepts

**1. Why transformers revolutionized AI**:
- **Parallelization**: Process all tokens simultaneously (vs sequential RNNs)
- **Long-range dependencies**: Direct connections via attention
- **Scalability**: Can train models with 100B+ parameters
- **Transfer learning**: Pre-train once, fine-tune for many tasks

**2. Complete architecture** (encoder-decoder):
- **Embeddings + Positional Encoding**: Convert tokens to positioned vectors
- **Encoder**: N layers of (self-attention + FFN)
- **Decoder**: N layers of (masked self-attention + cross-attention + FFN)
- **Residual connections + Layer norm**: Enable deep stacks

**3. Three architectural variants**:
- **Encoder-only (BERT)**: Bidirectional, best for understanding tasks
- **Decoder-only (GPT)**: Causal, best for generation (dominates 2023-2025)
- **Encoder-decoder (T5)**: Full transformer, best for seq2seq

**4. Key design principles**:
- **Attention**: Mix information across positions
- **FFN**: Process each position independently
- **Alternating**: Global mixing + local processing = hierarchical features
- **Residuals**: Enable gradient flow in deep networks
- **Layer norm**: Stabilize training

**5. Why each component matters**:
- **Multi-head attention**: Multiple perspectives on relationships
- **Feed-forward**: Non-linear transformation and capacity
- **Positional encoding**: Inject sequence order
- **Residual connections**: Cannot train deep networks without them
- **Layer normalization**: Cannot train stably without it

### Modern Landscape (2023-2025)

**Decoder-only models dominate**:
- GPT-4, Claude, LLaMA, Mistral all use decoder-only
- Why? Simpler, scales better, emergent abilities

**Architectural innovations**:
- **RoPE**: Better positional encoding
- **Flash Attention**: Faster attention computation
- **GQA**: Fewer KV heads for efficiency
- **MoE**: Conditional computation for scaling

**Context length explosion**:
- 2017: 512 tokens (original Transformer)
- 2023: 100K+ tokens (GPT-4, Claude)
- 2025: 1M+ tokens emerging

### What's Next?

You've mastered the transformer! Next topics:
- **Advanced attention**: Flash Attention, GQA, MoE
- **Training at scale**: Distributed training, optimization
- **Building LLMs**: Complete implementations

Understanding transformers deeply is your foundation for all of modern AI!

## Further Reading

### Essential Papers
1. **Vaswani et al. (2017)**: "Attention is All You Need" (Original transformer)
2. **Devlin et al. (2018)**: "BERT" (Encoder-only)
3. **Radford et al. (2018/2019)**: "GPT/GPT-2" (Decoder-only)
4. **Brown et al. (2020)**: "GPT-3" (Scaling decoder-only)
5. **Raffel et al. (2019)**: "T5" (Encoder-decoder, text-to-text)

### Tutorials and Visualizations
6. **The Illustrated Transformer** (Jay Alammar): http://jalammar.github.io/illustrated-transformer/
7. **The Annotated Transformer** (Harvard NLP): http://nlp.seas.harvard.edu/annotated-transformer/
8. **Attention? Attention!** (Lilian Weng): https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/

### Implementation Resources
- **PyTorch Transformer Tutorial**: Official docs
- **Hugging Face Transformers**: transformers library
- **MinGPT** (Andrej Karpathy): Clean GPT implementation
- **nanoGPT** (Andrej Karpathy): Minimal GPT for learning

### Advanced Topics
- **Efficient Transformers**: Survey of optimization techniques
- **Scaling Laws**: How performance scales with size
- **Emergent Abilities**: Capabilities that arise at scale