## BatchNorm Implementation in PyTorch

BatchNorm is one of the techniques used in deep learning to improve generalization.

Though the concept of BatchNorm is the same for both training and inference, there are some slight differences:

### Training

**Step 1) Normalize net inputs**

- Compute mean and variance over the entire batch:
  
  mean = (1 / N) * Σ (X_i), where N is the batch size.

  var = (1 / N) * Σ (X_i - mean)², where X_i is each individual element in the batch.

- Standardize the net input:
  
  X̂_i = (X_i - mean) / sqrt(var + ε), where ε is a small constant added for numerical stability.

- A **running mean** and **running variance** are computed because we don't calculate the mean/variance over the batch during inference and simply use the running mean and variance.

**Step 2) Scaling**

- We use learnable parameters α (gamma) and β (beta), which are updated via backpropagation:
  
  Y_i = γ * X̂_i + β, where γ and β are trainable parameters.

### Inference

**Step 1) Normalize net inputs**

- During inference, the mean and variance are replaced by the **running mean** and **running variance** that were computed during training:
  
  X̂_i = (X_i - running_mean) / sqrt(running_variance + ε)

**Step 2) Scaling**

- The scaling and shifting still happen using the learned parameters α and β:
  
  Y_i = γ * X̂_i + β


In [51]:
import torch

In [None]:
class BatchNorm1D(torch.nn.Module):
  def __init__(self,num_features,momentum=0.1):
    super().__init__()
    self.num_features=num_features
    self.gamma=torch.nn.Parameter(torch.ones(num_features))
    self.beta=torch.nn.Parameter(torch.zeros(num_features))
    self.running_mean=torch.zeros(1,num_features)
    self.running_var=torch.ones(1,num_features)
    self.momentum=momentum

  def __call__(self,X):
    mean=torch.mean(X,dim=0,keepdim=True) if self.training else self.running_mean
    var=torch.var(X,dim=0,keepdim=True) if self.training else self.running_var
    if self.training:
      self.running_mean=mean*self.momentum + self.running_mean*(1-self.momentum)
      self.running_var=var*self.momentum + self.running_var*(1-self.momentum)
    epsilon=1e-10
    denominator=(var + epsilon)**(1/2)
    z=(X-mean)/denominator
    return self.gamma*z+self.beta


In [53]:
batchnorm=BatchNorm1D(5)

In [54]:
batchnorm.train()
print(f"Before Training: Running Mean {batchnorm.running_mean}")
print(f"Before Training: Running Variance {batchnorm.running_var}\n")

batchnorm(torch.randn(10,5))


print(f"After Training: Running Mean {batchnorm.running_mean}")
print(f"After Training: Running Variance {batchnorm.running_var}\n")

batchnorm(torch.randn(10,5))


print(f"After More Training: Running Mean {batchnorm.running_mean}")
print(f"After More Training: Running Variance {batchnorm.running_var}")

Before Training: Running Mean tensor([[0., 0., 0., 0., 0.]])
Before Training: Running Variance tensor([[1., 1., 1., 1., 1.]])

After Training: Running Mean tensor([[-0.0108, -0.0365, -0.0032,  0.0061,  0.0076]])
After Training: Running Variance tensor([[1.0004, 0.9924, 0.9642, 1.0476, 0.9745]])

After More Training: Running Mean tensor([[ 0.0036, -0.0325,  0.0338, -0.0244,  0.0169]])
After More Training: Running Variance tensor([[0.9529, 0.9682, 0.9505, 0.9808, 0.9394]])


In [55]:
batchnorm.eval()
print(f"Before Inference: Running Mean {batchnorm.running_mean}")
print(f"Before Inference: Running Variance {batchnorm.running_var}\n")

batchnorm(torch.randn(10,5))


print(f"After Inference: Running Mean {batchnorm.running_mean}")
print(f"After Inference: Running Variance {batchnorm.running_var}")

Before Inference: Running Mean tensor([[ 0.0036, -0.0325,  0.0338, -0.0244,  0.0169]])
Before Inference: Running Variance tensor([[0.9529, 0.9682, 0.9505, 0.9808, 0.9394]])

After Inference: Running Mean tensor([[ 0.0036, -0.0325,  0.0338, -0.0244,  0.0169]])
After Inference: Running Variance tensor([[0.9529, 0.9682, 0.9505, 0.9808, 0.9394]])
