In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
from tqdm import tqdm
torch.manual_seed(4701)
np.random.seed(4701)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_ft = torch.load('model')
model_ft.to(device)



In [None]:
# Creates a dataloader object for the training and validation sets
train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'), transforms.Compose([transforms.ToTensor()]))
train_dataloader = torch.utils.data.DataLoader(train_dataset, \
    batch_size=batch_size, shuffle=True, num_workers=4)

mean, std = calc_mean_std(train_dataloader)

In [None]:
def softmax(x):
    # https://stackoverflow.com/questions/34968722/how-to-implement-the-softmax-function-in-python
    return np.exp(x) / np.sum(np.exp(x), axis=0)

In [None]:
def k_mismatches(model, dataloader, k, predicate=(lambda x, y: x != y)):
    mismatches = []
    count = 0
    for data, label in dataloader:
        data = data.to(device)
        output = model(data)
        _, prediction = torch.max(output, 1)
        for i, (p, l) in enumerate(zip(prediction, label)):
            if predicate(p, l):
                count += 1
                mismatches.append((p.item(), l.item(), data[i].cpu().numpy(), softmax(output[i].cpu().detach().numpy())))
                if count == k:
                    return mismatches  
    return mismatches

In [None]:
def unnormalize(img, mean, std):
    # img_unnormalized = std * (img + mean)
    img = copy.deepcopy(img)
    for i in range(img.shape[-1]):
        assert (i <= 2)
#         print(std[i], mean[i])
        img[:, :, i] *= std[i]
        img[:, :, i] += mean[i]
    return img

In [None]:
def display_k_mismatches(k, mismatches):
    fig, axs = plt.subplots(2, 2)
#     fig, axs = plt.subplots(2)
    fig.set_size_inches(10, 10)
    for i in range(k):
    # i = 1
    #     print(mismatches[i][3])
    #     print(classes)
        predicted = mismatches[i][0]
        actual = mismatches[i][1]
        axs[i//2, i % 2].set_title('Predicted: {} ({:.1f}% confidence)\nActually: {} ({:.1f}% confidence)'.format(
#         axs[i].set_title('Predicted: {} ({:.1f}% confidence)\nActually: {} ({:.1f}% confidence)'.format(
            classes[predicted],
            mismatches[i][3][predicted]*100,
            classes[mismatches[i][1]],
            mismatches[i][3][actual]*100
            ))

        # print(unnormalize(mismatches[i][2].transpose(1, 2, 0), mean, std))
        axs[i//2, i % 2].imshow(unnormalize(mismatches[i][2].transpose(1, 2, 0), mean, std))
#         axs[i].imshow(unnormalize(mismatches[i][2].transpose(1, 2, 0), mean, std))
    fig.savefig('figs/4mismatches.png')
    fig.show()
display_k_mismatches(k, mismatches)