In [13]:
'''
Author        : Aditya Jain
Date started  : 2nd February, 2022
About         : This script calculates and saves pre-training accuracies for each of the imagenet classes
'''

'\nAuthor        : Aditya Jain\nDate started  : 2nd February, 2022\nAbout         : This script calculates and saves pre-training accuracies for each of the imagenet classes\n'

In [22]:
import torchvision
from torchvision import models
from torchvision import datasets
from torchvision import transforms, utils
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset, DataLoader

import torch
from torch import nn
import torch.optim as optim
from torchsummary import summary
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import time
import datetime
import sys
sys.path.append("..")

from data.imagenetvaldataset import ImagenetValDataset
from data.hardexdataset import HardExDataset

#### Loading pre-trained ResNet50 Model

In [15]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

model  = models.resnet50(pretrained=True).to(device)

cuda


In [16]:
image_resize       = 224
batch_size         = 1 

val_root_dir       = '/network/datasets/imagenet.var/imagenet_torchvision/val/'
val_label_list     = '/home/mila/a/aditya.jain/mothAI/selfsupervision/data/validation_imagenet_labels.csv'
val_convert_list   = '/home/mila/a/aditya.jain/mothAI/selfsupervision/data/imagenet_modified_labels.csv'

transformer        = transforms.Compose([
                        transforms.Resize((image_resize, image_resize)),              # resize the image to 224x224 
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                                    ])

val_data          = ImagenetValDataset(val_root_dir, val_label_list, val_convert_list, transformer)
val_dataloader    = DataLoader(val_data,batch_size=batch_size)

In [17]:
class_acc         = {}  # storing the class accuracy data

data_loc          = '/home/mila/a/aditya.jain/scratch/selfsupervise_data/hard_examples/'
hard_classes      = os.listdir(data_loc)

with open("/home/mila/a/aditya.jain/mothAI/selfsupervision/data/imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]

In [18]:
model.eval()
for image_batch, label_batch in val_dataloader:
    
    image_batch, label_batch = image_batch.to(device), label_batch.to(device)
    prediction               = model(image_batch)    
    _, index                 = torch.topk(prediction, 1)
    
    corr_label   = label_batch.cpu().numpy()[0][0]    # integer label of the true class
    corr_class_n = categories[corr_label]             # name of the class
    pred_label   = index.cpu().numpy()[0][0]          # integer label of the predicted class
    
    if corr_class_n not in class_acc.keys():
        class_acc[corr_class_n] = {}
        class_acc[corr_class_n]['total_correct'] = 0
        class_acc[corr_class_n]['total_samples'] = 0
        
    if corr_label==pred_label:
        class_acc[corr_class_n]['total_correct'] += 1
        class_acc[corr_class_n]['total_samples'] += 1
    else:
        class_acc[corr_class_n]['total_samples'] += 1
        

In [23]:
hard_ex_data    = []     # classes for which we have hard examples
nonhard_ex_data = []     # classes for which we don't have hard examples

for key in class_acc.keys():
    if key in hard_classes:
        hard_ex_data.append([categories.index(key), key, round((class_acc[key]['total_correct']/class_acc[key]['total_samples'])*100,2)])
    else:
        nonhard_ex_data.append([categories.index(key), key, round((class_acc[key]['total_correct']/class_acc[key]['total_samples'])*100,2)])

save_dir = '/home/mila/a/aditya.jain/mothAI/selfsupervision/data/'
data_df  = pd.DataFrame(hard_ex_data, columns=['PyTorch_ID', 'Name_ID', 'Accuracy'])
data_df.to_csv(save_dir + 'pretrain_valacc_hardex.csv', index=False)

data_df  = pd.DataFrame(nonhard_ex_data, columns=['PyTorch_ID', 'Name_ID', 'Accuracy'])
data_df.to_csv(save_dir + 'pretrain_valacc_nonhardex.csv', index=False)