In [1]:
import torch
import numpy as np
from matplotlib.pyplot import cm 
import matplotlib.pyplot as plt 

import nbimporter
from q4_imagenet_finetune_pascal import PretrainedResNet
from q2_caffenet_pascal import CaffeNet
from torch.utils.data import DataLoader
from sklearn.neighbors import NearestNeighbors
from PIL import Image
import os
from voc_dataset import VOCDataset

In [2]:
def generate_rand(index_list, annotation_list):
   taken_classes = set()
   req_idx = []
   i = 0
   while True:
       if len(req_idx) >= 3:
           break
       curr_random = np.random.randint(0, len(index_list))
       curr_classes = np.where(annotation_list[curr_random][0] == 1)
       class_taken = False

       for i in range(curr_classes[0].shape[0]):
           if curr_classes[0][i] in taken_classes:
               class_taken = True
               continue
           taken_classes.add(curr_classes[0][i])
           
       if class_taken == True:
           continue
       req_idx.append(curr_random)

   return req_idx

In [3]:
def resnet_features(x):
    torch.cuda.empty_cache()
    if torch.cuda.is_available:
        device = 'cuda'
    else:
        device = 'cpu'
    state = torch.load('checkpoint-resnet18_pretrained-epoch10.pth')
    model = PretrainedResNet()
    model.to(device)
    model.load_state_dict(state)
    model.eval()
    x = x.to(device)
    i = 0
    for _, v in model.resnet.resnet._modules.items():
        x = v(x)
        if i == 8:
            break
        i += 1
    return x

In [4]:
def caffenet_features(x, pool5=True):
    torch.cuda.empty_cache()
    if torch.cuda.is_available:
        device = 'cuda'
    else:
        device = 'cpu'
    state = torch.load('checkpoint-caffenet-epoch50.pth')
    model = CaffeNet()
    model.to(device)
    model.load_state_dict(state)
    model.eval()
    x = x.to(device)
    if pool5:
        return model.forward_analysis_pool5(x)
    else:
        return model.forward_analysis_fc7(x)

In [5]:
def calculate_features(test_loader):
    total_resnet_features = np.zeros((0, 512), dtype=np.float32)
    total_caffenet_features = np.zeros((0, 6400), dtype=np.float32)
    batch_size = 256
    test_data_loader = DataLoader(test_loader, batch_size, False)

    # Calculate the features for all the test images.
    for _, (data, _, _) in enumerate(test_data_loader):
        temp = resnet_features(data).reshape(data.shape[0], -1).cpu().detach().numpy()
        total_resnet_features = np.vstack([total_resnet_features, temp])
        temp = caffenet_features(data).cpu().detach().numpy()
        total_caffenet_features = np.vstack([total_caffenet_features, temp])
    return total_resnet_features, total_caffenet_features

In [6]:
def find_neighbors(total_features, req_three_idx):
    neighbor_indices = np.zeros((0, 5), np.int16)
    resnet_nbrs = NearestNeighbors(n_neighbors=5, algorithm='ball_tree').fit(total_features)
    for idx in req_three_idx:
        _, indices = resnet_nbrs.kneighbors(total_features[idx].reshape(1, -1))
        neighbor_indices = np.vstack([neighbor_indices, indices])
    return neighbor_indices

In [7]:
def show_image(resnet_neighbor_indices, test_loader):
    for row in resnet_neighbor_indices:
        class_list = np.where(test_loader.anno_list[int(row[0])][0])
        print("\n --- NEW CLASS ---")
        print("Original Image")
        
        neighbour_no = 0
        for val in row:
            findex = test_loader.index_list[int(val)]
            fpath = os.path.join(test_loader.img_dir, findex + '.jpg')
            img = Image.open(fpath)
            img = img.resize((img.size[0]//2, img.size[1]//2))
            class_list = np.where(test_loader.anno_list[int(val)][0] == 1)
            print("Image classes are")
            for i in range(class_list[0].shape[0]):
                print(test_loader.CLASS_NAMES[class_list[0][i]])
            display(img)
            if neighbour_no <= 4:
                print("Neigbour: ", neighbour_no+1)
            neighbour_no += 1


In [8]:
def find_caffenet_features(test_loader):
    total_caffenet_features = np.zeros((0, 4096), dtype=np.float32)
    batch_size = 250
    test_data_loader = DataLoader(test_loader, batch_size, True)
    total_label = np.empty((0, 20))

    for i, (data, label, _) in enumerate(test_data_loader):
        temp = caffenet_features(data, False).cpu().detach().numpy()
        total_label = np.vstack([total_label, label])
        total_caffenet_features = np.vstack([total_caffenet_features, temp])
        if (i+1)*batch_size >= 1000:
            break
    return total_caffenet_features, total_label

In [9]:
def find_mean_label(label, test_loader, color):
    multiple_labels = np.where(label == 1)
    label_names = ""
    mean_color = np.zeros((4,))
    for i in range(len(multiple_labels[0])):
        label_names += test_loader.CLASS_NAMES[multiple_labels[0][i]] + " ,"
        mean_color += color[multiple_labels[0][i]]

    label_names = label_names[:-1]
    mean_color /= len(multiple_labels[0])
    return len(multiple_labels[0]), label_names, mean_color

In [10]:
def plot_features(tsne_projection, label, test_loader):
    fig, ax = plt.subplots(1)
    fig.set_size_inches(18.5, 10.5, forward=True)
    color = cm.rainbow(np.linspace(0, 1, 20))
    used_colors = set()
    for i in range(tsne_projection.shape[0]):
        number_of_classes, label_names, mean_color = find_mean_label(label[i], test_loader, color)
        if mean_color.mean() not in used_colors and number_of_classes == 1:
            used_colors.add(mean_color.mean())
            ax.scatter(tsne_projection[i][0], tsne_projection[i][1], color=mean_color, label=label_names)
        else:
            ax.scatter(tsne_projection[i][0], tsne_projection[i][1], color=mean_color)
    #ax.legend(loc='lower right')
    ax.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)

In [11]:
def plot_metrics(report, resnet=True):
    fig, ax1 = plt.subplots(1)
    color = cm.rainbow(np.linspace(0, 1, 20))
    ax1.set_xlabel("recall")
    if resnet:
        ax1.set_title("precision vs recall for resnet")
    else:
        ax1.set_title("precision vs recall for caffenet")
    ax1.set_ylabel("precision")
    i = 0
    for data in report:
        ax1.scatter(report[data][1], report[data][2], color=color[i], label=data)
        i += 1
    ax1.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)

In [12]:
def test_metrics(test_dataset, model_name='resnet'):
    torch.cuda.empty_cache()
    if torch.cuda.is_available:
        device = 'cuda'
    else:
        device = 'cpu'
    
    if model_name == 'resnet':
        state = torch.load('checkpoint-resnet18_pretrained-epoch10.pth')
        model = PretrainedResNet()
        model.to(device)
        model.load_state_dict(state)
        model.eval()

    else:
        state = torch.load('checkpoint-caffenet-epoch50.pth')
        model = CaffeNet()
        model.to(device)
        model.load_state_dict(state)
        model.eval()
    
    sigmoid = torch.sigmoid
    total_label = np.zeros((0, 20), dtype=np.int16)
    total_pred = np.zeros((0, 20), dtype=np.int16)

    test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False)
    for i, (data, label, weights) in enumerate(test_dataloader):
        data = data.to(device)
        prediction = model(data)
        prediction = prediction.cpu().detach()
        prediction = sigmoid(prediction).numpy()
        prediction = np.where(prediction >= 0.5, 1, 0)
        total_label = np.vstack([total_label, label.numpy()])
        total_pred = np.vstack([total_pred, prediction])

    return total_label, total_pred