# Understanding what BatchNorm1d does
Anthony Lee 2025-02-13

It was not very clear what [the PyTorch documentation](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html) meant by "the mean and standard-deviation are calculated per-dimension over the mini-batches", thus this is an empirical test to understand it.

From the test below, `BatchNorm1d` has the same result as when I calculate the mean and variance at dimension 0 and 2, which corresponds to batches and "for each sample within each of its dimensions".

### "per-dimension" == "for each sample within each of its dimensions"
One can think of the data to be represented in a 3D tensor with x being batches, y being channels, and z being the data length. When we take the mean or variance "per-dimension", we should have a mean for each channel and each item in the batch (i.e., x-axis). 

### "over the mini-batches"
When then further taking the mean of these "per-dimension" statistics, we will be taking the average of all the statistics with the same batch, thus we should have a vector of means that is the length of the number of channels because we are taking the mean "over" or "across" the batch/mini-batch.

In [2]:
import torch

# Generate some example data
batch_size, n_channel, data_length = 5, 3, 5
eps = 1e-9

input = torch.randint(low=int(1e1), high=int(1e5), size=(batch_size, n_channel, data_length), dtype=torch.float32)
print(f"Input shape: {input.shape}")
print()

# The BatchNorm1d function
norm = torch.nn.BatchNorm1d(num_features=n_channel, eps=eps)
output = norm(input)

print(f"BatchNorm1d output shape: {output.shape}")
print(output)

Input shape: torch.Size([5, 3, 5])

BatchNorm1d output shape: torch.Size([5, 3, 5])
tensor([[[ 1.8643, -1.4489, -1.3801, -0.3188,  0.4848],
         [ 1.7870, -0.7558, -0.0987,  0.0321,  1.2541],
         [ 1.7270,  0.0469, -1.4423, -0.0950, -0.0756]],

        [[-0.9905,  0.3496,  0.8653, -0.7460,  0.0162],
         [-0.7105, -0.7918, -0.9328,  0.8435, -0.1628],
         [-1.4659,  0.3299,  0.1079,  1.3647, -0.1284]],

        [[-1.2634, -0.7957, -0.6404,  0.2914,  2.0487],
         [-1.3422,  1.1924, -0.3426,  0.0965, -0.5516],
         [-0.5273, -0.2496, -0.3470,  1.4835,  1.0654]],

        [[-0.6816,  0.7755,  0.0096, -0.4295, -0.4612],
         [-0.7923,  1.6722, -0.8551,  0.9907, -1.3649],
         [ 1.4516, -0.0489,  0.5959, -1.4507,  0.3446]],

        [[-0.6476,  0.2967,  2.1950, -0.3604,  0.9670],
         [ 0.1861,  0.6920, -1.1769,  1.9294, -0.7983],
         [-1.2574, -1.2964, -0.3403,  1.4980, -1.2907]]],
       grad_fn=<NativeBatchNormBackward0>)


In [11]:
# Average across batches (0) and per-dimension (2)
dim = [0, 2]  # Across batches and per-dimension

mean = input.mean(dim=dim, keepdim=True)
variance = input.var(dim=dim, keepdim=True)

output = (input - mean) / torch.sqrt(variance + eps)

print("Mean and variances with the keep dim:")
print(f"\tmean:\t{mean}\n\tvar:\t{variance}")
print()

print(f"The dimensions below should match that of the 1th dimension of the input, which is {n_channel}")
print(f"mean and variance without keepdim=True:\n\tmean:\t{input.mean(dim=dim, keepdim=False)}\n\tvar:\t{input.var(dim=dim, keepdim=False)}")
print()

print("Keeping the dimension just makes the matrix calculations easier.")
print(f"Output shape: {output.shape}")
print(output)

Mean and variances with the keep dim:
	mean:	tensor([[[40511.1992],
         [41540.0391],
         [46157.1211]]])
	var:	tensor([[[5.7631e+08],
         [9.1064e+08],
         [7.3417e+08]]])

The dimensions below should match that of the 1th dimension of the input, which is 3
mean and variance without keepdim=True:
	mean:	tensor([40511.1992, 41540.0391, 46157.1211])
	var:	tensor([5.7631e+08, 9.1064e+08, 7.3417e+08])

Keeping the dimension just makes the matrix calculations easier.
Output shape: torch.Size([5, 3, 5])
tensor([[[ 1.8266, -1.4196, -1.3522, -0.3124,  0.4750],
         [ 1.7509, -0.7405, -0.0967,  0.0314,  1.2287],
         [ 1.6921,  0.0459, -1.4131, -0.0931, -0.0740]],

        [[-0.9705,  0.3426,  0.8478, -0.7309,  0.0159],
         [-0.6962, -0.7758, -0.9139,  0.8265, -0.1595],
         [-1.4363,  0.3233,  0.1057,  1.3371, -0.1258]],

        [[-1.2379, -0.7796, -0.6275,  0.2855,  2.0073],
         [-1.3151,  1.1683, -0.3357,  0.0946, -0.5405],
         [-0.5166, -0.24