In [None]:
import torch
import torch.nn.functional as F
from torchvision import models, transforms
import numpy as np
import cv2
from PIL import Image
from fastai.vision.all import *

In [None]:
class GradCAM:
    def __init__(self, learn, target_layer):
        """
        Args:
            learn: fast.ai Learner object
            target_layer: The layer to generate Grad-CAM for (e.g., learn.model[0][-1])
        """
        self.model = learn.model
        self.learn = learn
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        
        # Register hooks
        self.target_layer.register_forward_hook(self.save_activation)
        self.target_layer.register_full_backward_hook(self.save_gradient)
        
    def save_activation(self, module, input, output):
        self.activations = output
        
    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0]
    
    def __call__(self, img, class_idx=None):
        """
        Generate Grad-CAM using fast.ai's preprocessing pipeline
        
        Args:
            img: PIL Image or file path
            class_idx: Class index for which to generate Grad-CAM
                      If None, uses the model's highest scoring class
        
        Returns:
            cam: Grad-CAM heatmap
            output: Model's output predictions
        """
        # Use fast.ai's preprocessing pipeline
        if isinstance(img, (str, Path)):
            img = PILImage.create(img)
        
        # Get the test transform pipeline from the learner
        tfms = self.learn.dls.valid.after_item
        tfms_batch = self.learn.dls.valid.after_batch
        
        # Apply transforms
        x = tfms(img)
        # Convert to TensorImage if it isn't already
        if not isinstance(x, TensorImage):
            x = TensorImage(x)
        # Apply batch transforms
        x = tfms_batch(x.unsqueeze(0))[0]
        
        # Add batch dimension and move to device
        x = x.unsqueeze(0).to(self.learn.dls.device)
        
        # Forward pass
        self.model.eval()
        with torch.enable_grad():
            output = self.model(x)
        
        if class_idx is None:
            class_idx = output.argmax(dim=1)

        print(f"using class index: {class_idx.item()}")
        
        # Zero gradients
        self.model.zero_grad()
        
        # Backward pass for the specified class
        one_hot = torch.zeros_like(output)
        one_hot[0][class_idx] = 1
        output.backward(gradient=one_hot, retain_graph=True)
        
        # Generate Grad-CAM
        gradients = self.gradients.detach()
        activations = self.activations.detach()
        
        # Global average pooling of gradients
        weights = gradients.mean(dim=(2, 3), keepdim=True)
        
        # Weighted combination of activation maps
        cam = torch.sum(weights * activations, dim=1, keepdim=True)
        cam = F.relu(cam)  # Apply ReLU
        
        # Normalize
        cam = F.interpolate(cam, size=x.shape[2:], mode='bilinear', align_corners=True)
        cam = cam - cam.min()
        cam = cam / (cam.max() + 1e-7)
        
        return cam.squeeze().cpu().numpy(), output

def visualize_gradcam(learn, image_path, target_layer, class_idx=None):
    """
    Generate and visualize Grad-CAM for an image using a fast.ai learner
    
    Args:
        learn: fast.ai Learner object
        image_path: Path to input image
        target_layer: Target layer for Grad-CAM
        class_idx: Target class index (optional)
    
    Returns:
        original_image: Original input image
        heatmap: Grad-CAM heatmap
        result: Heatmap overlaid on original image
    """
    # Load original image
    original_image = Image.open(image_path).convert('RGB')
    original_array = np.array(original_image)
    
    # Generate Grad-CAM
    grad_cam = GradCAM(learn, target_layer)
    cam, output = grad_cam(image_path, class_idx)
    
    # Convert heatmap to RGB
    heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    
    # Resize heatmap to match original image size
    heatmap = cv2.resize(heatmap, (original_array.shape[1], original_array.shape[0]))
    
    # Combine heatmap with original image
    result = heatmap * 0.3 + original_array * 0.7
    result = result.astype(np.uint8)
    
    return original_array, heatmap, result

In [None]:
def show_gradcam(learn, img_path, class_idx=None):
    """
    Show Grad-CAM results for a fast.ai learner
    
    Args:
        learn: fast.ai Learner object
        img_path: Path to image
        class_idx: Optional class index (defaults to predicted class)
    """
    # For ResNet-like architectures, use the last layer of the last block
    target_layer = learn.model[0][-1][-1]  # Assumes the model is sequential with backbone at index 0
    
    original, heatmap, result = visualize_gradcam(
        learn,
        img_path,
        target_layer,
        class_idx
    )
    
    # Use fastai's show_images to display results
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    for ax, img, title in zip(axes, [original, heatmap, result], 
                            ['Original', 'Heatmap', 'Overlay']):
        ax.imshow(img)
        ax.set_title(title)
        ax.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
learn = load_learner(f"resnet34_split_0.pkl")

In [None]:
show_gradcam(learn, 'progav.png')

In [None]:
show_gradcam(learn, 'hakim.png')

In [None]:
show_gradcam(learn, 'certas.png')

In [None]:
show_gradcam(learn, 'sophysa.png')