In [None]:
%matplotlib inline
import collections
from typing import Callable

import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import torch
import torchsummary
import torchvision
import tqdm.notebook

# Constants

In [None]:
# Data
dataset_location: str = "../data"
batch_size: int = 256
train_validation_split: float = 0.7

# Torch
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Training
epochs: int = 20

# Load data
Load the MNIST dataset from torchvision and apply padding and normalisation as part of the transform.

In [None]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.Pad(2),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(0.5, 0.5)    
])

In [None]:
train_validation_data = torchvision.datasets.MNIST(dataset_location, transform=transform, download=True)
train_data, validation_data = torch.utils.data.random_split(train_validation_data, [train_validation_split, 1 - train_validation_split])
test_data = torchvision.datasets.MNIST(dataset_location, train=False, transform=transform, download=True)

num_classes = len(train_validation_data.classes)

In [None]:
def get_loader(dataset: torch.utils.data.Dataset) -> torch.utils.data.DataLoader:
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

train_loader = get_loader(train_data)
validation_loader = get_loader(validation_data)
test_loader = get_loader(test_data)

In [None]:
def get_sample() -> tuple[torch.Tensor, str]:
    data = next(iter(train_loader))
    return data[0][0].squeeze(0), train_validation_data.classes[data[1][0]]

image, label = get_sample()
print(f"Class: {label}")
plt.imshow(image);

# Original
## Training

In [None]:
from original.lenet import LeNet5

# Model
model = LeNet5().to(device)
torchsummary.summary(model, (1, 32, 32))

In [None]:
from original.manual_scheduler import ManualLRScheduler

# Optimizer and scheduler
learning_rates: list[float] = [5e-4, 2e-4, 1e-4, 5e-5, 1e-5]
counts: list[int] = [2, 3, 3, 4]

manual_lr_scheduler = ManualLRScheduler(learning_rates, counts)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rates[0])
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, manual_lr_scheduler.step)

In [None]:
def train_step(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler,
    train_loader: torch.utils.data.DataLoader,
    loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    accuracy_fn: Callable[[torch.Tensor, torch.Tensor], float],
    tqdm_description: str = ""
) -> tuple[float, float]: 
    training_loss = training_accuracy = 0
    for data, targets in tqdm.tqdm(train_loader, desc=tqdm_description, ncols=100):
        data = data.to(device)
        targets = targets.to(device)
        y = torch.nn.functional.one_hot(targets, num_classes).float()

        # Forward pass
        optimizer.zero_grad()
        y_pred = model(data)
        loss = loss_fn(y_pred, y).to("cpu")

        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Metrics
        training_loss += loss.item()
        training_accuracy += accuracy_fn(y_pred, targets)
    scheduler.step()
    return training_loss / len(train_loader.dataset), training_accuracy / len(train_loader.dataset)


In [None]:
@torch.inference_mode()
def validation_step(
    model: torch.nn.Module,
    validation_loader: torch.utils.data.DataLoader,
    loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    accuracy_fn: Callable[[torch.Tensor, torch.Tensor], float],
    tqdm_description: str = "",
) -> tuple[float, float]: 
    validation_loss = validation_accuracy = 0
    for data, targets in tqdm.tqdm(validation_loader, desc=tqdm_description, ncols=100):
        data = data.to(device)
        targets = targets.to(device)
        y = torch.nn.functional.one_hot(targets, num_classes).float()

        # Forward pass
        y_pred = model(data)
        loss = loss_fn(y_pred, y)
        
        # Metrics
        validation_loss += loss.item()
        validation_accuracy += accuracy_fn(y_pred, targets)
    return (
        validation_loss / len(validation_loader.dataset), 
        validation_accuracy / len(validation_loader.dataset)
    )

In [None]:
def train(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler,
    train_loader: torch.utils.data.DataLoader,
    train_history: dict[str, list[float]],
    validation_loader: torch.utils.data.DataLoader,
    validation_history: dict[str, list[float]],
    epochs: int,
    loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    accuracy_fn: Callable[[torch.Tensor, torch.Tensor], float]
) -> None:
    for epoch in range(1, epochs + 1):
        training_loss, training_accuracy = train_step(
            model,
            optimizer,
            scheduler,
            train_loader,
            loss_fn,
            accuracy_fn,
            f"Training epoch {epoch}/{epochs}"
        )
        print(f"Loss: {training_loss:.2f}, Accuracy: {training_accuracy:.2%}", flush=True)
        train_history["loss"].append(training_loss)
        train_history["accuracy"].append(training_accuracy)

        validation_loss, validation_accuracy = validation_step(
            model,
            validation_loader,
            loss_fn,
            accuracy_fn,
            f"Validating epoch {epoch}/{epochs}"
        )
        print(f"Loss: {validation_loss:.2f}, Accuracy: {validation_accuracy:.2%}", flush=True)
        validation_history["loss"].append(validation_loss)
        validation_history["accuracy"].append(validation_accuracy)

In [None]:
train_history = collections.defaultdict(list)
validation_history = collections.defaultdict(list)

def accuracy_fn(preds: torch.Tensor, targets:torch.Tensor) -> float:
    return (torch.argmin(preds, dim=1) == targets).sum().item()

train(
    model,
    optimizer,
    scheduler,
    train_loader,
    train_history,
    validation_loader,
    validation_history,
    epochs,
    LeNet5.loss,
    accuracy_fn
)

In [None]:
def plot_metric(histories: dict[str, dict[str, list[float]]], metric: str):
    ax = plt.figure().gca()

    for name, history in histories.items():
        plt.plot(range(1, len(history[metric]) + 1), history[metric], ".-", label=name.capitalize())

    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    plt.ylabel(metric.capitalize())
    plt.xlabel("Epochs")
    plt.legend()

    plt.show()


In [None]:
plot_metric({"Training": train_history, "Validation": validation_history} , "loss")

In [None]:
plot_metric({"Training": train_history, "Validation": validation_history} , "accuracy")

## Testing

In [None]:
test_loss, test_accuracy = validation_step(model, test_loader, LeNet5.loss, accuracy_fn, "Testing")
print(f"Testing Loss: {test_loss:.2f}, Testing accuracy: {test_accuracy:.2%}")

# Modern
## Training

In [None]:
from modern.lenet import LeNet5

# Model
model = LeNet5().to(device)
torchsummary.summary(model, (1, 32, 32))

In [None]:
# Optimizer
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)

In [None]:
train_history = collections.defaultdict(list)
validation_history = collections.defaultdict(list)

loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
def accuracy_fn(preds: torch.Tensor, targets:torch.Tensor) -> float:
    return (torch.argmax(preds, dim=1) == targets).sum().item()

train(
    model,
    optimizer,
    scheduler,
    train_loader,
    train_history,
    validation_loader,
    validation_history,
    epochs,
    loss_fn,
    accuracy_fn
)

In [None]:
plot_metric({"Training": train_history, "Validation": validation_history} , "loss")

In [None]:
plot_metric({"Training": train_history, "Validation": validation_history} , "accuracy")

In [None]:
test_loss, test_accuracy = validation_step(model, test_loader, loss_fn, accuracy_fn, "Testing")
print(f"Testing Loss: {test_loss:.2f}, Testing accuracy: {test_accuracy:.2%}")