In [None]:
import torch
import os
import numpy as np
import torchvision
import torch.nn.functional as F
from matplotlib import colors
from tqdm import tqdm
from torchvision import transforms
from PIL import Image
from torch import nn
from dataset import CUB
from utils import *

device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
# Change this to your own path for the downloaded and unzipped CUB dataset. Also requires model fine-tuned on CUB. Use train.py for this.
data_root = '..'
model_name = 'vgg'
image_size = 224
threshold = 0.1
torch.manual_seed(0)

In [None]:
if model_name == 'vgg':
    model = torchvision.models.vgg16_bn(pretrained = False).to(device)
elif model_name == 'alexnet':
    model = torchvision.models.alexnet(pretrained = False).to(device)
model.classifier[6] = nn.Linear(4096, 200).to(device)
model.load_state_dict(torch.load(f'model/{model_name}_CUB.pth'))

model.eval()
testset = CUB(data_root, normalization=True, train_test='test')


In [None]:
# Find samples where the third class prediction has a high relative probability
samples = get_samples(testset,model,threshold,max_no_samples=50,no_targets=3)
os.makedirs("./comparison_max_mean", exist_ok=True)

for (n, y1, y2, y3) in tqdm(samples):
    image = testset[n][0].view(1,3,224,224).to(device)
    label = testset[n][1]

    with torch.no_grad():
        prob = torch.softmax(model(image),dim=1)[0]
        argsort=prob.argsort()
        t1 = argsort[-1]
        t2 = argsort[-2]
        t3 = argsort[-3]
    
    folder_path = "./comparison_max_mean/%s" % (str(n) + "_true_" + str(label))
    os.makedirs(folder_path, exist_ok=True)

    # Save input image
    plt.imshow(denorm(image).detach().cpu().numpy().squeeze().transpose((1,2,0)))
    plt.axis('off')
    plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
    plt.savefig(folder_path + "/input.png", bbox_inches='tight', pad_inches=0)
    plt.show()
    
    print("Showing image:\t%d" %(n))
    print("True label:\t%d" %(label))
    print("Top class:\t%d,\tconfidence = %.3f\nSecond class:\t%d,\tconfidence = %.3f\nThird class:\t%d,\tconfidence = %.3f"\
              %(t1, prob[t1],t2, prob[t2],t3, prob[t3]))
        
    for mode in ['GC']:      
        for t, t_string in [(t1, "t1"), (t2, "t2"), (t3, "t3")]:
            gradcam = get_saliency(model, image, t, mode=mode, explanation='original')
            gradcam_weighted = get_saliency(model, image, t, mode=mode, explanation='weighted')
            gradcam_mean = get_saliency(model, image, t, mode=mode, explanation='mean')
            gradcam_max = get_saliency(model, image, t, mode=mode, explanation='max')
            # For displaying differences between original and mean, as they are very hard to spot
            # diffcam = (gradcam - gradcam_mean).abs()
                        
            print("Original")
            plt.imshow(denorm(image).detach().cpu().numpy().squeeze().transpose((1,2,0)))
            plt.imshow(gradcam.detach().cpu().squeeze().numpy(), cmap='bwr', alpha=0.5, norm=colors.CenteredNorm(0))
            plt.axis('off')
            plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
            path = ("%s/original_%s_%s_%d_%.3f.png" % (folder_path, mode, t_string, t, prob[t]))
            plt.savefig(path, bbox_inches='tight', pad_inches=0)
            plt.show()
            
            print("Weighted")
            plt.imshow(denorm(image).detach().cpu().numpy().squeeze().transpose((1,2,0)))
            plt.imshow(gradcam_weighted.detach().cpu().squeeze().numpy(), cmap='bwr', alpha=0.5, norm=colors.CenteredNorm(0))
            plt.axis('off')
            plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
            path = ("%s/weighted_%s_%s_%d_%.3f.png" % (folder_path, mode, t_string, t, prob[t]))
            plt.savefig(path, bbox_inches='tight', pad_inches=0)
            plt.show()
            
            print("Mean")
            plt.imshow(denorm(image).detach().cpu().numpy().squeeze().transpose((1,2,0)))
            plt.imshow(gradcam_mean.detach().cpu().squeeze().numpy(), cmap='bwr', alpha=0.5, norm=colors.CenteredNorm(0))
            plt.axis('off')
            plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
            path = ("%s/mean_%s_%s_%d_%.3f.png" % (folder_path, mode, t_string, t, prob[t]))
            plt.savefig(path, bbox_inches='tight', pad_inches=0)
            plt.show()
            
            print("Max")
            plt.imshow(denorm(image).detach().cpu().numpy().squeeze().transpose((1,2,0)))
            plt.imshow(gradcam_max.detach().cpu().squeeze().numpy(), cmap='bwr', alpha=0.5, norm=colors.CenteredNorm(0))
            plt.axis('off')
            plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
            path = ("%s/max_%s_%s_%d_%.3f.png" % (folder_path, mode, t_string, t, prob[t]))
            plt.savefig(path, bbox_inches='tight', pad_inches=0)
            plt.show()
            
            # print("Difference between original and mean (normalized)")
            # plt.imshow(denorm(image).detach().cpu().numpy().squeeze().transpose((1,2,0)))
            # plt.imshow(diffcam.detach().cpu().squeeze().numpy(), cmap='jet', alpha=0.5, norm=colors.CenteredNorm(0))
            # plt.axis('off')
            # plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
            # plt.show()