In [None]:
from torchvision import datasets
from torchvision.transforms import transforms
from torch.utils.data import random_split
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18

import matplotlib.pyplot as plt
%matplotlib inline

import numpy as np
from tqdm import tqdm

from sklearn.metrics import ConfusionMatrixDisplay
import warnings 
warnings.simplefilter(action='ignore', category=FutureWarning) 


## Helper Functions & Classes

In [None]:
# Hyperparameters
valid_size = 0.2
batch_size = 10
lr=0.001
n_epochs = 100

# Parameter Definition
num_workers = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
              'dog', 'frog', 'horse', 'ship', 'truck']
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])

random_seed = 42
torch.manual_seed(random_seed)

# Define denormalizer
de_mean = [-mean/std for mean, std in zip(mean, std)]
de_std = [1/std for std in std]
denormalizer = transforms.Normalize(mean=de_mean, std=de_std)

def spilt_train_valid(train_dataset, valid_set_size):
    valid_set_size = int(valid_set_size * len(train_dataset))
    train_set_size = len(train_dataset) - valid_set_size
    return random_split(train_dataset, [train_set_size, valid_set_size])

def load_cifar10(is_train, transform):
    return datasets.CIFAR10(root='data', train=is_train, 
                            transform=transform, download=True)

def confusion_matrix(preds, labels, conf_matrix):
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix

def plot_confusion_matrix(conf_matrix, class_names, normalize=False, ax_display=True, title='Confusion Matrix'):
    _, ax = plt.subplots(figsize=(8,6))
    if normalize:
        conf_matrix = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis]
        disp = ConfusionMatrixDisplay(confusion_matrix=conf_matrix, display_labels=class_names)
        disp.plot(cmap=plt.cm.Blues,ax=ax, xticks_rotation=45, values_format='.2f')
    else:
        conf_matrix = conf_matrix.astype('int')
        disp = ConfusionMatrixDisplay(confusion_matrix=conf_matrix, display_labels=class_names)
        disp.plot(cmap=plt.cm.Blues,ax=ax, xticks_rotation=45, values_format='d')
    if ax_display is False:
        left, right = plt.xlim()
        ax.spines['left'].set_position(('data', left))
        ax.spines['right'].set_position(('data', right))
        for edge_i in ['top', 'bottom', 'right', 'left']:
            ax.spines[edge_i].set_edgecolor("white")
    plt.title(title, fontdict={'size': 14})
    plt.tight_layout()
    plt.ylabel('True label', fontdict={'size': 14})
    plt.xlabel('Predicted label', fontdict={'size': 14})
    plt.show()

def plot_class_samples(samples, preds, values, title='Examples'):
  plt.figure(figsize=(12, 6))
  for idx in range(10):
      plt.subplot(2, 5, idx+1)
      img = denormalizer(samples[idx]).cpu().numpy()
      plt.imshow(np.transpose(img, (1, 2, 0))) 
      plt.title('True: {:s} (Pred: {:s})'.format(class_names[idx], class_names[preds[idx]]), 
                fontdict={'size': 10},
                color=("green" if preds[idx]==idx else "red"))
      plt.xlabel('Confidence value: {:.2f}'.format(values[idx]), 
                fontdict={'size'   : 10})
      plt.yticks([])
      plt.xticks([])
  plt.suptitle(title, fontdict={'size': 14})
  plt.tight_layout()
  plt.show()

## Data Augumentation

In [None]:
train_transform = transforms.Compose([
                                transforms.RandomHorizontalFlip(),
                                transforms.RandomGrayscale(),
                                transforms.ToTensor(),
                                transforms.Normalize(mean, std)
                                ])

test_transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize(mean, std)
                                ])

## Data Loading

In [None]:
# Load train, validation and test datasets as iterators
# Data
train_ds = load_cifar10(is_train=True, transform=train_transform)
test_ds = load_cifar10(is_train=False, transform=test_transform)
train_ds, valid_ds = spilt_train_valid(train_ds, valid_size)

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, 
                                         shuffle=True, num_workers=num_workers)
valid_loader = torch.utils.data.DataLoader(valid_ds, batch_size=batch_size, 
                                         shuffle=True, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=batch_size,
                                        shuffle=False, num_workers=num_workers)

## Network Definition

In [None]:
model = resnet18()
# num_ftrs = model.fc.in_features
# model.fc = nn.Linear(num_ftrs, len(class_names))
model.to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
loss_fn = nn.CrossEntropyLoss()

## Training

In [None]:
valid_loss_min = np.Inf
train_epochs_loss = []
valid_epochs_loss = []
train_epochs_acc = []
valid_epochs_acc = []

for epoch in range(1, n_epochs+1):
    # Record train/valid loss and acc for each epoch
    train_loss_sum = 0.0
    valid_loss_sum = 0.0
    train_acc_num = 0.0
    valid_acc_num = 0.0
    # Start training and validating
    # ========================= train model =====================
    print('======================== Epoch: {} ========================'.format(epoch))
    model.train()
    for images, labels in tqdm(train_loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        logits = model(images)
        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()
        train_loss_sum += loss.item() * images.size(0)
        train_acc_num += sum(torch.max(logits, dim=1)[1] == labels).cpu()
    train_acc = 100 * train_acc_num/len(train_loader.dataset)
    train_loss = train_loss_sum/len(train_loader.dataset)
    train_epochs_loss.append(train_loss)
    train_epochs_acc.append(train_acc)
    print('Train loss: {:.3f}, Train acc: {:.1f}%'.format(train_loss, train_acc))

    # ========================= valid model =====================
    model.eval()
    for images, labels in tqdm(valid_loader):
        images, labels = images.to(device), labels.to(device)
        with torch.no_grad():
            logits = model(images)
            loss = loss_fn(logits, labels)
        valid_loss_sum += loss.item() * images.size(0)
        valid_acc_num += sum(torch.max(logits, dim=1)[1] == labels).cpu()
    valid_acc= 100 * valid_acc_num/len(valid_loader.dataset)
    valid_loss = valid_loss_sum/len(valid_loader.dataset)
    valid_epochs_loss.append(valid_loss)
    valid_epochs_acc.append(valid_acc)
    print('Valid loss: {:.3f}, Valid acc: {:.1f}%'.format(valid_loss, valid_acc))

    # ========================= save model =====================
    # Save model if validation loss has decreased
    if valid_loss <= valid_loss_min:
        print('Validation loss decreased ({:.3f} --> {:.3f}).'.format(valid_loss_min, valid_loss))
        torch.save(model.state_dict(), 'model_cifar.pt')
        valid_loss_min = valid_loss
    
    # ========================= plot ==========================
    plt.figure(figsize=(12, 4))
    plt.subplot(121)
    plt.plot(train_epochs_acc, '-o', label="train_acc")
    plt.plot(valid_epochs_acc, '-o', label="valid_acc")
    plt.title("Accuracy Graph by Epoch")
    plt.legend()
    plt.subplot(122)
    plt.plot(train_epochs_loss, '-o', label="train_loss")
    plt.plot(valid_epochs_loss, '-o', label="valid_loss")
    plt.title("Loss Graphy by Epoch")
    plt.legend()
    plt.show()

## Testing

In [None]:
model.load_state_dict(torch.load('model_cifar.pt'))
cnf_matrix = torch.zeros(10, 10)
test_loss_sum = 0.0
test_acc_num = 0.0
# store the number of correct classified samples for each class
# and their total number
class_correct = list(0 for i in range(10))
class_total = list(0 for i in range(10))
# Store 10 sample well-classified images and their confidence value
correct_classfied_flags = [False for i in range(10)]
correct_classfied_samples = [i for i in range(10)]
correct_classfied_values = [0 for i in range(10)]
correct_classfied_preds = [0 for i in range(10)]
# Store 10 sample miss-classified images and their confidence value
miss_classfied_flags = [False for i in range(10)]
miss_classfied_samples = [i for i in range(10)]
miss_classfied_values = [0 for i in range(10)]
miss_classfied_preds = [0 for i in range(10)]

model.eval()
for images, labels in tqdm(test_loader):
    images, labels = images.to(device), labels.to(device)
    with torch.no_grad():
        logits = model(images)
        loss = loss_fn(logits, labels)
    test_loss_sum += loss.item() * images.size(0)
    pred = torch.max(logits, dim=1)[1]
    correct_tensor = pred == labels
    for idx in range(batch_size):
        # idx is the index of each image in one batch to RETRIVE from
        label = labels.data[idx].item()
        # label is the corresponding order of that image ot STORE in
        class_correct[label] += correct_tensor[idx].item()
        class_total[label] += 1
        if correct_tensor[idx] and not correct_classfied_flags[label]:
            correct_classfied_samples[label] = images.data[idx]
            correct_classfied_values[label] = max(nn.functional.softmax(logits[idx], dim=0)).item()
            correct_classfied_preds[label] = pred[idx].item()
            correct_classfied_flags[label] = True
        elif not correct_tensor[idx] and not miss_classfied_flags[label]:
            miss_classfied_samples[label] = images.data[idx]
            miss_classfied_values[label] = max(nn.functional.softmax(logits[idx], dim=0)).item()
            miss_classfied_preds[label] = pred[idx].item()
            miss_classfied_flags[label] = True   
    test_acc_num += sum(correct_tensor)
    cnf_matrix = confusion_matrix(pred, labels, cnf_matrix)

test_loss = test_loss_sum/len(test_loader.dataset)
test_acc = 100 * test_acc_num/len(test_loader.dataset)
print('Test Loss: {:.3f}, Test Acc: {:.1f}%'.format(test_loss, test_acc))
print('Test Accuracy by Class:')
for i in range(10):
    if class_total[i] > 0:
        print('{:8s}\t {:.1f}% ({:d}/{:d})'.format(class_names[i], 100 * class_correct[i] / class_total[i], np.sum(class_correct[i]), np.sum(class_total[i])))
    else:
        print('Test Accuracy of %5s: N/A (no training examples)' % (class_names[i]))

In [None]:
plot_class_samples(correct_classfied_samples, correct_classfied_preds, correct_classfied_values, title='Well-calssified Examples')

In [None]:
plot_class_samples(miss_classfied_samples, miss_classfied_preds, miss_classfied_values, title='Miss-classified Examples')

In [None]:
plot_confusion_matrix(cnf_matrix.numpy(), class_names, normalize=False, ax_display=False, title='Confusion Matrix')