## Step 0: Colab env
Uploading and installing dependencies

In [None]:
!pip install archspec==0.2.3 \
asttokens==2.4.1 \
boltons==23.0.0 \
Brotli==1.0.9 \
charset-normalizer==2.0.4 \
colorama==0.4.6 \
comm==0.2.2 \
cryptography==42.0.5 \
decorator==4.4.2 \
distro==1.9.0 \
executing==2.0.1 \
frozendict==2.4.2 \
idna==3.7 \
jedi==0.19.1 \
jsonpatch==1.33 \
jsonpointer==2.1 \
nest-asyncio==1.6.0 \
packaging==23.2 \
parso==0.8.4 \
platformdirs==4.2.1 \
pluggy==1.0.0 \
pure-eval==0.2.2 \
pycosat==0.6.6 \
pycparser==2.21 \
Pygments==2.18.0 \
PySocks==1.7.1 \
requests==2.32.3 \
ruamel.yaml==0.17.21 \
stack-data==0.6.3 \
tqdm==4.66.4 \
traitlets==5.14.3 \
truststore==0.8.0 \
urllib3==2.2.2 \
zstandard==0.22.0

## Step 1: Import Libraries and Set Up the Environment
Starting with libs and setting up device. If possible use GPU

In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast

torch.set_default_dtype(torch.float64)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## Step 2: Load and Preprocess the MNIST Dataset
We're loading MNIST dataset and apply all transformation and normalization

In [11]:
#Loading MNIST dataset

transform = transforms.Compose([
    transforms.ToTensor(),  # PIL Image -> Tensor
    transforms.Normalize((0.1307,), (0.3081,)) #normalizing dataset 
    ]) 

#Download MNIST dataset and used the transform to convert the data to tensor and normalize it
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

#Create DataLoader to batch and shuffle the data using mulitprocessing
train_loader = DataLoader(dataset=train_dataset, batch_size=4096, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=1024, shuffle=False, num_workers=4, pin_memory=True)

## Step 3: Initialize the MLP Model
Here we create our MLP model with dimension for MNIST dataset.
Input set is only 784 pixels (images 28x28) and output size is 10
Parameters

 

In [12]:
class MLP(nn.Module):
    def __init__(self, width):
        super(MLP, self).__init__()
        layers = []
        for i in range(len(width) - 1):
            layers.append(nn.Linear(width[i], width[i+1]))
            if i < len(width) - 2:
                layers.append(nn.ReLU())
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

In [13]:
# Instantiate the MLP model
model = MLP(width=[784, 32, 16, 10]).to(device)
#model = MLP(width=[392, 16, 8, 7]).to(device)

## Step 4: Define the Loss Function and Optimizer
CrossEntropyLoss for classification and Adam optimizer for training model 

In [14]:
# Define loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.002)

## Step 5: Train the Model
This cell contains the training loop where the model is trained for a specified number of epochs. 
The loss is printed after each epoch to monitor the training process.

In [None]:
!nvidia-smi

In [None]:
# Training loop parameters
train_losses = []
test_accuracies = []
scaler = torch.amp.GradScaler('cuda')

num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in tqdm(train_loader, desc=f'MLP Epoch [{epoch+1}/{num_epochs}]'):
        images = images.view(images.size(0), -1).to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        with torch.amp.autocast('cuda'):
            outputs = model(images)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()

        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    # Save training loss
    train_losses.append(running_loss / len(train_loader))

    # Evaluate the model
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.view(images.size(0), -1).to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    test_accuracy = 100 * correct / total
    test_accuracies.append(test_accuracy)

    print(f'MLP Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}, Test Accuracy: {test_accuracy:.2f}%')

# Plotting
plt.figure(figsize=(12, 4))

# Loss Plot
plt.subplot(1, 2, 1)
plt.plot(range(1, num_epochs+1), train_losses, label='Training Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.legend()

# Accuracy Plot
plt.subplot(1, 2, 2)
plt.plot(range(1, num_epochs+1), test_accuracies, label='Test Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.title('Test Accuracy over Epochs')
plt.legend()

plt.show()

## Step 6: Evaluate the Model
Finally, we evaluate the trained model on the test set to see how well it performs. The accuracy is calculated and printed.


In [None]:
# Visualize a few test set predictions
import matplotlib.pyplot as plt
import numpy as np

model.eval()
images, labels = next(iter(test_loader))
images = images.view(images.size(0), -1).to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)

# Show a few images and their predictions
images = images.view(-1, 28, 28).cpu().numpy()
labels = labels.cpu().numpy()
predicted = predicted.cpu().numpy()

fig, axes = plt.subplots(2, 5, figsize=(12, 5))
for i, ax in enumerate(axes.flat):
    ax.imshow(images[i], cmap='gray')
    ax.set_title(f'Pred: {predicted[i]}, True: {labels[i]}')
    ax.axis('off')

plt.show()
