In [None]:
%matplotlib inline
import collections

import matplotlib.pyplot as plt
import torch
import torchmetrics
import torchsummary
import torchvision

from pytorch_models.utils.dataset import get_loader, sample_first
from pytorch_models.utils.metrics import plot_metric, pretty_print_metrics
from pytorch_models.utils.train_validation import train, validate_one_epoch

# Constants


In [None]:
# Data
dataset_location: str = "../data"
batch_size: int = 256
train_validation_split: float = 0.7

# Torch
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Training
epochs: int = 20

# Load data

Load the MNIST dataset from torchvision and apply padding and normalisation as part of the transform.


In [None]:
transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Pad(2),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(0.5, 0.5),
    ]
)

In [None]:
train_validation_data = torchvision.datasets.MNIST(
    dataset_location, transform=transform, download=True
)
train_data, validation_data = torch.utils.data.random_split(
    train_validation_data, [train_validation_split, 1 - train_validation_split]
)
test_data = torchvision.datasets.MNIST(
    dataset_location, train=False, transform=transform, download=True
)

num_classes = len(train_validation_data.classes)

In [None]:
train_loader = get_loader(train_data, batch_size)
validation_loader = get_loader(validation_data, batch_size)
test_loader = get_loader(test_data, batch_size)

In [None]:
image, label = sample_first(train_loader, train_validation_data.classes)
print(f"Class: {label}")
plt.imshow(image);

# Original

## Training


In [None]:
from original.lenet import LeNet5

# Model
model = LeNet5().to(device)
torchsummary.summary(model, (1, 32, 32))

In [None]:
from original.manual_scheduler import ManualLRScheduler

# Optimizer and scheduler
learning_rates: list[float] = [5e-4, 2e-4, 1e-4, 5e-5, 1e-5]
counts: list[int] = [2, 3, 3, 4]

manual_lr_scheduler = ManualLRScheduler(learning_rates, counts)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rates[0])
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, manual_lr_scheduler.step)

In [None]:
class ArgminAccuracy(torchmetrics.Metric):
    correct: torch.Tensor

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
        if len(preds) != len(target):
            raise ValueError("preds and target must have the same length")

        self.correct += (torch.argmin(preds, dim=1) == target).sum()
        self.total += target.numel()

    def compute(self) -> torch.Tensor:
        return self.correct.float() / self.total

In [None]:
train_history = collections.defaultdict(list)
validation_history = collections.defaultdict(list)

train_metrics = torchmetrics.MetricCollection(
    {
        "accuracy": ArgminAccuracy(),
    }
).to(device)
validation_metrics = train_metrics.clone()

train(
    model,
    optimizer,
    scheduler,
    train_loader,
    train_history,
    validation_loader,
    validation_history,
    epochs,
    LeNet5.loss,
    train_validation_data.classes,
    train_metrics,
    validation_metrics,
    device,
)

In [None]:
plot_metric({"Training": train_history, "Validation": validation_history}, "loss")

In [None]:
plot_metric({"Training": train_history, "Validation": validation_history}, "accuracy")

## Testing


In [None]:
test_metrics = train_metrics.clone()

cifar_test_loss = validate_one_epoch(
    model, test_loader, LeNet5.loss, num_classes, test_metrics, device, "Testing"
)

test_history = {"loss": [cifar_test_loss]} | {
    metric: [history.to("cpu")] for metric, history in test_metrics.compute().items()
}
pretty_print_metrics(test_history, train_validation_data.classes)

# Modern

## Training


In [None]:
from modern.lenet import LeNet5

# Model
model = LeNet5().to(device)
torchsummary.summary(model, (1, 32, 32))
loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")

In [None]:
# Optimizer
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)

In [None]:
train_history = collections.defaultdict(list)
validation_history = collections.defaultdict(list)

train_metrics = torchmetrics.MetricCollection(
    {
        "accuracy": torchmetrics.Accuracy(
            "multiclass",
            num_classes=num_classes,
            average="micro",
        ),
    }
).to(device)
validation_metrics = train_metrics.clone()

train(
    model,
    optimizer,
    scheduler,
    train_loader,
    train_history,
    validation_loader,
    validation_history,
    epochs,
    loss_fn,
    train_validation_data.classes,
    train_metrics,
    validation_metrics,
    device,
)

In [None]:
plot_metric({"Training": train_history, "Validation": validation_history}, "loss")

In [None]:
plot_metric({"Training": train_history, "Validation": validation_history}, "accuracy")

In [None]:
test_metrics = train_metrics.clone()

cifar_test_loss = validate_one_epoch(
    model, test_loader, loss_fn, num_classes, test_metrics, device, "Testing"
)

test_history = {"loss": [cifar_test_loss]} | {
    metric: [history.to("cpu")] for metric, history in test_metrics.compute().items()
}
pretty_print_metrics(test_history, train_validation_data.classes)

## Feature Maps


In [None]:
feature_maps = []
image_data, _ = next(iter(test_loader))

with torch.inference_mode():
    model(image_data.to(device), feature_maps)

In [None]:
def display_feature_map(feature_map: torch.Tensor, num_cols: int):
    num_maps = len(feature_map)
    num_rows = (num_maps + num_cols - 1) // num_cols

    fig, ax = plt.subplots(num_rows, num_cols)
    for i, image in enumerate(feature_map):
        ax[i // num_cols, i % num_cols].imshow(image.to("cpu"))
        ax[i // num_cols, i % num_cols].axis("off")

    plt.show()

In [None]:
def display_all(image_data: torch.Tensor, feature_maps: list[torch.Tensor], i: int):
    ax = plt.gca()
    ax.axis("off")
    plt.imshow(image_data[i].squeeze())
    display_feature_map(feature_maps[0][i], 3)
    display_feature_map(feature_maps[1][i], 4)

In [None]:
display_all(image_data, feature_maps, 0)