# Day 9: Transformer Components - Part 2

This notebook explores GELU activation functions and complete transformer block architectures.

## Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import math

# Set style for plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1. GELU Activation Function

Let's implement and analyze the GELU activation function:

In [None]:
def gelu_exact(x):
    """Exact GELU implementation using error function."""
    return 0.5 * x * (1 + torch.erf(x / math.sqrt(2)))

def gelu_approx(x):
    """Approximate GELU implementation using tanh."""
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2/math.pi) * (x + 0.044715 * x**3)))

def compare_activations():
    """Compare different activation functions."""
    
    x = torch.linspace(-4, 4, 1000)
    
    # Compute different activations
    relu_out = F.relu(x)
    gelu_exact_out = gelu_exact(x)
    gelu_approx_out = gelu_approx(x)
    gelu_pytorch_out = F.gelu(x)
    swish_out = x * torch.sigmoid(x)
    
    # Plot activation functions
    plt.figure(figsize=(15, 10))
    
    # Activation functions
    plt.subplot(2, 2, 1)
    plt.plot(x, relu_out, label='ReLU', linewidth=2)
    plt.plot(x, gelu_exact_out, label='GELU (Exact)', linewidth=2)
    plt.plot(x, gelu_approx_out, label='GELU (Approx)', linewidth=2, linestyle='--')
    plt.plot(x, swish_out, label='Swish', linewidth=2)
    plt.xlabel('Input')
    plt.ylabel('Output')
    plt.title('Activation Functions Comparison')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Derivatives
    plt.subplot(2, 2, 2)
    x_grad = torch.linspace(-4, 4, 1000, requires_grad=True)
    
    relu_grad = torch.autograd.grad(F.relu(x_grad).sum(), x_grad, create_graph=True)[0]
    gelu_grad = torch.autograd.grad(F.gelu(x_grad).sum(), x_grad, create_graph=True)[0]
    swish_grad = torch.autograd.grad((x_grad * torch.sigmoid(x_grad)).sum(), x_grad, create_graph=True)[0]
    
    plt.plot(x_grad.detach(), relu_grad.detach(), label='ReLU', linewidth=2)
    plt.plot(x_grad.detach(), gelu_grad.detach(), label='GELU', linewidth=2)
    plt.plot(x_grad.detach(), swish_grad.detach(), label='Swish', linewidth=2)
    plt.xlabel('Input')
    plt.ylabel('Gradient')
    plt.title('Activation Function Gradients')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # GELU approximation accuracy
    plt.subplot(2, 2, 3)
    error = torch.abs(gelu_exact_out - gelu_approx_out)
    plt.plot(x, error, linewidth=2, color='red')
    plt.xlabel('Input')
    plt.ylabel('Absolute Error')
    plt.title('GELU Approximation Error')
    plt.grid(True, alpha=0.3)
    plt.yscale('log')
    
    # Probabilistic interpretation
    plt.subplot(2, 2, 4)
    cdf = 0.5 * (1 + torch.erf(x / math.sqrt(2)))
    plt.plot(x, cdf, label='Φ(x) - CDF', linewidth=2)
    plt.plot(x, gelu_exact_out / x, label='GELU(x)/x', linewidth=2)
    plt.xlabel('Input')
    plt.ylabel('Value')
    plt.title('GELU Probabilistic Interpretation')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print(f"Maximum approximation error: {error.max():.6f}")
    print(f"Mean approximation error: {error.mean():.6f}")

compare_activations()

## 2. Feed-Forward Network Implementation

Let's implement the position-wise feed-forward network:

In [None]:
class PositionwiseFeedForward(nn.Module):
    """Position-wise feed-forward network."""
    
    def __init__(self, d_model, d_ff, dropout=0.1, activation='gelu'):
        super().__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = activation
        
    def forward(self, x):
        if self.activation == 'gelu':
            return self.w_2(self.dropout(F.gelu(self.w_1(x))))
        elif self.activation == 'relu':
            return self.w_2(self.dropout(F.relu(self.w_1(x))))
        else:
            raise ValueError(f"Unknown activation: {self.activation}")

def test_ffn_activations():
    """Test different activations in feed-forward networks."""
    
    d_model = 128
    d_ff = 512
    seq_len = 32
    batch_size = 8
    
    # Create input
    x = torch.randn(batch_size, seq_len, d_model)
    
    # Create networks with different activations
    ffn_gelu = PositionwiseFeedForward(d_model, d_ff, activation='gelu')
    ffn_relu = PositionwiseFeedForward(d_model, d_ff, activation='relu')
    
    # Forward pass
    out_gelu = ffn_gelu(x)
    out_relu = ffn_relu(x)
    
    # Analyze outputs
    plt.figure(figsize=(12, 8))
    
    # Output distributions
    plt.subplot(2, 2, 1)
    plt.hist(out_gelu.detach().flatten(), bins=50, alpha=0.7, label='GELU', density=True)
    plt.hist(out_relu.detach().flatten(), bins=50, alpha=0.7, label='ReLU', density=True)
    plt.xlabel('Output Value')
    plt.ylabel('Density')
    plt.title('Output Distributions')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Activation patterns
    plt.subplot(2, 2, 2)
    gelu_activations = F.gelu(ffn_gelu.w_1(x))
    relu_activations = F.relu(ffn_relu.w_1(x))
    
    gelu_sparsity = (gelu_activations == 0).float().mean()
    relu_sparsity = (relu_activations == 0).float().mean()
    
    plt.bar(['GELU', 'ReLU'], [gelu_sparsity, relu_sparsity])
    plt.ylabel('Sparsity (Fraction of Zeros)')
    plt.title('Activation Sparsity')
    plt.grid(True, alpha=0.3)
    
    # Gradient analysis
    plt.subplot(2, 2, 3)
    x_grad = x.clone().requires_grad_(True)
    
    loss_gelu = ffn_gelu(x_grad).sum()
    loss_gelu.backward()
    grad_gelu = x_grad.grad.clone()
    
    x_grad.grad.zero_()
    loss_relu = ffn_relu(x_grad).sum()
    loss_relu.backward()
    grad_relu = x_grad.grad.clone()
    
    plt.hist(grad_gelu.flatten(), bins=50, alpha=0.7, label='GELU', density=True)
    plt.hist(grad_relu.flatten(), bins=50, alpha=0.7, label='ReLU', density=True)
    plt.xlabel('Gradient Value')
    plt.ylabel('Density')
    plt.title('Gradient Distributions')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Statistics comparison
    plt.subplot(2, 2, 4)
    stats = {
        'Mean': [out_gelu.mean().item(), out_relu.mean().item()],
        'Std': [out_gelu.std().item(), out_relu.std().item()],
        'Grad Norm': [grad_gelu.norm().item(), grad_relu.norm().item()]
    }
    
    x_pos = np.arange(len(stats))
    width = 0.35
    
    gelu_values = [stats[key][0] for key in stats.keys()]
    relu_values = [stats[key][1] for key in stats.keys()]
    
    plt.bar(x_pos - width/2, gelu_values, width, label='GELU')
    plt.bar(x_pos + width/2, relu_values, width, label='ReLU')
    plt.xlabel('Statistic')
    plt.ylabel('Value')
    plt.title('Statistical Comparison')
    plt.xticks(x_pos, stats.keys())
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"GELU - Mean: {out_gelu.mean():.4f}, Std: {out_gelu.std():.4f}")
    print(f"ReLU - Mean: {out_relu.mean():.4f}, Std: {out_relu.std():.4f}")
    print(f"GELU Sparsity: {gelu_sparsity:.4f}")
    print(f"ReLU Sparsity: {relu_sparsity:.4f}")

test_ffn_activations()

## 3. Complete Transformer Block

Let's implement and test complete transformer blocks with different configurations:

In [None]:
class LayerNorm(nn.Module):
    """Layer normalization."""
    
    def __init__(self, d_model, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps
    
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

class SimpleAttention(nn.Module):
    """Simplified attention for demonstration."""
    
    def __init__(self, d_model, num_heads=8):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        
    def forward(self, x):
        batch_size, seq_len, d_model = x.size()
        
        Q = self.w_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.w_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.w_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn_weights = F.softmax(scores, dim=-1)
        context = torch.matmul(attn_weights, V)
        
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        return self.w_o(context)

class TransformerBlock(nn.Module):
    """Complete transformer block."""
    
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1, norm_first=True):
        super().__init__()
        self.norm_first = norm_first
        
        self.attention = SimpleAttention(d_model, num_heads)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        
        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        if self.norm_first:  # Pre-norm
            # Self-attention with residual
            attn_out = self.attention(self.norm1(x))
            x = x + self.dropout(attn_out)
            
            # Feed-forward with residual
            ff_out = self.feed_forward(self.norm2(x))
            x = x + self.dropout(ff_out)
        else:  # Post-norm
            # Self-attention with residual and norm
            attn_out = self.attention(x)
            x = self.norm1(x + self.dropout(attn_out))
            
            # Feed-forward with residual and norm
            ff_out = self.feed_forward(x)
            x = self.norm2(x + self.dropout(ff_out))
        
        return x

def compare_transformer_blocks():
    """Compare pre-norm vs post-norm transformer blocks."""
    
    d_model = 128
    num_heads = 8
    d_ff = 512
    seq_len = 32
    batch_size = 4
    num_layers = 6
    
    # Create models
    pre_norm_blocks = nn.Sequential(*[
        TransformerBlock(d_model, num_heads, d_ff, norm_first=True)
        for _ in range(num_layers)
    ])
    
    post_norm_blocks = nn.Sequential(*[
        TransformerBlock(d_model, num_heads, d_ff, norm_first=False)
        for _ in range(num_layers)
    ])
    
    # Create input
    x = torch.randn(batch_size, seq_len, d_model)
    
    # Track activations through layers
    pre_norm_activations = []
    post_norm_activations = []
    
    # Pre-norm forward pass
    current_x = x.clone()
    pre_norm_activations.append(current_x.norm(dim=-1).mean())
    
    for block in pre_norm_blocks:
        current_x = block(current_x)
        pre_norm_activations.append(current_x.norm(dim=-1).mean())
    
    # Post-norm forward pass
    current_x = x.clone()
    post_norm_activations.append(current_x.norm(dim=-1).mean())
    
    for block in post_norm_blocks:
        current_x = block(current_x)
        post_norm_activations.append(current_x.norm(dim=-1).mean())
    
    # Plot results
    plt.figure(figsize=(12, 8))
    
    # Activation norms
    plt.subplot(2, 2, 1)
    layers = list(range(num_layers + 1))
    plt.plot(layers, [x.item() for x in pre_norm_activations], 'o-', label='Pre-norm', linewidth=2)
    plt.plot(layers, [x.item() for x in post_norm_activations], 's-', label='Post-norm', linewidth=2)
    plt.xlabel('Layer')
    plt.ylabel('Activation Norm')
    plt.title('Activation Norms Through Layers')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Gradient analysis
    plt.subplot(2, 2, 2)
    
    # Compute gradients
    x_grad = x.clone().requires_grad_(True)
    pre_norm_out = pre_norm_blocks(x_grad)
    pre_norm_loss = pre_norm_out.sum()
    pre_norm_loss.backward()
    pre_norm_grad = x_grad.grad.norm().item()
    
    x_grad = x.clone().requires_grad_(True)
    post_norm_out = post_norm_blocks(x_grad)
    post_norm_loss = post_norm_out.sum()
    post_norm_loss.backward()
    post_norm_grad = x_grad.grad.norm().item()
    
    plt.bar(['Pre-norm', 'Post-norm'], [pre_norm_grad, post_norm_grad])
    plt.ylabel('Input Gradient Norm')
    plt.title('Gradient Flow to Input')
    plt.grid(True, alpha=0.3)
    
    # Output distributions
    plt.subplot(2, 2, 3)
    plt.hist(pre_norm_out.detach().flatten(), bins=50, alpha=0.7, label='Pre-norm', density=True)
    plt.hist(post_norm_out.detach().flatten(), bins=50, alpha=0.7, label='Post-norm', density=True)
    plt.xlabel('Output Value')
    plt.ylabel('Density')
    plt.title('Output Distributions')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Parameter count
    plt.subplot(2, 2, 4)
    pre_norm_params = sum(p.numel() for p in pre_norm_blocks.parameters())
    post_norm_params = sum(p.numel() for p in post_norm_blocks.parameters())
    
    plt.bar(['Pre-norm', 'Post-norm'], [pre_norm_params, post_norm_params])
    plt.ylabel('Parameter Count')
    plt.title('Model Size Comparison')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"Pre-norm final activation norm: {pre_norm_activations[-1]:.4f}")
    print(f"Post-norm final activation norm: {post_norm_activations[-1]:.4f}")
    print(f"Pre-norm gradient norm: {pre_norm_grad:.4f}")
    print(f"Post-norm gradient norm: {post_norm_grad:.4f}")
    print(f"Pre-norm parameters: {pre_norm_params:,}")
    print(f"Post-norm parameters: {post_norm_params:,}")

compare_transformer_blocks()

## 4. Summary and Key Insights

This notebook demonstrated the key transformer components:

1. **GELU Activation**: Provides smooth, probabilistic gating with better gradient flow than ReLU
2. **Feed-Forward Networks**: Position-wise transformations that provide model capacity
3. **Pre-norm vs Post-norm**: Pre-norm generally provides better training stability

These components work together to create stable, trainable transformer architectures.