In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset

from torchvision.datasets import MNIST
import torchvision.transforms as T

from sklearn.model_selection import train_test_split

In [2]:
# parameters
DEVICE = ("cuda:0" if torch.cuda.is_available() else "cpu")
NUM_EPOCHS = 10
ALPHA = 0.01
VAL_SIZE = 10_000
BATCH_SIZE=32

NUM_FEATURES = 28*28
HIDDEN_1 = 100
HIDDEN_2 = 50
NUM_LABELS = 10

In [3]:
train_val_dataset = MNIST(root='../datasets/', train=True, download=True, transform=T.ToTensor())
test_dataset = MNIST(root='../datasets/', train=False, download=False, transform=T.ToTensor())

In [4]:
stratify = train_val_dataset.targets.numpy()
train_idxs, val_idxs = train_test_split(
                                range(len(train_val_dataset)),
                                stratify=stratify,
                                test_size=VAL_SIZE)

In [5]:
train_dataset = Subset(dataset=train_val_dataset, indices=train_idxs)
val_dataset = Subset(dataset=train_val_dataset, indices=val_idxs)

In [6]:
train_dataloader = DataLoader(dataset=train_dataset, 
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=4,
                              drop_last=True)

val_dataloader = DataLoader(dataset=val_dataset, 
                              batch_size=BATCH_SIZE,
                              num_workers=4)

test_dataloader = DataLoader(dataset=test_dataset, 
                              batch_size=BATCH_SIZE,
                              num_workers=4)

In [7]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(NUM_FEATURES, HIDDEN_1),
            nn.ReLU(),
            nn.Linear(HIDDEN_1, HIDDEN_2),
            nn.ReLU(),
            nn.Linear(HIDDEN_2, NUM_LABELS)
        )
        
    def forward(self, features):
        return self.layers(features)

In [8]:
def calculate_performance(model, criterion, dataloader):
    model.eval()
    num_samples = 0
    num_correct = 0
    loss_sum = 0
    
    with torch.inference_mode():
        for batch_idx, (features, labels) in enumerate(dataloader):
            features = features.view(-1, NUM_FEATURES).to(DEVICE)
            labels = labels.to(DEVICE)
            outputs = model(features)
            
            predictions = outputs.max(dim=1)[1]
            num_correct += (predictions == labels).sum().item()
            
            loss = criterion(outputs, labels)
            loss_sum += loss.cpu().item()
            num_samples += len(features)
    return loss_sum/num_samples, num_correct/num_samples

In [13]:
def train(model):
    optimizer = optim.SGD(model.parameters(), lr=ALPHA)
    # combine softmax with cross entropy loss simultaneously, no need to attach softmax to the model
    criterion = nn.CrossEntropyLoss(reduction="sum")
    for epoch in range(NUM_EPOCHS):
        model.train()
        for batch_idx, (features, labels) in enumerate(train_dataloader):
            features = features.view(-1, NUM_FEATURES).to(DEVICE)
            labels = labels.to(DEVICE)

            # empty the gradients
            optimizer.zero_grad()
            # forward pass
            outputs = model(features)
            # calculate loss
            loss = criterion(outputs, labels)
            # backward pass
            loss.backward()
            # clip gradients
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
            optimizer.step()

        val_loss, val_acc = calculate_performance(model, criterion, val_dataloader)
        print(f'Epoch: {epoch+1} | Validation Loss: {val_loss} | Validation Accuracy: {val_acc}')

In [14]:
model = Model().to(DEVICE)

In [15]:
# training goes slower due to gradient clipping
train(model)

Epoch: 1 | Validation Loss: 0.5509428356170655 | Validation Accuracy: 0.8536
Epoch: 2 | Validation Loss: 0.41197297353744505 | Validation Accuracy: 0.8851
Epoch: 3 | Validation Loss: 0.36639982678890226 | Validation Accuracy: 0.8954
Epoch: 4 | Validation Loss: 0.3403232001900673 | Validation Accuracy: 0.9037
Epoch: 5 | Validation Loss: 0.32071058940887454 | Validation Accuracy: 0.9076
Epoch: 6 | Validation Loss: 0.3040276431441307 | Validation Accuracy: 0.9148
Epoch: 7 | Validation Loss: 0.29033055518865586 | Validation Accuracy: 0.9179
Epoch: 8 | Validation Loss: 0.275537036550045 | Validation Accuracy: 0.9198
Epoch: 9 | Validation Loss: 0.26166385180950164 | Validation Accuracy: 0.9264
Epoch: 10 | Validation Loss: 0.24922023360133172 | Validation Accuracy: 0.9283
