## Normalization

Normalization is a common strategy to map values with different ranges (distributions) to the same range.

Normalization in deep learning could:
- mitigates vanishing gradient
- speed up model convergence 
- boost generalization

Based on different scenarios and network architecture, there are multiple normalization techniques:
1. Batch Norm
2. Layer Norm
3. Instance Norm
4. Group Norm

In [1]:
import torch
import torch.nn as nn
import math
from einops import rearrange

### Batch Norm (BN)

**Intuition**: Compute the mean and variance for each channel along the batch dimension before activation functions, and normalize the activation values to a distribution with mean 0 and variance 1 (plus learnable scaling and bias parameters).

**Background**: when performing gradient descent on mini-batches, the distribution of each batch is different, causing the input distribution of each network layer to continuously shift during training (*Internal Covariate Shift*, ICS). This makes model training difficult, slows down convergence, and can even lead to Vanishing Gradient or Exploding gradient problems.

BN **Effect**:
- Mitigate internal covariate shift, prevent Vanishing Gradient / Exploding gradient
- Allow using big Learning rate
- Reduce the sensitivity to weight initialization
- A minor Regularization effect on hidden layer's output

Potential **problems**:
- When *batch size small*: the mean / variance is not accurate to represent the entire data distribution
- When *batch size big*: memory cost ↑
- Not suitable for *variable length situation* (like sequence input)

![batch-norm-illustration](../assets/imgs/batch-norm-illustration.png)

#### Calculation

Assuming the output dimension of a layer is `[B, C, H, W]`, where `B` represents batch size, `C` represents number of channels, and `H` and `W` represent the height and width of the feature map respectively. We compute the mean and variance for each batch along the channel dimension, i.e., performing statistics on all sample pixels within the same channel:

1. **Calculate mean** $\mu_c$: $$\mu_c = \frac{1}{B \times H \times W} \sum_{b=1}^{B} \sum_{h=1}^{H} \sum_{w=1}^{W} x_{b,c,h,w}$$
2. **Calculate variance** $\sigma_c^2$: $$\sigma_c^2 = \frac{1}{B \times H \times W} \sum_{b=1}^{B} \sum_{h=1}^{H} \sum_{w=1}^{W} (x_{b,c,h,w} - \mu_c)^2$$
3. **Normalization**: $$\hat{x}_{b,c,h,w} = \frac{x_{b,c,h,w} - \mu_c}{\sqrt{\sigma_c^2 + \epsilon}}$$
4. **Scale and shift**: $$y_{b,c,h,w} = \gamma_c \hat{x}_{b,c,h,w} + \beta_c$$
Where $\gamma_c$ and $\beta_c$ are learnable parameters (initially set to $\gamma_c = 1, \beta_c = 0$).

Since it's usually image input that uses Batch normalization, I chose `BatchNorm2d` here for demonstration.

In [2]:
batch_size, num_features, height, weight = 20, 100, 35, 45
test_input = torch.randn(batch_size, num_features, height, weight)

In [3]:
# Torch calculation
# Without Learnable Parameters
m = nn.BatchNorm2d(num_features, affine=False)
torch_output = m(test_input)

In [4]:
# Manual calculation
eps = 1e-5 # defualt
mu_c = test_input.mean(dim=(0,2,3), keepdim=True) # Over Batch, hw dimensions
var_c = test_input.var(dim=(0,2,3), keepdim=True, unbiased=False) # Use population variance
manual_output = (test_input - mu_c) / torch.sqrt(var_c + eps)

In [5]:
torch.allclose(torch_output, manual_output, atol=1e-5)

True

### Layer Norm (LN)

For sequence based tasks, batch size could be very small (even to 1), and the input length varies, so Batch Normalization do not work so well. **Layer Norm** is proposed to normalize over a single instance's feature.

The mean and standard-deviation are calculated over the last D dimensions, where D is the dimension of `normalized_shape`. For example, if `normalized_shape` is `(3, 5)` (a 2-dimensional shape), the mean and standard-deviation are computed over the last 2 dimensions of the input (i.e. `input.mean((-2, -1))`).

![layer-norm-illustration](../assets/imgs/layer-norm-illustration.png)

$$\mathbf{y} = \frac{\mathbf{x} - \text{E}(\mathbf{x})}{\sqrt{\text{Var}(\mathbf{x}) + \varepsilon}} * \gamma + \beta$$

Characteristics of LN:
- **Not dependent on Batch**
- Good performance on sequence models (Recurrent Neural Network, Transformer): as it's irrelevant to batch size and sequence length
- Not as good as batch normalization on CNN: as each channel is included in calculation, some local features may be lost.

In [6]:
# NLP Example
batch, seq_len, embedding_dim = 4, 5, 10
embedding = torch.randn(batch, seq_len, embedding_dim)
layer_norm = nn.LayerNorm(embedding_dim)
# Activate module
torch_output = layer_norm(embedding)

In [7]:
# Manual Layer Normalization
eps = 1e-5
# Compute mean and variance across the last dimension (embedding_dim)
mu = embedding.mean(dim=-1, keepdim=True)  # Shape: (4, 5, 1)
var = embedding.var(dim=-1, keepdim=True, unbiased=False)  # Shape: (4, 5, 1)
manual_output = (embedding - mu) / torch.sqrt(var + eps)

In [8]:
torch.allclose(torch_output, manual_output, atol=1e-5)

True

In [9]:
# Image Example
B, C, H, W = 20, 5, 10, 10
test_input = torch.randn(B, C, H, W)
# Normalize over the last three dimensions (i.e. the channel and spatial dimensions)
layer_norm = nn.LayerNorm([C, H, W], elementwise_affine=False) # Remove learnable params
torch_output = layer_norm(test_input)

In [10]:
# Manual Layer Normalization
eps = 1e-5
mu = test_input.mean(dim=(1,2,3), keepdim=True)   # Shape: (B, 1, 1, 1), over channel,HW dimensions
var = test_input.var(dim=(1,2,3), keepdim=True, unbiased=False)  
manual_output = (test_input - mu) / torch.sqrt(var + eps)

In [11]:
torch.allclose(torch_output, manual_output, atol=1e-5)

True

### Instance Norm

Instance normalization are proposed in *image style transfer* task - [paper](https://arxiv.org/abs/1607.08022).

In style transfer networks, each image's style is often only affected by that single image's feature distribution. Batch Normalization's "cross-sample" normalization cannot preserve the style differences of individual images. Therefore, normalization needs to be **performed separately on each image's feature map**, operating on the H×W dimensions while preserving the batch and channel dimensions.

![instance_norm](../assets/imgs/instance-norm-illustration.png)

$$\begin{align*}
μ_{n,c} &= \frac{1}{HW}∑_{h=1}^H ∑_{w=1}^W x_{n,c,h,w}\\

\sigma^2_{n,c} &= \frac{1}{HW} ∑_{h=1}^H ∑_{w=1}^W (x_{n,c,h,w} - \mu_{n,c})^2\\

y_{n,c,h,w} &= (x_{n,c,h,w} - μ_{n,c})/\sqrt{(\sigma^2_{n,c} + ε)}
\end{align*}
$$

Characteristics of IN:
- Normalization over each sample and channel, presevered inter-sample and inter-channel features
- Good for style transfer

In [12]:
batch_size, num_features, height, weight = 20, 100, 35, 45
test_input = torch.randn(batch_size, num_features, height, weight)

In [13]:
# Torch calculation
# Without Learnable Parameters
m = nn.InstanceNorm2d(num_features, affine=False)
torch_output = m(test_input)

In [14]:
# Manual calculation
eps = 1e-5 # defualt
mu_c = test_input.mean(dim=(2,3), keepdim=True) # Over HW dimensions
var_c = test_input.var(dim=(2,3), keepdim=True, unbiased=False) # Use population variance
manual_output = (test_input - mu_c) / torch.sqrt(var_c + eps)

In [15]:
torch.allclose(torch_output,manual_output,atol=1e-5)

True

### Group Norm

Group normalization are proposed to cater for the dependence of Batch Normalization on the *batch size*. When the batch size is small, BN's estimated mean and variance become biased. 

GN proposed to **group channels** then perform Layer-norm like norm within each group.

![group-norm-illustration.png](../assets/imgs/group-norm-illustration.png)


$$\mu_{n,g} = \frac{1}{(C/G) H W} \sum_{c=g\frac{C}{G}}^{(g+1)\frac{C}{G}-1} \sum_{h=1}^{H} \sum_{w=1}^{W} x_{n,c,h,w}$$

$$\sigma_{n,g} = \sqrt{\frac{1}{(C/G) H W} \sum_{c=g\frac{C}{G}}^{(g+1)\frac{C}{G}-1} \sum_{h=1}^{H} \sum_{w=1}^{W} (x_{n,c,h,w} - \mu_{n,g})^2 + \varepsilon}$$



Focus more on channel structure than Layer Norm, but not rely too much on batch size like Batch Norm.

In [16]:
batch_size, num_features, height, weight = 20, 100, 35, 45
test_input = torch.randn(batch_size, num_features, height, weight)
num_groups = 4
m = nn.GroupNorm(num_groups,num_features,affine=False)
torch_output = m(test_input)

In [17]:
eps = 1e-5 # defualt
assert num_features % num_groups == 0, f"Channels {num_features} must be divisible by num_groups {num_groups}"
channel_per_group = num_features // num_groups

# einops in-replace of torch.view
test_input_grouped = rearrange(test_input, 'b (ng gc) h w -> b ng gc h w', ng=num_groups)

mu_c = test_input_grouped.mean(dim=(2,3,4), keepdim=True) # Over HW dimensions
var_c = test_input_grouped.var(dim=(2,3,4), keepdim=True, unbiased=False) # Use population variance
normalized_output = (test_input_grouped - mu_c) / torch.sqrt(var_c + eps)
manual_output = rearrange(normalized_output, 'b ng gc h w -> b (ng gc) h w')

In [18]:
torch.allclose(torch_output, manual_output, atol=1e-5)

True

### RMS Norm

RMS Norm (Root Mean Square Layer Normalization) is an improvement over Layer Norm: it **removes the dependence on the mean** and only depends on the variance of the vector. In some large models (such as certain large language models), RMS Norm has been *proven to reduce training instability and achieve better or faster convergence*.

$$
\text{RMS}(\mathbf{a}) = \sqrt{\frac{1}{n}\sum_{i=1}^n a_i^2}
$$

- **Advantages**: Removes mean operations, reduces computation; provides better stability in some situations.
- **Disadvantages**: Compared to LN, RMS Norm lacks the "centering" process, which may impact feature distribution.

In [19]:
batch_size, num_features, height, weight = 20, 100, 35, 45
test_input = torch.randn(batch_size, num_features, height, weight)
m = nn.RMSNorm([num_features,height,weight], elementwise_affine=False)
torch_output = m(test_input)

In [20]:
eps = 1e-5 # defualt
mu_c = torch.mean(test_input**2, dim=(1,2,3), keepdim=True)
manual_output = test_input / torch.sqrt(mu_c + eps)

In [21]:
torch.allclose(torch_output, manual_output, atol=1e-5)

True