## Motivation
1. **internal covariate shift**:change in the distribution of network activations due to change in the network parameters during training, which **slows down the traning by requiring lower training rates and careful parameter initialization**.
2. The situation is worse when the network is deep, amplifying the internal covariate shift layer by layer and when the activations are saturating. 

## Advantages of Batch Normalization
1. metigates internal covariate shift.
2. allows us to use much higher training rates;
3. allows us to be less careful about initialization and the use of activation functions.
4. acts as a regularizer, in some cases eliminating the need for Dropout.

## Towards Reducing Internal Covariate Shift
1. The network training converges faster if its inputs are whitened-i.e.,linearly transformed to have zero means and unit variances, and decorrelated.

2. It would be advantageous to achieve the same whitening of the inputs of each layer.

    Consider a network computing 
$$l=F_2(F_1(u,\theta_1),\theta_2)$$
where $F_1$ and $F_2$ are arbitrary transformations, and the parameters $\theta_1$, $\theta_2$ are to be learnt so as to minimize the loss $l$. Learning $\theta_2$ can be viewed as if the inputs $x=F_1(u, \theta_1)$ are fed into the network
$$l=F_2(x,\theta_2).$$
So if fixing the input distribution works for the whole network, it should also benefit the sub-network.

3. The optimization step must take the normalization into account, or else parameter updates maybe eliminated by the normalization procedure.

4. The full whitening of thelayer inputs is costly and not everywhere differentiable. So the authors made 2  necessary simpilification.
    - Each scalar feature is normalized independently, by making it have the mean of zero and the variance of one. Such normalization speeds convergence even of the features are not decorrelated.
    - use mini-batches to estimate the means and variance of each activation.

5. Simply normalizing each input of a layer may change what the layer can represent, the authors introduced, for each activation, a pair of parameters to scale and shift the normalized value, enabling the normalizing transformation to represent an identity transform.

## Batch Normalization Algorithm

Formally, denoting by $\mathbf{x} \in \mathcal{B}$ an input to batch normalization $BN$ that is from a minibatch $\mathcal{B}$, batch normalization transforms  $\mathbf{x}$ according to the following expression:
$$\mathrm{BN}(\mathbf{x}) = \boldsymbol{\gamma} \odot \frac{\mathbf{x} - \hat{\boldsymbol{\mu}}_\mathcal{B}}{\hat{\boldsymbol{\sigma}}_\mathcal{B}} + \boldsymbol{\beta}.$$

where, $\hat{\boldsymbol{\mu}}_\mathcal{B}$ is the sample mean and $\hat{\boldsymbol{\sigma}}_\mathcal{B}$ is the sample standard deviation of the minibatch $\mathcal{B}$. 

Formally, we calculate $\hat{\boldsymbol{\mu}}_\mathcal{B}$ and $\hat{\boldsymbol{\sigma}}_\mathcal{B}$ as follows:
$$
\begin{split}\begin{aligned} \hat{\boldsymbol{\mu}}_\mathcal{B} &= \frac{1}{|\mathcal{B}|} \sum_{\mathbf{x} \in \mathcal{B}} \mathbf{x},\\
\hat{\boldsymbol{\sigma}}_\mathcal{B}^2 &= \frac{1}{|\mathcal{B}|} \sum_{\mathbf{x} \in \mathcal{B}} (\mathbf{x} - \hat{\boldsymbol{\mu}}_{\mathcal{B}})^2 + \epsilon.\end{aligned}\end{split}
$$

Note that we add a small constant $\epsilon>0$ to the variance estimate to ensure that we never attempt division by zero, even in cases where the empirical variance estimate might vanish.

## Batch Normalization at testing
Batch Normalization layer works differently at training time (normalizaed by batch statistics) and testing time (normalized by dataset statistics).

There are 2 reasons:
1. First, the noise in the sample mean and the sample variance arising from estimating each on minibatches are no longer desirable once we have trained the model. we don't want our prediction to be different just because of which batch it resides;
2. we might not have the luxury of computing per-batch normalization statistics. For example, we might need to apply our model to make one prediction at a time.

## Batch Normalization for CNN
For convolutional layers, all activations in a certain channel are normalized using the statistics of this channel, and all activations in this channel share the same shifting and scaling parameter.

## Implementation from Scratch

In [1]:
import torch
from torch import nn

In [10]:

def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # Use `is_grad_enabled` to determine whether the current mode is training
    # mode or prediction mode
    if not torch.is_grad_enabled():
        # If it is prediction mode, directly use the mean and variance
        # obtained by moving average
        X_hat = (X - moving_mean) / torch.sqrt(moving_var+eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # When using a fully-connected layer, calculate the mean and
            # variance on the feature dimension
            mean = X.mean(dim=0)
            var = ((X - mean)**2).mean(dim=0)
        else:
            # When using a two-dimensional convolutional layer, calculate the
            # mean and variance on the channel dimension (axis=1). Here we
            # need to maintain the shape of `X`, so that the broadcasting
            # operation can be carried out later
            mean = X.mean(dim=[0,2,3], keepdim=True)
            var = ((X-mean)**2).mean(dim=[0,2,3],keepdim=True)
        # In training mode, the current mean and variance are used for the
        # standardization
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # Update the mean and variance using moving average
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta  # Scale and shift
    return Y, moving_mean.data, moving_var.data

In [11]:
class BatchNorm(nn.Module):
    # `num_features`: the number of outputs for a fully-connected layer
    # or the number of output channels for a convolutional layer. `num_dims`:
    # 2 for a fully-connected layer and 4 for a convolutional layer
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims==2:
            # this is a fully-connected layer
            shape = (1, num_features)
        else:
            # this is a convolutional layer
            shape = (1, num_features, 1, 1)
        # The scale parameter and the shift parameter (model parameters) are
        # initialized to 1 and 0, respectively
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # The variables that are not model parameters are initialized to 0 and 1
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)
    
    def forward(self, X):
        # If `X` is not on the main memory, copy `moving_mean` and
        # `moving_var` to the device where `X` is located
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        # Save the updated `moving_mean` and `moving_var`
        Y, self.moving_mean, self.moving_var = batch_norm(
            X, self.gamma, self.beta, self.moving_mean, self.moving_var,
            eps=1e-5, momentum=0.9)
        return Y

## Applying Batch Normalization in LeNet

In [12]:
net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4),
                    nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),
                    nn.Conv2d(6, 16,
                              kernel_size=5), BatchNorm(16, num_dims=4),
                    nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),
                    nn.Flatten(), nn.Linear(16 * 4 * 4, 120),
                    BatchNorm(120, num_dims=2), nn.Sigmoid(),
                    nn.Linear(120, 84), BatchNorm(84, num_dims=2),
                    nn.Sigmoid(), nn.Linear(84, 10))

## Dataset

In [13]:
import sys
sys.path.append("../dlutils")
from dataset import load_fashion_mnist_dataset

batch_size = 128
train_loader, test_loader = load_fashion_mnist_dataset(batch_size=batch_size)

## Loss

In [14]:
loss = torch.nn.CrossEntropyLoss()

## Optimizer

In [15]:
optimizer = torch.optim.SGD(net.parameters(), lr = 1)

## Train

In [16]:
from train import train_3ch

num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loss.to(device)
net.to(device)
train_3ch(net, loss, num_epochs, train_loader, optimizer, test_loader, device)

epoch 0, training loss 0.004905, training accuracy 0.743550, testing loss 0.005218, testing accuracy 0.731900
epoch 1, training loss 0.006141, training accuracy 0.702183, testing loss 0.006472, testing accuracy 0.697000
epoch 2, training loss 0.003736, training accuracy 0.817383, testing loss 0.004050, testing accuracy 0.805300
epoch 3, training loss 0.002724, training accuracy 0.872633, testing loss 0.003062, testing accuracy 0.859000
epoch 4, training loss 0.002630, training accuracy 0.878100, testing loss 0.003032, testing accuracy 0.865700
epoch 5, training loss 0.007594, training accuracy 0.669683, testing loss 0.008074, testing accuracy 0.658600
epoch 6, training loss 0.002788, training accuracy 0.863167, testing loss 0.003240, testing accuracy 0.848100
epoch 7, training loss 0.002742, training accuracy 0.872867, testing loss 0.003162, testing accuracy 0.860200
epoch 8, training loss 0.002711, training accuracy 0.873067, testing loss 0.003213, testing accuracy 0.860300
epoch 9, t

## Concise Implementation

In [18]:
net2 = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6),
                    nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),
                    nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16),
                    nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),
                    nn.Flatten(), nn.Linear(256, 120), nn.BatchNorm1d(120),
                    nn.Sigmoid(), nn.Linear(120, 84), nn.BatchNorm1d(84),
                    nn.Sigmoid(), nn.Linear(84, 10))

In [19]:
optimizer = torch.optim.SGD(net2.parameters(), lr=1)
train_3ch(net2, loss, num_epochs, train_loader, optimizer, test_loader, device)

epoch 0, training loss 0.003689, training accuracy 0.827050, testing loss 0.003929, testing accuracy 0.815200
epoch 1, training loss 0.005375, training accuracy 0.735267, testing loss 0.005571, testing accuracy 0.730400
epoch 2, training loss 0.003833, training accuracy 0.802933, testing loss 0.004114, testing accuracy 0.793500
epoch 3, training loss 0.003585, training accuracy 0.824850, testing loss 0.003918, testing accuracy 0.815700
epoch 4, training loss 0.002935, training accuracy 0.858333, testing loss 0.003196, testing accuracy 0.846100
epoch 5, training loss 0.002799, training accuracy 0.864117, testing loss 0.003153, testing accuracy 0.851600
epoch 6, training loss 0.002951, training accuracy 0.846683, testing loss 0.003346, testing accuracy 0.827000
epoch 7, training loss 0.002646, training accuracy 0.870883, testing loss 0.003009, testing accuracy 0.859100
epoch 8, training loss 0.002545, training accuracy 0.876617, testing loss 0.002983, testing accuracy 0.863800
epoch 9, t