In [None]:
%matplotlib inline
import collections
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
import torch
import torchmetrics
import torchsummary
import torchvision

from pytorch_models.utils.augments import AddGaussianNoise
from pytorch_models.utils.dataset import get_loader, sample_first
from pytorch_models.utils.metrics import plot_metric, pretty_print_metrics
from pytorch_models.utils.train_validation import train, validate_one_epoch

In [None]:
from pytorch_models.CNN.inception_v3 import InceptionV3 as Model
from pytorch_models.CNN.inception_v3 import LabelSmoothing

# Constants


In [None]:
# Data
dataset_location: str = "./data"

# Torch
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# MNIST


In [None]:
torch.manual_seed(0)

# Constants
# Dataset
mnist_train_validation_split: float = 0.7
mnist_batch_size: int = 32

# Training
mnist_epochs: int = 10
mnist_auxiliary_loss_weight: float = 0.3
mnist_learning_rate: float = 1e-4
mnist_label_smoothing_factor: float = 1e-4
mnist_optimizer_kwargs: dict[str, Any] = {"weight_decay": 0.01}

In [None]:
mnist_transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(0.1307, 0.3015),
        torchvision.transforms.Resize((299, 299), antialias=True),
        torchvision.transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
    ]
)

In [None]:
mnist_train_validation_data = torchvision.datasets.MNIST(
    dataset_location, transform=mnist_transform, download=True
)

train_idx, val_idx = train_test_split(
    np.arange(len(mnist_train_validation_data)),
    train_size=mnist_train_validation_split,
    random_state=333,
    shuffle=True,
    stratify=mnist_train_validation_data.targets,
)
mnist_train_data = torch.utils.data.Subset(mnist_train_validation_data, train_idx)
mnist_validation_data = torch.utils.data.Subset(mnist_train_validation_data, val_idx)

mnist_test_data = torchvision.datasets.MNIST(
    dataset_location, train=False, transform=mnist_transform, download=True
)

num_classes = len(mnist_train_validation_data.classes)

In [None]:
mnist_train_loader = get_loader(mnist_train_data, mnist_batch_size)
mnist_validation_loader = get_loader(mnist_validation_data, mnist_batch_size)
mnist_test_loader = get_loader(mnist_test_data, mnist_batch_size)

In [None]:
image, label = sample_first(mnist_train_loader, mnist_train_validation_data.classes)

print(f"Class: {label}")
image = torch.clamp(
    image.permute(1, 2, 0) * 0.3015 + 0.1307, 0, 1
)  # Convert to visible image

plt.imshow(image);

In [None]:
# Model
mnist_model = Model(num_classes).to(device)
torchsummary.summary(mnist_model, (3, 299, 299))

In [None]:
mnist_loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
label_smoothed_mnist_loss_fn = LabelSmoothing(
    mnist_loss_fn, mnist_label_smoothing_factor
)

# Optimizer
optimizer = torch.optim.Adam(
    mnist_model.parameters(), mnist_learning_rate, **mnist_optimizer_kwargs
)

In [None]:
mnist_train_history = collections.defaultdict(list)
mnist_validation_history = collections.defaultdict(list)

mnist_train_metrics = torchmetrics.MetricCollection(
    {
        "accuracy": torchmetrics.Accuracy(
            "multiclass",
            num_classes=num_classes,
            average="micro",
        ),
        "precision": torchmetrics.Precision(
            "multiclass",
            num_classes=num_classes,
            average=None,
        ),
        "recall": torchmetrics.Recall(
            "multiclass",
            num_classes=num_classes,
            average=None,
        ),
        "f1 score": torchmetrics.F1Score(
            "multiclass",
            num_classes=num_classes,
            average=None,
        ),
    }
).to(device)
mnist_validation_metrics = mnist_train_metrics.clone()

train(
    mnist_model,
    optimizer,
    None,
    mnist_train_loader,
    mnist_train_history,
    mnist_validation_loader,
    mnist_validation_history,
    mnist_epochs,
    label_smoothed_mnist_loss_fn,
    mnist_train_validation_data.classes,
    mnist_train_metrics,
    mnist_validation_metrics,
    device,
    mnist_auxiliary_loss_weight,
)

In [None]:
plot_metric(
    {"Training": mnist_train_history, "Validation": mnist_validation_history},
    metric="loss",
)

In [None]:
plot_metric(
    {"Training": mnist_train_history, "Validation": mnist_validation_history},
    metric="accuracy",
)

In [None]:
mnist_test_metrics = mnist_train_metrics.clone()

cifar_test_loss = validate_one_epoch(
    mnist_model,
    mnist_test_loader,
    mnist_loss_fn,
    num_classes,
    mnist_test_metrics,
    device,
    "Testing",
)

In [None]:
mnist_test_history = {"loss": [cifar_test_loss]} | {
    metric: [history.to("cpu")]
    for metric, history in mnist_test_metrics.compute().items()
}

pretty_print_metrics(mnist_test_history, mnist_train_validation_data.classes)

# CIFAR-100


In [None]:
torch.manual_seed(333)

# Constants
# Dataset
cifar_train_validation_split: float = 0.7
cifar_batch_size: int = 32

# Training
cifar_epochs: int = 75
cifar_auxiliary_loss_weight: float = 0.3
cifar_learning_rate: float = 1e-4
cifar_label_smoothing_factor: float = 0.1
cifar_optimizer_kwargs: dict[str, Any] = {"weight_decay": 0.01}

In [None]:
cifar_transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            [0.5071, 0.4865, 0.4409], [0.2009, 0.1984, 0.2023]
        ),
        torchvision.transforms.Resize((299, 299), antialias=True),
    ]
)

cifar_transforms_with_augmentations = torchvision.transforms.Compose(
    [
        cifar_transforms,
        torchvision.transforms.RandomResizedCrop(299, antialias=True),
        torchvision.transforms.RandomHorizontalFlip(),
        AddGaussianNoise(0, 0.01),
    ]
)

In [None]:
cifar_train_data = torchvision.datasets.CIFAR100(
    dataset_location, transform=cifar_transforms_with_augmentations, download=True
)
cifar_validation_data = torchvision.datasets.CIFAR100(
    dataset_location, transform=cifar_transforms, download=True
)
cifar_classes = cifar_train_data.classes

train_idx, val_idx = train_test_split(
    np.arange(len(cifar_train_data)),
    train_size=cifar_train_validation_split,
    random_state=333,
    shuffle=True,
    stratify=cifar_train_data.targets,
)
cifar_train_data = torch.utils.data.Subset(cifar_train_data, train_idx)
cifar_validation_data = torch.utils.data.Subset(cifar_validation_data, val_idx)

cifar_test_data = torchvision.datasets.CIFAR100(
    dataset_location, False, transform=cifar_transforms, download=True
)

In [None]:
cifar_train_loader = get_loader(cifar_train_data, cifar_batch_size)
cifar_validation_loader = get_loader(cifar_validation_data, cifar_batch_size)
cifar_test_loader = get_loader(cifar_test_data, cifar_batch_size)

In [None]:
image, label = sample_first(cifar_train_loader, cifar_classes)

print(f"Class: {label}")
image = torch.clamp(
    image.permute(1, 2, 0) * torch.tensor([0.2009, 0.1984, 0.2023])
    + torch.tensor([0.5071, 0.4865, 0.4409]),
    0,
    1,
)  # Convert to visible image

plt.imshow(image);

In [None]:
# Model
cifar_model = Model(num_classes).to(device)
torchsummary.summary(cifar_model, (3, 299, 299))

In [None]:
cifar_loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
label_smoothed_cifar_loss_fn = LabelSmoothing(
    cifar_loss_fn, cifar_label_smoothing_factor
)

# Optimizer
optimizer = torch.optim.Adam(
    cifar_model.parameters(), cifar_learning_rate, **cifar_optimizer_kwargs
)

In [None]:
cifar_train_history = collections.defaultdict(list)
cifar_validation_history = collections.defaultdict(list)

cifar_train_metrics = torchmetrics.MetricCollection(
    {
        "accuracy": torchmetrics.Accuracy(
            "multiclass",
            num_classes=num_classes,
            average="micro",
        ),
        "precision": torchmetrics.Precision(
            "multiclass",
            num_classes=num_classes,
            average=None,
        ),
        "recall": torchmetrics.Recall(
            "multiclass",
            num_classes=num_classes,
            average=None,
        ),
        "f1 score": torchmetrics.F1Score(
            "multiclass",
            num_classes=num_classes,
            average=None,
        ),
    }
).to(device)
cifar_validation_metrics = cifar_train_metrics.clone()

train(
    cifar_model,
    optimizer,
    None,
    cifar_train_loader,
    cifar_train_history,
    cifar_validation_loader,
    cifar_validation_history,
    cifar_epochs,
    cifar_loss_fn,
    cifar_classes,
    cifar_train_metrics,
    cifar_validation_metrics,
    device,
    cifar_auxiliary_loss_weight,
)

In [None]:
plot_metric(
    {"Training": cifar_train_history, "Validation": cifar_validation_history},
    metric="loss",
)

In [None]:
plot_metric(
    {"Training": cifar_train_history, "Validation": cifar_validation_history},
    metric="accuracy",
)

In [None]:
cifar_test_metrics = cifar_train_metrics.clone()

cifar_test_loss = validate_one_epoch(
    cifar_model,
    cifar_test_loader,
    cifar_loss_fn,
    num_classes,
    cifar_test_metrics,
    device,
    "Testing",
)

In [None]:
cifar_test_history = {"loss": [cifar_test_loss]} | {
    metric: [history.to("cpu")]
    for metric, history in cifar_test_metrics.compute().items()
}

pretty_print_metrics(cifar_test_history, cifar_classes)