# Introdution

Batch normalization is a technique that improves the training of deep neural networks by normalizing the inputs to each layer. It helps stabilize learning, allows for higher learning rates, and can reduce the sensitivity to the initial weights. Below is an example using the synthetic dataset we previously generated, where we implement batch normalization in a neural network.


The following equations describe the process of Batch Normalization:

### Compute the Mean

For each feature (across the batch), the mean is calculated as:

\begin{equation}
\mu_B = \frac{1}{m} \sum_{i=1}^{m} x_i
\end{equation}

where:
- $x_i$ is the input feature of the $i$-th example in the mini-batch.
- $m$ is the number of examples in the mini-batch.

### Compute the Variance

The variance for each feature (across the batch) is calculated as:

\begin{equation}
\sigma_B^2 = \frac{1}{m} \sum_{i=1}^{m} (x_i - \mu_B)^2
\end{equation}

### Normalize the Input

Each input feature is normalized using the computed mean and variance:

\begin{equation}
\hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}
\end{equation}

where:
- $\epsilon$ is a small constant added to the variance to prevent division by zero.

### Scale and Shift

Finally, the normalized input is scaled and shifted using learnable parameters $\gamma$ and $\beta$:

\begin{equation}
y_i = \gamma \hat{x}_i + \beta
\end{equation}

where:
- $\gamma$ and $\beta$ are learned during training and allow the model to adjust the normalization as needed.

# nn.BatchNorm1d

In [9]:
import torch
import torch.nn as nn

# Simple example tensor to simulate a batch of data with 3 features
input_tensor = torch.tensor([[1.0, 2.0, 3.0],
                             [4.0, 5.0, 6.0],
                             [4.0, 5.0, 6.0],
                             [4.0, 5.0, 6.0],
                             [4.0, 5.0, 6.0],
                             [4.0, 5.0, 6.0],
                             [7.0, 8.0, 9.0]])

# Define a neural network with a single hidden layer using BatchNorm1d
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(3, 3)          # Linear layer
        self.bn1 = nn.BatchNorm1d(3)        # BatchNorm1d for 3 features
        self.relu = nn.ReLU()               # ReLU activation function

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)  # Apply batch normalization after the linear layer
        x = self.relu(x)
        return x

# Instantiate the network
model = SimpleNet()

# Forward pass through the network
output = model(input_tensor)

print("Input Tensor:")
print(input_tensor)
print("\nOutput Tensor after BatchNorm1d and ReLU:")
print(output)


Input Tensor:
tensor([[1., 2., 3.],
        [4., 5., 6.],
        [4., 5., 6.],
        [4., 5., 6.],
        [4., 5., 6.],
        [4., 5., 6.],
        [7., 8., 9.]])

Output Tensor after BatchNorm1d and ReLU:
tensor([[1.8708e+00, 1.8708e+00, 1.8708e+00],
        [4.2488e-08, 0.0000e+00, 0.0000e+00],
        [4.2488e-08, 0.0000e+00, 0.0000e+00],
        [4.2488e-08, 0.0000e+00, 0.0000e+00],
        [4.2488e-08, 0.0000e+00, 0.0000e+00],
        [4.2488e-08, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00]], grad_fn=<ReluBackward0>)


# Demo: MNIST 

In [5]:
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np

def create_datasets(batch_size):

    # percentage of training set to use as validation
    valid_size = 0.2

    # convert data to torch.FloatTensor
    transform = transforms.ToTensor()

    # choose the training and test datasets
    train_data = datasets.MNIST(root='data', 
                                train=True,
                                download=True, 
                                transform=transform)

    test_data = datasets.MNIST(root='data',
                               train=False,
                               download=True,
                               transform=transform)

    # obtain training indices that will be used for validation
    num_train = len(train_data)
    indices = list(range(num_train))
    np.random.shuffle(indices)
    split = int(np.floor(valid_size * num_train))
    train_idx, valid_idx = indices[split:], indices[:split]
    
    # define samplers for obtaining training and validation batches
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)
    
    # load training data in batches
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=batch_size,
                                               sampler=train_sampler,
                                               num_workers=0)
    
    # load validation data in batches
    valid_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=batch_size,
                                               sampler=valid_sampler,
                                               num_workers=0)
    
    # load test data in batches
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=batch_size,
                                              num_workers=0)
    
    return train_loader, test_loader, valid_loader

In [6]:
# Hyperparameters
batch_size = 64
learning_rate = 0.001
num_epochs = 10

train_loader, test_loader, valid_loader = create_datasets(batch_size)

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# Neural Network with Batch Normalization
class NeuralNet(nn.Module):
    def __init__(self):
        super(NeuralNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(28*28, 256),
            nn.BatchNorm1d(256),  # Batch Normalization
            nn.ReLU()
        )
        self.layer2 = nn.Sequential(
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),  # Batch Normalization
            nn.ReLU()
        )
        self.layer3 = nn.Sequential(
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),   # Batch Normalization
            nn.ReLU()
        )
        self.output_layer = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten the input tensor
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.output_layer(x)
        return x

# Model, Loss, and Optimizer
model = NeuralNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training Loop
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    for i, (images, labels) in enumerate(train_loader):
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

# Testing the Model
model.eval()  # Set the model to evaluation mode
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f'Accuracy of the model on the 10000 test images: {100 * correct / total:.2f}%')


Epoch [1/10], Step [100/750], Loss: 0.4536
Epoch [1/10], Step [200/750], Loss: 0.3959
Epoch [1/10], Step [300/750], Loss: 0.1229
Epoch [1/10], Step [400/750], Loss: 0.2238
Epoch [1/10], Step [500/750], Loss: 0.1535
Epoch [1/10], Step [600/750], Loss: 0.1737
Epoch [1/10], Step [700/750], Loss: 0.2120
Epoch [2/10], Step [100/750], Loss: 0.0279
Epoch [2/10], Step [200/750], Loss: 0.1405
Epoch [2/10], Step [300/750], Loss: 0.0140
Epoch [2/10], Step [400/750], Loss: 0.0823
Epoch [2/10], Step [500/750], Loss: 0.0457
Epoch [2/10], Step [600/750], Loss: 0.0279
Epoch [2/10], Step [700/750], Loss: 0.0694
Epoch [3/10], Step [100/750], Loss: 0.0092
Epoch [3/10], Step [200/750], Loss: 0.0478
Epoch [3/10], Step [300/750], Loss: 0.0518
Epoch [3/10], Step [400/750], Loss: 0.2893
Epoch [3/10], Step [500/750], Loss: 0.0238
Epoch [3/10], Step [600/750], Loss: 0.0460
Epoch [3/10], Step [700/750], Loss: 0.1030
Epoch [4/10], Step [100/750], Loss: 0.0536
Epoch [4/10], Step [200/750], Loss: 0.0299
Epoch [4/10