# Reference

https://machinelearningmastery.com/how-to-develop-a-cnn-from-scratch-for-cifar-10-photo-classification/
https://www.geeksforgeeks.org/cifar-10-image-classification-in-tensorflow/
https://www.stefanfiott.com/machine-learning/cifar-10-classifier-using-cnn-in-pytorch/
https://www.33rdsquare.com/convolutional-neural-network-pytorch-implementation-on-cifar10-dataset/
https://www.geeksforgeeks.org/building-a-convolutional-neural-network-using-pytorch/
https://ncl.instructure.com/courses/55046/files/8928967?module_item_id=3535263
https://ncl.instructure.com/courses/55046/files/8913068?module_item_id=3535259
https://ncl.instructure.com/courses/55046/files/8942227?module_item_id=3540677

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader

class NeuralNet(nn.Module):
    """ 
        My neural network model for CIFAR-10 classification created using different sources listed in the References section.
        @param batch_normalization: Whether to use batch normalization or not. It is used to answer Q2 
        
      """
    def __init__(self, batch_normalization=False):
        super(NeuralNet, self).__init__()
        self.batch_normalization = batch_normalization

        def conv_block(in_channels, out_channels):
            layers = [nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.ReLU()]
            if batch_normalization:
                layers.append(nn.BatchNorm2d(out_channels))
            return layers

        self.block1 = nn.Sequential(*conv_block(3, 32), *conv_block(32, 32), nn.MaxPool2d(2, 2))
        self.block2 = nn.Sequential(*conv_block(32, 64), *conv_block(64, 64), nn.MaxPool2d(2, 2))
        self.block3 = nn.Sequential(*conv_block(64, 128), *conv_block(128, 128), nn.MaxPool2d(2, 2))

        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.classifier(x)
        return x

def load_dataset():
    """ 
        Load CIFAR-10 dataset and return train and test loaders.
      """
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
    testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

    return trainloader, testloader


def count_parameters(model):
    """ 
        Count the total number of parameters in the model. It is used to keep track of the complexity of the model.
      """
    return sum(p.numel() for p in model.parameters())

def train_model(model, trainloader, testloader, epochs=50, device='cuda', early_stopping_patience=10):
    """ 
        Train the model using the given train and test loaders. It is used to train the model and return the best model. 

        @param model: The model to train
        @param trainloader: The DataLoader for the training set
        @param testloader: The DataLoader for the test set
        @param epochs: The number of epochs to train the model, default is 50
        @param device: The device to use for training, default is 'cuda'
        @param early_stopping_patience: The number of epochs to wait before early stopping, default is 10

        @info : loss function is CrossEntropyLoss and optimizer is Adam
      """
    loss_fun = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())

    model = model.to(device)
    train_losses, test_losses = [], []
    best_loss = float('inf')
    patience = 0
    best_model_state = None

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fun(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        train_loss = running_loss / len(trainloader)
        train_losses.append(train_loss)

        model.eval()
        test_loss = 0
        with torch.no_grad():
            for inputs, labels in testloader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = loss_fun(outputs, labels)
                test_loss += loss.item()

        test_loss /= len(testloader)
        test_losses.append(test_loss)

        if test_loss < best_loss:
            best_loss = test_loss
            patience = 0
            best_model_state = model.state_dict()
        else:
            patience += 1
            if patience >= early_stopping_patience:
                print("Early stopping triggered.")
                break

        print(f"Epoch {epoch + 1}/{epochs}: Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")

    model.load_state_dict(best_model_state)
    return train_losses, test_losses, model

def plot_convergence(train_losses, test_losses, title, filename):
    """ 
        Plot the convergence of the model during training. It is used to visualize the training process and compare my models.
      """
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(test_losses, label='Test Loss')
    plt.title(title)
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(filename)
    plt.show()

def plot_feature_filters(model, test_loader, device):
    """ 
        Plot the feature filters for each layer in the model. Used to visualize the feature filters of each convolutional layer. 
        We can see the dimensions of the feature maps and how they change as we go deeper into the network.
      """
    model.eval()
    test_images, _ = next(iter(test_loader))
    test_image = test_images[0].unsqueeze(0).to(device)

    with torch.no_grad():
        activations = []
        x = test_image
        for layer in list(model.block1) + list(model.block2) + list(model.block3):
            x = layer(x)
            if isinstance(layer, nn.ReLU):
                activations.append(x.cpu().numpy())

    for i, activation in enumerate(activations):
        plt.figure(figsize=(15, 15))
        num_filters = activation.shape[1]
        for j in range(min(num_filters, 16)):
            plt.subplot(4, 4, j + 1)
            plt.imshow(activation[0, j, :, :], cmap='viridis')
            plt.axis('off')
        plt.suptitle(f"Layer {i + 1} Feature Maps")
        plt.show()

def run_experiment():
    """ 
        Run the experiment to train the model with and without batch normalization. This is the main function to run. 
        It will get the data, train the models with different layers. It will plot the convergence and feature filters for the models. 
        
      """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')

    trainloader, testloader = load_dataset()

    print("Training without batch normalization...")
    model_without_bn = NeuralNet(batch_normalization=False)
    print(f"Total parameters: {count_parameters(model_without_bn)}")
    train_losses_no_bn, test_losses_no_bn, best_model_no_bn = train_model(
        model_without_bn, trainloader, testloader, epochs=50, device=device
    )
    torch.save(best_model_no_bn.state_dict(), 'model_without_bn.pth')

    print("Training with batch normalization...")
    model_with_bn = NeuralNet(batch_normalization=True)
    print(f"Total parameters: {count_parameters(model_with_bn)}")
    train_losses_with_bn, test_losses_with_bn, best_model_with_bn = train_model(
        model_with_bn, trainloader, testloader, epochs=50, device=device
    )
    torch.save(best_model_with_bn.state_dict(), 'model_with_bn.pth')

    # Plot convergence
    plot_convergence(
        train_losses_no_bn, test_losses_no_bn,
        "Model Convergence Without Batch Normalization", "convergence_no_bn.png"
    )
    plot_convergence(
        train_losses_with_bn, test_losses_with_bn,
        "Model Convergence With Batch Normalization", "convergence_with_bn.png"
    )

    # Plot feature filters
    print("Plotting feature filters from the best model with batch normalization...")
    plot_feature_filters(model_with_bn, testloader, device)

if __name__ == '__main__':
    run_experiment()
