# Normalization & Activation Improvements: Pre-LN, RMSNorm, GeLU, SwiGLU

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/deep-learning-abc/blob/main/normalization_activations.ipynb)

This notebook implements from scratch the key normalization and activation improvements that make modern Transformers trainable and performant:

**Normalization:**
1. **Post-LayerNorm** — original (unstable for deep models)
2. **Pre-LayerNorm** — apply norm *before* sublayers (current standard)
3. **RMSNorm** — simpler and faster (used in LLaMA, Mistral)

**Activations:**
4. **ReLU** → **GeLU** → **SwiGLU** progression

We compare training stability, gradient flow, and computational cost.

In [None]:
!pip install torch matplotlib

In [None]:
import torch
import matplotlib.pyplot as plt
import math

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 0. Mathematical Foundations

### Layer Normalization

LayerNorm normalizes across the feature dimension for each token independently:

$$\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$$

where $\mu = \frac{1}{d} \sum_i x_i$ and $\sigma^2 = \frac{1}{d} \sum_i (x_i - \mu)^2$.

### RMSNorm

RMSNorm simplifies by removing the mean centering:

$$\text{RMSNorm}(x) = \gamma \cdot \frac{x}{\text{RMS}(x) + \epsilon}, \quad \text{RMS}(x) = \sqrt{\frac{1}{d} \sum_i x_i^2}$$

No $\beta$ (bias), no $\mu$ (mean subtraction) → ~10-30% faster.

### Pre-LN vs Post-LN

**Post-LN** (original): `x + LayerNorm(Sublayer(x))`

**Pre-LN** (modern): `x + Sublayer(LayerNorm(x))`

Pre-LN places the residual connection as a direct path, making gradients flow more easily through deep networks.

### Activation Functions

$$\text{ReLU}(x) = \max(0, x)$$

$$\text{GeLU}(x) = x \cdot \Phi(x) \approx 0.5 \cdot x \cdot \left(1 + \tanh\left(\sqrt{\frac{2}{\pi}}(x + 0.044715 x^3)\right)\right)$$

$$\text{Swish}(x) = x \cdot \sigma(x)$$

$$\text{SwiGLU}(x, W, V) = \text{Swish}(xW) \odot (xV)$$

SwiGLU uses a **gating mechanism**: one linear projection controls *what* information passes, another controls *how much*.

## 1. LayerNorm vs RMSNorm

In [None]:
def layer_norm(x, gamma, beta, eps=1e-5):
    """Standard Layer Normalization."""
    mean = x.mean(dim=-1, keepdim=True)
    var = x.var(dim=-1, keepdim=True, unbiased=False)
    x_norm = (x - mean) / torch.sqrt(var + eps)
    return gamma * x_norm + beta

def rms_norm(x, gamma, eps=1e-5):
    """RMS Normalization — no mean centering, no bias."""
    rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps)
    return gamma * (x / rms)

# Compare outputs
torch.manual_seed(42)
d_model = 8
x = torch.randn(2, 4, d_model, device=device)  # (batch, seq, d_model)
gamma = torch.ones(d_model, device=device)
beta = torch.zeros(d_model, device=device)

ln_out = layer_norm(x, gamma, beta)
rms_out = rms_norm(x, gamma)

print('Input x[0,0]:', x[0, 0].cpu())
print('LayerNorm:   ', ln_out[0, 0].cpu())
print('RMSNorm:     ', rms_out[0, 0].cpu())

print(f'\nLayerNorm stats: mean={ln_out[0,0].mean():.6f}, std={ln_out[0,0].std():.4f}')
print(f'RMSNorm stats:   mean={rms_out[0,0].mean():.6f}, RMS={torch.sqrt((rms_out[0,0]**2).mean()):.4f}')

In [None]:
# Verify: LayerNorm output has mean=0, std=1
#         RMSNorm output has RMS=1 (but mean may not be 0)
print('LayerNorm guarantees:')
print(f'  Output mean ≈ 0: {ln_out.mean(dim=-1)[0].cpu()}')
print(f'  Output std  ≈ 1: {ln_out.std(dim=-1, unbiased=False)[0].cpu()}')

print('\nRMSNorm guarantees:')
print(f'  Output RMS  ≈ 1: {torch.sqrt((rms_out**2).mean(dim=-1))[0].cpu()}')
print(f'  Output mean:     {rms_out.mean(dim=-1)[0].cpu()}  (NOT necessarily 0)')

In [None]:
# Speed comparison (conceptual — count operations)
print('Operation count comparison (for d features):')
print('\nLayerNorm:')
print('  1. Compute mean:     d additions + 1 division')
print('  2. Subtract mean:    d subtractions')
print('  3. Compute variance: d multiplications + d additions + 1 division')
print('  4. Normalize:        d subtractions + d divisions')
print('  5. Scale + shift:    d multiplications + d additions')
print('  Total: ~6d operations')

print('\nRMSNorm:')
print('  1. Compute x²:      d multiplications')
print('  2. Mean of x²:      d additions + 1 division')
print('  3. Square root:     1 operation')
print('  4. Normalize:       d divisions')
print('  5. Scale:           d multiplications')
print('  Total: ~4d operations')
print('\n→ RMSNorm is ~33% fewer operations (no mean subtraction, no bias addition)')

## 2. Post-LN vs Pre-LN Transformer Blocks

The placement of normalization has a dramatic effect on training stability.

In [None]:
def simple_attention(x, W_Q, W_K, W_V):
    """Simplified single-head attention for demonstration."""
    Q, K, V = x @ W_Q, x @ W_K, x @ W_V
    d_k = Q.shape[-1]
    scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(d_k)
    weights = torch.softmax(scores, dim=-1)
    return torch.bmm(weights, V)

def ffn(x, W1, b1, W2, b2, activation_fn):
    """Feed-forward network with configurable activation."""
    return activation_fn(x @ W1 + b1) @ W2 + b2

def post_ln_block(x, attn_params, ffn_params, norm_params, activation_fn):
    """Post-LayerNorm: x + LayerNorm(Sublayer(x))"""
    W_Q, W_K, W_V = attn_params
    W1, b1, W2, b2 = ffn_params
    gamma1, beta1, gamma2, beta2 = norm_params
    
    # Attention sublayer with Post-LN
    attn_out = simple_attention(x, W_Q, W_K, W_V)
    x = layer_norm(x + attn_out, gamma1, beta1)  # norm AFTER residual
    
    # FFN sublayer with Post-LN
    ffn_out = ffn(x, W1, b1, W2, b2, activation_fn)
    x = layer_norm(x + ffn_out, gamma2, beta2)  # norm AFTER residual
    
    return x

def pre_ln_block(x, attn_params, ffn_params, norm_params, activation_fn):
    """Pre-LayerNorm: x + Sublayer(LayerNorm(x))"""
    W_Q, W_K, W_V = attn_params
    W1, b1, W2, b2 = ffn_params
    gamma1, beta1, gamma2, beta2 = norm_params
    
    # Attention sublayer with Pre-LN
    x_norm = layer_norm(x, gamma1, beta1)  # norm BEFORE sublayer
    attn_out = simple_attention(x_norm, W_Q, W_K, W_V)
    x = x + attn_out  # clean residual path
    
    # FFN sublayer with Pre-LN
    x_norm = layer_norm(x, gamma2, beta2)  # norm BEFORE sublayer
    ffn_out = ffn(x_norm, W1, b1, W2, b2, activation_fn)
    x = x + ffn_out  # clean residual path
    
    return x

print('Post-LN: x → Sublayer → Add → Norm → next')
print('Pre-LN:  x → Norm → Sublayer → Add → next')
print('\nPre-LN gives the gradient a clean highway through residual connections.')

In [None]:
def init_block_params(d_model, d_ff, device):
    """Initialize parameters for one transformer block."""
    scale = 0.02
    attn_params = (
        torch.randn(d_model, d_model, device=device) * scale,
        torch.randn(d_model, d_model, device=device) * scale,
        torch.randn(d_model, d_model, device=device) * scale,
    )
    ffn_params = (
        torch.randn(d_model, d_ff, device=device) * scale,
        torch.zeros(d_ff, device=device),
        torch.randn(d_ff, d_model, device=device) * scale,
        torch.zeros(d_model, device=device),
    )
    norm_params = (
        torch.ones(d_model, device=device),
        torch.zeros(d_model, device=device),
        torch.ones(d_model, device=device),
        torch.zeros(d_model, device=device),
    )
    return attn_params, ffn_params, norm_params

# Compare activation magnitudes through deep stacks
torch.manual_seed(42)
d_model = 32
d_ff = 128
n_layers = 20
batch, seq_len = 2, 8
relu = torch.relu

x = torch.randn(batch, seq_len, d_model, device=device)

# Track norms through layers
post_ln_norms = [x.norm().item()]
pre_ln_norms = [x.norm().item()]

x_post = x.clone()
x_pre = x.clone()

for layer in range(n_layers):
    attn_p, ffn_p, norm_p = init_block_params(d_model, d_ff, device)
    
    x_post = post_ln_block(x_post, attn_p, ffn_p, norm_p, relu)
    post_ln_norms.append(x_post.norm().item())
    
    x_pre = pre_ln_block(x_pre, attn_p, ffn_p, norm_p, relu)
    pre_ln_norms.append(x_pre.norm().item())

fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(post_ln_norms, 'o-', linewidth=2, label='Post-LN', color='red')
ax.plot(pre_ln_norms, 's-', linewidth=2, label='Pre-LN', color='blue')
ax.set_xlabel('Layer')
ax.set_ylabel('Activation Norm')
ax.set_title('Activation Magnitude Through Deep Transformer Stack\n(Pre-LN grows smoothly, Post-LN can be erratic)')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 3. Pre-RMSNorm Block (Modern Standard)

The modern LLM standard combines Pre-LN placement with RMSNorm:

In [None]:
def pre_rmsnorm_block(x, attn_params, ffn_params, gamma1, gamma2, activation_fn):
    """Modern transformer block: Pre-RMSNorm (used in LLaMA, Mistral)."""
    W_Q, W_K, W_V = attn_params
    W1, b1, W2, b2 = ffn_params
    
    # RMSNorm before attention
    x_norm = rms_norm(x, gamma1)
    attn_out = simple_attention(x_norm, W_Q, W_K, W_V)
    x = x + attn_out
    
    # RMSNorm before FFN
    x_norm = rms_norm(x, gamma2)
    ffn_out = ffn(x_norm, W1, b1, W2, b2, activation_fn)
    x = x + ffn_out
    
    return x

# Track norms through layers
torch.manual_seed(42)
x_rms = x.clone()
rms_norms_list = [x_rms.norm().item()]

for layer in range(n_layers):
    attn_p, ffn_p, _ = init_block_params(d_model, d_ff, device)
    gamma1 = torch.ones(d_model, device=device)
    gamma2 = torch.ones(d_model, device=device)
    x_rms = pre_rmsnorm_block(x_rms, attn_p, ffn_p, gamma1, gamma2, relu)
    rms_norms_list.append(x_rms.norm().item())

fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(post_ln_norms, 'o-', linewidth=2, label='Post-LayerNorm', color='red')
ax.plot(pre_ln_norms, 's-', linewidth=2, label='Pre-LayerNorm', color='blue')
ax.plot(rms_norms_list, '^-', linewidth=2, label='Pre-RMSNorm', color='green')
ax.set_xlabel('Layer')
ax.set_ylabel('Activation Norm')
ax.set_title('Normalization Strategy Comparison (20 layers)')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 4. Activation Functions: ReLU → GeLU → SwiGLU

In [None]:
def relu(x):
    """ReLU: max(0, x) — original Transformer activation."""
    return torch.clamp(x, min=0)

def gelu(x):
    """GeLU: x * Phi(x) — smooth probabilistic activation.
    Used in BERT, GPT-2, GPT-3."""
    return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * x ** 3)))

def swish(x):
    """Swish: x * sigmoid(x) — smooth self-gated activation."""
    return x * torch.sigmoid(x)

# Visualize all three
x_range = torch.linspace(-4, 4, 200, device=device)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Activation functions
activations = [
    ('ReLU', relu(x_range)),
    ('GeLU', gelu(x_range)),
    ('Swish', swish(x_range)),
]

for name, y in activations:
    axes[0].plot(x_range.cpu(), y.cpu(), linewidth=2, label=name)
axes[0].axhline(y=0, color='gray', linestyle='--', alpha=0.3)
axes[0].axvline(x=0, color='gray', linestyle='--', alpha=0.3)
axes[0].set_xlabel('x')
axes[0].set_ylabel('f(x)')
axes[0].set_title('Activation Functions')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Gradients (computed manually for clarity)
x_grad = x_range.clone().requires_grad_(True)
for name, fn in [('ReLU', relu), ('GeLU', gelu), ('Swish', swish)]:
    y = fn(x_grad)
    grad = torch.autograd.grad(y.sum(), x_grad, create_graph=True)[0]
    axes[1].plot(x_range.cpu().detach(), grad.cpu().detach(), linewidth=2, label=name)
    x_grad = x_range.clone().requires_grad_(True)

axes[1].axhline(y=0, color='gray', linestyle='--', alpha=0.3)
axes[1].axvline(x=0, color='gray', linestyle='--', alpha=0.3)
axes[1].set_xlabel('x')
axes[1].set_ylabel("f'(x)")
axes[1].set_title('Gradients — GeLU/Swish are smooth (no hard cutoff at 0)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print('Key differences:')
print('  ReLU:  Hard cutoff at 0 — "dead neurons" get zero gradient forever')
print('  GeLU:  Smooth transition — small negative values get small (nonzero) gradients')
print('  Swish: Similar to GeLU, slightly different shape (basis for SwiGLU)')

## 5. SwiGLU — The Modern FFN

SwiGLU replaces the standard FFN with a gated linear unit:

**Standard FFN**: $\text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2$

**SwiGLU FFN**: $\text{SwiGLU}(x) = (\text{Swish}(xW_{\text{gate}}) \odot xW_{\text{up}}) W_{\text{down}}$

The key idea: one projection ($W_{\text{gate}}$) learns *what to let through*, another ($W_{\text{up}}$) learns *what values to pass*.

In [None]:
def ffn_relu(x, W1, b1, W2, b2):
    """Original FFN: ReLU(xW1 + b1)W2 + b2"""
    return relu(x @ W1 + b1) @ W2 + b2

def ffn_gelu(x, W1, b1, W2, b2):
    """GeLU FFN: GeLU(xW1 + b1)W2 + b2"""
    return gelu(x @ W1 + b1) @ W2 + b2

def ffn_swiglu(x, W_gate, W_up, W_down):
    """SwiGLU FFN: Swish(x @ W_gate) * (x @ W_up) @ W_down
    
    Note: SwiGLU uses 3 weight matrices (no bias) instead of 2 matrices + 2 biases.
    To keep parameter count similar, d_ff is typically reduced by 2/3.
    """
    gate = swish(x @ W_gate)  # Controls what passes
    up = x @ W_up              # Values to pass
    return (gate * up) @ W_down

# Compare
torch.manual_seed(42)
d_model = 16
d_ff = 64  # Standard FFN expansion factor
d_ff_swiglu = int(d_ff * 2 / 3)  # SwiGLU uses smaller d_ff (3 matrices instead of 2)

x = torch.randn(2, 4, d_model, device=device)

# Standard FFN params
W1 = torch.randn(d_model, d_ff, device=device) * 0.1
b1 = torch.zeros(d_ff, device=device)
W2 = torch.randn(d_ff, d_model, device=device) * 0.1
b2 = torch.zeros(d_model, device=device)

# SwiGLU params
W_gate = torch.randn(d_model, d_ff_swiglu, device=device) * 0.1
W_up = torch.randn(d_model, d_ff_swiglu, device=device) * 0.1
W_down = torch.randn(d_ff_swiglu, d_model, device=device) * 0.1

out_relu = ffn_relu(x, W1, b1, W2, b2)
out_gelu = ffn_gelu(x, W1, b1, W2, b2)
out_swiglu = ffn_swiglu(x, W_gate, W_up, W_down)

print('FFN output shapes:')
print(f'  ReLU FFN:  {out_relu.shape}')
print(f'  GeLU FFN:  {out_gelu.shape}')
print(f'  SwiGLU FFN: {out_swiglu.shape}')

# Parameter count comparison
relu_params = d_model * d_ff + d_ff + d_ff * d_model + d_model  # W1, b1, W2, b2
swiglu_params = d_model * d_ff_swiglu * 3  # W_gate, W_up, W_down (no biases)

print(f'\nParameter counts:')
print(f'  ReLU/GeLU FFN (d_ff={d_ff}):  {relu_params}')
print(f'  SwiGLU FFN (d_ff={d_ff_swiglu}): {swiglu_params}')
print(f'  Ratio: {swiglu_params/relu_params:.2f}x')

In [None]:
# Visualize the gating mechanism in SwiGLU
torch.manual_seed(42)
x_demo = torch.randn(1, 8, d_model, device=device)

gate_values = swish(x_demo @ W_gate)  # How much to let through
up_values = x_demo @ W_up             # What to let through
gated_output = gate_values * up_values # The product

fig, axes = plt.subplots(1, 3, figsize=(18, 4))

titles = [
    f'Gate: Swish(x @ W_gate)\nControls "how much"',
    f'Up: x @ W_up\nControls "what"',
    f'Output: gate * up\nGated result',
]
data = [gate_values, up_values, gated_output]

for ax, title, d in zip(axes, titles, data):
    im = ax.imshow(d[0].detach().cpu().numpy(), cmap='RdBu', aspect='auto')
    ax.set_xlabel(f'Hidden dim (d_ff={d_ff_swiglu})')
    ax.set_ylabel('Token position')
    ax.set_title(title)
    plt.colorbar(im, ax=ax)

plt.suptitle('SwiGLU Gating Mechanism', fontsize=13)
plt.tight_layout()
plt.show()

## 6. Modern Transformer Block

Putting it all together: the modern LLM block uses **Pre-RMSNorm + SwiGLU FFN**.

In [None]:
def modern_transformer_block(x, attn_params, swiglu_params, gamma1, gamma2):
    """Modern transformer block: Pre-RMSNorm + SwiGLU (LLaMA-style)."""
    W_Q, W_K, W_V = attn_params
    W_gate, W_up, W_down = swiglu_params
    
    # Pre-RMSNorm attention
    x_norm = rms_norm(x, gamma1)
    attn_out = simple_attention(x_norm, W_Q, W_K, W_V)
    x = x + attn_out
    
    # Pre-RMSNorm SwiGLU FFN
    x_norm = rms_norm(x, gamma2)
    ffn_out = ffn_swiglu(x_norm, W_gate, W_up, W_down)
    x = x + ffn_out
    
    return x

def original_transformer_block(x, attn_params, ffn_params, norm_params):
    """Original 2017 transformer block: Post-LayerNorm + ReLU FFN."""
    return post_ln_block(x, attn_params, ffn_params, norm_params, relu)

# Compare the two architectures
torch.manual_seed(42)
d_model = 32
d_ff = 128
d_ff_sg = int(d_ff * 2 / 3)
n_layers = 20

x = torch.randn(2, 8, d_model, device=device)
x_orig = x.clone()
x_modern = x.clone()

orig_norms = [x_orig.norm().item()]
modern_norms = [x_modern.norm().item()]

for layer in range(n_layers):
    scale = 0.02
    
    # Original block params
    attn_p = (
        torch.randn(d_model, d_model, device=device) * scale,
        torch.randn(d_model, d_model, device=device) * scale,
        torch.randn(d_model, d_model, device=device) * scale,
    )
    ffn_p = (
        torch.randn(d_model, d_ff, device=device) * scale,
        torch.zeros(d_ff, device=device),
        torch.randn(d_ff, d_model, device=device) * scale,
        torch.zeros(d_model, device=device),
    )
    norm_p = (
        torch.ones(d_model, device=device),
        torch.zeros(d_model, device=device),
        torch.ones(d_model, device=device),
        torch.zeros(d_model, device=device),
    )
    
    # SwiGLU params
    swiglu_p = (
        torch.randn(d_model, d_ff_sg, device=device) * scale,
        torch.randn(d_model, d_ff_sg, device=device) * scale,
        torch.randn(d_ff_sg, d_model, device=device) * scale,
    )
    gamma1 = torch.ones(d_model, device=device)
    gamma2 = torch.ones(d_model, device=device)
    
    x_orig = original_transformer_block(x_orig, attn_p, ffn_p, norm_p)
    orig_norms.append(x_orig.norm().item())
    
    x_modern = modern_transformer_block(x_modern, attn_p, swiglu_p, gamma1, gamma2)
    modern_norms.append(x_modern.norm().item())

fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(orig_norms, 'o-', linewidth=2, label='Original (Post-LN + ReLU)', color='red')
ax.plot(modern_norms, 's-', linewidth=2, label='Modern (Pre-RMSNorm + SwiGLU)', color='green')
ax.set_xlabel('Layer')
ax.set_ylabel('Activation Norm')
ax.set_title('Original vs Modern Transformer Block\n(20 layers, forward pass)')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# Summary comparison
print('=' * 80)
print('EVOLUTION: Normalization & Activation in Transformers')
print('=' * 80)
print(f'{"Component":<20} {"Original (2017)":<25} {"Modern (2024)":<25}')
print('-' * 80)
print(f'{"Norm placement":<20} {"Post-LN":<25} {"Pre-LN":<25}')
print(f'{"Norm type":<20} {"LayerNorm":<25} {"RMSNorm":<25}')
print(f'{"Activation":<20} {"ReLU":<25} {"SwiGLU":<25}')
print(f'{"FFN structure":<20} {"Linear→ReLU→Linear":<25} {"Gate+Up→SwiGLU→Down":<25}')
print(f'{"Norm params":<20} {"gamma + beta":<25} {"gamma only":<25}')
print(f'{"Max stable depth":<20} {"~6-12 layers":<25} {"100+ layers":<25}')
print(f'{"Used in":<20} {"Original Transformer":<25} {"LLaMA, Mistral, Gemma":<25}')
print('=' * 80)

## Summary

In this notebook we implemented from scratch:

**Normalization:**
1. **LayerNorm** — normalizes by mean and variance, with learnable $\gamma$ and $\beta$
2. **RMSNorm** — normalizes by RMS only, no mean centering, no bias. 10-30% faster.
3. **Pre-LN** — places normalization *before* sublayers, giving residual connections a clean gradient path. Enables training 100+ layer models.

**Activations:**
4. **ReLU → GeLU** — smooth activation eliminates "dead neuron" problem
5. **SwiGLU** — gated linear unit with Swish activation. Two projections: one gates "how much", one provides "what". More expressive per parameter.

**Modern standard (LLaMA, Mistral, Gemma):** Pre-RMSNorm + SwiGLU FFN