In [None]:
import torch
import torchvision
import numpy as np
import cv2
from matplotlib import pyplot as plt
from PIL import Image

from torchvision.transforms.transforms import ToPILImage
from torchvision import transforms

In [None]:
# Load the model weights
model = torchvision.models.resnet50(pretrained=True)

# Set the model to 'evaluation' mode, that means freeze the weights
model.eval()

print(model)

In [None]:
# Set the target layer we want to inspect
target_layer = model.layer4[-1]

In [None]:
# Load an image
pic = cv2.imread("images/dog/cat.jpg", 1)

# BGR to RGB
img = pic.copy()
img = img[:,:,::-1]
img = np.ascontiguousarray(img)

# Convert to torch tensor
t = transforms.Compose([transforms.ToTensor()])
img = t(img)

# Add batch dimension
img = img.unsqueeze(0)

In [None]:
def forward_hook(module, input, output):
    activation.append(output)
    
def backward_hook(module, grad_in, grad_out):
    grad.append(grad_out[0])
    
# Add hooks to get the tensors
target_layer.register_forward_hook(forward_hook)
target_layer.register_backward_hook(backward_hook)

grad = []
activation = []

# forward pass to get the activation
out = model(img)
print("TOP 5", torch.topk(out, 5))

# class for dog
loss = out[0, 178]

# class for cat
#loss = out[0, 285]
print("LOSS", loss.item())

# clear the gradients
model.zero_grad()

# backward pass to get the gradients
loss.backward()

# get the gradients and activations collected in the hook
grads = grad[0].cpu().data.numpy().squeeze()
fmap = activation[0].cpu().data.numpy().squeeze()

print(fmap.shape)
print(grads.shape)

In [None]:
print("grads.shape", grads.shape)
tmp = grads.reshape([grads.shape[0], -1])
                     
# Get the mean value of the gradients of every featuremap
weights = np.mean(tmp, axis=1)
print("weights.shape", weights.shape)

In [None]:
cam = np.zeros(grads.shape[1:])
for i,w  in enumerate(weights):
    cam += w*fmap[i, :]

In [None]:
cam = np.zeros(grads.shape[1:])

for i,w in enumerate(weights):
    cam += w*fmap[i, :]
    cam = (cam>0)*cam
    cam = cam / cam.max() * 255

In [None]:
print("PIC SHAPE", pic.shape)
npic = np.array(torchvision.transforms.ToPILImage()(pic).convert('RGB'))
print("NPIC SHAPE", npic.shape)

cam = cv2.resize(cam, (npic.shape[1], npic.shape[0]))
print("CAM SHAPE", cam.shape)

heatmap = cv2.applyColorMap(np.uint8(cam), cv2.COLORMAP_JET)
cam_img = npic*0.3 + heatmap*0.7
print(cam_img.shape)

display(torchvision.transforms.ToPILImage()(np.uint8(cam_img[:, :, ::-1])))

In [None]:
plt.imshow(heatmap)

In [None]:
categories = torchvision.models.ResNet50_Weights.DEFAULT.meta["categories"]
for i,c in enumerate(categories):
    print(i, c)