In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn

In [2]:
N, C, L = 2, 4, 6
x1 = torch.randn(C, L)
x2 = torch.randn(N, C, L)

## Demystifying Batchnorm & Layernorm

### Batchnorm

For a tensor with dimensions $N$, $C$, $L$:
$$ \mu_C = \frac{1}{NL} \sum_{i=1}^{N} \sum_{j=1}^{L} x_{iCj} \\ \ \\ \sigma_C^2 = \frac{1}{NL} \sum_{i=1}^N \sum_{j=1}^L (x_{iCj} - \mu_C)^2 \\ \ \\ \hat{x} = \frac{x - \mu_c}{\sqrt{\sigma_C^2 + \epsilon}}$$

In [13]:
mu = torch.mean(x2, [0, 2], keepdim=True)  # x2.sum(0).sum(1) / (N * L)
sig2 = torch.var(x2, [0, 2], correction=0, keepdim=True)  # no correction for bias (Bessel's correction) per equation above
xhat = (x2 - mu) / torch.sqrt(sig2 + 1e-5)
xhat

tensor([[[-0.4146,  0.5744,  0.2266,  0.3660, -0.0236,  0.7007],
         [-0.1406, -0.5448, -1.9718, -1.2147,  0.6173,  0.1095],
         [ 0.6311, -0.1577, -0.5117, -1.2719,  0.5677,  1.0311],
         [ 1.2559, -0.1773, -1.2809,  0.3551,  0.5591, -2.0358]],

        [[ 0.3466,  0.1193, -2.5732,  1.1626, -1.3817,  0.8970],
         [ 1.1253,  0.2677,  0.2603,  2.0825, -0.2255, -0.3650],
         [ 1.9877, -0.9347, -0.3170, -1.7596,  0.2667,  0.4683],
         [ 0.7194,  1.2219, -0.3333, -0.2791, -0.9944,  0.9895]]])

In [15]:
bn = nn.BatchNorm1d(4)
bn(x2)

tensor([[[-0.4146,  0.5744,  0.2266,  0.3660, -0.0236,  0.7007],
         [-0.1406, -0.5448, -1.9718, -1.2147,  0.6173,  0.1095],
         [ 0.6311, -0.1577, -0.5117, -1.2719,  0.5677,  1.0311],
         [ 1.2559, -0.1773, -1.2809,  0.3551,  0.5591, -2.0358]],

        [[ 0.3466,  0.1193, -2.5732,  1.1626, -1.3817,  0.8970],
         [ 1.1253,  0.2677,  0.2603,  2.0825, -0.2255, -0.3650],
         [ 1.9877, -0.9347, -0.3170, -1.7596,  0.2667,  0.4683],
         [ 0.7194,  1.2219, -0.3333, -0.2791, -0.9944,  0.9895]]],
       grad_fn=<NativeBatchNormBackward0>)

## LayerNorm

Layernorm *just* normalizes over the last dimension $L$

In [19]:
mu = x2.mean(2, keepdim=True)
sig2 = x2.var(2, keepdim=True, unbiased=False)

xhat = (x2 - mu) / torch.sqrt(sig2 + 1e-5)
xhat

tensor([[[-1.7468,  0.8993, -0.0312,  0.3420, -0.7006,  1.2373],
         [ 0.4467, -0.0241, -1.6857, -0.8041,  1.3292,  0.7379],
         [ 0.7450, -0.2630, -0.7154, -1.6868,  0.6640,  1.2561],
         [ 1.3171,  0.0387, -0.9458,  0.5136,  0.6955, -1.6192]],

        [[ 0.4425,  0.2705, -1.7667,  1.0599, -0.8652,  0.8590],
         [ 0.7118, -0.3038, -0.3125,  1.8454, -0.8879, -1.0531],
         [ 1.7311, -0.7539, -0.2286, -1.4554,  0.2677,  0.4391],
         [ 0.6205,  1.2459, -0.6893, -0.6219, -1.5118,  0.9566]]])

In [17]:
ln = nn.LayerNorm(6)
ln(x2)

tensor([[[-1.7468,  0.8993, -0.0312,  0.3420, -0.7006,  1.2373],
         [ 0.4467, -0.0241, -1.6857, -0.8041,  1.3292,  0.7379],
         [ 0.7450, -0.2630, -0.7154, -1.6868,  0.6640,  1.2561],
         [ 1.3171,  0.0387, -0.9458,  0.5136,  0.6955, -1.6192]],

        [[ 0.4425,  0.2705, -1.7667,  1.0599, -0.8652,  0.8590],
         [ 0.7118, -0.3038, -0.3125,  1.8454, -0.8879, -1.0531],
         [ 1.7311, -0.7539, -0.2286, -1.4554,  0.2677,  0.4391],
         [ 0.6205,  1.2459, -0.6893, -0.6219, -1.5118,  0.9566]]],
       grad_fn=<NativeLayerNormBackward0>)

### LayerNorm
Similar to BatchNorm. BatchNorm calculates the sample statistics (mean and variance) over dim=0 of the input batch. This means that the BatchNorm is calculating the sample statistics across samples in the batch. LayerNorm calculates the sample statistics over dim=1 of the input batch. This means that LayerNorm calculates sample statistics for individual samples.

More precisely, BatchNorm calculates sample statistics over all elements of all instances (samples) in a batch for each feature independently. Whereas LayerNorm calculates sample statistics across the features and elements for each instance (sample) independently.
![image.png](attachment:image.png)

So, what we get with LayerNorm is unit-Gaussian (standardized) features for each token. The LayerNorm dimensions are (B, T) - we have T unit-Gaussian tokens for each of the B samples in the batch. At least, they are unit-Gaussian at initialization. They have trainable *gamma* and *beta* parameters (scale and shift) parameters like BatchNorm does, so after training they may have a different, learned, distribution.

**Note:** Because LayerNorm statistics are independent across samples in a batch, LayerNorm does not need to use buffers (running mean and running var); the statistics may be calculated the same way at any time regardless of batch size. This alos means that LayerNorm will work the same way during evaluation as during training.

One reason why LayerNorm may be preferred in this architecture is because BatchNorm can be bad for NLP applications (as I so painfully learnt working with RNNs in makemore). This is because variable sequence lengths can lead to variable batch lengths, which produces variation in BatchNorm sample statistics, leading to instability during training. 

#### Deviation in Model Architecture
We will implement the LayerNorms before each layer (stack/feed-forward) of the network instead of after. Andrej says that this is more often the standard practice today.