# Setup

In [52]:
# Import libraries

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision import datasets, transforms

In [55]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [143]:
# Load data

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root="../data", train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root="../data", train=False, transform=transform, download=True)

In [None]:
batch_size = 64
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Training

In [146]:
# Model

class DigitNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, X):
        x = self.pool1(F.relu(self.conv1(X))) # 1x28x28 --> convolution --> 32x28x28 -->  maxpool --> 32x14x14
        x = self.pool2(F.relu(self.conv2(x))) # 32x14x14 --> convolution --> 64x14x14 --> maxpool --> 64x7x7
        x = x.view(-1, 64 * 7 * 7) # 64x7x7 --> 3136
        x = F.relu(self.fc1(x)) # 3136 --> 128
        x = self.fc2(x) # 128 --> 10
        return x

In [147]:
model = DigitNet().to(device)
model

DigitNet(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=3136, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

In [157]:
print(f"Parameter count: {sum(p.numel() for p in model.parameters())}") 

Parameter count: 421642


In [148]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(),
                             lr=0.001)

In [152]:
def train_step():
    model.train()

    train_loss = 0.0
    train_acc = 0.0

    for batch, (X, y) in enumerate(train_dataloader):
        X = X.to(device)
        y = y.to(device)

        # Forward pass
        y_logits = model(X)
        y_pred_probs = torch.softmax(y_logits, dim=1)
        y_preds = torch.argmax(y_pred_probs, dim=1)

        # Loss
        loss = criterion(y_logits, y)
        train_loss += loss

        # Accuracy
        acc = (y_preds == y).sum().item() / len(y)
        train_acc += acc
        
        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    train_loss /= len(train_dataloader)
    train_acc /= len(train_dataloader)
    return train_loss, train_acc

In [153]:
def test_step():
    model.eval()

    test_loss = 0.0
    test_acc = 0.0

    for batch, (X, y) in enumerate(test_dataloader):
        X = X.to(device)
        y = y.to(device)
        
        # Forward pass
        y_logits = model(X)
        y_pred_probs = torch.softmax(y_logits, dim=1)
        y_preds = torch.argmax(y_pred_probs, dim=1)

        # Loss
        loss = criterion(y_logits, y)
        test_loss += loss

        # Accuracy
        acc = (y_preds == y).sum().item() / len(y)
        test_acc += acc
    
    test_loss /= len(test_dataloader)
    test_acc /= len(test_dataloader)
    return test_loss, test_acc

In [None]:
# Initialize metrics

metrics = {}
metrics["model"] = str(model)
metrics["train_losses"] = []
metrics["train_accuracies"] = []
metrics["test_losses"] = []
metrics["test_accuracies"] = []

metrics

In [None]:
# Training

epochs = 3

for epoch in range(epochs):
    train_loss, train_acc = train_step()
    test_loss, test_acc = test_step()

    metrics["train_losses"].append(train_loss.item())
    metrics["train_accuracies"].append(train_acc)
    metrics["test_losses"].append(test_loss.item())
    metrics["test_accuracies"].append(test_acc)

    if epoch % (epochs/10) == 0 or epoch == epochs-1:
        print(f"Epoch: {epoch} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%")

In [None]:
def plot_metrics(metrics: dict):
    x = range(0, len(metrics["train_losses"]))

    plt.figure(figsize=(15, 6))

    plt.subplot(2, 1, 1)
    plt.title("Loss")
    plt.plot(x, metrics["train_losses"], label="Train Losses")
    plt.plot(x, metrics["test_losses"], label="Test Losses")
    plt.legend()

    plt.subplot(2, 1, 2)
    plt.title("Accuracy")
    plt.plot(x, metrics["train_accuracies"], label="Train Accuracies")
    plt.plot(x, metrics["test_accuracies"], label="Test Accuracies")
    plt.legend()

    plt.show()

plot_metrics(metrics)