## Batch Norm(alization)

> Batch Normalization is a technique to improve training time and convergence properties of neural networks (and has some other additional properties)

## Normalization step

> Batch Norm normalizes incoming data using `mean` and `variance` calculated __on a single batch size__

$$
\hat{X} = \frac{X - \mu(X)}{\sqrt{\sigma^2(X)}}
$$

After this operation we obtain zero centered data with variance of 1, hence how the layers prefer their inputs (see input data normalization).

## Reparametrization step

After we normalize inputs we use __per-feature mean and variance changing parameters: `beta` and `gamma`__:

$$
X_f = \gamma\hat{X} + \beta
$$

Gamma and beta are both learnable parameters, learned through gradient descent, just like our model weights. Allowing the model to learn these allows it to choose how to spread (gamma) and shift (beta) the incoming data so that the current layer can transform it best.

> `gamma` and `beta` have the same size as number of input features!

Let me make that even more clear so you don't overlook it...

> Gamma and beta are tensors, with the same shape and as many elements as there are in the input to the layer. That's because you want to find the means and standard deviations for each feature... NOT THE MEAN OR STD DEV OVER ALL FEATURES WITHIN EACH EXAMPLE. Gamma and beta are not scalars!

Here are all the steps one should take:

$$
\hat{X} = \frac{X - \mu(X)}{\sqrt{\sigma^2(X)}}
$$

$$
X_f = \gamma\hat{X} + \beta
$$

## Train vs evaluation behaviour

> As with dropout, __the behaviour of the batch normalisation layer differs between training and evaluation__

Why?

- __We might do inference using single sample or small batches (which works worse with batch normalization)__
- __Neural network outputs COULD change with the same samples, IF the sample is in different batch__ (we usually go for deterministic outputs during inference phase)

In order to do that, __we will keep running mean and running variance__ gather throughout the training. Those values will be used instead of `mean` and `variance` calculated from batch, hence (notice lack of `X`):

$$
\hat{X} = \frac{X - \mu}{\sqrt{\sigma^2}}
$$

$$
X_f = \gamma\hat{X} + \beta
$$

In order to obtain those values, during each forward pass we will keep __running mean__ with specified momentum, given by (for timestep `i`):

$$
\mu_i = m \times \mu_{i-1} + (1 - m) \times \mu(X)
$$

$$
\sigma^2_i = m \times \sigma^2_{i-1} + (1 - m) \times \sigma^2(X)
$$

> __Those values should be kept inside a model, but they shouldn't be trained!__

## Things to note
- Batch normalisation should be applied after linear (or convolutional) layers, before the activation function.

## Batch Normalisation in PyTorch

To initialise the batch norm layers, you need to specify the number of incoming features and that's it. This is required so that the layer can initialise the parameters beta and gamma with the correct shape (vectors of length = number of features), so that they are ready to use in the first forward pass.

Check out the docs [here](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html)

Below, we implement a simple neural network with batchnorm layers.

In [None]:
import torch

class MyNetwork(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(784, 256),
            torch.nn.BatchNorm1d(256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 64),
            torch.nn.BatchNorm1d(64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 10),
            torch.nn.Softmax()
        )

    def forward(self, x):
        return self.layers(x)

Again, make sure your model and it's layers are in the right mode by using your model's `.train()` and `.eval()` methods.

## Exercise

> Create `BatchNorm1d` layer!

- `__init__`:
    - Save all `__init__` arguments as attributes
    - In `__init__` create `gamma` and `beta` as `torch.nn.Parameter` instances
    - Those should be, respectively zeros of shape `(1, num_features)` and ones of the same shape (hence default mean and variance)
    - Running mean and running variance buffers are created for you (we will talk about it after exercise)
- `forward`:
    - `if not self.training` (no training phase) - subtract `self.running_mean` and divide by square root of `self.running_var + self.eps`, save result as `X_hat
    - For training do the following:
        - Calculate mean and variance __across batch (0) dimension__ (you have to unsqueeze it afterwards or `keepdim=True`)
        - Update `self.running_mean` with the formula given above (essentially running mean)
        - Do the same for variance
        - `X_hat` is now equal `X` with subtracted mean (__of this batch!__) and divided by variance (__of this batch!__) (look at the formulas if you are not sure!)
    - `return` `X_hat` multiplied by `self.gamma` and with `self.beta` added

In [1]:
import torch


class BatchNorm1d(torch.nn.Module):
    def __init__(self, num_features, momentum: float = 0.9, eps: float = 1e-7):
        super().__init__()
        self.num_features = num_features

        self.gamma = torch.nn.Parameter(torch.ones(1, self.num_features))
        self.beta = torch.nn.Parameter(torch.zeros(1, self.num_features))
        
        # You can use self.running_mean and self.running_var right now
        self.register_buffer("running_mean", torch.ones(1, self.num_features))
        self.register_buffer("running_var", torch.ones(1, self.num_features))

        self.momentum = momentum
        self.eps = eps

    def forward(self, X):
        if not self.training:
            X_hat = X - self.running_mean / torch.sqrt(self.running_var + self.eps)
        else:
            mean = X.mean(dim=0).unsqueeze(dim=0)
            var = ((X - mean) ** 2).mean(dim=0).unsqueeze(dim=0)

            # Update running mean and variance
            self.running_mean *= self.momentum
            self.running_mean += (1 - self.momentum) * mean

            self.running_var *= self.momentum
            self.running_var += (1 - self.momentum) * var

            X_hat = X - mean / torch.sqrt(var + self.eps)

        return X_hat * self.gamma + self.beta

In [3]:
input_data = torch.randn(5, 5)
layer = BatchNorm1d(5)

layer(input_data), input_data

(tensor([[ 1.9229, -0.8777, -0.2607,  1.5660, -0.1610],
         [ 1.6854, -0.0881,  0.5912,  0.4814,  1.7144],
         [-0.3730,  0.9642,  1.2159, -1.1008, -1.4934],
         [-1.9959,  0.3057, -1.5824, -0.5691, -0.3325],
         [-0.6004, -0.6081,  0.2965, -0.4376,  0.4075]], grad_fn=<AddBackward0>),
 tensor([[ 2.1895, -0.7009, -1.2250,  1.7636,  0.4287],
         [ 1.9520,  0.0887, -0.3731,  0.6790,  2.3041],
         [-0.1064,  1.1410,  0.2516, -0.9033, -0.9037],
         [-1.7293,  0.4824, -2.5467, -0.3715,  0.2572],
         [-0.3337, -0.4313, -0.6678, -0.2400,  0.9971]]))

# Why BatchNorm works?

## Internal Covariate Shift

In original paper (link [here](https://arxiv.org/abs/1502.03167)) authors claimed `BatchNorm` improves neural network properties by removing a, so called, __Internal Covariate Shift__.

### What is it?

- Change in distribution passed to layers in neural network

Imagine we have layers `a->b->c->d->e`:

- If we change weights in `b` `c->d->e` is affected
- We are constantly changing weights

Neural network works on gradient obtained from `backpropagation`, hence:

- `e` layer gradients are changed based on `a->b->c->d` values __currently__
- `d` layer gradients are changed based on `a->b->c` values __currently__

You can see that, except for `a`, we are optimizing using info __from the previous step__.

> `BatchNormalization` decouples most important statistics and interactions between layers, making the optimization "more on point"

### Debunked?

To this day, people are not sure if reducing ICS is the main effect.

> To reduce ICS it is enough to normalize the data after each pass!

Some experiments show that the ICS is not reduced __at all__ (unlikely), while others point to reduction in it (much more likely)

## Moving away from oversaturation

- Some activations oversaturate (e.g. `sigmoid`)
- If we normalize the data, those usually move away from `oversaturation` regime
- Once again, motivation for `beta` and `gamma` is not clear in this case
- Drastic help when it comes to non oversaturating functions like `ReLU` (though the output is zero-centered once again)


## Easier control of statistics

Simply removing ICS can be done merely by normalizing input data, __but we also have `gamma` and `beta` parameters__.

> What is argued, is that it is easier for neural network to __control first and second moment of output distribution (mean and variance)__ 

### Thought experiment

> Let's see, how many parameters would the neural network need to control in order to change `mean` and `variance` of `nn.Linear(100, 10)` layer:

- Each neuron has `100` inputs, hence `100` values
- There are `10` output neurons
- `1000` values in total to control __BOTH MEAN AND VARIANCE__ simultaneuosly

> What happens when we add BatchNorm layer?

Now, it is guaranteed, that layer's output will have `mean=0` and `std=1` (we can think of it as "checkpointing system" for `mean` and `std`).

- To change outputs mean distribution we only need to change `10` values (same for variance, __both are independent__)
- That __should be__ `100x` times easier to control (given that first order methods optimization is not the smartest and has no info about function curvature)

## Smoothing out optimization landscape

Due to above and normalization `BatchNorm` smooths the loss landscape:
- In general, loss landscape for neural networks can be sharp, have a lot of flat and problematic regions
- This leads to large/miniscule gradients which cause __vanishing gradient (dying gradient) or exploding gradient (inverse of this phenomena)__
- __`BatchNorm` makes the loss change at a smaller rate, hence gradient is also more "stable"__ (so called Lipschitzness of function)

Due to above, we can:
- Use larger learning rates
- We are not so dependent on the hyperparameters choice
- We are not so dependent on initialization of neural network weights
- __Leads to faster training (less epochs)__

### First order optimization

`SGD` is blind to second order interactions like Hessians (derivatives of derivatives).

> What does the Hessian tell the optimization algorithm?

__How the function curves__:

![](./images/curvature.png)

- negative hessian values - function curves down, hence you will minimize the function even faster
- zero hessian values - going downhill, good for optimization also
- positive hessian values - function will curve up, hence you will get to the minima shortly and any step forward will __increase the loss value__

> Those values define how much you can learn from first order methods

> BatchNorm simplifies this landscape implicitly as the second order derivatives will be closer to zero a lot of time (hence gradient steps are safer)

> __One could use those values to approximate how much a certain step will reduce loss function__

> __It is hard/infeasible to calculate with neural networks and use this information due to sheer amount of parameters and hardware constraints__

__All of that in the nearby region, but still informative!__

- At the start of training, first order terms play the major role, hence we quickly move downhill
- As the task of further optimization gets progressively harder the better the network is at the task, the second order curvature will be more important than the first one
- __Unlucky step__: If you minimize according to gradient with some step size, you may end up in a positive curvature regime and actually increase your loss instead of decreasing it!

## BatchNormalization as a regularization

- It introduces randomness (noise) into the data, as statistics are calculated as we go on batch basis (mean adds the noise, variance multiplies it)
- __Statistics are batch dependent__ - if the batch size is too small, those __might be unreliable__ (see challenges for other normalization schemes to mitigate this problem)
- Noise robustness, works a little like Dropout (but is less severe)
- This forces neural network to learn different features (as the ones it was dependent on might be too noisy to use for some batches)

## Tips

> Where to apply?

- Originally applied before activation
- Seems to work better __after__ activation (as the inputs to learnable linear layer are controlled and unaffected by activation)
- Difference isn't drastic
- __Some people applied BatchNormalization before and after activation!__

> What about `Dropout`?

- Dropout randomizes statistics gathered by `BatchNorm`
- Alternative (so called `AlphaDropout` which preserves the statistics is available), so it is feasible
- __We rarely use ANY `Dropout` in tandem with `BatchNorm`!__ (it might actually make your scores worse!)
- In general: BatchNorm for convolutional layers and linear layers, __if you need more regularization, apply dropout on `linear` layers ONLY__
- Conflicting opinions, but this is the common practice (might do some experimentation)

In [None]:
before_block = torch.nn.Sequential(
    torch.nn.Linear(10, 10), torch.nn.BatchNorm1d(10), torch.nn.ReLU()
)

after_block = torch.nn.Sequential(
    torch.nn.Linear(10, 10), torch.nn.ReLU(), torch.nn.BatchNorm1d(10)
)

## Summary

- BatchNormalization is a technique which improves training time and convergence
- You normalize your batch during training
- Multiply the output by `gamma` and you add `beta` which are parameters controlled by neural network (so it can learn `mean` and `variance` easier)
- You gather running mean and running variance statistics as you go and use them at test time, because:
    - Every example will be classified independently of the batch it is in (good deterministic behaviour)
    - It allows you to use single example during training
- BatchNormalization needs many examples in the batch, the more the better and more reliable the statistics are
- Exact reasons why it works are unclear but include:
    - Internal Covariate Shift
    - Easier control over first/second order terms
    - Simplifies loss landscape by reducing second order interactions during forward and backward
- Is partially regularization technique due to noise it injects
- Probably will not work best with Dropout
- Should be used after activation function, though the difference is rarely drastic

## Challenges

- Try BatchNorm in previous notebook (dying/vanishing gradients). What did you observe?

__Remember to check those challenges AFTER you know what the convolution is and how it plays with BatchNorm!__

- What is [Instance Normalization](https://stackoverflow.com/questions/45463778/instance-normalisation-vs-batch-normalisation)?
- What is [Layer Normalization](https://leimao.github.io/blog/Layer-Normalization/)?
- What is [Group Normalization](https://towardsdatascience.com/what-is-group-normalization-45fe27307be7)?