In [1]:
import torch
import torch.nn as nn

# [Batch Normalization](https://arxiv.org/pdf/1502.03167)

Given a mini-batch of activations $ x = \{x_1, x_2, \dots, x_m\} $, the batch normalization steps are as follows:

1. Compute the mini-batch mean:
$$ 
\mu_B = \frac{1}{m} \sum_{i=1}^m x_i 
$$

2. Compute the mini-batch variance:
$$ 
\sigma_B^2 = \frac{1}{m} \sum_{i=1}^m (x_i - \mu_B)^2
$$



3. Normalize the input:
$$
\hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}
$$

4. Scale and shift:
$$
y_i = \gamma \hat{x}_i + \beta
$$

where $ \gamma $ and $ \beta $ are learnable parameters that allow the model to undo normalization if needed, and $ \epsilon $ is a small constant for numerical stability.

---

### When input $ X \in \mathbb{R}^{B \times C \times H \times W} $ is a batch of image features:

BatchNorm is applied **per channel**, across batch and spatial dimensions:

1. Compute mean:
   $$
   \mu_c = \frac{1}{BHW} \sum_{b=1}^{B} \sum_{h=1}^{H} \sum_{w=1}^{W} X_{bchw}
   $$

2. Compute variance:
   $$
   \sigma_c^2 = \frac{1}{BHW} \sum_{b=1}^{B} \sum_{h=1}^{H} \sum_{w=1}^{W} \left(X_{bchw} - \mu_c\right)^2
   $$

3. Normalize:
   $$
   \hat{X}_{bchw} = \frac{X_{bchw} - \mu_c}{\sqrt{\sigma_c^2 + \epsilon}}
   $$

4. Scale and shift:
   $$
   Y_{bchw} = \gamma_c \hat{X}_{bchw} + \beta_c
   $$

Where $ \gamma_c $, $ \beta_c $ are learnable (`affine`) parameters per channel.

`momentum`: $\hat{x}_{\text{new}} = (1 - \text{momentum}) \cdot \hat{x} + \text{momentum} \cdot x_t$, where $\hat{x}$ is estimated statistic i.e. `running_mean` and $x_t$ is the new observed value i.e. `batch_mean`


<center>
<img src="./assets/batch-norm.png" alt="" width="350"/>
</center>

---

Vairance computation:
$$
\mathrm{Var}[x_c] = \mathbb{E}[(x_c)^2] - \left(\mathbb{E}[x_c]\right)^2
$$

In [2]:
class BatchNorm(nn.Module):
    def __init__(self, num_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True):
        super().__init__()
        self.num_channels = num_channels
        self.eps = eps # ϵ for numerical stability
        self.momentum = momentum # for exponential moving average
        self.affine = affine
        self.track_running_stats = track_running_stats

        # Learnable parameters for scale and shift
        if self.affine:
            self.scale = nn.Parameter(torch.ones(num_channels)) # γ
            self.shift = nn.Parameter(torch.zeros(num_channels)) # β

        # Buffers to store exponential moving averages of mean and variance
        if self.track_running_stats:
            # register tensor as a part of BatchNorm state w/o making it a learnable parameter
            self.register_buffer("exp_mean", torch.zeros(num_channels))
            self.register_buffer("exp_var", torch.ones(num_channels))

    def forward(self, x):
        batch_size = x.shape[0] # [B, C, H, W] or [B, C, ...]

        # Reshape tensor to shape: [B, C, H * W]
        reshaped_x = x.view(batch_size, self.num_channels, -1)

        if self.training: # Defined in base class, tracks whether module is in training or evaluation mode

            # Calculate batch mean and variance
            mean = reshaped_x.mean(dim=[0, 2]) # o/p shape: [C]
            
            # Compute variance efficiently (can use: var = x.var(dim=[0, 2], unbiased=False))
            x_squared_mean = (reshaped_x ** 2).mean(dim=[0, 2]) # o/p shape: [C]
            mean_square = (mean ** 2)
            var = x_squared_mean - mean_square # o/p shape: [C]

            if self.track_running_stats:
                self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
                self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var
        else:
            # Evaluation mode: Use stored moving averages 
            mean = self.exp_mean
            var = self.exp_var

        # Normalize
        # Add dimensions to [C] for broadcasting properly
        x_hat = (x - mean[None, :, None, None]) / torch.sqrt(var[None, :, None, None] + self.eps)

        # Apply affine transformation
        if self.affine:
            # The learnable parameters γ and β are defined per channel so 
            # reshape gamma and beta to [1, C, 1, 1] for broadcasting.
            scale = self.scale.view(1, -1, 1, 1)
            shift = self.shift.view(1, -1, 1, 1)
            x_hat = scale * x_hat + shift
            
        return x_hat

In [3]:
# Test
bn = BatchNorm(3)
x = torch.randn(4, 3, 8, 8)
y = bn(x)
print(f"{y.shape = }")

y.shape = torch.Size([4, 3, 8, 8])


# [Layer Normalization](https://arxiv.org/pdf/1607.06450)

Layer Normalization is a normalization technique used in neural networks to stabilize training, especially in recurrent networks (RNNs) and transformers.

LayerNorm normalizes across the features of each individual example.

It is generally used for NLP tasks.

---

### When input $ X \in \mathbb{R}^{B \times d} $, i.e A batch of feature vectors (e.g. embeddings), where:

- $ B $ is the batch size
- $ d $ is the feature dimension
- $ \gamma, \beta \in \mathbb{R}^{d} $ are learnable parameters

$$
\text{LN}(X) = \gamma \cdot \frac{X - \mathbb{E}_{d}[X]}{\sqrt{\mathrm{Var}_{d}[X] + \epsilon}} + \beta
$$

Where:
$$
\mathbb{E}_{d}[X] = \frac{1}{d} \sum_{i=1}^{d} X_i \quad \text{(mean across features)}
$$

$$
\mathrm{Var}_{d}[X] = \frac{1}{d} \sum_{i=1}^{d} (X_i - \mathbb{E}_{d}[X])^2 \quad \text{(variance across features)}
$$

---

### When input $ X \in \mathbb{R}^{L \times B \times d} $, i.e A sequence of feature vectors (e.g. transformer input), where:

- $ L $ is the sequence length
- $ B $ is the batch size
- $ d $ is the feature dimension
- $ \gamma, \beta \in \mathbb{R}^{d} $ are learnable parameters

$$
\text{LN}(X) = \gamma \cdot \frac{X - \mathbb{E}_{d}[X]}{\sqrt{\mathrm{Var}_{d}[X] + \epsilon}} + \beta
$$

Normalize across the feature dimension $ d $ for each time step and each sample.

---

In Layer Normalization, two learnable parameters are introduced for each feature dimension:
- $ \textbf{Gain} $ (also called $ \it{scale} $): $ \gamma \in \mathbb{R}^{d} $
- $ \textbf{Bias} $ (also called $ \it{shift} $): $ \beta \in \mathbb{R}^{d} $

These parameters are applied after normalizing the input, to allow the model to restore or adjust the representation scale if needed.

After normalization, all inputs have:
- Mean ≈ 0
- Variance ≈ 1

That might limit the model's expressiveness, so we give it the flexibility to learn:

- A new scale (gain $ \gamma $) e.g., “stretch” certain features
- A new offset (bias $ \beta $) e.g., “shift” feature means

This allows the model to undo normalization if it wants, or apply custom transformations per feature.


<center>
<img src="./assets/lyr-norm.png" alt="" width="350"/>
</center>

In [4]:
class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True):
        super().__init__()
        self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
        self.eps = eps
        self.elementwise_affine = elementwise_affine

        if self.elementwise_affine:
            self.gain = nn.Parameter(torch.ones(normalized_shape))
            self.bias = nn.Parameter(torch.zeros(normalized_shape))
        else:
            # Register tensors as learnable
            self.register_parameter('gain', None)
            self.register_parameter('bias', None)

    def forward(self, x):
        # Compute mean and variance over the last normalized_shape dimensions
        dims = tuple(range(-len(self.normalized_shape), 0)) # e.g. (-1,) for 1-D, (-2, -1) for 2-D
        # Another way: dims = [-(i + 1) for i in range(len(self.normalized_shape))]
        mean = x.mean(dim=dims, keepdim=True)
        var = x.var(dim=dims, unbiased=False, keepdim=True)

        # Normalize
        x_hat = (x - mean) / torch.sqrt(var + self.eps)

        # Apply elementwise affine transformation
        if self.elementwise_affine:
            x_hat = self.gain * x_hat + self.bias

        return x_hat

In [5]:
# Test
x = torch.randn(4, 3, 8, 8)
normalized_shape = x.shape[2:] # H x W
ln = LayerNorm(normalized_shape)
y = ln(x)
print(f"{y.shape = }")

y.shape = torch.Size([4, 3, 8, 8])


# [Instance Normalization](https://arxiv.org/pdf/1607.08022)

Instance Normalization (**IN**) is a type of normalization commonly used in Computer Vision tasks, especially in Style Transfer and Image Generation


Instance Normalization normalizes each sample and each channel independently, across spatial dimensions only.

Let the input be a 4-D tensor: $X \in \mathbb{R}^{B \times C \times H \times W}$ (e.g. batch of images)

For each sample $ b \in \{1, \dots, B\} $ and each channel $ c \in \{1, \dots, C\} $, compute:

$$
\mu_{bc} = \frac{1}{HW} \sum_{h=1}^{H} \sum_{w=1}^{W} X_{bchw}
$$

$$
\sigma_{bc}^2 = \frac{1}{HW} \sum_{h=1}^{H} \sum_{w=1}^{W} (X_{bchw} - \mu_{bc})^2
$$

$$
\text{IN}(X_{bchw}) = \gamma_c \cdot \frac{X_{bchw} - \mu_{bc}}{\sqrt{\sigma_{bc}^2 + \epsilon}} + \beta_c
$$

where:
- $ \gamma_c, \beta_c \in \mathbb{R} $ are learnable scale and shift parameters for each channel $ c $
- $ \epsilon $ is a small constant for numerical stability

<center>
<img src="./assets/inst-norm.png" alt="" width="350"/>
</center>

In [6]:
class InstanceNorm(nn.Module):
    def __init__(self, num_channels, eps=1e-5, affine=True):
        super().__init__()
        self.num_channels = num_channels
        self.eps = eps
        self.affine = affine

        # Learnable parameters for scale and shift
        if self.affine:
            self.gamma = nn.Parameter(torch.ones(num_channels))  # scale (γ)
            self.beta = nn.Parameter(torch.zeros(num_channels))  # shift (β)
        else:
            # Register tensors as learnable
            self.register_parameter('gamma', None)
            self.register_parameter('beta', None)

    def forward(self, x):
        B, C, H, W = x.shape

        # Reshape x to [B * C, H * W]
        x_reshaped = x.view(B * C, -1)

        # Compute mean and variance per (B, C)
        mean = x_reshaped.mean(dim=1, keepdim=True) # o/p shape: [B * C, 1]
        var = x_reshaped.var(dim=1, unbiased=False, keepdim=True) # o/p shape: [B * C, 1]

        # Normalize
        x_norm = (x_reshaped - mean) / torch.sqrt(var + self.eps) # o/p shape: [B * C, H * W]

        # Reshape back to [B, C, H, W]
        x_norm = x_norm.view(B, C, H, W)

        # Apply affine transformation
        if self.affine:
            # The learnable parameters γ and β are defined per channel so 
            # reshape gamma and beta to [1, C, 1, 1] for broadcasting.
            gamma = self.gamma.view(1, C, 1, 1)
            beta = self.beta.view(1, C, 1, 1)
            x_norm = gamma * x_norm + beta

        return x_norm

In [7]:
# Test
x = torch.randn(4, 3, 8, 8)
i_norm = InstanceNorm(3)
y = i_norm(x)
print(f"{y.shape = }")

y.shape = torch.Size([4, 3, 8, 8])


# [Group Normalization](https://arxiv.org/pdf/1803.08494)

Group Normalization (**GN**) is a normalization technique introduced as a more stable alternative to Batch Normalization, especially when batch sizes are small.

Group Normalization splits the channels of the input tensor into groups, then normalizes each group independently over the spatial dimensions and the group's channels. This provides the benefits of normalization like BatchNorm but does not depend on batch size.

Let the input be a 4-D tensor: $X \in \mathbb{R}^{B \times C \times H \times W}$ (e.g. batch of images)

where:
- $ B $ is the batch size
- $ C $ is the number of channels
- $ H $, $ W $ are the spatial dimensions

Group Normalization divides the $ C $ channels into $ G $ groups, each containing $ C/G $ channels. For each sample $ b \in \{1, \dots, B\} $ and each group $ g \in \{1, \dots, G\} $, it computes the mean and variance over all values in that group (across the group’s channels and spatial dimensions):

$$
\mu_{bg} = \frac{1}{m} \sum_{i=1}^{m} X_{bg,i}
$$

$$
\sigma_{bg}^2 = \frac{1}{m} \sum_{i=1}^{m} \left(X_{bg,i} - \mu_{bg}\right)^2
$$

where:
$$
m = \frac{C}{G} \times H \times W
$$


Then, each element $ X_{bg,i} $ is normalized and transformed using learnable parameters $ \gamma \in \mathbb{R}^{C} $, $ \beta \in \mathbb{R}^{C} $:

$$
\text{GN}(X_{bg,i}) = \gamma_{c(i)} \cdot \frac{X_{bg,i} - \mu_{bg}}{\sqrt{\sigma_{bg}^2 + \epsilon}} + \beta_{c(i)}
$$

where $ c(i) $ maps element $ i $ back to its original channel index within the input.



<center>
<img src="./assets/grp-norm.png" alt="" width="350"/>
</center>

In [8]:
class GroupNorm(nn.Module):
    def __init__(self, num_channels, num_groups, eps=1e-5, affine=True):
        super().__init__()
        assert num_channels % num_groups == 0, "num_channels must be divisible by num_groups"
        self.num_channels = num_channels
        self.num_groups = num_groups
        self.eps = eps
        self.affine = affine

        # Learnable parameters for scale and shift
        if self.affine:
            self.gamma = nn.Parameter(torch.ones(num_channels))  # scale (γ)
            self.beta = nn.Parameter(torch.zeros(num_channels))  # shift (β)
        else:
            # Register tensors as learnable
            self.register_parameter('gamma', None)
            self.register_parameter('beta', None)

    def forward(self, x):
        B, C, H, W = x.shape
        G = self.num_groups

        # Reshape x to [B, G, C/G, H, W]
        x = x.view(B, G, C // G, H, W)

        # Compute mean and variance over (C/G, H, W), i.e. indices [2, 3, 4]
        mean = x.mean(dim=(2, 3, 4), keepdim=True)
        var = x.var(dim=(2, 3, 4), keepdim=True, unbiased=False)

        # Normalize
        x = (x - mean) / torch.sqrt(var + self.eps)

        # Reshape x back to [B, C, H, W] to restore original x structure
        x = x.view(B, C, H, W)

        # Apply affine transformation
        if self.affine:
            # The learnable parameters γ and β are defined per channel so 
            # reshape gamma and beta to [1, C, 1, 1] for broadcasting.
            gamma = self.gamma.view(1, C, 1, 1)
            beta = self.beta.view(1, C, 1, 1)
            x = gamma * x + beta

        return x

In [9]:
# Test
x = torch.randn(4, 32, 64, 64)
num_channels = 32
num_groups = 8
gn = GroupNorm(num_channels, num_groups)
y = gn(x)
print(f"{y.shape = }")

y.shape = torch.Size([4, 32, 64, 64])
