In [None]:
imdzadza
import torch.nn.functdzadza
import numpy as np
import cv2  # for resizing heatmap if needed

class GradCAM:
    def __init__(self, model, target_layer_name):
        """
        model: your PyTorch model (e.g. your resnet-based model)
        target_layer_name: string name of the layer for grad-cam 
                           (e.g. 'layer4' in ResNet).
        """
        self.model = model
        self.target_layer = None
        self.gradients = None
        self.activations = None

        # Register forward & backward hooks
        for name, module in self.model.named_modules():
            if name == target_layer_name:
                self.target_layer = module
                self.target_layer.register_forward_hook(self.save_activation)
                self.target_layer.register_backward_hook(self.save_gradient)
                break
    
    def save_activation(self, module, input, output):
        self.activations = output.detach()

    def save_gradient(self, module, grad_input, grad_output):
        # grad_output is a tuple containing gradients of the layer output wrt. the loss
        self.gradients = grad_output[0].detach()

    def generate(self, input_tensor, class_idx=None):
        """
        Generates Grad-CAM for the given input_tensor (1 x C x H x W).
        
        If class_idx is None, we take the predicted class with max logit.
        Returns: a heatmap (2D) upsampled to input size.
        """
        self.model.eval()
        # 1) Forward pass
        output = self.model(input_tensor)

        # If class_idx isn't specified, use max logit
        if class_idx is None:
            class_idx = torch.argmax(output, dim=1).item()
        
        # 2) Zero grads, do backward on the chosen class
        self.model.zero_grad()
        # Extract the score for the relevant class
        score = output[0, class_idx]
        score.backward()

        # 3) Compute weight from gradients: mean over each channel
        # self.gradients shape => [batch=1, channels, h, w]
        weights = torch.mean(self.gradients, dim=(0, 2, 3))  # shape [channels]

        # 4) Weighted combination of forward activations
        # self.activations => [1, channels, h, w]
        activation = self.activations[0]  # remove batch dim => [channels, h, w]
        gradcam = torch.zeros(activation.shape[1:], dtype=activation.dtype)

        for i, w in enumerate(weights):
            gradcam += w * activation[i, :, :]
        
        # 5) ReLU
        gradcam = F.relu(gradcam)

        # 6) Normalize => heatmap
        gradcam = gradcam - gradcam.min()
        gradcam = gradcam / (gradcam.max() + 1e-8)

        # Convert to numpy
        gradcam = gradcam.cpu().numpy()

        # 7) Upsample to input_tensor size if needed
        # input_tensor might be e.g. 1 x 1 x 310 x 310
        # gradcam might be e.g. 10 x 10. So let's upsample with cv2:
        _, _, H, W = input_tensor.shape
        gradcam = cv2.resize(gradcam, (W, H))

        return gradcam


ModuleNotFoundError: No module named 'cv2'

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

# 1) Instantiate GradCAM with your final model, specifying the target layer
model.eval()
gradcam = GradCAM(model, target_layer_name='features.7')  
# e.g. if your last conv block is 'features.7' in your custom model

# 2) Suppose we have some MIP image as 'mip_tensor'
#    shape => (1, 1, 310, 310), normalized, etc.

# 3) Generate the heatmap
heatmap = gradcam.generate(mip_tensor)

# 4) Visualize
plt.figure(figsize=(12, 4))

# Original MIP: we might want it in np form
original_mip = mip_tensor.squeeze().cpu().numpy()  # shape (310, 310)

ax1 = plt.subplot(1, 3, 1)
ax1.imshow(original_mip, cmap='gray', vmin=1, vmax=10)
ax1.set_title("Original MIP")

# Heatmap only
ax2 = plt.subplot(1, 3, 2)
ax2.imshow(heatmap, cmap='jet') 
ax2.set_title("Grad-CAM Heatmap")

# Overlay: alpha-blend heatmap onto original
alpha = 0.5
overlay = original_mip.copy()
# But let's do a quick color-coded overlay using OpenCV or manual approach

# For a quick approach, we can do a manual blend:
jetmap = plt.get_cmap('jet')(heatmap)[:, :, :3]  # RGBA -> RGB
overlay = alpha * jetmap + (1 - alpha) * np.dstack([original_mip]*3 / original_mip.max())

ax3 = plt.subplot(1, 3, 3)
ax3.imshow(overlay)
ax3.set_title("Overlay")


In [None]:
mask = (heatmap > 0.7)
rows = np.any(mask, axis=1)
cols = np.any(mask, axis=0)
if np.any(rows) and np.any(cols):
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]
    # expand a bit
    margin = 5
    rmin = max(rmin - margin, 0)
    rmax = min(rmax + margin, heatmap.shape[0])
    cmin = max(cmin - margin, 0)
    cmax = min(cmax + margin, heatmap.shape[1])
    
    # Crop the overlay
    zoomed_overlay = overlay[rmin:rmax, cmin:cmax, :]
    
    plt.figure(figsize=(5,5))
    plt.imshow(zoomed_overlay)
    plt.title("Zoomed region of high Grad-CAM importance")
    plt.show()


In [None]:
fig, axes = plt.subplots(1, 4, figsize=(16,4))

axes[0].imshow(original_mip, cmap='gray', vmin=1, vmax=10)
axes[0].set_title("Original MIP")

axes[1].imshow(heatmap, cmap='jet')
axes[1].set_title("Heatmap")

axes[2].imshow(overlay)
axes[2].set_title("Overlay")

axes[3].imshow(zoomed_overlay)
axes[3].set_title("Zoomed Region")

for ax in axes:
    ax.axis('off')

plt.suptitle("Grad-CAM Visualization", fontsize=16)
plt.tight_layout()
plt.show()
