# Comparing Deep Learning Architectures on MNIST: Accuracy vs Energy Efficiency

## Abstract
This tutorial demonstrates the implementation and comparison of four neural network architectures- Multilayer Perceptron(MLP), Convolutional Neural Network(CNN), Recurrent Neural Network(RNN), and Spiking Neural Network(SNN)-on the MNIST handwritten digital dataset. These models are evaluated on basis of accuracy and computational efficiency( as a proxy for energy usage) while exploring hyper-parameter impacts. This tutorial makes use of the common machine learning libraries such as PyTorch, SNNTorch, Matplotlib and Numpy to provide a clear pipeline for training, visualisation and critical analysis. By the end, the readers will understand trade-offs between the model's complexity, accuracy and efficiency in resource constrained environments.

## Learning Objectives
1. Implement MLP, CNN, RNN, and SNN models using Pytorch.
2. Train and evaluate models on MNIST
3. Analyse hyper-parameter effects on accuracy and training time
4. Visualise results with graphs and tables

## Table of Contents

1. [Setup: Installing Libraries](#1-setup-installing-libraries)
2. [Loading and Preprocessing MNIST Data](#2-loading-and-preprocessing-mnist-data)
3. Model Architectures
   - [3.1 MLP](#31-mlp)
   - [3.2 CNN](#32-cnn)
   - [3.3 RNN](#33-rnn)
   - [3.4 SNN](#34-snn)
4. [Training Loop and Metrics](#4-training-loop-and-metrics)
5. [Hyper-parameter Analysis](#5-hyper-parameter-analysis)
6. [Results: Accuracy vs. Training Time](#6-results-accuracy-vs-training-time)
7. [Conclusion](#7-conclusion)
8. [Future Work](#8-future-work)


<h3>Differences from Existing Tutorials</h3>

<table style="width:100%; text-align: left; border-collapse: collapse; border: 2px solid black;">
    <tr>
        <th style="border: 2px solid black; padding: 8px;">Aspect</th>
        <th style="border: 2px solid black; padding: 8px;">Existing Tutorials</th>
        <th style="border: 2px solid black; padding: 8px;">This Tutorial</th>
    </tr>
    <tr>
        <td style="border: 2px solid black; padding: 8px;"><b>Models Covered</b></td>
        <td style="border: 2px solid black; padding: 8px;">Typically MLP, CNN, or RNN</td>
        <td style="border: 2px solid black; padding: 8px;">Adds SNN for bio-inspired efficiency analysis</td>
    </tr>
    <tr>
        <td style="border: 2px solid black; padding: 8px;"><b>Evaluation Metrics</b></td>
        <td style="border: 2px solid black; padding: 8px;">Focus on accuracy</td>
        <td style="border: 2px solid black; padding: 8px;">Includes training time as energy efficiency proxy</td>
    </tr>
    <tr>
        <td style="border: 2px solid black; padding: 8px;"><b>Hyperparameters</b></td>
        <td style="border: 2px solid black; padding: 8px;">Limited to learning rate/epochs</td>
        <td style="border: 2px solid black; padding: 8px;">Tests optimizer choices and layer configurations</td>
    </tr>
    <tr>
        <td style="border: 2px solid black; padding: 8px;"><b>Visualization</b></td>
        <td style="border: 2px solid black; padding: 8px;">Basic accuracy plots</td>
        <td style="border: 2px solid black; padding: 8px;">Comparative tables and multi-model graphs</td>
    </tr>
</table>


### 1. Setup: Installing Libraries
Install the required libraries if necessary through pip installer using the following command: '**pip install torch torchvision matplotlib numpy snntorch**' and then we can import the libraries.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import time
import snntorch as snn  # For SNN

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

### 2. Loading and Preprocessing MNIST Data
The following code downloads the MNIST data from the MNIST servers and preprocesses it. It is essential to use some part of the data to train the models and the rest to test it, therefore the data is split into train and test data.

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) #The normalization values `mean=0.1307` and `std=0.3081` are precomputed statistics for the MNIST dataset. They are derived from the entire training set and widely adopted in the machine learning community (LeCun et al., 1998; PyTorch Documentation).
])

train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, transform=transform)

batch_size = 64
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

### 3. Model Architectures
#### 3.1 MLP
An MLP is a foundational neural network architecture with fully connected("dense") layers. It processes input data through sequential linear transformations and non-linear activations[[1]](#References). Simple MLP's however, struggle with spatial data like images due to their lack of inductive bias for grid structure.<br>
The MLP is structured in the following way:<br>
` Flattens the 28×28 MNIST image into a 784-element vector` → `512-neurons hidden layer(ReLU activation)` → `10 neurons (Output)`


In [None]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Flatten(),               # Converts 28x28 image to 1D vector
            nn.Linear(28*28, 512),      # Fully connected layer
            nn.ReLU(),                  # Non-linear activation
            nn.Linear(512, 10)          # Output Layer
        )
    def forward(self, x):
        return self.layers(x)

#### 3.2 CNN
CNNs excel at image tasks by leveraging convolutional layers to detect spatial patterns(edges, textures) hierarchically. They use parameter-sharing(kernels) and pooling to reduce dimensionality while preserving spatial relationships[[2]](#references).<br>
The CNN is structured in the following way:<br>
`Conv2D (3×3, 32 channels)` → `ReLU` → `MaxPool` → `Conv2D (3×3, 64 channels)` → `ReLU` → `MaxPool` → `Flatten` → `Dense Layer`


In [None]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(1, 32, 3),  # 1 input channel (grayscale), 32 output channels
            nn.ReLU(), 
            nn.MaxPool2d(2),      # Reduces feature map size by half   
            nn.Conv2d(32, 64, 3), # Deeper feature extraction
            nn.ReLU(),
            nn.MaxPool2d(2),       
            nn.Flatten(),         # Prepare for dense layer
            nn.Linear(64*5*5, 10) # Final classification
        )
    def forward(self, x):
        return self.layers(x)

#### 3.3 RNN
RNNs process sequential data(e.g., time series, text) by maintaining a hidden state that captures temporal dependancies[[3]](#references). Here, we treat each row of the MNIST image as a "time step" to demonstrate RNN flexibility.<br>
The RNN is implemented using the following structure:<br>
`Input image (28×28)` → `Squeeze channel dimension` → `Process 28 rows as a sequence (28 time steps)` → `RNN layer (hidden size 128)` → `Final hidden state` → `Dense layer`


In [None]:
class RNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn = nn.RNN(28, 128, batch_first=True) # Input size 28 (per row), hidden size 128
        self.fc = nn.Linear(128, 10)
    def forward(self, x):
        x = x.squeeze(1)  # Remove channel dimension
        out, _ = self.rnn(x) # Process rows as a sequence
        return self.fc(out[:, -1, :]) # Use last time step's output

#### 3.4 SNN
SNNs mimic biological neurones by transmitting information via spikes over time. They are highly energy efficient(sparse activations) and suitable neuromorphic hardware[[4]](#references). In the implementation shown below, we make use of 'snntorch' which is a python native library for SNNs.<br>
The SNN follows the structure as shown below:<br>
`Flattened input` → `Dense layer` → `Leaky integrate-and-fire (LIF) neuron` → `Dense output`

In [None]:
class SNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 512)
        self.lif1 = snn.Leaky(beta=0.9) #LIF neuron model
        self.fc2 = nn.Linear(512, 10)
    def forward(self, x):
        x = x.view(-1, 28*28)
        mem = self.lif1.init_leaky() # Initialize membrane potential
        spk, mem = self.lif1(self.fc1(x), mem) # Update membrane/spike
        return self.fc2(spk) # Classify spikes

### 4. Training Loop and Metrics
The training loop is the core process where the model learns from data. It involves:<br>
1. **Forward pass**: which computes predictions.
2. **Loss calculation**: Measures error using a loss function.
3. **Backward pass**: Computes gradients via backpropagation.
4. **Optimisation**: Updates model weights using an optimiser(eg: Adam).<br>

In this tutorial we also track training time as a proxy for energy efficiency, since shorter training times often correlate with lower computational and energy costs.

In [None]:
def train_model(model, optimizer, epochs=5):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    train_time = 0  # Track total training time

    for epoch in range(epochs):
        start_time = time.time()
        for X, y in train_loader:
            X, y = X.to(device), y.to(device)  # Move data to GPU/CPU
            optimizer.zero_grad()              # Reset gradients
            output = model(X)                  # Forward pass
            loss = criterion(output, y)        # Compute loss
            loss.backward()                    # Backpropagation
            optimizer.step()                   # Update weights
        epoch_time = time.time() - start_time
        train_time += epoch_time
        print(f"Epoch {epoch+1}/{epochs}, Time: {epoch_time:.2f}s")
    
    return train_time  # Return total training time

#### Evaluation Function
After training, we evaluate model performance on the test set to measure generalisation accuracy. This ensures the model isn't overfitting to the training data.

In [None]:
def test_model(model):
    model.eval()  # Set model to evaluation mode (disables dropout/batchnorm)
    correct = 0
    
    with torch.no_grad():  # Disable gradient tracking for efficiency
        for X, y in test_loader:
            X, y = X.to(device), y.to(device)
            output = model(X)
            pred = output.argmax(dim=1)  # Get predicted class
            correct += (pred == y).sum().item()  # Count correct predictions
    
    return correct / len(test_data)  # Return accuracy

### 5. Hyper-parameter Analysis
Hyper-parameters are settings that control the learning process. In this tutorial, we analyse:<br>
1. Optimiser Choice(e.g., Adam vs SGD).
2. Learning Rate(step size for weight updates).
3. Batch Size(number of samples processed per iterations).<br>

By default, all models use `Adam` with `lr=0.001` and `batch_size=64`. However you can can use the following code to test different variations:

In [None]:
# Example: Testing different learning rates
learning_rates = [0.001, 0.01, 0.1]
for lr in learning_rates:
    optimizer = optim.Adam(model.parameters(), lr=lr)
    train_model(model, optimizer)

Use this code with default Adam optimiser with learning rate of 0.001:

In [None]:
models = {'MLP': MLP(), 'CNN': CNN(), 'RNN': RNN(), 'SNN': SNN()}
results = []

for name, model in models.items():
    # Default: Adam optimizer, lr=0.001
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    train_time = train_model(model, optimizer)
    accuracy = test_model(model)
    results.append({
        'Model': name,
        'Accuracy': accuracy,
        'Training Time (s)': train_time
    })
    print(f"{name}: Accuracy={accuracy:.3f}, Time={train_time:.1f}s")

### 6. Results: Accuracy vs. Training Time
In order make the results more clear we will be generating graphs that represent the accuracy and training times of the deep learning architectures. The following code generates the graphs and table:

In [None]:
import pandas as pd

# Generate comparison table
df = pd.DataFrame(results)
print("Performance Comparison:\n", df)

# Plot results
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
df.plot(x='Model', y='Accuracy', kind='bar', ax=ax[0], title='Accuracy Comparison')
df.plot(x='Model', y='Training Time (s)', kind='bar', ax=ax[1], title='Training Time Comparison')
plt.show()

### 7. Conclusion
After running all the above cells, we can make the following conclusions:
- **MLP**: It has the low training time(i.e. faster training time) but the accuracy is the lowest due to limited spatial processing.
- **CNN**: Highest accuracy but it has the slowest training time due to its convolutional operations.
- **RNN**: Performs moderately well as it's designed for sequential data.
- **SNN**: It might be surprising as by definition this should have the least training time however, it balances out speed with accuracy. In order to achieve better speed or accuracy you could try out hybrid architectures of SNN as mentioned later in the [Future Work](#8-future-work) section.<br>

CNN performs the best in terms of accuracy, while MLP's/SNNs are faster.

### 8. Future Work
- Test more hyper-parameters and see how it affects the accuracy and training times(e.g., batch size, layer widths).
- Measure actual energy consumption using hardware tools.
- Explore hybrid models, like combining CNN-SNN in order to achieve high accuracy with low energy consumption. 

### References  
1. **MLP**:  
   Goodfellow, I., Bengio, Y., & Courville, A. (2016). [*Deep Learning*](https://www.deeplearningbook.org). MIT Press.

   PyTorch Linear Layer Documentation: https://pytorch.org/docs/stable/generated/torch.nn.Linear.html  

2. **CNN**:  
   LeCun, Y., Bottou, L., Bengio, Y., & Haffner, P. (1998). [Gradient-Based Learning Applied to Document Recognition](http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf). *Proceedings of the IEEE*.

   PyTorch Conv2d Documentation: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html

3. **RNN**:  
   Graves, A. (2012). [Supervised Sequence Labelling with Recurrent Neural Networks](https://www.cs.toronto.edu/~graves/preprint.pdf). Springer.

   PyTorch RNN Documentation: https://pytorch.org/docs/stable/generated/torch.nn.RNN.html

4. **SNN**:  
   Eshraghian, J. K., et al. (2021). [Training Spiking Neural Networks Using Lessons from Deep Learning](https://arxiv.org/abs/2109.12894). *arXiv*.

5. **PyTorch Documentation**:  
   [PyTorch Official Documentation](https://pytorch.org/docs/stable/index.html)  

6. **Adam Optimizer**:  
   Kingma, D. P., & Ba, J. (2014). [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980). *arXiv*.  

7. **MNIST Dataset**:  
   LeCun, Y., Cortes, C., & Burges, C. (1998). [The MNIST Database of Handwritten Digits](http://yann.lecun.com/exdb/mnist/).  
   [PyTorch MNIST Loading](https://pytorch.org/vision/stable/datasets.html#mnist)  

8. **SNNTorch Library**:  
   [SNNTorch Documentation](https://snntorch.readthedocs.io/)  