In [1]:
import torch
import warnings

warnings.filterwarnings("ignore")

### Batch Normalization

- Helps in training deep neural networks reliably (and fast) since the distribution from one layer to another remains same (no effort is spent to learn the new distribution)
- Mean and std are perfectly differentiable functions and hence it can be used
- We want the pre-activation values to be roughly gaussian only at initialization
- We don't want the pre-activations to be forced to be gaussians always, we would like the neural net to be able to move these distributions around to make it more diffuse, more shape, make some neurons more trigger happy or less trigger happy
- We want the backpropagation algorithm to tell us how the distribution should move around
- The above is achieved by doing scale and shift
- The examples in a batch are coupled mathematically in the forward and backward pass of a neural net
- Hence the pre-activations and the logits are functions of all examples in the batch
- The pre-activation values change slightly depending on the other examples present in the batch and it will jitter
- This jittering has a regularization effect
- Stabilized training
- Allows for suboptimal initializations
- Even outs the slope of the loss function resulting in faster optimization

In [10]:
class BatchNorm1d:

    def __init__(self, dim, momentum=0.1, eps=1e-5):
        self.eps = eps
        self.momentum = momentum
        self.training = True
        # params trained with backprop
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)
        # buffers (trained with a running 'momentum' update)
        self.running_mean = torch.zeros(dim)
        self.running_var = torch.ones(dim)

    def __call__(self, x):
        # forward pass
        if self.training:
            xmean = x.mean(0, keepdims=True) # batch mean
            xvar = x.var(0, keepdims=True) # batch std
        else:
            xmean = self.running_mean
            xvar = self.running_var
        xhat = (x - xmean) / torch.sqrt(xvar + self.eps)
        self.out = self.gamma*xhat + self.beta
        # update the buffers
        if self.training:
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvar
        return self.out

    def parameters(self):
        return (self.gamma, self.beta)