# Normalization Layers: Mathematical Derivations

This notebook derives the forward and backward passes for two normalization techniques used in transformers:

1. **LayerNorm** (Ba et al., 2016) - Standard in original Transformer
2. **RMSNorm** (Zhang & Sennrich, 2019) - Used in LLaMA, Gemma, and modern architectures

We'll derive the math from first principles and verify against our implementations.

## Glossary of Terms

| Term | Definition |
|------|------------|
| **Feature dimension** | The last axis of a tensor representing the learned attributes of each token. In a tensor of shape `(batch, sequence, features)`, the feature dimension has size `features` (often called `d_model`, typically 512-4096). Each position along this axis represents a different learned characteristic. |
| **Batch dimension** | The first axis of a tensor, representing independent samples processed together for efficiency. Shape: `(batch, ...)`. |
| **Sequence dimension** | The second axis in transformer tensors, representing positions in the input sequence (e.g., words in a sentence). Shape: `(batch, sequence, features)`. |
| **Normalization** | Rescaling values to have specific statistical properties (e.g., zero mean, unit variance). Stabilizes training by preventing activations from growing unboundedly. |
| **Internal covariate shift** | The phenomenon where the distribution of layer inputs changes during training as earlier layers update their weights, making optimization harder. |
| **Learnable parameters** | Weights ($\gamma$, $\beta$) that are updated during training via gradient descent, allowing the network to adapt the normalization. |
| **Upstream gradient** | The gradient $\frac{\partial L}{\partial y}$ flowing backward from later layers. We use this to compute gradients for earlier layers. |
| **Epsilon ($\epsilon$)** | A small constant (e.g., $10^{-5}$) added to denominators to prevent division by zero when variance is near zero. |
| **Affine transformation** | A linear transformation followed by a translation: $y = \gamma x + \beta$. The scale ($\gamma$) and shift ($\beta$) in normalization layers. |
| **RMS (Root Mean Square)** | $\sqrt{\frac{1}{n}\sum x_i^2}$ - the quadratic mean of values. Unlike standard deviation, it doesn't subtract the mean first. |

## Formulas and Theorems

### Layer Normalization

| Formula | Description |
|---------|-------------|
| $\mu = \frac{1}{d} \sum_{i=1}^{d} x_i$ | Mean over feature dimension |
| $\sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2$ | Variance over feature dimension |
| $\hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}}$ | Normalized features |
| $y = \gamma \odot \hat{x} + \beta$ | LayerNorm output (with learnable $\gamma, \beta$) |
| $\frac{\partial L}{\partial \gamma_i} = \sum \bar{y}_i \hat{x}_i$ | Gradient w.r.t. scale parameter |
| $\frac{\partial L}{\partial \beta_i} = \sum \bar{y}_i$ | Gradient w.r.t. shift parameter |
| $\frac{\partial L}{\partial x} = \frac{1}{\sigma}\left( \bar{\hat{x}} - \overline{\bar{\hat{x}}} - \hat{x} \cdot \overline{\bar{\hat{x}} \odot \hat{x}} \right)$ | Backward pass (where $\overline{(\cdot)}$ denotes mean) |

### RMS Normalization

| Formula | Description |
|---------|-------------|
| $\text{RMS}(x) = \sqrt{\frac{1}{d}\sum_{i=1}^{d} x_i^2 + \epsilon}$ | Root Mean Square |
| $\hat{x} = \frac{x}{\text{RMS}(x)}$ | RMS-normalized features |
| $y = \gamma \odot \hat{x}$ | RMSNorm output (no $\beta$ parameter) |
| $\frac{\partial L}{\partial \gamma_i} = \sum \bar{y}_i \hat{x}_i$ | Gradient w.r.t. scale parameter |
| $\frac{\partial L}{\partial x} = \frac{1}{r}\left( \bar{\hat{x}} - \hat{x} \cdot \overline{\bar{\hat{x}} \odot \hat{x}} \right)$ | Backward pass (simpler than LayerNorm) |

### Key Identities

| Identity | Description |
|----------|-------------|
| $\text{Var}(x) = \text{E}[x^2] - \text{E}[x]^2$ | Variance decomposition |
| $\text{RMS}(x)^2 = \text{Var}(x) + \mu^2$ | Relationship between RMS and variance |
| $\frac{\partial}{\partial x_i} \sqrt{f(x)} = \frac{f'(x_i)}{2\sqrt{f(x)}}$ | Derivative of square root |

## Prerequisites

This notebook assumes familiarity with:

### 1. Mean and Variance

For a vector $\mathbf{x} = [x_1, x_2, ..., x_d]$:

**Mean (Expected Value):**
$$\mu = \text{E}[x] = \frac{1}{d} \sum_{i=1}^{d} x_i$$

**Variance:** The average squared deviation from the mean:
$$\text{Var}(x) = \sigma^2 = \text{E}[(x - \mu)^2] = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2$$

**Standard Deviation:** $\sigma = \sqrt{\text{Var}(x)}$

**Alternative variance formula:** $\text{Var}(x) = \text{E}[x^2] - \text{E}[x]^2$

### 2. The Chain Rule for Multivariate Functions

When a variable $x_i$ affects the output through multiple paths, we sum the gradients from each path:

$$\frac{\partial L}{\partial x_i} = \sum_{\text{all paths}} \frac{\partial L}{\partial (\text{intermediate})} \cdot \frac{\partial (\text{intermediate})}{\partial x_i}$$

For normalization, $x_i$ affects the output through:
1. The direct path: $\hat{x}_i$ depends on $x_i$
2. The mean: $\mu$ depends on all $x_j$
3. The variance/RMS: $\sigma$ or $r$ depends on all $x_j$

### 3. Element-wise Operations

We use $\odot$ to denote element-wise (Hadamard) multiplication:
$$(a \odot b)_i = a_i \cdot b_i$$

This is different from matrix multiplication or dot products.

### 4. Gradient Notation

We use bar notation for gradients:
- $\bar{y} = \frac{\partial L}{\partial y}$ (upstream gradient)
- $\bar{x} = \frac{\partial L}{\partial x}$ (gradient we're computing)
- $\bar{\hat{x}} = \bar{y} \odot \gamma$ (gradient into the normalized value)

In [None]:
import numpy as np
import sys
sys.path.insert(0, '..')
from ai_comps.normalization import LayerNorm, RMSNorm

---
## Part 1: Layer Normalization

### 1.1 Motivation

Deep networks suffer from **internal covariate shift**: the distribution of layer inputs changes during training as earlier layers update. LayerNorm stabilizes training by normalizing activations to zero mean and unit variance *within each sample*.

Unlike BatchNorm (which normalizes across the batch dimension), LayerNorm normalizes across the feature dimension, making it suitable for:
- Variable-length sequences (no batch statistics needed)
- Small batch sizes
- Autoregressive models

### 1.2 Forward Pass

Given input $\mathbf{x} \in \mathbb{R}^d$ (a single token's features), LayerNorm computes:

**Step 1: Compute mean**
$$\mu = \frac{1}{d} \sum_{i=1}^{d} x_i$$

**Step 2: Compute variance**
$$\sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2$$

**Step 3: Normalize**
$$\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}$$

where $\epsilon \approx 10^{-5}$ prevents division by zero.

**Step 4: Scale and shift (learnable)**
$$y_i = \gamma_i \hat{x}_i + \beta_i$$

The parameters $\gamma, \beta \in \mathbb{R}^d$ allow the network to learn to "undo" the normalization if beneficial.

**Compact form:**
$$y = \gamma \odot \frac{x - \mu}{\sigma} + \beta$$

where $\sigma = \sqrt{\sigma^2 + \epsilon}$ and $\odot$ denotes element-wise multiplication.

In [None]:
# Verify forward pass
np.random.seed(42)
d = 4
x = np.random.randn(d).astype(np.float32)

# Manual computation
eps = 1e-5
mu = x.mean()
var = ((x - mu) ** 2).mean()
sigma = np.sqrt(var + eps)
x_hat = (x - mu) / sigma

print(f"Input x:     {x}")
print(f"Mean μ:      {mu:.6f}")
print(f"Variance σ²: {var:.6f}")
print(f"Std σ:       {sigma:.6f}")
print(f"Normalized:  {x_hat}")
print(f"\nVerify: mean(x_hat) = {x_hat.mean():.6f} (should be ≈0)")
print(f"Verify: var(x_hat)  = {x_hat.var():.6f} (should be ≈1)")

### 1.3 Backward Pass Derivation

This is where it gets interesting. Given upstream gradient $\frac{\partial L}{\partial y}$ (denoted $\bar{y}$), we need:

1. $\frac{\partial L}{\partial \gamma}$ and $\frac{\partial L}{\partial \beta}$ (parameter gradients)
2. $\frac{\partial L}{\partial x}$ (gradient to propagate backward)

#### Parameter gradients (easy)

From $y_i = \gamma_i \hat{x}_i + \beta_i$:

$$\frac{\partial L}{\partial \gamma_i} = \sum_{\text{batch, seq}} \bar{y}_i \cdot \hat{x}_i$$

$$\frac{\partial L}{\partial \beta_i} = \sum_{\text{batch, seq}} \bar{y}_i$$

#### Input gradient (tricky)

The challenge: $\hat{x}_i$ depends on ALL $x_j$ through $\mu$ and $\sigma$.

Let's define intermediate quantities:
- $\bar{\hat{x}}_i = \bar{y}_i \cdot \gamma_i$ (gradient into normalized value)

Using the chain rule through $\hat{x}_i = \frac{x_i - \mu}{\sigma}$:

$$\frac{\partial L}{\partial x_i} = \frac{\partial L}{\partial \hat{x}_i}\frac{\partial \hat{x}_i}{\partial x_i} + \sum_j \frac{\partial L}{\partial \hat{x}_j}\frac{\partial \hat{x}_j}{\partial \mu}\frac{\partial \mu}{\partial x_i} + \sum_j \frac{\partial L}{\partial \hat{x}_j}\frac{\partial \hat{x}_j}{\partial \sigma}\frac{\partial \sigma}{\partial x_i}$$

#### Computing each term

**Direct term:** $\frac{\partial \hat{x}_i}{\partial x_i} = \frac{1}{\sigma}$

**Through mean:** $\frac{\partial \mu}{\partial x_i} = \frac{1}{d}$, and $\frac{\partial \hat{x}_j}{\partial \mu} = -\frac{1}{\sigma}$

**Through variance:** This requires care. Let $v = \sigma^2 = \frac{1}{d}\sum_j(x_j - \mu)^2$

$$\frac{\partial v}{\partial x_i} = \frac{2(x_i - \mu)}{d}$$

$$\frac{\partial \sigma}{\partial v} = \frac{1}{2\sigma}$$

$$\frac{\partial \hat{x}_j}{\partial \sigma} = -\frac{x_j - \mu}{\sigma^2} = -\frac{\hat{x}_j}{\sigma}$$

#### Putting it together

After algebra (see Ba et al., 2016), the clean form is:

$$\frac{\partial L}{\partial x_i} = \frac{1}{\sigma}\left( \bar{\hat{x}}_i - \frac{1}{d}\sum_j \bar{\hat{x}}_j - \frac{\hat{x}_i}{d}\sum_j \bar{\hat{x}}_j \hat{x}_j \right)$$

In vector form:

$$\boxed{\frac{\partial L}{\partial x} = \frac{1}{\sigma}\left( \bar{\hat{x}} - \text{mean}(\bar{\hat{x}}) - \hat{x} \cdot \text{mean}(\bar{\hat{x}} \odot \hat{x}) \right)}$$

where $\bar{\hat{x}} = \bar{y} \odot \gamma$.

In [None]:
# Verify backward pass with numerical gradient check
def layernorm_forward(x, gamma, beta, eps=1e-5):
    mu = x.mean()
    var = ((x - mu) ** 2).mean()
    sigma = np.sqrt(var + eps)
    x_hat = (x - mu) / sigma
    y = gamma * x_hat + beta
    return y, x_hat, sigma

def layernorm_backward(dy, x_hat, sigma, gamma):
    """Analytical backward pass."""
    d = dy.shape[-1]
    dx_hat = dy * gamma
    
    # The key formula
    m1 = dx_hat.mean()                    # mean(dx_hat)
    m2 = (dx_hat * x_hat).mean()          # mean(dx_hat * x_hat)
    dx = (dx_hat - m1 - x_hat * m2) / sigma
    
    dgamma = (dy * x_hat).sum()
    dbeta = dy.sum()
    return dx, dgamma, dbeta

# Numerical gradient check
x = np.random.randn(4).astype(np.float64)  # float64 for precision
gamma = np.ones(4, dtype=np.float64)
beta = np.zeros(4, dtype=np.float64)

y, x_hat, sigma = layernorm_forward(x, gamma, beta)
dy = np.random.randn(4).astype(np.float64)  # upstream gradient

# Analytical gradient
dx_analytical, _, _ = layernorm_backward(dy, x_hat, sigma, gamma)

# Numerical gradient
eps_num = 1e-5
dx_numerical = np.zeros_like(x)
for i in range(len(x)):
    x_plus = x.copy(); x_plus[i] += eps_num
    x_minus = x.copy(); x_minus[i] -= eps_num
    y_plus, _, _ = layernorm_forward(x_plus, gamma, beta)
    y_minus, _, _ = layernorm_forward(x_minus, gamma, beta)
    dx_numerical[i] = ((y_plus - y_minus) * dy).sum() / (2 * eps_num)

print("Analytical gradient: ", dx_analytical)
print("Numerical gradient:  ", dx_numerical)
print(f"Max difference: {np.abs(dx_analytical - dx_numerical).max():.2e}")

---
## Part 2: RMS Normalization

### 2.1 Motivation

RMSNorm (Zhang & Sennrich, 2019) simplifies LayerNorm by removing the mean-centering step. The hypothesis: **re-centering is not essential; re-scaling is what matters.**

Benefits:
- ~10-15% faster (no mean computation)
- Fewer operations in backward pass
- Empirically works just as well

### 2.2 Forward Pass

Given input $\mathbf{x} \in \mathbb{R}^d$:

**Step 1: Compute RMS (Root Mean Square)**
$$\text{RMS}(x) = \sqrt{\frac{1}{d}\sum_{i=1}^{d} x_i^2 + \epsilon}$$

**Step 2: Normalize**
$$\hat{x}_i = \frac{x_i}{\text{RMS}(x)}$$

**Step 3: Scale (learnable)**
$$y_i = \gamma_i \hat{x}_i$$

Note: No $\beta$ (shift) parameter! The mean is not centered, so a shift would just be absorbed.

**Compact form:**
$$y = \gamma \odot \frac{x}{\text{RMS}(x)}$$

In [None]:
# Verify RMSNorm forward
x = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
eps = 1e-6

# Manual computation
rms = np.sqrt((x ** 2).mean() + eps)
x_hat = x / rms

print(f"Input x:    {x}")
print(f"RMS:        {rms:.6f}")
print(f"Normalized: {x_hat}")
print(f"\nNote: mean(x_hat) = {x_hat.mean():.4f} (NOT zero, unlike LayerNorm)")
print(f"But: RMS(x_hat) = {np.sqrt((x_hat**2).mean()):.4f} (IS one)")

### 2.3 Backward Pass Derivation

The backward pass is simpler than LayerNorm because there's no mean to track.

#### Parameter gradient

From $y_i = \gamma_i \hat{x}_i$:

$$\frac{\partial L}{\partial \gamma_i} = \sum_{\text{batch, seq}} \bar{y}_i \cdot \hat{x}_i$$

#### Input gradient

Let $r = \text{RMS}(x) = \sqrt{\frac{1}{d}\sum_j x_j^2 + \epsilon}$

Then $\hat{x}_i = \frac{x_i}{r}$ and $y_i = \gamma_i \hat{x}_i$.

The gradient $\bar{x}_i = \frac{\partial L}{\partial x_i}$ has two paths:
1. Direct: through $\hat{x}_i = x_i / r$
2. Indirect: through $r$ which depends on all $x_j$

**Direct path:**
$$\frac{\partial \hat{x}_i}{\partial x_i}\bigg|_{r\text{ fixed}} = \frac{1}{r}$$

**Indirect path through $r$:**
$$\frac{\partial r}{\partial x_i} = \frac{x_i}{d \cdot r}$$

$$\frac{\partial \hat{x}_j}{\partial r} = -\frac{x_j}{r^2} = -\frac{\hat{x}_j}{r}$$

Combining with chain rule:
$$\frac{\partial L}{\partial x_i} = \frac{\bar{\hat{x}}_i}{r} - \frac{x_i}{d \cdot r^2} \sum_j \bar{\hat{x}}_j \hat{x}_j$$

where $\bar{\hat{x}}_i = \bar{y}_i \gamma_i$.

**Simplified form:**
$$\boxed{\frac{\partial L}{\partial x} = \frac{1}{r}\left( \bar{\hat{x}} - \hat{x} \cdot \text{mean}(\bar{\hat{x}} \odot \hat{x}) \right)}$$

Compare to LayerNorm's backward:
$$\frac{\partial L}{\partial x} = \frac{1}{\sigma}\left( \bar{\hat{x}} - \text{mean}(\bar{\hat{x}}) - \hat{x} \cdot \text{mean}(\bar{\hat{x}} \odot \hat{x}) \right)$$

RMSNorm is missing the $- \text{mean}(\bar{\hat{x}})$ term (because no mean-centering in forward).

In [None]:
# Verify RMSNorm backward pass
def rmsnorm_forward(x, gamma, eps=1e-6):
    rms = np.sqrt((x ** 2).mean() + eps)
    x_hat = x / rms
    y = gamma * x_hat
    return y, x_hat, rms

def rmsnorm_backward(dy, x_hat, rms, gamma):
    """Analytical backward pass."""
    dx_hat = dy * gamma
    
    # The key formula (simpler than LayerNorm!)
    m = (dx_hat * x_hat).mean()  # mean(dx_hat * x_hat)
    dx = (dx_hat - x_hat * m) / rms
    
    dgamma = (dy * x_hat).sum()
    return dx, dgamma

# Numerical gradient check
x = np.random.randn(4).astype(np.float64)
gamma = np.ones(4, dtype=np.float64)

y, x_hat, rms = rmsnorm_forward(x, gamma)
dy = np.random.randn(4).astype(np.float64)

# Analytical gradient
dx_analytical, _ = rmsnorm_backward(dy, x_hat, rms, gamma)

# Numerical gradient
eps_num = 1e-5
dx_numerical = np.zeros_like(x)
for i in range(len(x)):
    x_plus = x.copy(); x_plus[i] += eps_num
    x_minus = x.copy(); x_minus[i] -= eps_num
    y_plus, _, _ = rmsnorm_forward(x_plus, gamma)
    y_minus, _, _ = rmsnorm_forward(x_minus, gamma)
    dx_numerical[i] = ((y_plus - y_minus) * dy).sum() / (2 * eps_num)

print("Analytical gradient: ", dx_analytical)
print("Numerical gradient:  ", dx_numerical)
print(f"Max difference: {np.abs(dx_analytical - dx_numerical).max():.2e}")

---
## Part 3: Comparison and Insights

### Why does RMSNorm work without centering?

LayerNorm's centering ($x - \mu$) removes the DC component of the signal. But:

1. **Transformers use residual connections**: $y = x + f(x)$. Even if $f(x)$ has non-zero mean, the residual adds it back.

2. **Learnable $\gamma$ can compensate**: If mean matters, the network can learn to encode it in the scale.

3. **Empirically equivalent**: Zhang & Sennrich (2019) showed no accuracy loss on translation tasks.

### Computational comparison

In [None]:
# Count operations
d = 512  # typical model dimension

print("Forward pass operations per vector:")
print(f"  LayerNorm: {d} (sum for μ) + {d} (sub) + {d} (sq) + {d} (sum for σ²) + {d} (div) + {d} (mul) + {d} (add) = ~{7*d}")
print(f"  RMSNorm:   {d} (sq) + {d} (sum) + {d} (div) + {d} (mul) = ~{4*d}")
print(f"\nRMSNorm is ~{7/4:.1f}x fewer operations")

### Gradient flow comparison

Both normalizations ensure gradients don't vanish/explode by keeping activations bounded. But their Jacobians differ:

**LayerNorm Jacobian** (how $\partial y_i / \partial x_j$ varies):
- Diagonal: $\frac{\gamma_i}{\sigma}$
- Off-diagonal: $-\frac{\gamma_i}{d\sigma}(1 + \hat{x}_i\hat{x}_j)$

**RMSNorm Jacobian:**
- Diagonal: $\frac{\gamma_i}{r}(1 - \frac{\hat{x}_i^2}{d})$
- Off-diagonal: $-\frac{\gamma_i \hat{x}_i \hat{x}_j}{d \cdot r}$

Both have full-rank Jacobians (good for gradient flow), but RMSNorm's is slightly sparser.

---
## Part 4: Verify Against Implementation

In [None]:
# Test our actual implementation
np.random.seed(123)
x = np.random.randn(2, 4, 8).astype(np.float32)  # (batch, seq, dim)
dy = np.random.randn(2, 4, 8).astype(np.float32)

# LayerNorm
ln = LayerNorm(d_model=8)
y_ln = ln.forward(x)
dx_ln = ln.backward(dy)

print("LayerNorm:")
print(f"  Input shape:  {x.shape}")
print(f"  Output shape: {y_ln.shape}")
print(f"  Output mean:  {y_ln.mean(axis=-1)[0,0]:.6f} (should be ≈0)")
print(f"  Output var:   {y_ln.var(axis=-1)[0,0]:.6f} (should be ≈1)")

# RMSNorm
rn = RMSNorm(d_model=8)
y_rn = rn.forward(x)
dx_rn = rn.backward(dy)

print("\nRMSNorm:")
print(f"  Input shape:  {x.shape}")
print(f"  Output shape: {y_rn.shape}")
print(f"  Output RMS:   {np.sqrt((y_rn**2).mean(axis=-1))[0,0]:.6f} (should be ≈1)")

---
## Summary

| Aspect | LayerNorm | RMSNorm |
|--------|-----------|----------|
| Forward | $\gamma \odot \frac{x-\mu}{\sigma} + \beta$ | $\gamma \odot \frac{x}{\text{RMS}(x)}$ |
| Parameters | $\gamma, \beta$ | $\gamma$ only |
| Centers mean? | Yes | No |
| Normalizes scale? | Yes (to unit var) | Yes (to unit RMS) |
| Backward complexity | Higher | Lower |
| Used in | BERT, GPT-2/3, Original Transformer | LLaMA, Gemma, T5 |

**Key insight**: The mean-centering in LayerNorm is mathematically elegant but practically unnecessary in deep residual networks.