This notebook tests a CNN training network and constructs a centralised baseline for the MNIST dataset

In [None]:
from datasets import load_dataset
from matplotlib import pyplot as plt
import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm

from common.nn import MnistClassifier

In [None]:
# Load Dataset
dataset = load_dataset("mnist").with_format("torch")
train = dataset["train"]
test = dataset["test"]
train_dataloader = DataLoader(train, batch_size=32)
test_dataloader = DataLoader(test, batch_size=32)

In [None]:
# Visualise data
sample_feature, sample_label = train[0]['image'], train[0]["label"]
plt.imshow(sample_feature)
print(f"Feature shape: {sample_feature.shape} | Label {sample_label}")

In [None]:
# Train the Classifier
torch.manual_seed(0)
net = MnistClassifier()
cross_entropy_loss = nn.CrossEntropyLoss()
optim = torch.optim.Adam(net.parameters())

losses = []
for epoch in tqdm(range(10)):
    c = 0
    for batch in train_dataloader:
        c+=1
        if c > 100:
            break
        X, y = batch["image"], batch["label"]
        # Predict Labels
        py = net(X)
        # Loss
        loss = cross_entropy_loss(py, y)
        # Step
        optim.zero_grad()
        loss.backward()
        losses.append(loss.item())
        optim.step()

In [None]:
plt.plot(losses)

In [None]:
average_loss, correct = 0, 0
size, num_batches = len(test_dataloader.dataset), len(test_dataloader)
for batch in test_dataloader:
    X, y = batch["image"], batch["label"]
    # Predict Labels
    py = net(X)
    # Loss
    average_loss += cross_entropy_loss(py, y).item()
    # Correct
    correct += (py.argmax(1) == y).type(torch.float).sum().item()
average_loss = average_loss / num_batches
accuracy = correct / size