# Pipeline for Training and Analysis of 20 2x100 FC models

## Imports and Hyperparameters

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import Subset

from src.cifar.models import *
from src.util import split_train_val, test, train, save_model

import matplotlib.pyplot as plt

import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# this should print 'cuda' if you are assigned a GPU
print(device)

train_batch_size = 100
test_batch_size = 100
n_epochs = 5
learning_rate = 1e-2
seed = 100
input_dim = 32*32*3
out_dim = 10
num_hidden_layers = 2
layer_size = 100
momentum = 0.9
weight_decay_lam = 1e-4

fc_model_params = [(2,100)]*20

complex_models = []

cuda


## Load Data

In [2]:
transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])
train_dataset = torchvision.datasets.CIFAR10('./datasets/', train=True, download=True, transform=transforms)
test_dataset = torchvision.datasets.CIFAR10('./datasets/', train=False, download=True, transform=transforms)

raw_test_data = torchvision.datasets.CIFAR10('./datasets/', train=False, download=True, transform=torchvision.transforms.ToTensor())

# sanity check
print('training data size:{}'.format(len(train_dataset)))
print('test data size:{}'.format(len(test_dataset)))

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
training data size:50000
test data size:10000


In [3]:
train_dataset, val_dataset = split_train_val(train_dataset, valid_ratio=1/6)
print('training data size:{}'.format(len(train_dataset)))
print('validation data size:{}'.format(len(val_dataset)))

training data size:41667
validation data size:8333


## Loaders

In [4]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=train_batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)
raw_test_loader = torch.utils.data.DataLoader(raw_test_data, batch_size=test_batch_size, shuffle=False)

# sanity check
print('training data size:{}'.format(len(train_loader.dataset)))
print('validation data size:{}'.format(len(val_loader.dataset)))
print('test data size:{}'.format(len(test_loader.dataset)))

training data size:41667
validation data size:8333
test data size:10000


## Create Networks

In [5]:
fc_models = [FC(input_dim, out_dim, num_hidden_layers, layer_size) for num_hidden_layers, layer_size in fc_model_params]
complex_models = load_pretrained_models(complex_models)

## Train Networks

In [6]:
%%capture
for i, model in enumerate(fc_models):
    model.to(device)

    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay_lam)

    print("Training FC model {}".format(i+1))

    for epoch in range(1, n_epochs + 1):
        train(model, train_loader, optimizer, epoch, device)

    test(model, val_loader, device)

    print("Saving FC model: {}".format(model))
    save_model(model, dataset="CIFAR10", filename="FC" + str(i))

for i, model in enumerate(os.listdir("./models/CIFAR10/")):
    if model.endswith(".pth") and "FC" in model:
        fc_models[i].load_state_dict(torch.load("./models/CIFAR10/" + model))
        with torch.no_grad():
            test(fc_models[i].to(device), test_loader, device)

## Missclassification Tracking

Images tracked as tuples (batch_idx, image_idx) and can be accessed from dataset as test_dataset[batch_idx][image_idx]

In [7]:
misses = dict()
complex_list = list(complex_models.values())
for model in fc_models + complex_list:

    model.to(device)
    model.eval()

    with torch.no_grad():
        
        for i, (data, target) in enumerate(test_loader):

            data = data.to(device)
            target = target.to(device)

            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)

            missed = pred.not_equal(target.data.view_as(pred)).view(-1).cpu().numpy()

            for j, miss in enumerate(missed):
                if miss:
                    if (i, j) in misses:
                        misses[(i, j)] += 1
                    else:
                        misses[(i, j)] = 1

print("Missed {}/{}".format(sum(misses.values()), len(test_loader.dataset)))

Missed 4871/10000
Missed 9894/10000
Missed 14886/10000
Missed 19874/10000
Missed 24849/10000
Missed 29735/10000
Missed 34678/10000
Missed 39683/10000
Missed 44589/10000
Missed 49469/10000
Missed 54380/10000
Missed 59297/10000
Missed 64222/10000
Missed 69133/10000
Missed 74166/10000
Missed 79189/10000
Missed 84178/10000
Missed 89145/10000
Missed 94159/10000
Missed 99073/10000


In [8]:
all_miss = {k: v for k, v in misses.items() if v > 0}
significant = {k: v for k, v in misses.items() if v > 19}
print(len(significant))
print(len(all_miss))

1634
8277


In [14]:
y = list(all_miss.values()) + [0]*(len(test_dataset)-len(all_miss))

fig = plt.figure(figsize=(12, 8), facecolor="w")
plt.hist(y, [x-0.5 for x in range(22)], edgecolor="k")
plt.xlabel("Number of Missclassifications")
plt.ylabel("Frequency")
plt.savefig("plots/CIFAR10/im_freq.png")
plt.close()

In [None]:
"""for sample in significant:

    idx = sample[0]*test_batch_size + sample[1]
    im_net = test_loader.dataset[idx][0].view(3, 32, 32)
    im_raw = raw_test_loader.dataset[idx][0].view(3, 32, 32)
    label = test_loader.dataset[idx][1]

    preds = np.array([model(im_net.unsqueeze(0).to(device)).argmax(dim=1, keepdim=True).cpu().numpy() for model in fc_models + complex_list]).flatten()
    fig = plt.figure(figsize=(12, 8), facecolor="w")
    plt.hist(preds, [x-0.5 for x in range(11)], edgecolor="k")
    plt.xlabel("Predicted Label")
    plt.ylabel("Frequency")
    plt.title("True Label: {}".format(label))
    plt.savefig("plots/CIFAR10/images/im_" + str(idx) + "_freq.png")
    plt.close()

    fig = plt.figure()
    plt.imshow(im_raw.squeeze().cpu().numpy().transpose((1,2,0)))
    plt.title("Actual: {}".format(test_loader.dataset.targets[idx]))
    plt.savefig("plots/CIFAR10/images/im_" + str(idx) + ".png")
    plt.close()"""

'for sample in significant:\n\n    idx = sample[0]*test_batch_size + sample[1]\n    im_net = test_loader.dataset[idx][0].view(3, 32, 32)\n    im_raw = raw_test_loader.dataset[idx][0].view(3, 32, 32)\n    label = test_loader.dataset[idx][1]\n\n    preds = np.array([model(im_net.unsqueeze(0).to(device)).argmax(dim=1, keepdim=True).cpu().numpy() for model in fc_models + complex_list]).flatten()\n    fig = plt.figure(figsize=(12, 8), facecolor="w")\n    plt.hist(preds, [x-0.5 for x in range(11)], edgecolor="k")\n    plt.xlabel("Predicted Label")\n    plt.ylabel("Frequency")\n    plt.title("True Label: {}".format(label))\n    plt.savefig("plots/CIFAR/images/im_" + str(idx) + "_freq.png")\n    plt.close()\n\n    fig = plt.figure()\n    plt.imshow(im_raw.squeeze().cpu().numpy().transpose((1,2,0)))\n    plt.title("Actual: {}".format(test_loader.dataset.targets[idx]))\n    plt.savefig("plots/CIFAR/images/im_" + str(idx) + ".png")\n    plt.close()'

# Changes

- add 0 bin on hist
- try cifar10.1
- regularize FCs (remove overfit)
    - weight decay
    - dropout (?)