This notebook is used to for to train and build the model that our webapp will ultimately serve to users.

Import required libraries.

In [11]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, ConcatDataset
from tqdm import tqdm

Define neural network architecture.

In [12]:
class DigitClassifier(nn.Module):
    def __init__(self):
        super(DigitClassifier, self).__init__()
        self.model = nn.Sequential(
            # Input size (1, 28, 28)
            # First convolution block
            nn.Conv2d(1, 32, kernel_size=3, stride=1), # (16, 26, 26) - kernel size of 3 reduces spatial dimensions by 2
            nn.BatchNorm2d(32),
            nn.ReLU(),
            # nn.MaxPool2d(2),  # (16, 13, 13) - max-pooling with kernel size of 2 reduces spatial dimensions by half
            # Second convolutional block
            nn.Conv2d(32, 64, kernel_size=3, stride=1), # (32, 24, 24) - kernel size of 3 reduces spatial dimensions by 2
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),  # (64, 12, 12) - max-pooling with kernel size of 2 reduces spatial dimensions by half
            # Linear head
            nn.Dropout2d(0.7),
            nn.Flatten(),
            # Simple linear layer in linear head
            nn.Linear(64*12*12, 128),
            nn.ReLU(),
            # Final linear layer in head
            nn.Dropout(0.5),
            nn.Linear(128, 10), # compress 128 pixels (12*12) into 10 outputs for each digit,
            nn.LogSoftmax(dim=1)
        )
        
    def forward(self, x):
        return self.model(x)

Load MNIST dataset

In [13]:
mnist_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

data_augmentation = transforms.Compose([
    transforms.RandomRotation(degrees=15),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    mnist_transform
])

original_train_dataset = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=data_augmentation
)
augmented_train_dataset = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=data_augmentation
)
train_dataset = ConcatDataset([original_train_dataset, augmented_train_dataset])
test_dataset = datasets.MNIST(
    root='./data',
    train=False,
    download=True,
    transform=mnist_transform
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=8)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=8)

Define training and testing functions.

In [14]:
def train(model, device, train_loader, optimizer, criterion):
    model.train()
    total_correct = 0
    total_samples = 0
    train_loss = 0.0

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        

        train_loss += loss.item() * data.size(0)
        
        _, pred = torch.max(output, 1)
        total_correct += (pred == target).sum().item()
        total_samples += data.size(0)
    
    train_loss /= len(train_loader.dataset)
    accuracy = 100. * total_correct / total_samples
    return train_loss, accuracy

def test(model, device, test_loader, criterion):
    model.eval()
    test_loss = 0.0
    total_correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)

            test_loss += loss.item() * data.size(0)
            _, pred = torch.max(output, 1)
            total_correct += (pred == target).sum().item()
    
    test_loss /= len(test_loader.dataset)
    accuracy = 100. * total_correct / len(test_loader.dataset)
    return test_loss, accuracy


Train and test the model, then plot the train and test losses per epoch.

In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DigitClassifier().to(device)
criterion = nn.NLLLoss(reduction='sum')
optimizer = optim.Adam(model.parameters(), lr=1e-3)

print(f'device: {device.type}, train samples: {len(train_loader.dataset)}, test samples: {len(test_loader.dataset)}')

num_epochs = 100
train_losses = []
train_accuracies = []
test_losses = []
test_accuracies = []
for epoch in tqdm(range(1, num_epochs + 1), desc='Training model...'):
    train_loss, train_accuracy = train(model, device, train_loader, optimizer, criterion)
    test_loss, test_accuracy = test(model, device, test_loader, criterion)

    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)
    test_losses.append(test_loss)
    test_accuracies.append(test_accuracy)
    
    # print(f'Epoch {epoch}: Average Train Loss = {np.mean(train_loss):.4f}, Average Test Loss = {np.mean(test_loss):.4f}')
    # print(f'\tTrain Accuracy = {train_accuracy:.4f}, Test Accuracy = {test_accuracy:.4f}')
print('Complete.')

# Plot loss
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.xlabel('Samples')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss per sample')
plt.show()

# Plot loss
plt.plot(range(1, num_epochs + 1), train_accuracies, label='Train Accuracy')
plt.plot(range(1, num_epochs + 1), test_accuracies, label='Test Accuracy')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy per Epoch')
plt.show()

device: cuda, train samples: 120000, test samples: 10000


Training model...:   0%|          | 0/100 [00:00<?, ?it/s]