In [None]:
import copy
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from PIL import Image

In [None]:
data = torchvision.datasets.CIFAR10('~/data', train=True, download=True)
mu = data.data.mean(axis=(0, 1, 2)) # (N, H, W, 3) -> 3
std = data.data.std(axis=(0, 1, 2)) # (N, H, W, 3) -> 3
print(data.data.shape)

In [None]:
n_valid = 2000
def Cifar10(train, trforms=None):
  tfms_norm = torchvision.transforms.Compose([
      transforms.ToTensor(),
      # ToTensor already maps 0-255 to 0-1, so devide mu and std by 255 below 
      transforms.Normalize(mu / 255, std /255), 
  ])
  tf = transforms.Compose([trforms, tfms_norm]) if trforms is not None else tfms_norm 
  return torchvision.datasets.CIFAR10(root='~/data',train=train, download=True, transform=tf)

# data_train, data_valid = torch.utils.data.random_split(data_train, (45000, 5000))
loader_kwargs = {'batch_size': 128, 'num_workers': 4}

# last 5000 training images for validation
data_valid = Cifar10(train=True)
data_valid.data = data_valid.data[-n_valid:]
data_valid.targets = data_valid.targets[-n_valid:]
loader_valid = torch.utils.data.DataLoader(data_valid, **loader_kwargs)

# all images in the test set
data_test = Cifar10(train=False)
loader_test = torch.utils.data.DataLoader(data_test, **loader_kwargs)


def get_loader_train(augmentations=lambda x: x):
    # first 45000 training images for training
    data_train = Cifar10(train=True, trforms=transforms.Compose([augmentations]))
    data_train.data = data_train.data[:-n_valid]
    data_train.targets = data_train.targets[:-n_valid]
    loader_train = torch.utils.data.DataLoader(data_train, **loader_kwargs, shuffle=True)
    return loader_train

In [None]:
def one_epoch(model, data_loader, opt=None):
    device = next(model.parameters()).device
    train = False if opt is None else True
    model.train() if train else model.eval()
    losses, correct, total = [], 0, 0
    for x, y in data_loader:
        x, y = x.to(device), y.to(device)
        with torch.set_grad_enabled(train):
            logits = model(x)
        loss = F.cross_entropy(logits, y)

        if train:
            opt.zero_grad()
            loss.backward()
            opt.step()

        losses.append(loss.item())
        total += len(x)
        correct += (torch.argmax(logits, dim=1) == y).sum().item()
    return np.mean(losses), correct / total


def train(model, loader_train, loader_valid, lr=1e-3, max_epochs=30, weight_decay=0., patience=3):
    train_losses, train_accuracies = [], []
    valid_losses, valid_accuracies = [], []

    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    best_valid_accuracy = 0
    best_valid_accuracy_epoch = 0

    t = tqdm(range(max_epochs))
    for epoch in t:
        train_loss, train_acc = one_epoch(model, loader_train, opt)
        train_losses.append(train_loss)
        train_accuracies.append(train_acc)

        valid_loss, valid_acc = one_epoch(model, loader_valid)
        valid_losses.append(valid_loss)
        valid_accuracies.append(valid_acc)

        t.set_description(f'train_acc: {train_acc:.2f}, valid_acc: {valid_acc:.2f}')

        if valid_acc > best_valid_accuracy:
            best_valid_accuracy = valid_acc
            best_valid_accuracy_epoch = epoch

        if epoch > best_valid_accuracy_epoch + patience:
            break
    t.set_description(f'best valid acc: {best_valid_accuracy:.2f}')

    return train_losses, train_accuracies, valid_losses, valid_accuracies


def plot_history(train_losses, train_accuracies, valid_losses, valid_accuracies):
    plt.figure(figsize=(7, 3))

    plt.subplot(1, 2, 1)
    plt.xlabel('epoch')
    plt.ylabel('loss')
    p = plt.plot(train_losses, label='train')
    plt.plot(valid_losses, label='valid')
    plt.ylim(0, 2)
    plt.legend()
    plt.grid()

    plt.subplot(1, 2, 2)
    plt.xlabel('epoch')
    plt.ylabel('accuracy')
    p = plt.plot(train_accuracies, label='train')
    plt.plot(valid_accuracies, label='valid')
    plt.ylim(0, 1.05)
    plt.legend()
    plt.grid()

    plt.tight_layout()
    plt.show()

In [None]:
model = torchvision.models.resnet18(pretrained=True)
#print(model)
# See that the head after the conv-layers (in the bottom) is one linear layer, from 512 features to 1k-class logits.
# We want to replace it with a new head to 10-class logits:
model.fc = nn.Linear(512, 10)
# Also, the model has been trained on images with a resolution of 224. Let's upscale our cifar10 images:

model = nn.Sequential(
    nn.UpsamplingBilinear2d((224,224)),
    model,
)

In [None]:
model = model.cuda()
loader_train = get_loader_train(affine_hflip)
plot_history(*train(model,loader_train,loader_valid, lr=1e-4))

In [None]:
test_acc = one_epoch(model, loader_test)[1]
print(f'{test_acc * 100:.1f} % test accuracy')