# Layer Normalization (LayerNorm)

**What is LayerNorm?**  
Layer Normalization is a technique to stabilize and accelerate training by normalizing
the inputs across the **features** of each sample (not across the batch like BatchNorm).
It is especially important in **Transformers**, where it is used after attention
and feed-forward blocks.

---

## 🔹 How it works
For an input vector $ x \in \mathbb{R}^d $:

1. Compute the mean:
$$
\mu = \frac{1}{d} \sum_{i=1}^{d} x_i
$$

2. Compute the variance:
$$
\sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2
$$

3. Normalize:
$$
\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}
$$

4. Scale and shift with learnable parameters:
$$
y_i = \gamma \hat{x}_i + \beta
$$

---

## 🔹 Key Properties
- Works **independently per sample** → does not depend on batch size.  
- Normalizes across **features (d)**, not across the batch.  
- Always used in **Transformer layers** to stabilize training.  
- Parameters:  
  - $ \gamma $: learnable scale  
  - $ \beta $: learnable bias  

---

## 🔹 Why Transformers use LayerNorm
- Handles variable sequence lengths and small batch sizes (where BatchNorm fails).  
- Keeps activations stable, allowing very deep architectures to train.  
- Essential for attention-based models (BERT, GPT, etc.).  

---
✅ In practice: Every Transformer block applies LayerNorm around self-attention and feed-forward layers.


# 🧾 BatchNorm vs LayerNorm: Why Each Fits Different Models

## 🔹 Why BatchNorm is better in CNNs
- **Spatial consistency across batch helps**: In images, each channel has similar statistics across a batch.  
  BatchNorm leverages this by normalizing per channel across the batch and pixels.  
- **Acts as regularization**: Randomness in batch statistics acts like noise injection → improves generalization.  
- **Efficiency**: Very efficient to compute in CNNs since it operates channel-wise over the batch.  
- **If we use LayerNorm in CNNs**: We lose the batch-level regularization effect and often see slower or worse convergence.

---

## 🔹 Why LayerNorm is better in Transformers
- **Batch statistics don’t fit sequences**: In NLP/Transformers, inputs have variable sequence lengths and padding.  
  BatchNorm would give unstable or misleading statistics.  
- **Small/variable batch sizes**: Transformers often train with small or inconsistent batch sizes (long sequences).  
  BatchNorm becomes unstable, but LayerNorm works even with `batch_size = 1`.  
- **Stability in deep models**: Transformers can be *hundreds of layers deep*.  
  LayerNorm guarantees stable per-token activations → avoids exploding/vanishing signals.  
- **Autoregressive inference**: GPT-style models often decode **one token at a time**.  
  BatchNorm breaks in this case, but LayerNorm still works.

---

## Summary
- **BatchNorm** → Great for **CNNs** (images), where batch statistics are meaningful and helpful.  
- **LayerNorm** → Essential for **Transformers / RNNs** (sequences), where batch statistics are unstable or undefined.  

👉 That’s why:  
- Vision models (ResNet, VGG, etc.) rely on **BatchNorm**.  
- Transformer models (BERT, GPT, T5, etc.) rely on **LayerNorm**.


## 🧪 Using `nn.LayerNorm` like a Transformer

Transformers normalize **each token’s hidden vector** (dimension \(D\)) in tensors of shape `(N, T, D)`
- `N`: batch size
- `T`: sequence length (tokens/time)
- `D`: model width (hidden size)

We therefore set `normalized_shape=D` so PyTorch normalizes across the **last dimension** only.
Below: a demo with dummy data; we show pre-LN (common in modern Transformers) around a residual block.


In [1]:
import torch
from torch import nn

torch.manual_seed(0)

N, T, D = 2, 5, 16         # batch, sequence length, hidden dim
hidden = torch.randn(N, T, D)

ln = nn.LayerNorm(D, eps=1e-5)  # normalize per token vector of size D
ff = nn.Sequential(
    nn.Linear(D, 4*D),
    nn.GELU(),
    nn.Linear(4*D, D),
)

# Pre-LN Transformer-style block: Ln -> SubLayer -> Residual
def transformer_block(x):
    # LayerNorm per token
    h = ln(x)
    # Feed-forward sublayer
    h = ff(h)
    # Residual connection
    return x + h

out = transformer_block(hidden)
loss = out.pow(2).mean()
loss.backward()

print("out shape:", out.shape)
print("ln.weight.shape:", ln.weight.shape)
print("grad on ln.weight (first 5):", ln.weight.grad[:5])


out shape: torch.Size([2, 5, 16])
ln.weight.shape: torch.Size([16])
grad on ln.weight (first 5): tensor([-0.0033, -0.0077,  0.0009,  0.0047,  0.0165])


# 📘 Tiny Transformer Block Demo with LayerNorm

This example shows how to implement a **minimal Transformer encoder block** in PyTorch using:

- **`nn.MultiheadAttention`** for self-attention  
- **Residual connections** around attention and feed-forward layers  
- **`nn.LayerNorm`** (Pre-LN style) for stability  
- **A simple classifier head** on top of the Transformer block  

The input is shaped `(N, T, D)`:
- `N`: batch size  
- `T`: sequence length (number of tokens)  
- `D`: embedding dimension  

We train it on dummy data to demonstrate:
1. Forward and backward passes work.
2. LayerNorm is applied correctly.
3. The model can output logits for classification.

This mirrors the structure of real Transformers (e.g., BERT, GPT) but in a **tiny, easy-to-read version**.


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
# PyTorch Transformer Block with LayerNorm — full working demo

# Define a tiny Transformer block
class SmallTransformerBlock(nn.Module):
    def __init__(self, d_model=32, nhead=4, dim_ff=64):
        super().__init__()
        # Multi-head self-attention
        self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, batch_first=True)
        # Feed-forward block
        self.ff = nn.Sequential(
            nn.Linear(d_model, dim_ff),
            nn.ReLU(inplace=True),
            nn.Linear(dim_ff, d_model),
        )
        # Two LayerNorms (pre-norm style)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

    def forward(self, x):
        # Self-attention with residual
        h = self.ln1(x)
        attn_out, _ = self.attn(h, h, h)
        x = x + attn_out

        # Feed-forward with residual
        h = self.ln2(x)
        ff_out = self.ff(h)
        x = x + ff_out
        return x

# Create dummy token batch
torch.manual_seed(0)
N, T, D = 8, 10, 32  # batch size, sequence length, hidden dim
X = torch.randn(N, T, D)
y = torch.randint(0, 5, (N,))  # dummy labels for 5 classes

# Wrap block into a simple classifier
class SmallTransformerClassifier(nn.Module):
    def __init__(self, d_model=32, nhead=4, dim_ff=64, num_classes=5):
        super().__init__()
        self.block = SmallTransformerBlock(d_model, nhead, dim_ff)
        self.head = nn.Linear(d_model, num_classes)

    def forward(self, x):
        x = self.block(x)
        # Take mean over sequence dimension as pooling
        x = x.mean(dim=1)
        logits = self.head(x)
        return logits

# Instantiate model, loss, optimizer
model = SmallTransformerClassifier(d_model=D, nhead=4, dim_ff=64, num_classes=5)
criterion = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=1e-3)

print(model)

# TRAINING mode
model.train()
logits_train = model(X)
loss = criterion(logits_train, y)
opt.zero_grad()
loss.backward()
opt.step()
print(f"\n[TRAIN] logits shape: {logits_train.shape}, loss: {loss.item():.4f}")

# EVAL/INFERENCE mode
model.eval()
with torch.no_grad():
    logits_eval = model(X)
print(f"[EVAL ] logits shape: {logits_eval.shape}")

# check a LayerNorm’s params
ln1 = model.block.ln1
print(f"\nLayerNorm weight shape: {ln1.weight.shape}")
print(f"LayerNorm bias   shape: {ln1.bias.shape}")


SmallTransformerClassifier(
  (block): SmallTransformerBlock(
    (attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
    )
    (ff): Sequential(
      (0): Linear(in_features=32, out_features=64, bias=True)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=64, out_features=32, bias=True)
    )
    (ln1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    (ln2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
  )
  (head): Linear(in_features=32, out_features=5, bias=True)
)

[TRAIN] logits shape: torch.Size([8, 5]), loss: 1.6507
[EVAL ] logits shape: torch.Size([8, 5])

LayerNorm weight shape: torch.Size([32])
LayerNorm bias   shape: torch.Size([32])


## 🔧 From-Scratch LayerNorm: forward + manual backward

We normalize **per sample across features** (last dimension).  
Given $x\in\mathbb{R}^{\dots \times D}$:

- $\mu=\text{mean}(x,\text{dim}=-1)$,
- $\sigma^2=\text{mean}\big((x-\mu)^2,\text{dim}=-1\big)$,
- $\hat{x}=(x-\mu)/\sqrt{\sigma^2+\varepsilon}$,
- $y=\gamma\hat{x}+\beta$.

**Backward (per sample, across features \(D\))**  
Let $g=\frac{\partial \mathcal{L}}{\partial y}$ and $m=D$.  
Then
$$
\begin{aligned}
\frac{\partial \mathcal{L}}{\partial \beta} &= \sum g, \\
\frac{\partial \mathcal{L}}{\partial \gamma} &= \sum (g\odot \hat{x}), \\
\text{with } q &= g \odot \gamma, \\
\frac{\partial \mathcal{L}}{\partial x}
&= \frac{1}{m}\cdot\frac{1}{\sqrt{\sigma^2+\varepsilon}}\Big(
m\,q \;-\; \sum q \;-\; \hat{x}\,\sum (q\odot \hat{x})
\Big),
\end{aligned}
$$
where the sums are over the **feature** axis (keep dimensions for broadcasting).

Below is a minimal, transparent implementation plus a small gradient check against PyTorch autograd.


In [4]:
class LayerNormScratch:
    """LayerNorm over the last dimension only (… , D)."""
    def __init__(self, D, eps=1e-5, device="cpu", dtype=torch.float32):
        self.D = D
        self.eps = eps
        self.gamma = torch.ones(D, device=device, dtype=dtype)
        self.beta  = torch.zeros(D, device=device, dtype=dtype)
        self.dgamma = torch.zeros_like(self.gamma)
        self.dbeta  = torch.zeros_like(self.beta)
        self.cache = None

    def forward(self, x):
        mu  = x.mean(dim=-1, keepdim=True)
        var = ((x - mu) ** 2).mean(dim=-1, keepdim=True)
        std = torch.sqrt(var + self.eps)
        xhat = (x - mu) / std
        y = xhat * self.gamma + self.beta
        self.cache = (xhat, std, x)
        return y

    def backward(self, dout):
        xhat, std, x = self.cache
        m = x.shape[-1]                       # number of features
        # param grads (sum over all non-feature dims)
        lead_axes = tuple(range(dout.dim()-1))
        self.dgamma = (dout * xhat).sum(dim=lead_axes)
        self.dbeta  = dout.sum(dim=lead_axes)
        # input grad
        q = dout * self.gamma
        q_sum = q.sum(dim=-1, keepdim=True)
        qxhat_sum = (q * xhat).sum(dim=-1, keepdim=True)
        dx = (1.0 / m) * (1.0 / std) * (m * q - q_sum - xhat * qxhat_sum)
        return dx


In [5]:
torch.manual_seed(0)

# 2D check
N, D = 4, 6
x = torch.randn(N, D, requires_grad=True)
ln = LayerNormScratch(D)

y = ln.forward(x)
loss = (y**2).mean()
dout = torch.autograd.grad(loss, y, create_graph=True)[0]
dx_manual = ln.backward(dout)

# autograd reference
gamma = ln.gamma.detach().clone().requires_grad_(True)
beta  = ln.beta.detach().clone().requires_grad_(True)
mu = x.mean(dim=-1, keepdim=True)
var = ((x - mu) ** 2).mean(dim=-1, keepdim=True)
std = torch.sqrt(var + ln.eps)
xhat = (x - mu) / std
y_ref = xhat * gamma + beta
loss_ref = (y_ref**2).mean()
dx_auto, dgamma_auto, dbeta_auto = torch.autograd.grad(loss_ref, (x, gamma, beta))

print("[2D] max |dx diff|   =", (dx_manual - dx_auto).abs().max().item())
print("[2D] max |dgamma diff| =", (ln.dgamma - dgamma_auto).abs().max().item())
print("[2D] max |dbeta diff|  =", (ln.dbeta  - dbeta_auto ).abs().max().item())

# 3D check
N, T, D = 2, 5, 16
x3 = torch.randn(N, T, D, requires_grad=True)
ln3 = LayerNormScratch(D)

y3 = ln3.forward(x3)
loss3 = (y3**2).mean()
dout3 = torch.autograd.grad(loss3, y3, create_graph=True)[0]
dx3_manual = ln3.backward(dout3)

gamma3 = ln3.gamma.detach().clone().requires_grad_(True)
beta3  = ln3.beta.detach().clone().requires_grad_(True)
mu3 = x3.mean(dim=-1, keepdim=True)
var3 = ((x3 - mu3) ** 2).mean(dim=-1, keepdim=True)
std3 = torch.sqrt(var3 + ln3.eps)
xhat3 = (x3 - mu3) / std3
y3_ref = xhat3 * gamma3 + beta3
loss3_ref = (y3_ref**2).mean()
dx3_auto, dgamma3_auto, dbeta3_auto = torch.autograd.grad(loss3_ref, (x3, gamma3, beta3))

print("[3D] max |dx diff|   =", (dx3_manual - dx3_auto).abs().max().item())
print("[3D] max |dgamma diff| =", (ln3.dgamma - dgamma3_auto).abs().max().item())
print("[3D] max |dbeta diff|  =", (ln3.dbeta  - dbeta3_auto ).abs().max().item())

# ------------ Quick nn.LayerNorm demo ------------
hidden = torch.randn(N, T, D)
ln_pt = nn.LayerNorm(D)
out = ln_pt(hidden)
print("nn.LayerNorm out shape:", out.shape)


[2D] max |dx diff|   = 1.923558556882199e-08
[2D] max |dgamma diff| = 0.0
[2D] max |dbeta diff|  = 0.0
[3D] max |dx diff|   = 5.024730853619985e-09
[3D] max |dgamma diff| = 0.0
[3D] max |dbeta diff|  = 0.0
nn.LayerNorm out shape: torch.Size([2, 5, 16])
