# Lab 2.3.2: Building a Transformer Block

Now that you understand attention, let's build a complete Transformer encoder block!
This is the building block of BERT, GPT, and modern LLMs.

## Learning Objectives

By the end of this lab, you will:
1. Build a complete Transformer encoder block with all components
2. Understand layer normalization and its placement (Pre-LN vs Post-LN)
3. Implement the feed-forward network (FFN) sublayer
4. Connect everything with residual connections

## Prerequisites

- Completed: Lab 2.3.1 (Attention Mechanisms)
- Understanding: Attention, residual connections, layer normalization

**Time:** 2 hours
**Difficulty:** ⭐⭐⭐ (Intermediate)

---

---

## ELI5: The Transformer Block

> **Imagine a very organized study group:**
>
> **Step 1: Discussion Round (Attention)**
> Everyone shares their ideas. Each person decides which ideas to pay attention to and updates their understanding.
>
> **Step 2: Personal Reflection (Feed-Forward)**
> Each person privately thinks about what they learned and processes it through their own understanding.
>
> **Step 3: Note-Taking (Residual)**
> Instead of replacing all old knowledge, you add the new insights to what you already knew.
>
> **Step 4: Summary (Layer Norm)**
> Everyone normalizes their notes to keep things balanced and prevent some topics from dominating.
>
> **Repeat 6-24 times** with increasingly refined understanding!

### The Architecture

```
Input
  │
  ├──────────────┐
  │              │
  ▼              │
Multi-Head      │ (Residual Connection)
Attention       │
  │              │
  ├──────────────┘
  │
  ▼
Layer Norm
  │
  ├──────────────┐
  │              │
  ▼              │ (Residual Connection)
Feed-Forward    │
Network         │
  │              │
  ├──────────────┘
  │
  ▼
Layer Norm
  │
  ▼
Output
```

---

## Part 1: Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import time

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# DGX Spark detection
USE_BFLOAT16 = False
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"Memory: {gpu_mem:.1f} GB")
    
    # Check for DGX Spark (Blackwell GB10) or other bfloat16-capable GPUs
    if gpu_mem > 100 or "GB10" in torch.cuda.get_device_name(0):
        USE_BFLOAT16 = True
        print("\n✨ DGX Spark detected! Will use bfloat16 for optimal performance.")
        print("   Note: bfloat16 will be applied explicitly to models, not globally.")
    elif torch.cuda.is_bf16_supported():
        USE_BFLOAT16 = True
        print("\n✅ bfloat16 supported on this GPU")

# Note: We do NOT set torch.set_default_dtype() globally as it can cause issues
# with pre-trained models. Instead, cast models explicitly when needed:
#   model = model.to(dtype=torch.bfloat16)

torch.manual_seed(42)

---

## Part 2: Multi-Head Attention (Recap)

Let's bring in our attention implementation from the previous notebook.

In [None]:
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention mechanism.
    
    This is the same implementation from Notebook 01,
    cleaned up and production-ready.
    """
    
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        """
        Args:
            d_model: Model dimension
            num_heads: Number of attention heads
            dropout: Dropout probability
        """
        super().__init__()
        
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        seq_len_q = query.size(1)
        seq_len_k = key.size(1)
        
        # Project and reshape for multi-head
        Q = self.W_q(query).view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2)
        
        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attention = F.softmax(scores, dim=-1)
        attention = self.dropout(attention)
        
        # Apply attention to values
        context = torch.matmul(attention, V)
        
        # Concatenate heads
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len_q, self.d_model)
        
        return self.W_o(context)

# Quick test
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(2, 10, 512)  # (batch, seq_len, d_model)
out = mha(x, x, x)
print(f"Multi-Head Attention output shape: {out.shape}")

---

## Part 3: Feed-Forward Network (FFN)

### ELI5: The FFN

> **After the group discussion (attention), each person needs private thinking time.**
>
> The Feed-Forward Network is like a two-step thought process:
> 1. **Expand**: "Let me consider MANY possible interpretations" (expand to 4x dimensions)
> 2. **Contract**: "Now let me focus on what's most important" (back to original size)
>
> This happens independently for each word/position!

In [None]:
class FeedForward(nn.Module):
    """
    Position-wise Feed-Forward Network.
    
    FFN(x) = max(0, xW1 + b1)W2 + b2
    
    Also known as MLP (Multi-Layer Perceptron) in some papers.
    """
    
    def __init__(self, d_model: int, d_ff: int = None, dropout: float = 0.1):
        """
        Args:
            d_model: Model dimension
            d_ff: Feed-forward dimension (typically 4 * d_model)
            dropout: Dropout probability
        """
        super().__init__()
        
        # Default: 4x expansion
        if d_ff is None:
            d_ff = 4 * d_model
        
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        """
        Args:
            x: Input tensor (batch, seq_len, d_model)
            
        Returns:
            Output tensor (batch, seq_len, d_model)
        """
        # Expand → ReLU → Contract
        x = self.linear1(x)      # (batch, seq, d_ff)
        x = F.relu(x)            # Non-linearity
        x = self.dropout(x)
        x = self.linear2(x)      # (batch, seq, d_model)
        return x

# Visualize the expansion
d_model = 512
d_ff = 2048
ffn = FeedForward(d_model, d_ff)

print(f"FFN Architecture:")
print(f"  Input:  {d_model} dimensions")
print(f"  Hidden: {d_ff} dimensions (4x expansion)")
print(f"  Output: {d_model} dimensions")
print(f"  Parameters: {sum(p.numel() for p in ffn.parameters()):,}")

### Modern FFN Variants

The original Transformer uses ReLU, but modern models often use:
- **GELU** (used in BERT, GPT-2): Smoother than ReLU
- **SwiGLU** (used in LLaMA, PaLM): Gated activation
- **GeGLU** (used in some models): GELU-gated

In [None]:
class FeedForwardGELU(nn.Module):
    """FFN with GELU activation (modern standard)."""
    
    def __init__(self, d_model: int, d_ff: int = None, dropout: float = 0.1):
        super().__init__()
        
        if d_ff is None:
            d_ff = 4 * d_model
        
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.linear1(x)
        x = F.gelu(x)  # GELU instead of ReLU
        x = self.dropout(x)
        x = self.linear2(x)
        return x


class FeedForwardSwiGLU(nn.Module):
    """
    FFN with SwiGLU activation (LLaMA, PaLM style).
    
    SwiGLU(x) = (xW1 * Swish(xV)) W2
    """
    
    def __init__(self, d_model: int, d_ff: int = None, dropout: float = 0.1):
        super().__init__()
        
        if d_ff is None:
            # For SwiGLU, we typically use 2/3 of the original d_ff
            # because we have an extra gate projection
            d_ff = int(4 * d_model * 2 / 3)
        
        self.w1 = nn.Linear(d_model, d_ff)
        self.w2 = nn.Linear(d_ff, d_model)
        self.w3 = nn.Linear(d_model, d_ff)  # Gate projection
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # SwiGLU: element-wise multiply with gated activation
        gate = F.silu(self.w1(x))  # Swish activation
        x = gate * self.w3(x)      # Gating
        x = self.dropout(x)
        x = self.w2(x)
        return x

# Compare activations
x = torch.linspace(-3, 3, 100)

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(x.numpy(), F.relu(x).numpy(), label='ReLU', linewidth=2)
plt.plot(x.numpy(), F.gelu(x).numpy(), label='GELU', linewidth=2)
plt.plot(x.numpy(), F.silu(x).numpy(), label='SiLU/Swish', linewidth=2)
plt.legend()
plt.title('Activation Functions')
plt.xlabel('Input')
plt.ylabel('Output')
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
# Gradient comparison
x_grad = x.clone().requires_grad_(True)
F.relu(x_grad).sum().backward()
relu_grad = x_grad.grad.clone()

x_grad = x.clone().requires_grad_(True)
F.gelu(x_grad).sum().backward()
gelu_grad = x_grad.grad.clone()

plt.plot(x.numpy(), relu_grad.numpy(), label='ReLU gradient', linewidth=2)
plt.plot(x.numpy(), gelu_grad.numpy(), label='GELU gradient', linewidth=2)
plt.legend()
plt.title('Gradients')
plt.xlabel('Input')
plt.ylabel('Gradient')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Note: GELU and SiLU have smooth gradients, which can help training stability!")

---

## Part 4: Residual Connections

### ELI5: Why Residual Connections?

> **Imagine you're learning to cook a new dish.**
>
> **Without residual connections:**
> Each lesson completely replaces what you knew before. If one lesson is bad, you forget everything!
>
> **With residual connections:**
> Each lesson ADDS to your existing knowledge. Bad lessons don't erase good ones.
> `new_knowledge = old_knowledge + lesson_learned`
>
> **Why this matters:**
> - Gradients flow easily (no vanishing gradients)
> - Easy for the model to learn identity (just output zeros)
> - Training is much more stable

### Mathematical View

```
Without residual: output = F(x)
With residual:    output = x + F(x)
```

The model learns F(x) = what to ADD, not what to REPLACE.

In [None]:
# Demonstrate the gradient flow difference

def gradient_flow_demo():
    """Show how residual connections help gradient flow."""
    
    num_layers = 50
    d_model = 64
    
    # Without residual connections
    class DeepNoResidual(nn.Module):
        def __init__(self):
            super().__init__()
            self.layers = nn.ModuleList([
                nn.Linear(d_model, d_model) for _ in range(num_layers)
            ])
            
        def forward(self, x):
            for layer in self.layers:
                x = torch.tanh(layer(x))  # Tanh saturates easily
            return x
    
    # With residual connections
    class DeepWithResidual(nn.Module):
        def __init__(self):
            super().__init__()
            self.layers = nn.ModuleList([
                nn.Linear(d_model, d_model) for _ in range(num_layers)
            ])
            
        def forward(self, x):
            for layer in self.layers:
                x = x + torch.tanh(layer(x))  # ADD instead of replace
            return x
    
    # Test gradient magnitudes
    no_res = DeepNoResidual()
    with_res = DeepWithResidual()
    
    x = torch.randn(1, d_model, requires_grad=True)
    
    # Forward + backward without residual
    out_no_res = no_res(x)
    out_no_res.sum().backward()
    grad_no_res = x.grad.norm().item()
    
    x = torch.randn(1, d_model, requires_grad=True)
    out_with_res = with_res(x)
    out_with_res.sum().backward()
    grad_with_res = x.grad.norm().item()
    
    print(f"Gradient magnitude through {num_layers} layers:")
    print(f"  Without residual: {grad_no_res:.2e}")
    print(f"  With residual:    {grad_with_res:.2e}")
    print(f"\n  Ratio: {grad_with_res / (grad_no_res + 1e-10):.1f}x stronger gradients with residual!")

gradient_flow_demo()

---

## Part 5: Layer Normalization

### ELI5: Layer Normalization

> **Imagine test scores in a class.**
> - One test is out of 100, another out of 10, another out of 1000
> - Hard to compare! "I got 80" means different things.
>
> **Layer normalization standardizes:**
> - Make all values have mean=0, variance=1
> - Now everything is on the same scale
> - Training becomes more stable

### Pre-LN vs Post-LN

The original Transformer used "Post-LN" (normalize after), but modern models prefer "Pre-LN" (normalize before).

In [None]:
# Post-LN (Original Transformer, BERT)
class PostLNBlock(nn.Module):
    """
    Post-Layer Normalization: normalize AFTER the residual addition.
    
    output = LayerNorm(x + Sublayer(x))
    """
    def __init__(self, d_model, sublayer):
        super().__init__()
        self.sublayer = sublayer
        self.norm = nn.LayerNorm(d_model)
        
    def forward(self, x):
        return self.norm(x + self.sublayer(x))


# Pre-LN (GPT-2, GPT-3, LLaMA, most modern models)
class PreLNBlock(nn.Module):
    """
    Pre-Layer Normalization: normalize BEFORE the sublayer.
    
    output = x + Sublayer(LayerNorm(x))
    """
    def __init__(self, d_model, sublayer):
        super().__init__()
        self.sublayer = sublayer
        self.norm = nn.LayerNorm(d_model)
        
    def forward(self, x):
        return x + self.sublayer(self.norm(x))


print("Post-LN (Original):  output = LayerNorm(x + Sublayer(x))")
print("Pre-LN (Modern):     output = x + Sublayer(LayerNorm(x))")
print("\nPre-LN advantages:")
print("  - More stable gradients at initialization")
print("  - Doesn't need learning rate warmup")
print("  - Easier to train very deep models")

---

## Part 6: Complete Transformer Encoder Block

Now let's put it all together!

In [None]:
class TransformerEncoderBlock(nn.Module):
    """
    A single Transformer Encoder block.
    
    Architecture (Pre-LN):
        x -> LayerNorm -> MultiHeadAttention -> + -> LayerNorm -> FFN -> + -> output
        |                                      |    |                    |
        +--------------------------------------+    +--------------------+
                  (residual)                            (residual)
    """
    
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        d_ff: int = None,
        dropout: float = 0.1,
        activation: str = "gelu",
        pre_norm: bool = True
    ):
        """
        Args:
            d_model: Model dimension
            num_heads: Number of attention heads
            d_ff: Feed-forward dimension (default: 4 * d_model)
            dropout: Dropout probability
            activation: Activation function ("relu", "gelu", "swiglu")
            pre_norm: Use Pre-LN (True) or Post-LN (False)
        """
        super().__init__()
        
        if d_ff is None:
            d_ff = 4 * d_model
        
        self.pre_norm = pre_norm
        
        # Multi-head attention
        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
        
        # Feed-forward network
        if activation == "swiglu":
            self.ffn = FeedForwardSwiGLU(d_model, d_ff, dropout)
        elif activation == "gelu":
            self.ffn = FeedForwardGELU(d_model, d_ff, dropout)
        else:
            self.ffn = FeedForward(d_model, d_ff, dropout)
        
        # Layer norms
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # Dropout for residual
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        """
        Args:
            x: Input tensor (batch, seq_len, d_model)
            mask: Optional attention mask
            
        Returns:
            Output tensor (batch, seq_len, d_model)
        """
        if self.pre_norm:
            # Pre-LN: Normalize before sublayer
            # Attention block
            normed = self.norm1(x)
            attn_out = self.attention(normed, normed, normed, mask)
            x = x + self.dropout(attn_out)
            
            # FFN block
            normed = self.norm2(x)
            ffn_out = self.ffn(normed)
            x = x + self.dropout(ffn_out)
        else:
            # Post-LN: Normalize after residual
            # Attention block
            attn_out = self.attention(x, x, x, mask)
            x = self.norm1(x + self.dropout(attn_out))
            
            # FFN block
            ffn_out = self.ffn(x)
            x = self.norm2(x + self.dropout(ffn_out))
        
        return x

# Create and test a single block
block = TransformerEncoderBlock(
    d_model=512,
    num_heads=8,
    d_ff=2048,
    dropout=0.1,
    activation="gelu",
    pre_norm=True
)

x = torch.randn(2, 10, 512)
out = block(x)

print(f"Transformer Encoder Block:")
print(f"  Input shape:  {x.shape}")
print(f"  Output shape: {out.shape}")
print(f"  Parameters:   {sum(p.numel() for p in block.parameters()):,}")

---

## Part 7: Stacking Blocks into a Full Encoder

Real Transformers stack multiple blocks:
- **BERT-base**: 12 blocks
- **BERT-large**: 24 blocks
- **GPT-3**: 96 blocks

In [None]:
class TransformerEncoder(nn.Module):
    """
    Full Transformer Encoder: stack of encoder blocks.
    
    This is what BERT uses for understanding text!
    """
    
    def __init__(
        self,
        num_layers: int,
        d_model: int,
        num_heads: int,
        d_ff: int = None,
        dropout: float = 0.1,
        activation: str = "gelu",
        pre_norm: bool = True
    ):
        """
        Args:
            num_layers: Number of encoder blocks to stack
            d_model: Model dimension
            num_heads: Number of attention heads
            d_ff: Feed-forward dimension
            dropout: Dropout probability
            activation: Activation function
            pre_norm: Use Pre-LN normalization
        """
        super().__init__()
        
        self.layers = nn.ModuleList([
            TransformerEncoderBlock(
                d_model=d_model,
                num_heads=num_heads,
                d_ff=d_ff,
                dropout=dropout,
                activation=activation,
                pre_norm=pre_norm
            )
            for _ in range(num_layers)
        ])
        
        # Final layer norm (for Pre-LN architecture)
        self.final_norm = nn.LayerNorm(d_model) if pre_norm else nn.Identity()
        
    def forward(self, x, mask=None):
        """
        Args:
            x: Input tensor (batch, seq_len, d_model)
            mask: Optional attention mask
            
        Returns:
            Output tensor (batch, seq_len, d_model)
        """
        for layer in self.layers:
            x = layer(x, mask)
        
        return self.final_norm(x)


# Create BERT-base style encoder
encoder = TransformerEncoder(
    num_layers=12,
    d_model=768,
    num_heads=12,
    d_ff=3072,
    dropout=0.1
)

print(f"BERT-base style Transformer Encoder:")
print(f"  Layers: 12")
print(f"  d_model: 768")
print(f"  Heads: 12")
print(f"  d_ff: 3072")
print(f"  Total parameters: {sum(p.numel() for p in encoder.parameters()):,}")

In [None]:
# Test with a realistic batch
batch_size = 8
seq_len = 128
d_model = 768

x = torch.randn(batch_size, seq_len, d_model)

# Move to GPU if available
if torch.cuda.is_available():
    encoder = encoder.cuda()
    x = x.cuda()

# Time the forward pass (ensure accurate timing with CUDA sync)
if torch.cuda.is_available():
    torch.cuda.synchronize()  # Wait for all pending operations
start = time.time()
with torch.no_grad():
    out = encoder(x)
    if torch.cuda.is_available():
        torch.cuda.synchronize()  # Wait for GPU computation to finish
elapsed = time.time() - start

print(f"\nForward pass:")
print(f"  Input:  {x.shape}")
print(f"  Output: {out.shape}")
print(f"  Time:   {elapsed*1000:.2f} ms")
print(f"  Throughput: {batch_size * seq_len / elapsed:.0f} tokens/sec")

---

## Part 8: Comparison with PyTorch Built-in

Let's verify our implementation matches PyTorch's!

In [None]:
# PyTorch's built-in TransformerEncoderLayer
pytorch_layer = nn.TransformerEncoderLayer(
    d_model=512,
    nhead=8,
    dim_feedforward=2048,
    dropout=0.0,  # No dropout for comparison
    activation="gelu",
    batch_first=True,
    norm_first=True  # Pre-LN
)

# Our implementation
our_layer = TransformerEncoderBlock(
    d_model=512,
    num_heads=8,
    d_ff=2048,
    dropout=0.0,
    activation="gelu",
    pre_norm=True
)

print("Parameter count comparison:")
print(f"  PyTorch: {sum(p.numel() for p in pytorch_layer.parameters()):,}")
print(f"  Ours:    {sum(p.numel() for p in our_layer.parameters()):,}")

In [None]:
# Speed comparison
x = torch.randn(32, 128, 512)

if torch.cuda.is_available():
    pytorch_layer = pytorch_layer.cuda()
    our_layer = our_layer.cuda()
    x = x.cuda()

# Warmup
for _ in range(10):
    _ = pytorch_layer(x)
    _ = our_layer(x)

if torch.cuda.is_available():
    torch.cuda.synchronize()

# Time PyTorch
start = time.time()
for _ in range(100):
    _ = pytorch_layer(x)
if torch.cuda.is_available():
    torch.cuda.synchronize()
pytorch_time = time.time() - start

# Time ours
start = time.time()
for _ in range(100):
    _ = our_layer(x)
if torch.cuda.is_available():
    torch.cuda.synchronize()
our_time = time.time() - start

print("\nSpeed comparison (100 iterations):")
print(f"  PyTorch: {pytorch_time*1000:.2f} ms")
print(f"  Ours:    {our_time*1000:.2f} ms")
print(f"  Ratio:   {our_time/pytorch_time:.2f}x")

---

## Part 9: Memory Analysis for DGX Spark

Let's see what we can fit in our 128GB unified memory!

In [None]:
def estimate_transformer_memory(
    batch_size: int,
    seq_len: int,
    d_model: int,
    num_layers: int,
    num_heads: int,
    d_ff: int,
    dtype_bytes: int = 2  # bfloat16
) -> dict:
    """
    Estimate memory usage for a Transformer encoder.
    
    Returns dictionary with memory breakdown.
    """
    # Parameters per layer
    # Attention: 4 * d_model^2 (W_q, W_k, W_v, W_o)
    attention_params = 4 * d_model * d_model
    # FFN: d_model * d_ff + d_ff * d_model
    ffn_params = 2 * d_model * d_ff
    # Layer norms: 2 * 2 * d_model (gamma, beta for 2 norms)
    norm_params = 4 * d_model
    
    total_params = num_layers * (attention_params + ffn_params + norm_params)
    param_memory = total_params * dtype_bytes
    
    # Activations (for backward pass)
    # Input to each layer
    activation_per_layer = batch_size * seq_len * d_model * dtype_bytes
    # Attention scores (batch, heads, seq, seq)
    attention_scores = batch_size * num_heads * seq_len * seq_len * dtype_bytes
    # FFN hidden
    ffn_hidden = batch_size * seq_len * d_ff * dtype_bytes
    
    activations = num_layers * (activation_per_layer + attention_scores + ffn_hidden)
    
    # Optimizer states (for Adam: 2x for moment estimates)
    optimizer_memory = 2 * param_memory
    
    total = param_memory + activations + optimizer_memory
    
    return {
        "parameters": total_params,
        "param_memory_gb": param_memory / 1e9,
        "activation_memory_gb": activations / 1e9,
        "optimizer_memory_gb": optimizer_memory / 1e9,
        "total_memory_gb": total / 1e9
    }

# Test different configurations
configs = [
    {"name": "BERT-base", "layers": 12, "d_model": 768, "heads": 12, "d_ff": 3072},
    {"name": "BERT-large", "layers": 24, "d_model": 1024, "heads": 16, "d_ff": 4096},
    {"name": "GPT-2 Medium", "layers": 24, "d_model": 1024, "heads": 16, "d_ff": 4096},
    {"name": "GPT-2 Large", "layers": 36, "d_model": 1280, "heads": 20, "d_ff": 5120},
]

print("Memory estimates for batch_size=32, seq_len=512 (bfloat16):")
print("=" * 70)

for config in configs:
    mem = estimate_transformer_memory(
        batch_size=32,
        seq_len=512,
        d_model=config["d_model"],
        num_layers=config["layers"],
        num_heads=config["heads"],
        d_ff=config["d_ff"]
    )
    
    fits = "✅" if mem["total_memory_gb"] < 128 else "❌"
    
    print(f"\n{config['name']}:")
    print(f"  Parameters:  {mem['parameters']/1e6:.1f}M")
    print(f"  Params mem:  {mem['param_memory_gb']:.2f} GB")
    print(f"  Activations: {mem['activation_memory_gb']:.2f} GB")
    print(f"  Optimizer:   {mem['optimizer_memory_gb']:.2f} GB")
    print(f"  Total:       {mem['total_memory_gb']:.2f} GB {fits}")

### DGX Spark Advantage

With 128GB unified memory, you can:
- Train BERT-large with large batch sizes
- Fine-tune GPT-2 Large without memory tricks
- Use longer sequences than typical GPUs allow

On a typical 24GB GPU, you'd need gradient checkpointing or very small batches!

---

## Try It Yourself: Exercises

### Exercise 1: Add Decoder Block

Create a `TransformerDecoderBlock` that includes:
1. Masked self-attention (causal)
2. Cross-attention to encoder outputs
3. Feed-forward network

<details>
<summary>Hint</summary>
The decoder has two attention layers: one for self-attention with causal mask, and one for cross-attention where queries come from decoder, keys/values from encoder.
</details>

In [None]:
class TransformerDecoderBlock(nn.Module):
    """
    Transformer Decoder block with:
    1. Masked self-attention
    2. Cross-attention to encoder
    3. Feed-forward network
    
    TODO: Implement this!
    """
    
    def __init__(self, d_model, num_heads, d_ff=None, dropout=0.1):
        super().__init__()
        # YOUR CODE HERE
        pass
    
    def forward(self, x, encoder_output, self_mask=None, cross_mask=None):
        # YOUR CODE HERE
        pass

# Test your implementation:
# decoder_block = TransformerDecoderBlock(512, 8)
# decoder_input = torch.randn(2, 10, 512)
# encoder_output = torch.randn(2, 20, 512)
# out = decoder_block(decoder_input, encoder_output)

### Exercise 2: Implement Gradient Checkpointing

For very deep models, we can trade compute for memory by recomputing activations during backward pass.

In [None]:
from torch.utils.checkpoint import checkpoint

class TransformerEncoderCheckpointed(nn.Module):
    """
    Transformer Encoder with gradient checkpointing.
    
    This saves memory by not storing all activations,
    at the cost of recomputing them during backward pass.
    
    TODO: Implement using torch.utils.checkpoint.checkpoint
    """
    
    def __init__(self, num_layers, d_model, num_heads, d_ff=None, dropout=0.1):
        super().__init__()
        # YOUR CODE HERE
        pass
    
    def forward(self, x, mask=None, use_checkpointing=True):
        # YOUR CODE HERE
        # Hint: use checkpoint(layer, x, mask) instead of layer(x, mask)
        pass

---

## Common Mistakes

### Mistake 1: Wrong Residual Connection Order

In [None]:
# Wrong: Adding normalized output to non-normalized input
def wrong_residual(x, sublayer, norm):
    return x + norm(sublayer(x))  # Inconsistent normalization!

# Right (Pre-LN): Normalize before sublayer, add to original
def right_residual_preln(x, sublayer, norm):
    return x + sublayer(norm(x))  # Consistent: both paths end unnormalized

# Right (Post-LN): Normalize after residual
def right_residual_postln(x, sublayer, norm):
    return norm(x + sublayer(x))  # Consistent: output is always normalized

print("Pre-LN:  output = x + sublayer(norm(x))")
print("Post-LN: output = norm(x + sublayer(x))")

### Mistake 2: Forgetting Final Layer Norm in Pre-LN

In [None]:
# Wrong: No final norm in Pre-LN architecture
class WrongPreLNEncoder(nn.Module):
    def __init__(self, layers):
        super().__init__()
        self.layers = layers
        # Missing: self.final_norm = nn.LayerNorm(d_model)
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x  # Output is not normalized!

# Right: Include final layer norm
class RightPreLNEncoder(nn.Module):
    def __init__(self, layers, d_model):
        super().__init__()
        self.layers = layers
        self.final_norm = nn.LayerNorm(d_model)
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.final_norm(x)  # Properly normalized output

print("Pre-LN needs a final LayerNorm after all layers!")

### Mistake 3: Wrong Dropout Placement

In [None]:
# Wrong: Dropout after residual addition
def wrong_dropout(x, sublayer_out, norm, dropout):
    return dropout(x + sublayer_out)  # Drops both original AND new info!

# Right: Dropout on sublayer output before addition
def right_dropout(x, sublayer_out, norm, dropout):
    return x + dropout(sublayer_out)  # Only drops new info, preserves original

print("Dropout should be applied to sublayer output BEFORE adding residual.")
print("This preserves the information in the skip connection.")

---

## Checkpoint

You've learned:
- ✅ How to build a complete Transformer encoder block
- ✅ The role of feed-forward networks and different activations
- ✅ Why residual connections enable deep networks
- ✅ Pre-LN vs Post-LN normalization strategies
- ✅ How to stack blocks into a full encoder
- ✅ Memory estimation for DGX Spark optimization

---

## Challenge (Optional)

Implement **RMSNorm** (Root Mean Square Normalization) as an alternative to LayerNorm:

```
RMSNorm(x) = x * scale / sqrt(mean(x^2) + epsilon)
```

This is used in LLaMA and is slightly faster than LayerNorm since it doesn't compute the mean!

In [None]:
class RMSNorm(nn.Module):
    """
    Root Mean Square Normalization (used in LLaMA).
    
    TODO: Implement this!
    """
    
    def __init__(self, d_model, eps=1e-6):
        super().__init__()
        # YOUR CODE HERE
        pass
    
    def forward(self, x):
        # YOUR CODE HERE
        pass

# Test:
# rms = RMSNorm(512)
# x = torch.randn(2, 10, 512)
# out = rms(x)
# print(f"RMSNorm output shape: {out.shape}")

---

## Further Reading

- [On Layer Normalization in the Transformer Architecture](https://arxiv.org/abs/2002.04745) - Pre-LN analysis
- [GLU Variants Improve Transformer](https://arxiv.org/abs/2002.05202) - SwiGLU and friends
- [Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467) - RMSNorm paper
- [Deep Residual Learning](https://arxiv.org/abs/1512.03385) - Original ResNet paper

---

## Cleanup

In [None]:
# Clear memory
import gc

del encoder, block, pytorch_layer, our_layer
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()

print("Memory cleared! Ready for the next notebook.")

---

## Next Up

In **Notebook 03: Positional Encoding Study**, we'll learn how Transformers understand word order:
- Sinusoidal positional encodings
- Rotary Position Embeddings (RoPE)
- ALiBi and other modern approaches

---

*Excellent work! You've now built the core of modern NLP from scratch!*