In [None]:
%matplotlib inline
import collections
from typing import Callable

import matplotlib.pyplot as plt
import torch
import torchmetrics
import torchsummary
import torchvision
import tqdm.notebook

from model.googlenet import GoogLeNet
from utils.augments import AddGaussianNoise, Clip
from utils.dataset import get_loader, sample_first
from utils.metrics import plot_metric, pretty_print_metrics
from utils.train_validation import train, validate_one_epoch

# Constants

In [None]:
# Data
dataset_location: str = "../data"

# Torch
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# MNIST

In [None]:
mnist_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(0.1307, 0.3015),
    torchvision.transforms.Resize((224, 224), antialias=True),
    torchvision.transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
])

In [None]:
train_validation_split = 0.7

mnist_train_validation_data = torchvision.datasets.MNIST(
    dataset_location,
    transform=mnist_transform,
    download=True
)
mnist_train_data, mnist_validation_data = torch.utils.data.random_split(
    mnist_train_validation_data,
    [train_validation_split, 1 - train_validation_split]
)
mnist_test_data = torchvision.datasets.MNIST(
    dataset_location,
    train=False,
    transform=mnist_transform,
    download=True
)

num_classes = len(mnist_train_validation_data.classes)

In [None]:
batch_size: int = 128
mnist_train_loader = get_loader(mnist_train_data, batch_size)
mnist_validation_loader = get_loader(mnist_validation_data, batch_size)
mnist_test_loader = get_loader(mnist_test_data, batch_size)

In [None]:
image, label = sample_first(mnist_train_loader, mnist_train_validation_data.classes)

print(f"Class: {label}")
image = torch.clamp(image.permute(1, 2, 0) * 0.3015 + 0.1307, 0, 1) # Convert to visible image

plt.imshow(image);

In [None]:
# Model
mnist_model = GoogLeNet(num_classes).to(device)
torchsummary.summary(mnist_model, (3, 224, 224))

In [None]:
loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
mnist_epochs = 20
auxiliary_loss_weight = 0.3

# Optimizer
optimizer = torch.optim.Adam(mnist_model.parameters(), 1e-4)

In [None]:
train_history = collections.defaultdict(list)
validation_history = collections.defaultdict(list)

train_metrics = torchmetrics.MetricCollection({
    "accuracy": torchmetrics.classification.MulticlassAccuracy(num_classes, average="micro"),
    "precision": torchmetrics.classification.MulticlassPrecision(num_classes, average=None),
    "recall": torchmetrics.classification.MulticlassRecall(num_classes, average=None),
    "f1 score": torchmetrics.classification.MulticlassF1Score(num_classes, average=None),
}).to(device)
validation_metrics = train_metrics.clone()

train(
    mnist_model,
    optimizer,
    None,
    mnist_train_loader,
    train_history,
    mnist_validation_loader,
    validation_history,
    mnist_epochs,
    loss_fn, 
    mnist_train_validation_data.classes, 
    train_metrics,
    validation_metrics,
    device,
    auxiliary_loss_weight
)

In [None]:
plot_metric({"Training": train_history, "Validation": validation_history}, metric="loss")

In [None]:
plot_metric({"Training": train_history, "Validation": validation_history}, metric="accuracy")

In [None]:
test_metrics = train_metrics.clone()

test_loss = validate_one_epoch(
    mnist_model, 
    mnist_test_loader,
    loss_fn,
    num_classes,
    test_metrics,
    device,
    "Testing"    
)

In [None]:
test_history = {
    "loss": [test_loss]
} | {metric: [history.to("cpu")] for metric, history in test_metrics.compute().items()}

pretty_print_metrics(test_history, mnist_train_validation_data.classes)

# CIFAR-100

In [None]:
cifar_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.5071, 0.4865, 0.4409], [0.2009, 0.1984, 0.2023]),
    torchvision.transforms.Resize((224, 224), antialias=True),
])

cifar_transforms_with_augmentations = torchvision.transforms.Compose([
    cifar_transforms,
    torchvision.transforms.RandomResizedCrop(224, antialias=True),
    torchvision.transforms.RandomHorizontalFlip(),
    AddGaussianNoise(0, 0.01),
])

In [None]:
train_validation_split = 0.7

cifar_train_validation_data = torchvision.datasets.CIFAR100(
    dataset_location, 
    transform=cifar_transforms,
    download=True
)
cifar_train_data, cifar_validation_data = torch.utils.data.random_split(
    cifar_train_validation_data,
    [train_validation_split, 1 - train_validation_split]
)
cifar_train_data.dataset.transform = cifar_transforms_with_augmentations
cifar_test_data = torchvision.datasets.CIFAR100(
    dataset_location,
    False,
    transform=cifar_transforms,
    download=True
)

num_classes = len(cifar_train_validation_data.classes)

In [None]:
batch_size: int = 64
cifar_train_loader = get_loader(cifar_train_data, batch_size)
cifar_validation_loader = get_loader(cifar_validation_data, batch_size)
cifar_test_loader = get_loader(cifar_test_data, batch_size)

In [None]:
image, label = sample_first(
    cifar_train_loader,
    cifar_train_validation_data.classes
)

print(f"Class: {label}")
image = torch.clamp(
        image.permute(1, 2, 0) 
        * torch.tensor([0.2009, 0.1984, 0.2023]) 
        + torch.tensor([0.5071, 0.4865, 0.4409]),
        0,
        1
    ) # Convert to visible image

plt.imshow(image);

In [None]:
# Model
cifar_model = GoogLeNet(num_classes).to(device)
torchsummary.summary(cifar_model, (3, 224, 224))

In [None]:
loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
cifar_epochs = 100
auxiliary_loss_weight = 0.3

# Optimizer
optimizer = torch.optim.Adam(cifar_model.parameters(), 1e-4, weight_decay=0.01)

In [None]:
train_history = collections.defaultdict(list)
validation_history = collections.defaultdict(list)

train_metrics = torchmetrics.MetricCollection({
    "accuracy": torchmetrics.classification.MulticlassAccuracy(num_classes, average="micro"),
    "precision": torchmetrics.classification.MulticlassPrecision(num_classes, average=None),
    "recall": torchmetrics.classification.MulticlassRecall(num_classes, average=None),
    "f1 score": torchmetrics.classification.MulticlassF1Score(num_classes, average=None),
}).to(device)
validation_metrics = train_metrics.clone()

train(
    cifar_model,
    optimizer,
    None,
    cifar_train_loader,
    train_history,
    cifar_validation_loader,
    validation_history,
    cifar_epochs,
    loss_fn, 
    cifar_train_validation_data.classes, 
    train_metrics,
    validation_metrics,
    device,
    auxiliary_loss_weight
)

In [None]:
plot_metric({"Training": train_history, "Validation": validation_history}, metric="loss")

In [None]:
plot_metric({"Training": train_history, "Validation": validation_history}, metric="accuracy")

In [None]:
test_metrics = train_metrics.clone()

test_loss = validate_one_epoch(
    cifar_model, 
    cifar_test_loader,
    loss_fn,
    num_classes,
    test_metrics,
    device,
    "Testing"    
)

In [None]:
test_history = {
    "loss": [test_loss]
} | {metric: [history.to("cpu")] for metric, history in test_metrics.compute().items()}

pretty_print_metrics(test_history, cifar_train_validation_data.classes)