Reference: https://medium.com/@stepanulyanin/implementing-grad-cam-in-pytorch-ea0937c31e82

Grad-CAM Research paper: https://arxiv.org/pdf/1610.02391.pdf

ImageNet classes: https://deeplearning.cms.waikato.ac.nz/user-guide/class-maps/IMAGENET/

In [None]:
# Imports
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
from torch.utils import data
from torchvision.models import vgg19
from torchvision import transforms
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt

In [None]:
# Download dataset
from torchvision.datasets.utils import download_url
#dataset_url = "https://www.bestelectricbikes.com/wp-content/uploads/2021/03/Bike_Brake_Repair_Banner_Photo-1-1024x683.jpg"
#dataset_url = "https://www.zhsydz.com/wp-content/uploads/2020/01/electric-commuter-bike-1000x500.jpg"
dataset_url = "https://cdn.britannica.com/q:60/82/212182-050-50D9F3CE/basketball-LeBron-James-Cleveland-Cavaliers-2018.jpg"
data_dir = './data/test_folder/'    # dataset directory
download_url(dataset_url, data_dir) 

In [None]:
for img in os.listdir(data_dir):
    filename = img
image = cv2.imread(os.path.join(data_dir, filename))

print('The image', filename, 'is of shape', image.shape)

In [None]:
# use the ImageNet transformation
transform = transforms.Compose([transforms.Resize((224, 224)), 
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
# Above mean and standard deviation are of ImageNet datasets.

# define a 1 image dataset
dataset = ImageFolder('./data/', transform=transform)

# define the dataloader to load that single image
dataloader = data.DataLoader(dataset=dataset, shuffle=False, batch_size=1)

In [None]:
img1, img2 = next(iter(dataloader))
dummy = img1[0]
plt.imshow(dummy.permute(1, 2, 0))

In [None]:
vgg = vgg19(pretrained=True)
features_conv = vgg.features
print(features_conv)
max_pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
classifier = vgg.classifier
print(classifier)

In [None]:
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        
        # get the pretrained VGG19 network
        self.vgg = vgg19(pretrained=True)
        
        # disect the network to access its last convolutional layer
        self.features_conv = self.vgg.features[:36]
        
        # get the max pool of the features stem
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        
        # get the classifier of the vgg19
        self.classifier = self.vgg.classifier
        
        # placeholder for the gradients
        self.gradients = None
    
    # hook for the gradients of the activations
    def activations_hook(self, grad):
        self.gradients = grad
        
    def forward(self, x):
        x = self.features_conv(x)
        
        # register the hook
        h = x.register_hook(self.activations_hook)
        
        # apply the remaining pooling
        x = self.max_pool(x)
        x = x.view((1, -1))
        x = self.classifier(x)
        return x
    
    # method for the gradient extraction
    def get_activations_gradient(self):
        return self.gradients
    
    # method for the activation exctraction
    def get_activations(self, x):
        return self.features_conv(x)

In [None]:
# initialize the VGG model
vgg = VGG()

# set the evaluation mode
vgg.eval()

# get the image from the dataloader
img, _ = next(iter(dataloader))

# get the most likely prediction of the model
pred = vgg(img)

In [None]:
# Below library from "https://github.com/nottombrown/imagenet-stubs"
import imagenet_stubs
from imagenet_stubs.imagenet_2012_labels import label_to_name
max_pred_index = torch.argmax(pred)
print('Index of predicted class of ImageNet is:', max_pred_index.item(), ',', 'Class name:', label_to_name(max_pred_index.item()))
print(pred[0][max_pred_index])

In [None]:
# compute the gradient of the output with respect to the parameters of the model
pred[:, max_pred_index].backward()

In [None]:
# pull the gradients out of the model
gradients = vgg.get_activations_gradient()
print(gradients.shape)

In [None]:
# pool the gradients across the channels
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])     # from this we get expression 1 of paper
pooled_gradients_shape = list(pooled_gradients.shape)
print('number of importance weights:', pooled_gradients_shape[0])

In [None]:
# get the activations of the last convolutional layer
activations = vgg.get_activations(img).detach()
print('shape of activation map:', activations.shape)

In [None]:
# weight the channels by corresponding gradients
for i in range(pooled_gradients_shape[0]):
    activations[:, i, :, :] *= pooled_gradients[i]

In [None]:
# average the channels of the activations
heatmap = torch.mean(activations, dim=1).squeeze()
print(heatmap.shape)

In [None]:
# relu on top of the heatmap
heatmap = np.maximum(heatmap, 0)        # from this we get expression 2 of paper

In [None]:
# normalize the heatmap
heatmap /= torch.max(heatmap)

In [None]:
# draw the heatmap
plt.matshow(heatmap.squeeze())

In [None]:
img = cv2.imread(os.path.join(folder,filename))
heatmap = cv2.resize(heatmap.numpy(), (img.shape[1], img.shape[0]), interpolation= cv2.INTER_LINEAR)
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
superimposed_img = heatmap * 0.4 + img
cv2.imwrite('./Grad-CAM_image.jpg', superimposed_img)

In [None]:
from IPython.display import Image
display(Image(os.path.join(folder,filename)))
display(Image('./Grad-CAM_image.jpg'))

## Operations Example

In [None]:
a = torch.tensor([[[[-2.,10],[4,40],[6,60],[8,80],[10,100]],
                   [[-3,30],[6,60],[9,90],[12,120],[15,150]],
                   [[-4,40],[8,80],[12,120],[16,160],[20,200]],
                   [[-5,50],[10,100],[15,150],[20,200],[25,250]],
                   [[-6,60],[12,120],[18,180],[24,240],[30,300]],
                   [[-7,70],[14,140],[21,210],[28,280],[35,350]]]])
for i in range(list(a.shape)[0]):
    print(a[:, 0, :, :])
print(a.shape)
b = torch.mean(a, dim=1).squeeze()
c = torch.mean(a, dim=1).squeeze()
print(b)
print(c)
b = np.maximum(b, 0)
print(b)
b /= torch.max(b)
print(b)
plt.matshow(b.squeeze())

d = np.uint8(255 * b)
print(d)