# Batch Normalization

## Build from Scratch

In [2]:
import torch
from torch import nn, optim
import torch.nn.functional as F
import renyan_utils as ry

In [11]:
def batch_norm(is_training, X, gamma, beta, moving_mean, moving_var, eps, momentum):
    if not is_training:
        X_hat = (X - moving_mean) / (torch.sqrt(moving_var + eps))
    else:
        assert len(X.shape) in (2, 4)
        # for fully connected layer
        if len(X.shape) == 2:
            mean = X.mean(dim = 0) # mean of each column
            var = ((X - mean) ** 2).mean(dim = 0)
        # for conv2d, we calculate mean and var for axis=1
        else:
            mean = X.mean(dim = 0, keepdim = True).mean(dim = 2, keepdim = True).mean(dim = 3, keepdim = True)
            var = ((X - mean) ** 2).mean(dim = 0, keepdim = True).mean(dim = 2, keepdim = True).mean(dim = 3, keepdim = True)
        X_hat = (X - mean) / torch.sqrt(var + eps)
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta
    return Y, moving_mean, moving_var

In [24]:
class BatchNorm(nn.Module):
    def __init__(self, num_features, num_dims):
        super(BatchNorm, self).__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.zeros(shape)
    
    def forward(self, X):
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        Y, self.moving_mean, self.moving_var = batch_norm(self.training, X, self.gamma, self.beta, self.moving_mean, self.moving_var, eps = 1e-5, momentum = 0.9)
        return Y

## BN version LeNet

In [25]:
net = nn.Sequential(nn.Conv2d(1, 6, 5),
                   BatchNorm(6, num_dims = 4),
                   nn.Sigmoid(),
                   nn.MaxPool2d(2, 2),
                   nn.Conv2d(6, 16, 5),
                   BatchNorm(16, num_dims = 4),
                   nn.Sigmoid(),
                   nn.MaxPool2d(2, 2),
                   ry.FlattenLayer(),
                   nn.Linear(16*4*4, 120),
                   BatchNorm(120, num_dims = 2),
                   nn.Sigmoid(),
                   nn.Linear(120, 84),
                   BatchNorm(84, num_dims = 2),
                   nn.Sigmoid(),
                   nn.Linear(84, 10))

In [26]:
batch_size = 256
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_iter, test_iter = ry.load_data_fashion_mnist_resize(batch_size)
lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr = lr)

In [27]:
ry.train_mnist_net(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

training on cpu
epoch 1, loss 0.9907, train acc 0.791, test acc 0.840, time 18.2 sec
epoch 2, loss 0.2242, train acc 0.867, test acc 0.838, time 19.1 sec
epoch 3, loss 0.1207, train acc 0.880, test acc 0.866, time 20.8 sec
epoch 4, loss 0.0817, train acc 0.888, test acc 0.866, time 22.5 sec
epoch 5, loss 0.0608, train acc 0.893, test acc 0.874, time 22.5 sec


In [33]:
# from torchsummary import summary
# img, lbl = next(iter(test_iter))
# summary(net, (1, 28, 28))

## With Tools

In [35]:
net = nn.Sequential(nn.Conv2d(1, 6, 5),
                   nn.BatchNorm2d(6),
                   nn.Sigmoid(),
                   nn.MaxPool2d(2, 2),
                   nn.Conv2d(6, 16, 5),
                   nn.BatchNorm2d(16),
                   nn.Sigmoid(),
                   nn.MaxPool2d(2, 2),
                   ry.FlattenLayer(),
                   nn.Linear(16*4*4, 120),
                   nn.BatchNorm1d(120),
                   nn.Sigmoid(),
                   nn.Linear(120, 84),
                   nn.BatchNorm1d(84),
                   nn.Sigmoid(),
                   nn.Linear(84, 10))

In [36]:
ry.train_mnist_net(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

training on cpu
epoch 1, loss 2.3329, train acc 0.068, test acc 0.064, time 12.9 sec
epoch 2, loss 1.1664, train acc 0.068, test acc 0.065, time 12.4 sec
epoch 3, loss 0.7776, train acc 0.068, test acc 0.065, time 13.2 sec
epoch 4, loss 0.5832, train acc 0.068, test acc 0.064, time 14.6 sec
epoch 5, loss 0.4666, train acc 0.068, test acc 0.064, time 14.7 sec
