In [None]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np

import torch
import torchvision
import torch.nn as nn

from trainer import Trainer

### Utility Functions

In [None]:
def get_device():
    """Get available device"""

    if torch.cuda.is_available():
        print("Using CUDA...")
        return torch.device("cuda")
    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
        print("Using MPS...")
        return torch.device("mps")
    else:
        print("Using CPU...")
        return torch.device("cpu")

def imshow(img):
    """Display CIFAR10 image"""
    
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [None]:
# setup device
device = get_device()

### Load CIFAR-10 dataset

In [None]:
BATCH_SIZE = 256

transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
train_set = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)

# split into 80-20 train and validation dataset
train_set, val_set = torch.utils.data.random_split(train_set, [0.8, 0.2])
train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_dataloader = torch.utils.data.DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

test_set = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
# get one sample
images, labels = train_set[0]

print(f"Input Tensor: {images.shape}")
print(f"Label: {labels}")

# get some random training images
dataiter = iter(train_dataloader)
images, labels = next(dataiter)

# show images
imshow(torchvision.utils.make_grid(images[:4]))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))

### Training Existing ResNet-18 Model
We will call PyTorch's existing implementation the baseline model

In [None]:
baseline_net = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)

# drop last linear layer and fit new linear layer for CIFAR-10
baseline_net.fc = nn.Linear(in_features=512, out_features=10)
baseline_net

In [None]:
# setup trainer class to fine-tune ResNet-18
baseline_trainer = Trainer(baseline_net, model_name="baseline_pytorch_resnet18", batch_size=BATCH_SIZE, device=device)
baseline_trainer.train(train_dataloader, val_dataloader)

In [None]:
baseline_trainer.test(test_dataloader)
baseline_trainer.plot_metrics()

### Training My PyTorch Implementation of ResNet-18

In [None]:
from resnet18 import ResNet18

resnet = ResNet18.from_pretrained("resnet18")
print("lfg, we didn't crash!")
resnet_trainer = Trainer(resnet, device=device)
resnet_trainer.train(train_dataloader, val_dataloader)

In [None]:
resnet_trainer.test(test_dataloader)
resnet_trainer.plot_metrics()

# Acknowledgements

- [PyTorch CIFAR10 Training Tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)
- [A Detailed Introduction to ResNet and Its Implementation in PyTorch](https://medium.com/@freshtechyy/a-detailed-introduction-to-resnet-and-its-implementation-in-pytorch-744b13c8074a) by Huili Yu
- [Let's reproduce GPT-2 (124M)](https://www.youtube.com/watch?v=l8pRSuU81PU) by Andrej Karpathy
- [Helpful conventions for PyTorch model building](https://github.com/FrancescoSaverioZuppichini/Pytorch-how-and-when-to-use-Module-Sequential-ModuleList-and-ModuleDict/blob/master/README.md) by FrancescoSaverioZuppichini  