In [None]:
# MNIST classification with Batch Normalization and various model architectures

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as T
import torchvision.datasets as datasets

from torch.autograd import Variable
from torch.utils import data
from torch.utils.data import DataLoader

import seaborn as sn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import itertools

# Define transformation with normalization to scale pixel values between -1 and 1
transform = T.Compose([T.ToTensor(), T.Normalize((0.5,), (0.5,))])

# Load MNIST dataset
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)  # Training set
testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)  # Test set

# Create data loaders to feed data in batches
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)  # Shuffle for training
testloader = DataLoader(testset, batch_size=32, shuffle=False)   # No shuffle for evaluation

# Define Trainer class to handle training and testing
class Trainer():
    def __init__(self, trainloader, testloader, net, optimizer, criterion):
        self.trainloader = trainloader
        self.testloader = testloader
        self.net = net
        self.optimizer = optimizer
        self.criterion = criterion

    def train(self, epoch=100):  # Train for given number of epochs
        self.net.train()
        for e in range(epoch):
            running_loss = 0.0
            for i, data in enumerate(self.trainloader, 0):
                inputs, labels = data[0].cuda(), data[1].cuda()
                self.optimizer.zero_grad()           # Clear previous gradients
                output = self.net(inputs)            # Forward pass
                loss = self.criterion(output, labels)  # Compute loss
                loss.backward()                      # Backpropagation
                self.optimizer.step()                # Update parameters
                running_loss += loss.item()
                if i % 500 == 0:
                    print('[%d, %5d] loss: %.3f' % (e + 1, i + 1, running_loss / 500))
                    running_loss = 0.0
                    self.test()  # Optionally evaluate during training
        print('Finished Training')

    def test(self):  # Evaluate on test data
        self.net.eval()
        correct = 0
        for inputs, labels in self.testloader:
            inputs, labels = inputs.cuda(), labels.cuda()
            output = self.net(inputs)
            pred = output.max(1, keepdim=True)[1]  # Predicted label
            correct += pred.eq(labels.view_as(pred)).sum().item()
        print('\nTest set: Accuracy: {}/{} ({:.0f}%)\n'.format(
            correct, len(self.testloader.dataset), 100. * correct / len(self.testloader.dataset)))



# --------- Model 1: 2-Layer Fully Connected Network with BatchNorm ---------

class MNIST_Net(nn.Module):
    def __init__(self):
        super(MNIST_Net, self).__init__()
        self.fc0 = nn.Linear(28*28, 30)      # Fully connected layer from input to hidden
        self.bn0 = nn.BatchNorm1d(30)        # Batch normalization on hidden layer
        self.fc1 = nn.Linear(30, 10)         # Output layer for 10 classes
        self.act = nn.ReLU()                 # ReLU activation function

    def forward(self, x):
        x = x.view(-1, 28*28)                # Flatten image
        x = self.fc0(x)
        x = self.bn0(x)                      # Apply batch normalization
        x = self.act(x)
        x = self.fc1(x)
        return x

mnist_net = MNIST_Net().cuda()
criterion = nn.CrossEntropyLoss()            # Loss function for classification
optimizer = optim.Adam(mnist_net.parameters(), lr=0.001)  # Adam optimizer

trainer = Trainer(trainloader=trainloader,
                  testloader=testloader,
                  net=mnist_net,
                  criterion=criterion,
                  optimizer=optimizer)

trainer.train(epoch=4)
trainer.test()

# Count the number of trainable parameters in the model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(mnist_net)



# --------- Model 2: Conv + Fully Connected + BatchNorm ---------

class MNIST_Net(nn.Module):
    def __init__(self):
        super(MNIST_Net, self).__init__()
        self.conv0 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=6, stride=2)  # Convolutional layer
        self.conv0_bn = nn.BatchNorm2d(8)      # Batch norm after conv
        self.fc0 = nn.Linear(8*12*12, 10)      # Fully connected to output
        self.act = nn.ReLU()                   # Activation function

    def forward(self, x):
        x = self.conv0(x)
        x = self.conv0_bn(x)
        x = self.act(x)
        x = x.view(x.shape[0], -1)             # Flatten
        x = self.fc0(x)
        return x

mnist_net = MNIST_Net().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(mnist_net.parameters(), lr=0.001)

trainer = Trainer(trainloader=trainloader,
                  testloader=testloader,
                  net=mnist_net,
                  criterion=criterion,
                  optimizer=optimizer)

trainer.train(epoch=4)
trainer.test()
count_parameters(mnist_net)

# --------- Model 3: Conv + Pool + Fully Connected + BatchNorm ---------

class MNIST_Net(nn.Module):
    def __init__(self):
        super(MNIST_Net, self).__init__()
        self.conv0 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=6, stride=2)  # Conv layer
        self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2)  # Max pooling
        self.conv0_bn = nn.BatchNorm2d(8)      # Batch norm
        self.fc0 = nn.Linear(8*6*6, 10)         # Fully connected output
        self.act = nn.ReLU()                   # Activation

    def forward(self, x):
        x = self.conv0(x)
        x = self.pool0(x)                      # Apply pooling
        x = self.conv0_bn(x)                   # Apply batch norm
        x = self.act(x)
        x = x.view(x.shape[0], -1)             # Flatten
        x = self.fc0(x)
        return x

mnist_net = MNIST_Net().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(mnist_net.parameters(), lr=0.001)

trainer = Trainer(trainloader=trainloader,
                  testloader=testloader,
                  net=mnist_net,
                  criterion=criterion,
                  optimizer=optimizer)

trainer.train(epoch=4)
trainer.test()