# Solutions: Lab 2.3.2 - Transformer Block

This notebook contains solutions to the exercises from notebook 02.

---

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.checkpoint import checkpoint

torch.manual_seed(42)

## Exercise 1: Decoder Block

**Task:** Create a TransformerDecoderBlock with masked self-attention and cross-attention.

In [None]:
class MultiHeadAttention(nn.Module):
    """Multi-head attention (from notebook 01)."""
    
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        seq_q, seq_k = query.size(1), key.size(1)
        
        Q = self.W_q(query).view(batch_size, seq_q, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(batch_size, seq_k, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, seq_k, self.num_heads, self.d_k).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attention = F.softmax(scores, dim=-1)
        attention = self.dropout(attention)
        
        context = torch.matmul(attention, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_q, self.d_model)
        
        return self.W_o(context)


class TransformerDecoderBlock(nn.Module):
    """
    Transformer Decoder block with:
    1. Masked self-attention
    2. Cross-attention to encoder
    3. Feed-forward network
    """
    
    def __init__(self, d_model, num_heads, d_ff=None, dropout=0.1):
        super().__init__()
        
        if d_ff is None:
            d_ff = 4 * d_model
        
        # Self-attention (masked)
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        
        # Cross-attention
        self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
        
        # Feed-forward
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        
        # Layer norms
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, encoder_output, self_mask=None, cross_mask=None):
        """
        Args:
            x: Decoder input (batch, tgt_len, d_model)
            encoder_output: Encoder output (batch, src_len, d_model)
            self_mask: Causal mask for self-attention
            cross_mask: Mask for cross-attention (e.g., padding mask)
        """
        # 1. Masked self-attention
        residual = x
        x = self.norm1(x)
        x = self.self_attn(x, x, x, self_mask)
        x = residual + self.dropout(x)
        
        # 2. Cross-attention
        residual = x
        x = self.norm2(x)
        x = self.cross_attn(x, encoder_output, encoder_output, cross_mask)
        x = residual + self.dropout(x)
        
        # 3. Feed-forward
        residual = x
        x = self.norm3(x)
        x = self.ffn(x)
        x = residual + self.dropout(x)
        
        return x

# Test
decoder_block = TransformerDecoderBlock(512, 8)
decoder_input = torch.randn(2, 10, 512)
encoder_output = torch.randn(2, 20, 512)

# Create causal mask
tgt_len = 10
causal_mask = torch.tril(torch.ones(tgt_len, tgt_len)).bool()
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)  # (1, 1, tgt, tgt)

out = decoder_block(decoder_input, encoder_output, self_mask=causal_mask)
print(f"Decoder input shape: {decoder_input.shape}")
print(f"Encoder output shape: {encoder_output.shape}")
print(f"Decoder block output shape: {out.shape}")
print("\nDecoder block implemented successfully!")

## Exercise 2: Gradient Checkpointing

**Task:** Implement gradient checkpointing to save memory.

In [None]:
class TransformerEncoderCheckpointed(nn.Module):
    """
    Transformer Encoder with gradient checkpointing.
    
    This saves memory by not storing all activations,
    at the cost of recomputing them during backward pass.
    """
    
    def __init__(self, num_layers, d_model, num_heads, d_ff=None, dropout=0.1):
        super().__init__()
        
        if d_ff is None:
            d_ff = 4 * d_model
        
        self.layers = nn.ModuleList([
            self._make_layer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.final_norm = nn.LayerNorm(d_model)
        
    def _make_layer(self, d_model, num_heads, d_ff, dropout):
        return nn.ModuleDict({
            'attn': MultiHeadAttention(d_model, num_heads, dropout),
            'ffn': nn.Sequential(
                nn.Linear(d_model, d_ff),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(d_ff, d_model)
            ),
            'norm1': nn.LayerNorm(d_model),
            'norm2': nn.LayerNorm(d_model),
            'dropout': nn.Dropout(dropout)
        })
    
    def _forward_layer(self, layer, x, mask):
        """Forward pass through a single layer."""
        # Self-attention
        residual = x
        x = layer['norm1'](x)
        x = layer['attn'](x, x, x, mask)
        x = residual + layer['dropout'](x)
        
        # FFN
        residual = x
        x = layer['norm2'](x)
        x = layer['ffn'](x)
        x = residual + layer['dropout'](x)
        
        return x
    
    def forward(self, x, mask=None, use_checkpointing=True):
        """
        Forward pass with optional gradient checkpointing.
        
        Args:
            x: Input tensor
            mask: Attention mask
            use_checkpointing: Whether to use gradient checkpointing
        """
        for layer in self.layers:
            if use_checkpointing and self.training:
                # Use checkpoint to save memory
                # Note: checkpoint requires the function to not have any side effects
                x = checkpoint(
                    self._forward_layer,
                    layer, x, mask,
                    use_reentrant=False
                )
            else:
                x = self._forward_layer(layer, x, mask)
        
        return self.final_norm(x)

# Test
encoder = TransformerEncoderCheckpointed(
    num_layers=6,
    d_model=512,
    num_heads=8
)

x = torch.randn(2, 100, 512, requires_grad=True)

# With checkpointing
encoder.train()
out = encoder(x, use_checkpointing=True)
loss = out.sum()
loss.backward()

print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")
print(f"Gradient computed: {x.grad is not None}")
print("\nGradient checkpointing implemented successfully!")

## Exercise 3 (Challenge): RMSNorm

**Task:** Implement RMSNorm as an alternative to LayerNorm.

In [None]:
class RMSNorm(nn.Module):
    """
    Root Mean Square Normalization (used in LLaMA).
    
    RMSNorm(x) = x * scale / sqrt(mean(x^2) + epsilon)
    
    Slightly faster than LayerNorm since it doesn't compute the mean.
    """
    
    def __init__(self, d_model, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(d_model))
        
    def forward(self, x):
        # Compute RMS
        rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        
        # Normalize and scale
        return x / rms * self.scale

# Test
rms = RMSNorm(512)
x = torch.randn(2, 10, 512)
out = rms(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")

# Compare with LayerNorm
ln = nn.LayerNorm(512)

import time

x_large = torch.randn(32, 1000, 512)

# Time RMSNorm
start = time.time()
for _ in range(100):
    _ = rms(x_large)
rms_time = time.time() - start

# Time LayerNorm
start = time.time()
for _ in range(100):
    _ = ln(x_large)
ln_time = time.time() - start

print(f"\nSpeed comparison (100 iterations):")
print(f"  RMSNorm:   {rms_time*1000:.2f} ms")
print(f"  LayerNorm: {ln_time*1000:.2f} ms")
print(f"  Speedup:   {ln_time/rms_time:.2f}x")

---

## Key Takeaways

1. **Decoder blocks** have three sublayers: masked self-attention, cross-attention, and FFN
2. **Gradient checkpointing** trades compute for memory - useful for large models
3. **RMSNorm** is faster than LayerNorm and works well in practice

---