This notebook is used to for to train and build the model that our webapp will ultimately serve to users.

Import required libraries.

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchmetrics
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
from torch.utils.data import TensorDataset
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from torchvision import datasets, transforms
from torchvision.transforms import functional
from torch.utils.data import DataLoader, ConcatDataset
import gradio as gr
import onnx
import tensorflow_datasets as tfds

torch.set_float32_matmul_precision('high')

Define neural network architecture.

In [None]:
class DigitClassifier(nn.Module):
    def __init__(self):
        super(DigitClassifier, self).__init__()

        self.model = nn.Sequential(
            # Input size (1, 28, 28)
            # Convolutional Base
            # First convolution block
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1), # (32, 28, 28)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # (32, 14, 14)
            # Second convolutional block
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2), # (64, 14, 14)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # (64, 7, 7)

            # Linear head
            nn.Flatten(),
            # First linear block
            nn.Linear(64*7*7, 128),
            nn.ReLU(),
            # Second linear block
            nn.Dropout(0.5),
            nn.Linear(128, 10),
            nn.LogSoftmax(dim=1)
        )
        
    def forward(self, x):
        return self.model(x)

Setup pytorch lightning module for training the digit classifier.

In [None]:
class LitDigitClassifier(pl.LightningModule):
    def __init__(self):
        super(LitDigitClassifier, self).__init__()
        self.model = DigitClassifier()
        self.loss = nn.NLLLoss()
        self.train_losses = []
        self.test_losses = []
        self.epoch_train_accs = []
        self.epoch_test_accs = []
        self.train_acc = torchmetrics.Accuracy(task='multiclass', num_classes=10)
        self.test_acc =  torchmetrics.Accuracy(task='multiclass', num_classes=10)
    
    def forward(self, x):
        return self.model(x)
    
    def on_fit_start(self):
        pl.seed_everything(42, workers=True) 
    
    def training_step(self, batch, batch_idx):
        X, y = batch
        y_pred = self.model(X)
        loss = self.loss(y_pred, y)

        self.train_losses.append(loss.item())
        self.epoch_train_accs.append(self.train_acc(y_pred, y).item())
        
        return loss

    def validation_step(self, batch, batch_idx):
        X, y = batch
        y_pred = self.model(X)
        loss = self.loss(y_pred, y)

        self.test_losses.append(loss.item())
        self.epoch_test_accs.append(self.test_acc(y_pred, y).item())

        self.log('val_loss', loss, logger=False)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.model.parameters(), lr=1e-4)
        return optimizer

    def on_train_epoch_end(self):
        avg_train_loss = np.mean(self.train_losses)
        self.train_losses.clear()
        
        avg_train_acc = np.mean(self.epoch_train_accs)
        self.epoch_train_accs.clear()
        self.train_acc.reset()

        avg_test_loss = np.mean(self.test_losses)
        self.test_losses.clear()
        
        avg_test_acc = np.mean(self.epoch_test_accs)
        self.epoch_test_accs.clear()
        self.test_acc.reset()

        
        self.logger.experiment.add_scalars('loss', { 'train': avg_train_loss, 'test': avg_test_loss }, self.current_epoch)
        self.logger.experiment.add_scalars('accuracy', { 'train': avg_train_acc, 'test': avg_test_acc }, self.current_epoch)

    def on_test_epoch_end(self):
        avg_loss = np.mean(self.test_losses)
        self.test_losses.clear()
        
        avg_acc = np.mean(self.epoch_test_accs)
        self.epoch_test_accs.clear()
        self.test_acc.reset()

        self.logger.experiment.add_scalars('loss', { 'test': avg_loss }, self.current_epoch)
        self.logger.experiment.add_scalars('accuracy', { 'test': avg_acc }, self.current_epoch)


In [None]:
# # Create a new instance of your model architecture
# model = LitDigitClassifier()

# try:
#     # Load the state_dict from the file
#     model.model.load_state_dict(torch.load('model_weights.pth'))
# except:
#     print('Unable to load previous model')

# # Set the model to evaluation mode
# model.model.eval()

In [None]:
class InvertImage(object):
    def __call__(self, x):
        return 1 - x
    
class CorruptedMNISTDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset):
        self.base_dataset = base_dataset

    def __getitem__(self, index):
        image, label = self.base_dataset[index]
        return image, label.item()  # Convert label tensor to int

    def __len__(self):
        return len(self.base_dataset)

def tf_to_torch(dataset):
    images = []
    labels = []
    for example in tfds.as_numpy(dataset):
        image, label = example['image'], example['label']
        image = image.astype(np.float32) / 255
        image = torch.from_numpy(image).permute((2, 0, 1))
        image = functional.normalize(image, (0.1307, ), (0.3081, ))
        images.append(image)
        labels.append(torch.tensor(label))
    return TensorDataset(torch.stack(images), torch.stack(labels))

Load MNIST dataset

In [None]:
corrupted_dataset_paths = [
    'shot_noise',
    'impulse_noise',
    'glass_blur',
    'motion_blur',
    'shear',
    'scale',
    'rotate',
    'brightness',
    'translate',
    'stripe',
    'fog',
    'spatter',
    'dotted_line',
    'zigzag',
    'canny_edges'
]

mnist_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

inv_minst_transform = transforms.Compose([
    transforms.ToTensor(),
    InvertImage(),
    transforms.Normalize((0.1307,), (0.3081,))
])

spatial_mnist_transform = transforms.Compose([
    # transforms.RandomRotation(degrees=15), # rotation can cause some weird issues
    transforms.RandomAffine(degrees=0, translate=(0.2, 0.2), scale=(0.9, 1.1)),
    mnist_transform
])

inv_spatial_mnist_transform = transforms.Compose([
    # transforms.RandomRotation(degrees=15), # rotation can cause some weird issues
    transforms.RandomAffine(degrees=0, translate=(0.2, 0.2), scale=(0.9, 1.1)),
    inv_minst_transform
])

original_train_dataset = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=mnist_transform
)
inv_train_dataset = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=inv_minst_transform
)
spatial_train_dataset = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=spatial_mnist_transform
)
inv_spatial_train_dataset = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=inv_spatial_mnist_transform
)
corrupted_mnist_train_dataset = ConcatDataset(map(lambda x: 
    CorruptedMNISTDataset(tf_to_torch(tfds.load(f'mnist_corrupted/{x}', split='train', shuffle_files=False, download=True, data_dir='./data', with_info=False))),
    corrupted_dataset_paths
))
train_dataset = ConcatDataset([original_train_dataset, inv_train_dataset, spatial_train_dataset, inv_spatial_train_dataset, corrupted_mnist_train_dataset])

original_test_dataset = datasets.MNIST(
    root='./data',
    train=False,
    download=True,
    transform=mnist_transform
)

inv_test_dataset = datasets.MNIST(
    root='./data',
    train=False,
    download=True,
    transform=inv_minst_transform
)

corrupted_mnist_test_dataset = ConcatDataset(map(lambda x: 
    CorruptedMNISTDataset(tf_to_torch(tfds.load(f'mnist_corrupted/{x}', split='test', shuffle_files=False, download=True, data_dir='./data', with_info=False))),
    corrupted_dataset_paths
))

test_dataset = ConcatDataset([original_test_dataset, inv_test_dataset, corrupted_mnist_test_dataset])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)

In [None]:
# Count occurrences of each class
class_counts = np.zeros(10)
for _, label in train_dataset:
    class_counts[label] += 1

# Create a bar graph
class_labels = np.arange(10)
plt.bar(class_labels, class_counts)

plt.xlabel('Class Labels')
plt.ylabel('Counts')
plt.title('Counts of Each Class in the Training Dataset')

plt.show()

# Count occurrences of each class
class_counts = np.zeros(10)
for _, label in test_dataset:
    class_counts[label] += 1

# Create a bar graph
class_labels = np.arange(10)
plt.bar(class_labels, class_counts)

plt.xlabel('Class Labels')
plt.ylabel('Counts')
plt.title('Counts of Each Class in the Test Dataset')

plt.show()

Preview some of the training samples.

In [None]:
# Function to unnormalize and convert tensor to a PIL image
def unnormalize(tensor):
    mean = 0.1307
    std = 0.3081
    img = tensor.clone().detach().numpy()
    img = (img * std) + mean
    img = img.squeeze()
    return img

def preview_dataset(dataset):
    _, axes = plt.subplots(3, 3, figsize=(6, 6))
    for i, ax in enumerate(axes.ravel()):
        idx = torch.randint(0, len(dataset), (1,)).item()
        img, label = dataset[idx]
        img_unnorm = unnormalize(img)
        ax.imshow(img_unnorm, cmap='gray')
        ax.set_title(f'Label: {label}')
        ax.axis('off')

    plt.tight_layout()
    plt.show()

preview_dataset(corrupted_mnist_test_dataset)

Train and test the model, then plot the train and test losses per epoch.

In [None]:
model = LitDigitClassifier()
logger = TensorBoardLogger('lightning_logs', name='mnist')
early_stop_callback = EarlyStopping(monitor='val_loss', min_delta=1e-3, patience=5, verbose=True)
trainer = pl.Trainer(
    max_epochs=200,
    min_epochs=3,
    precision='32',
    devices=torch.cuda.device_count(),
    accelerator="gpu",
    logger=logger,
    callbacks=[early_stop_callback]
)
trainer.fit(model, train_loader, test_loader)

In [None]:
torch.save(model.model.state_dict(), 'model_weights.pth')

In [None]:
# Create a new instance of your model architecture
digitClassifier = DigitClassifier()

# Load the state_dict from the file
digitClassifier.load_state_dict(torch.load('model_weights.pth'))

# Set the model to evaluation mode
digitClassifier.eval()

In [None]:
# Function to unnormalize and convert tensor to a PIL image
def unnormalize(tensor):
    mean = 0.1307
    std = 0.3081
    img = tensor.clone().detach().numpy()
    img = (img * std) + mean
    img = img.squeeze()
    return img

# Display a grid of sample images
num_samples = 9
fig, axes = plt.subplots(3, 3, figsize=(6, 6))

digitClassifier.eval()
for i, ax in enumerate(axes.ravel()):
    idx = torch.randint(0, len(train_dataset), (1,)).item()
    img, label = train_dataset[idx]
    img_unnorm = unnormalize(img)
    with torch.no_grad():
        prediction = torch.exp(digitClassifier(img.unsqueeze(0)))
        predicted_label = torch.argmax(prediction, dim=1).item()
    ax.imshow(img_unnorm, cmap='gray')
    ax.set_title(f'Actual: {label}, Predicted: {predicted_label}')
    ax.axis('off')

plt.tight_layout()
plt.show()

Experiment with gradio interface.

In [None]:
def classify(image):
    if image is None:
        return {str(i): 0 for i in range(10)}
    
    digitClassifier.eval()
    with torch.no_grad():
        X = transforms.ToTensor()(image).unsqueeze(0)
        y = torch.exp(digitClassifier(X)).tolist()[0]
        confidences = {str(i): y[i] for i in range(10)}
    return confidences
    
sketchpad = gr.Sketchpad(shape=(28, 28), invert_colors=False)
label = gr.components.Label(num_top_classes=3)
interface = gr.Interface(classify, sketchpad, label, live=True)

In [None]:
interface.launch()

Finally, export our trained model via ONNX to our webapps `asset` folder.

In [None]:
dummy_input = torch.randn(1, 1, 28, 28)
digitClassifier.eval()

# Export the model
torch.onnx.export(
    digitClassifier,                        # model being run
    # model input (or a tuple for multiple inputs)
    dummy_input,
    # where to save the model (can be a file or file-like object)
    "src/assets/mnist.onnx",
    input_names = ['input'],     # the model's input names
    output_names = ['output'],    # the model's output names
    dynamic_axes = {
        # variable length axes
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    },
    verbose=True,
)

In [None]:
# Load the ONNX model
onnx_model = onnx.load("src/assets/mnist.onnx")

# Check that the model is well formed
onnx.checker.check_model(onnx_model)

# Print a human readable representation of the graph
print(onnx.helper.printable_graph(onnx_model.graph))