In [None]:
%matplotlib inline
import collections

import matplotlib.pyplot as plt
import torch
import torchsummary
import torchvision
import tqdm.notebook

from traditional.lenet import LeNet5
from traditional.manual_scheduler import ManualLRScheduler

# Constants

In [None]:
# Data
dataset_location: str = "../data"
batch_size: int = 256
train_validation_split: float = 0.7

# Torch
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Training
epochs: int = 20

# Load data
Load the MNIST dataset from torchvision and apply padding and normalisation as part of the transform.

In [None]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.Pad(2),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(0.5, 0.5)    
])

In [None]:
train_validation_data = torchvision.datasets.MNIST(dataset_location, transform=transform, download=True)
train_data, validation_data = torch.utils.data.random_split(train_validation_data, [train_validation_split, 1 - train_validation_split])
test_data = torchvision.datasets.MNIST(dataset_location, train=False, transform=transform, download=True)

num_classes = len(train_validation_data.classes)

In [None]:
def get_loader(dataset: torch.utils.data.Dataset) -> torch.utils.data.DataLoader:
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

train_loader = get_loader(train_data)
validation_loader = get_loader(validation_data)
test_loader = get_loader(test_data)

In [None]:
def get_sample() -> tuple[torch.Tensor, str]:
    data = next(iter(train_loader))
    return data[0][0].squeeze(0), train_validation_data.classes[data[1][0]]

image, label = get_sample()
print(f"Class: {label}")
plt.imshow(image);

# Training

In [None]:
# Model
model = LeNet5().to(device)
torchsummary.summary(model, (1, 32, 32))

In [None]:
# Optimizer and scheduler
learning_rates: list[float] = [5e-4, 2e-4, 1e-4, 5e-5, 1e-5]
counts: list[int] = [2, 3, 3, 4]

manual_lr_scheduler = ManualLRScheduler(learning_rates, counts)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rates[0])
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, manual_lr_scheduler.step)

In [None]:
def train_step(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler,
    train_loader: torch.utils.data.DataLoader,
    tqdm_description: str = ""
) -> tuple[float, float]: 
    training_loss = training_accuracy = 0
    for data, targets in tqdm.tqdm(train_loader, desc=tqdm_description, ncols=100):
        data = data.to(device)
        y = torch.nn.functional.one_hot(targets, num_classes)

        # Forward pass
        optimizer.zero_grad()
        y_pred = model(data).to("cpu")
        loss = LeNet5.loss(y_pred, y)

        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Metrics
        training_loss += loss.item()
        training_accuracy += (torch.argmin(y_pred, dim=1) == targets).sum().item()
    scheduler.step()
    return training_loss / len(train_loader.dataset), training_accuracy / len(train_loader.dataset)


In [None]:
def train(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler,
    train_loader: torch.utils.data.DataLoader,
    train_history: dict[str, list[float]],
    validation_loader: torch.utils.data.DataLoader,
    validation_history: dict[str, list[float]],
    epochs: int
) -> None:
    for epoch in range(1, epochs + 1):
        training_loss, training_accuracy = train_step(
            model,
            optimizer,
            scheduler,
            train_loader,
            f"Training epoch {epoch}/{epochs}"
        )
        print(f"Loss: {training_loss}, Accuracy: {100 * training_accuracy:.2f}%", flush=True)
        train_history["loss"].append(training_loss)
        train_history["accuracy"].append(training_accuracy)

        # with torch.no_grad():
        #     for _ in tqdm.tqdm(train_loader, desc=f"Validating epoch {epoch}/{epochs}", ncols=100):
        #         pass

In [None]:
train_history = collections.defaultdict(list)
validation_history = collections.defaultdict(list)

train(
    model,
    optimizer,
    scheduler,
    train_loader,
    train_history,
    validation_loader,
    validation_history,
    epochs
)

In [None]:
if train_history["loss"]:
    plt.plot(range(1, epochs + 1), train_history["loss"], label="Training")
if validation_history["loss"]:
    plt.plot(range(1, epochs + 1), validation_history["loss"], label="Validation")

plt.ylabel("Loss")
plt.xlabel("epochs")
plt.legend()
plt.show()

In [None]:
if train_history["accuracy"]:
    plt.plot(range(1, epochs + 1), train_history["accuracy"], label="Training")
if validation_history["accuracy"]:
    plt.plot(range(1, epochs + 1), validation_history["accuracy"], label="Validation")

plt.ylabel("Accuracy")
plt.xlabel("epochs")
plt.legend()
plt.show()