Each layer of a neural network has inputs with a corresponding distribution, which is affected during the training process. The effect of these sources of randomness on the distribution of the inputs to internal layers during training is described as internal covariate shift. 

Reduce internal covariate shift

Whitened: linearly transformed to have zero means and unit variances, and decorrelated.

Higher learning rate without vanishing or exploding gradients

Regularizing effect --> unnecessary to use dropout to mitigate overfitting

attr:`momentum` argument is different from one used in optimizer classes and the conventional notion of momentum.
view it as moving avg.

Learnable params: \gamma and \beta: for each dim of input (can represent the identity transform)
if trainable: computes mean and var according to the current batch and updates the running stats
else: use running stats for validation

Batch Normalization is done over the `C` dimension, computing statistics on `(N, L)` slices.

Why over C dim?
Because weights shared across channels.

In [1]:
"""
Comparison of manual BatchNorm2d layer implementation in Python and
nn.BatchNorm2d
@author: ptrblck
"""

import torch
import torch.nn as nn


class MyBatchNorm2d(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1,
                 affine=True, track_running_stats=True):
        super(MyBatchNorm2d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)

    def forward(self, input):

        exponential_average_factor = 0.0

        if self.training and self.track_running_stats: 
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        # calculate running estimates
        if self.training:
            mean = input.mean([0, 2, 3])
            # use biased var in train
            var = input.var([0, 2, 3] , unbiased=False)
            n = input.numel() / input.size(1)
            with torch.no_grad():
                self.running_mean = exponential_average_factor * mean\
                    + (1 - exponential_average_factor) * self.running_mean
                # update running_var with unbiased var
                self.running_var = exponential_average_factor * var * n / (n - 1)\
                    + (1 - exponential_average_factor) * self.running_var
        else:
            mean = self.running_mean
            var = self.running_var

        input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
        if self.affine:
            input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]

        return input