In [1]:
import os
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
from utils import *
import torch
from torch import nn
from torchvision import models, transforms
import pickle
import pandas as pd
from tqdm import tqdm
tqdm.pandas()

In [2]:
df = pd.read_csv("/home/lizhaochen/fyp/fyp-long-tail-recognition/analysis/label.csv")
tail = df[df.label_count<=20]['Unnamed: 0'].tolist()
median = df[(df.label_count>20)&(df.label_count<100)]['Unnamed: 0'].tolist()
head = df[df.label_count>=100]['Unnamed: 0'].tolist()

In [3]:
feat_model_path = './models/ResNet50Feature.py'
cls_model_path = './models/DotProductClassifier.py'
feat_dim = 2048
feat_model = source_import(feat_model_path).create_model()
cls_model = source_import(cls_model_path).create_model(feat_dim=feat_dim)

Loading Scratch ResNet 50 Feature Model.
No Pretrained Weights For Feature Model.
Loading Dot Product Classifier.
Random initialized classifier weights.


In [4]:
model_path = '/home/lizhaochen/fyp/fyp-long-tail-recognition/logs/ImageNet_LT/stage1/ImageNet_LT_90_coslrres50/ImageNet_LT_90_coslrres50.pth'

In [5]:
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
model_state = checkpoint['state_dict_best']
weights = model_state['feat_model']
weights = {k.replace('module.',''): weights[k] for k in weights}
feat_model.load_state_dict(weights)
weights = model_state['classifier']
weights = {k.replace('module.',''): weights[k] for k in weights}
cls_model.load_state_dict(weights)

<All keys matched successfully>

In [6]:
read_tensor = transforms.Compose([
    lambda x: Image.open(x),
    lambda x: x.convert('RGB'),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225]),
    lambda x: torch.unsqueeze(x, 0)
])

In [None]:
crop_img = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224)
])

In [7]:
class Flatten(nn.Module):
    """One layer module that flattens its input."""
    def __init__(self):
        super(Flatten, self).__init__()
    def forward(self, x):
        return x.view(x.size(0), -1)

In [8]:
feat_model_final = nn.Sequential(*list(feat_model.children())[:-2])
cls_model_final = nn.Sequential(*(list(feat_model.children())[-2:-1] + [Flatten()] + list(cls_model.children())))
feat_model = feat_model.eval()
cls_model = cls_model.eval()

In [9]:
def get_class_name(c):
    labels = source_import("imagenet_dict.py").class_dict
    return labels[int(c)].split(',')[0]

In [10]:
def GradCAM(img, c, features_fn, classifier_fn):
    feats = features_fn(img)
    _, N, H, W = feats.size()
    out = classifier_fn(feats)
    c_score = out[0, c]
    grads = torch.autograd.grad(c_score, feats)
    w = grads[0][0].mean(-1).mean(-1)
    sal = torch.matmul(w, feats.view(N, H*W))
    sal = sal.view(H, W).cpu().detach().numpy()
    sal = np.maximum(sal, 0)
    return sal

In [22]:
with open("/home/lizhaochen/fyp/fyp-long-tail-recognition/logs/ImageNet_LT/stage1/ImageNet_LT_90_coslrres50/ImageNet_LT_90_coslrres50_path_test.pkl", "rb") as f:
    image_pths = list(pickle.load(f))
with open("/home/lizhaochen/fyp/fyp-long-tail-recognition/logs/ImageNet_LT/stage1/ImageNet_LT_90_coslrres50/ImageNet_LT_90_coslrres50_preds_softmax.pkl", "rb") as f:
    preds = list(pickle.load(f))
# plt.figure(figsize=(15, 5))
plot_count = 0
correct = 0
for idx, path in tqdm(enumerate(image_pths)):
    correct_flag=False
    if int(idx/50) not in median:
        continue
    class_name = get_class_name(int(idx)/50)
    plot_count += 1
    img_tensor = read_tensor(path)
    logits = cls_model(feat_model(img_tensor)[0])
    pp, cc = torch.topk(nn.Softmax(dim=1)(logits), 3)
    plt.figure(figsize=(15, 5))
    for i, (p, c) in enumerate(zip(pp[0], cc[0])):
        plt.subplot(1, 3, i+1)
        sal = GradCAM(img_tensor, c, feat_model_final, cls_model_final)
        img = crop_img(Image.open(path))
        sal = Image.fromarray(sal)
        sal = sal.resize(img.size, resample=Image.LINEAR)
        plt.title('{}: {:.1f}%'.format(get_class_name(c), 100*float(p)))
        if get_class_name(c) == class_name and i == 0:
            correct_flag = True
        plt.axis('off')
        plt.imshow(img)
        plt.imshow(np.array(sal), alpha=0.5, cmap='jet')

    # plt.tight_layout()
    dirname = "gradcam_images/{}_{}_{}".format(class_name, int(idx/50),'median')
    if not os.path.exists(dirname):
        os.mkdir(dirname)
    plt.savefig(os.path.join(dirname, "{}_{}.jpeg".format(idx, str(correct_flag))))
    plt.clf()
    plt.cla()
    plt.close()
    if plot_count == 500:
        break
    

1049it [03:33,  4.91it/s]
