In [None]:
from datetime import datetime
import random
import os

import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torcheval.metrics import Mean, MulticlassAccuracy
import torchvision
from tqdm import tqdm

from models import ResNet32

In [None]:
# Download datasets

train_dataset = torchvision.datasets.CIFAR10(root="data", train=True, download=True, transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10(root="data", train=False, download=True, transform=torchvision.transforms.ToTensor())

In [None]:
# Set device

device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

In [None]:
# Helper functions for visualizing predictions

to_pil_image = torchvision.transforms.ToPILImage()


def predict_with_probs(model, x):
    """
    Generates predictions and corresponding probabilities from a trained
    network and a list of images
    """
    logits = model(x)
    output = F.softmax(logits, dim=1)
    probs, preds = torch.max(output, 1)
    return preds, probs


def plot_classes_preds(model, x, y, classes):
    """
    Generates matplotlib Figure using a trained network, along with images
    and labels from a batch, that shows the network's top prediction along
    with its probability, alongside the actual label, coloring this
    information based on whether the prediction was correct or not.
    Uses the "images_to_probs" function.
    """
    with torch.no_grad():
        preds, probs = predict_with_probs(model, x.to(device))
    preds = preds.cpu().numpy()
    probs = probs.cpu().numpy()
    # Plot the images in the batch, along with predicted and true labels
    fig, axs = plt.subplots(1, 4, figsize=(10, 3))
    for i, ax in enumerate(axs):
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        ax.imshow(to_pil_image(x[i]))
        ax.set_title(
            "{0}, {1:.1%}\n(actual: {2})".format(
                classes[preds[i]],
                probs[i],
                classes[y[i]]),
            color=("green" if preds[i]==y[i].item() else "red")
        )
    return fig

In [None]:
def train(loader, model, loss_fn, optimizer, epoch, writer, loss_metric, accuracy_metric):
    model.train()
    with tqdm(loader) as progress:
        for x, y in progress:
            x = x.to(device)
            y = y.to(device)

            pred = model(x)
            loss = loss_fn(pred, y)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            loss_metric.update(loss.detach())
            accuracy_metric.update(pred, y)

            progress.set_postfix_str(
                f"loss={loss_metric.compute().item():.4f}, accuracy={accuracy_metric.compute().item():.2%}",
                refresh=False
            )

    writer.add_scalar("loss/train", scalar_value=loss_metric.compute(), global_step=epoch)
    loss_metric.reset()
    writer.add_scalar("accuracy/train", scalar_value=accuracy_metric.compute(), global_step=epoch)
    accuracy_metric.reset()

In [None]:
def test(loader, model, loss_fn, epoch, writer, loss_metric, accuracy_metric):
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            pred = model(x)
            loss = loss_fn(pred, y)

            loss_metric.update(loss)
            accuracy_metric.update(pred, y)
    
    print(f"test_loss={loss_metric.compute().item():.4f}, test_accuracy={accuracy_metric.compute().item():.2%}")

    writer.add_scalar("loss/test", scalar_value=loss_metric.compute(), global_step=epoch)
    loss_metric.reset()
    writer.add_scalar("accuracy/test", scalar_value=accuracy_metric.compute(), global_step=epoch)
    accuracy_metric.reset()

    x, y = zip(*random.choices(loader.dataset, k=4))
    x = torch.stack(x)
    y = torch.tensor(y)
    writer.add_figure("predictions", plot_classes_preds(model, x, y, loader.dataset.classes), global_step=epoch)

In [None]:
def training_run(train_dataset, test_dataset, epochs):
    # Make data loaders
    batch_size = 64
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)

    # Build model
    model = ResNet32(num_classes=len(train_dataset.classes)).to(device)

    # Set loss function and optimizer
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=1e-4, momentum=0.9, nesterov=True)

    # Make TensorBoard writer
    now = datetime.now().strftime("%b%d_%H-%M-%S")
    log_dir = os.path.join("runs", "train", now)
    writer = SummaryWriter(log_dir=log_dir)

    # Define metrics
    loss_metric = Mean(device=device)
    accuracy_metric = MulticlassAccuracy(device=device)

    # Train the model
    print("Training model")
    for epoch in range(epochs):
        print("-------------------------------")
        print(f"Epoch {epoch}")
        train(train_loader, model, loss_fn, optimizer, epoch, writer, loss_metric, accuracy_metric)
        test(test_loader, model, loss_fn, epoch, writer, loss_metric, accuracy_metric)
    print("-------------------------------")
    print("Done!")

    return model

In [None]:
epochs = 10
model = training_run(train_dataset, test_dataset, epochs)

In [None]:
# Predict

model.eval()
x, y = next(iter(test_dataset))
with torch.no_grad():
    x = x.to(device)
    pred = model(x[None])[0]
    predicted, actual = test_dataset.classes[pred.argmax(0)], test_dataset.classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')