In [None]:
%load_ext autoreload
%autoreload 2

import os
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from datetime import datetime

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim

from trainer import Trainer

### Utility Functions

In [None]:
def get_device():
    """Get available device"""

    if torch.cuda.is_available():
        print("Using CUDA...")
        return torch.device("cuda")
    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
        print("Using MPS...")
        return torch.device("mps")
    else:
        print("Using CPU...")
        return torch.device("cpu")

def imshow(img):
    """Display CIFAR10 image"""
    
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [None]:
# setup device
device = get_device()

### Load CIFAR-10 dataset

In [None]:
BATCH_SIZE = 4

# dataset loading from https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
test_dataloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
# get one sample
images, labels = trainset[0]

print(f"Input Tensor: {images.shape}")
print(f"Label: {labels}")

# get some random training images
dataiter = iter(train_dataloader)
images, labels = next(dataiter)

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(BATCH_SIZE)))

### Training Existing ResNet-18 Model
We will call PyTorch's existing implementatino the baseline model

In [None]:
# setup trainer class to fine-tune ResNet-18
baseline_net = torchvision.models.resnet18(pretrained=True)
baseline_trainer = Trainer(baseline_net, device=device)

### Training My PyTorch Implementation of ResNet-18

In [None]:
from resnet18 import ResNet

resnet = ResNet(...)
resnet_trainer = Trainer(resnet, device=device)

# Acknowledgements

- [PyTorch CIFAR10 Training Tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)
  