In [1]:
import torch
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
import numpy as np
import cv2

In [2]:
# 1. Load model
model = models.resnet50(pretrained=True)
model.eval()



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [3]:
# 2. Load image
image_path = r'C:\Users\Rishabh\Documents\3d-hcct\image.png'
img = Image.open(image_path).convert('RGB')
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])
input_tensor = transform(img).unsqueeze(0)

In [4]:
input_tensor.shape

torch.Size([1, 3, 224, 224])

In [5]:
# Hook to get gradients and feature maps
gradients = []
activations = []

def backward_hook(module, grad_input, grad_output):
    print('grad_input:- ',len(grad_input), grad_input[0].shape)
    print('grad_output:- ',len(grad_output), grad_output[0].shape)
    gradients.append(grad_output[0])

def forward_hook(module, input, output):
    activations.append(output)

In [15]:
gradients[0].shape, activations[0].shape, len(gradients)

(torch.Size([1, 2048, 7, 7]), torch.Size([1, 2048, 7, 7]), 1)

In [7]:
target_layer = model.layer4[2].conv3
target_layer.register_forward_hook(forward_hook)
target_layer.register_backward_hook(backward_hook)

<torch.utils.hooks.RemovableHandle at 0x13474045650>

In [8]:
target_layer

Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)

In [9]:
# 3. Forward pass
output = model(input_tensor)
pred_class = output.argmax(dim=1)

  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


In [10]:
# 4. Backward pass for target class
model.zero_grad()
output[0, pred_class].backward()

grad_input:-  3 torch.Size([1, 512, 7, 7])
grad_output:-  1 torch.Size([1, 2048, 7, 7])


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


In [11]:
# 5. Get Grad-CAM weights
grads = gradients[0].mean(dim=[2, 3], keepdim=True)
print(grads.shape)
acts = activations[0]
print(acts.shape)
cam = (grads * acts).sum(dim=1).squeeze()
print(cam.shape)
cam = F.relu(cam)

torch.Size([1, 2048, 1, 1])
torch.Size([1, 2048, 7, 7])
torch.Size([7, 7])


In [12]:
# Normalize and resize heatmap
cam = cam.detach().numpy()
cam = cv2.resize(cam, (224, 224))
cam = (cam - cam.min()) / (cam.max() - cam.min())

In [13]:
# Overlay heatmap on image
heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
final_img = heatmap + np.float32(img.resize((224, 224))) / 255
final_img = final_img / np.max(final_img)

In [14]:
final_img.shape

(224, 224, 3)