This notebook compares the performance of a neural network that uses Batch Normalization to one that does not.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# Transformations for the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# Loading MNIST dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

# Define a simple neural network
class Net(nn.Module):
    def __init__(self, use_batch_norm=False):
        super(Net, self).__init__()
        self.use_batch_norm = use_batch_norm

        self.fc1 = nn.Linear(28*28, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 10)

        # Batch Normalization layers
        if use_batch_norm:
            self.bn1 = nn.BatchNorm1d(512)
            self.bn2 = nn.BatchNorm1d(256)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))

        # Apply batch normalization if specified
        if self.use_batch_norm:
            x = self.bn1(x)

        x = torch.relu(self.fc2(x))

        # Apply batch normalization if specified
        if self.use_batch_norm:
            x = self.bn2(x)

        x = self.fc3(x)
        return x

# Function to train the model
def train_model(model, criterion, optimizer, epochs=3):
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 100 == 99:
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 100))
                running_loss = 0.0

# Training without batch normalization
model_no_bn = Net(use_batch_norm=False)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model_no_bn.parameters(), lr=0.01, momentum=0.9)

print("Training without Batch Normalization:")
train_model(model_no_bn, criterion, optimizer)

# Training with batch normalization
model_with_bn = Net(use_batch_norm=True)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model_with_bn.parameters(), lr=0.01, momentum=0.9)

print("\nTraining with Batch Normalization:")
train_model(model_with_bn, criterion, optimizer)

Training without Batch Normalization:
[1,   100] loss: 1.166
[1,   200] loss: 0.446
[1,   300] loss: 0.361
[1,   400] loss: 0.323
[1,   500] loss: 0.288
[1,   600] loss: 0.267
[1,   700] loss: 0.231
[1,   800] loss: 0.246
[1,   900] loss: 0.210
[2,   100] loss: 0.200
[2,   200] loss: 0.174
[2,   300] loss: 0.171
[2,   400] loss: 0.157
[2,   500] loss: 0.155
[2,   600] loss: 0.145
[2,   700] loss: 0.135
[2,   800] loss: 0.146
[2,   900] loss: 0.145
[3,   100] loss: 0.132
[3,   200] loss: 0.127
[3,   300] loss: 0.109
[3,   400] loss: 0.117
[3,   500] loss: 0.110
[3,   600] loss: 0.106
[3,   700] loss: 0.099
[3,   800] loss: 0.094
[3,   900] loss: 0.095

Training with Batch Normalization:
[1,   100] loss: 0.464
[1,   200] loss: 0.253
[1,   300] loss: 0.219
[1,   400] loss: 0.172
[1,   500] loss: 0.183
[1,   600] loss: 0.152
[1,   700] loss: 0.161
[1,   800] loss: 0.149
[1,   900] loss: 0.143
[2,   100] loss: 0.118
[2,   200] loss: 0.106
[2,   300] loss: 0.098
[2,   400] loss: 0.102
[2,   

In [None]:
# Transformations for the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# Loading MNIST test dataset
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

# Function to calculate accuracy
def get_accuracy(model, dataloader):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataloader:
            inputs, labels = data
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

# Print accuracy for both models
print("\nAccuracy without Batch Normalization: {:.2%}".format(get_accuracy(model_no_bn, testloader)))
print("Accuracy with Batch Normalization: {:.2%}".format(get_accuracy(model_with_bn, testloader)))



Accuracy without Batch Normalization: 96.77%
Accuracy with Batch Normalization: 97.33%


Slightly faster convergence with better performance