# Batch Normalization

## Build from Scratch

In [1]:
import torch
from torch import nn, optim
import torch.nn.functional as F
import renyan_utils as ry

In [13]:
def batch_norm(is_training, X, gamma, beta, moving_mean, moving_var, eps, momentum):
    if not is_training:
        X_hat = (X - moving_mean) / (torch.sqrt(moving_var + eps))
    else:
        assert len(X.shape) in (2, 4)
        # for fully connected layer
        if len(X.shape) == 2:
            mean = X.mean(dim = 0) # mean of each column
            var = ((X - mean) ** 2).mean(dim = 0)
        # for conv2d
        else:
            mean = X.mean(dim = 0, keepdim = True).mean(dim = 2, keepdim = True)
            var = ((X - mean) ** 2).mean(dim = 0, keepdim = True).mean(dim = 3, keepdim = True)
        X_hat = (X - mean) / torch.sqrt(var + eps)
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta
    return Y, moving_mean, moving_var