## Step 1 - Packages

In [None]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from importlib import reload
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from torchvision.transforms.functional import pil_to_tensor
from torchsummary import summary
import pathlib
import sys

data_path = pathlib.Path().cwd() / ".." / "data"

In [None]:
print(f"Python version -> {sys.version}")  # 3.12.3
print(f"torch version -> {torch.__version__}")  # 2.8.0+cu128

## Step 2 - MNIST Data

In [None]:
mnist = torchvision.datasets.MNIST(root=data_path, transform=pil_to_tensor, train=True)
train, test = torch.utils.data.random_split(mnist, [0.7, 0.3])

train_loader = torch.utils.data.DataLoader(
    train, batch_size=512, shuffle=True, num_workers=4
)
test_loader = torch.utils.data.DataLoader(
    test, batch_size=512, shuffle=True, num_workers=4
)

In [None]:
print(len(train))

In [None]:
x_train = torch.empty(len(train), 28, 28)
y_train = torch.empty(len(train), 1)
for i, xytrain in enumerate(train):
    xi, yi = xytrain
    x_train[i] = xi.reshape(28, 28)
    y_train[i] = yi


x_test = torch.empty(len(test), 28, 28)
y_test = torch.empty(len(test), 1)
for i, xytest in enumerate(test):
    xi, yi = xytest
    x_test[i] = xi.reshape(28, 28)
    y_test[i] = yi

In [None]:
for i in range(9):
    plt.subplot(3, 3, i + 1)
    j = torch.randint(low=0, high=42000, size=(1,))
    digit = x_train[j]
    label = y_train[j]
    plt.imshow(digit.reshape(28, 28, 1))
    plt.title(label)
plt.tight_layout()

## Step 3 - Preparing the data

In [None]:
print("Before normalization : Min={}, max={}".format(x_train.min(), x_train.max()))

xmax = x_train.max()
x_train = x_train / xmax
x_test = x_test / xmax

print("After normalization  : Min={}, max={}".format(x_train.min(), x_train.max()))

## Step 4 - Create model

In [None]:
class ConvMNIST(torch.nn.Module):
    def __init__(self):
        super(ConvMNIST, self).__init__()
        self.input_layers = torch.nn.Sequential(
            torch.nn.LazyConv2d(8, (3, 3)),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d((2, 2)),
            torch.nn.Dropout(0.2),
        )
        self.hidden_layer1 = torch.nn.Sequential(
            torch.nn.LazyConv2d(16, (3, 3)),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d((2, 2)),
            torch.nn.Dropout(0.2),
        )
        self.hidden_layer2 = torch.nn.Sequential(
            torch.nn.Flatten(
                1,
            ),
            torch.nn.LazyLinear(100),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5),
        )
        self.out_layer = torch.nn.Sequential(
            torch.nn.LazyLinear(10),
            torch.nn.Softmax(-1),
        )
        self.all_layers = torch.nn.Sequential(
            self.input_layers, self.hidden_layer1, self.hidden_layer2, self.out_layer
        )

    def forward(self, input):
        return self.all_layers(input)

In [None]:
model = ConvMNIST()
print(model)

In [None]:
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
import tqdm

loss_function = CrossEntropyLoss()

In [None]:
model(digit.reshape(1, 1, 28, 28) / 256.0)

## Step 5 - Train the model

In [None]:
def evaluate(model, loader):
    ce_loss_tot = 0.0
    with torch.no_grad():
        for data in loader:
            digits, labels = data
            digits = digits / 256.0
            prediction = model(digits)
            ce_loss_tot += loss_function(prediction, labels).item() / len(loader)
    return ce_loss_tot

In [None]:
model = ConvMNIST()
optimizer = Adam(model.parameters(), lr=0.01, betas=(0.9, 0.999))

In [None]:
train_loss_epoch = evaluate(model, train_loader)
test_loss_epoch = evaluate(model, test_loader)
print(train_loss_epoch, test_loss_epoch)

Predictions for untrained model:

In [None]:
y_softmax = model(x_test.reshape(18000, 1, 28, 28))
y_predictions = torch.argmax(y_softmax, axis=-1)
misclassified_indices = torch.where((y_predictions - y_test.squeeze()) != 0)
print(
    confusion_matrix(y_test, y_pred=y_predictions)
)  # order matters! (actual, predicted)

In [None]:
epochs = 16
model.train = True
epochs_bar = tqdm.tqdm(range(epochs))
train_loss = []
val_loss = []
for epoch in epochs_bar:
    train_loss_epoch = 0.0
    for train_data in train_loader:
        digits, labels = train_data
        digits = digits / 256.0
        prediction = model(digits)
        cross_entropy_loss = loss_function(prediction, labels)
        optimizer.zero_grad()
        cross_entropy_loss.backward()
        optimizer.step()
        train_loss_epoch += cross_entropy_loss.detach().item() / len(train_loader)

    train_loss.append(train_loss_epoch)
    with torch.no_grad():
        test_loss_epoch = evaluate(model, test_loader)
        val_loss.append(test_loss_epoch)
    epochs_bar.set_description(
        f"train loss: {train_loss_epoch:.4f}, val loss: {test_loss_epoch:.4f}"
    )

In [None]:
model.train = False

In [None]:
plt.plot(train_loss)
plt.plot(val_loss)
plt.xlabel("epoch")
plt.ylabel("CrossEntropyLoss")

## Step 6 - Evaluate

In [None]:
score = evaluate(model, test_loader)

In [None]:
x_test.shape

In [None]:
errors = [i for i in range(len(x_test)) if y_predictions[i] != y_test[i]]

In [None]:
y_softmax = model(x_test.reshape(18000, 1, 28, 28))
y_predictions = torch.argmax(y_softmax, axis=-1)
misclassified_indices = torch.where((y_predictions - y_test.squeeze()) != 0)
print(
    confusion_matrix(y_test, y_pred=y_predictions)
)  # order matters! (actual, predicted)

In [None]:
plt.figure(figsize=(15, 15))
for i in range(25):
    plt.subplot(5, 5, i + 1)
    j = torch.randint(low=0, high=18000, size=(1,))
    digit = x_test[j]
    y_pred = model(digit.reshape(1, 1, 28, 28)).argmax()
    label = y_test[j].item()
    plt.imshow(digit.reshape(28, 28, 1))
    if int(label) == y_pred:
        plt.title(f"True: {int(label)}, pred: {y_pred}, v", color="g")
    else:
        plt.title(f"True: {int(label)}, pred: {y_pred}, x", color="r")

plt.tight_layout()

In [None]:
plt.figure(figsize=(15, 15))
for i in range(25):
    plt.subplot(5, 5, i + 1)
    j = misclassified_indices[0][i]
    digit = x_test[j]
    # y_pred = model(digit.reshape(1, 1, 28, 28)).argmax()
    label = y_test[j].item()
    plt.imshow(digit.reshape(28, 28, 1))
    plt.title(f"True: {label}, pred: {y_predictions[j]}, {y_softmax[j].max():.4f}")
plt.tight_layout()

In [None]:
y_pred