In [None]:
import torch
import torch.nn.functional as F
import cv2
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms, models
from PIL import Image

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        
        # Register hooks
        self.target_layer.register_forward_hook(self.save_activation)
        self.target_layer.register_backward_hook(self.save_gradient)
    
    def save_activation(self, module, input, output):
        self.activations = output
    
    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0]
    
    def generate_cam(self, input_image, class_idx=None):
        # Forward pass
        model_output = self.model(input_image)
        
        if class_idx is None:
            class_idx = torch.argmax(model_output, dim=1)
        
        # Zero gradients
        self.model.zero_grad()
        
        # Backward pass
        class_score = model_output[:, class_idx]
        class_score.backward()
        
        # Get gradients and activations
        gradients = self.gradients[0]  # Remove batch dimension
        activations = self.activations[0]  # Remove batch dimension
        
        # Calculate weights (global average pooling of gradients)
        weights = torch.mean(gradients, dim=(1, 2))
        
        # Generate CAM
        cam = torch.zeros(activations.shape[1:], dtype=torch.float32)
        for i, w in enumerate(weights):
            cam += w * activations[i, :, :]
        
        # Apply ReLU to focus on positive contributions
        cam = F.relu(cam)
        
        # Normalize CAM
        cam = cam - torch.min(cam)
        cam = cam / torch.max(cam)
        
        return cam.detach().cpu().numpy()

def preprocess_image(image_path, size=(224, 224)):
    """Preprocess image for model input"""
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    image = Image.open(image_path).convert('RGB')
    original_image = np.array(image)
    input_tensor = transform(image).unsqueeze(0)
    
    return input_tensor, original_image

def overlay_heatmap(image, heatmap, alpha=0.6):
    """Overlay heatmap on original image"""
    # Resize heatmap to match image size
    heatmap_resized = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
    
    # Convert heatmap to color (using jet colormap)
    heatmap_colored = cv2.applyColorMap(np.uint8(255 * heatmap_resized), cv2.COLORMAP_JET)
    heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
    
    # Overlay
    superimposed = heatmap_colored * alpha + image * (1 - alpha)
    return superimposed.astype(np.uint8)

# Example usage with pretrained ResNet
def demo_gradcam():
    # Load pretrained model
    model = models.resnet50(pretrained=True)
    model.eval()
    
    # Choose target layer (usually the last convolutional layer)
    target_layer = model.layer4[-1].conv3  # For ResNet50
    
    # Initialize GradCAM
    gradcam = GradCAM(model, target_layer)
    
    # Load and preprocess image
    # Replace 'your_image.jpg' with actual image path
    image_path = 'your_image.jpg'  
    try:
        input_tensor, original_image = preprocess_image(image_path)
        
        # Generate GradCAM
        cam = gradcam.generate_cam(input_tensor)
        
        # Create overlay
        overlay = overlay_heatmap(original_image, cam)
        
        # Visualize results
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        axes[0].imshow(original_image)
        axes[0].set_title('Original Image')
        axes[0].axis('off')
        
        axes[1].imshow(cam, cmap='jet')
        axes[1].set_title('GradCAM Heatmap')
        axes[1].axis('off')
        
        axes[2].imshow(overlay)
        axes[2].set_title('Overlay')
        axes[2].axis('off')
        
        plt.tight_layout()
        plt.show()
        
        # Print prediction
        with torch.no_grad():
            output = model(input_tensor)
            pred_class = torch.argmax(output, dim=1).item()
            confidence = torch.softmax(output, dim=1)[0, pred_class].item()
            print(f"Predicted class: {pred_class}, Confidence: {confidence:.3f}")
    
    except FileNotFoundError:
        print("Image file not found. Please provide a valid image path.")
        print("This is a demo code - replace 'your_image.jpg' with actual image path")

# Advanced GradCAM with multiple layers
class MultiLayerGradCAM:
    def __init__(self, model, target_layers):
        self.model = model
        self.target_layers = target_layers
        self.gradcams = []
        
        for layer in target_layers:
            self.gradcams.append(GradCAM(model, layer))
    
    def generate_multi_cam(self, input_image, class_idx=None):
        cams = []
        for gradcam in self.gradcams:
            cam = gradcam.generate_cam(input_image, class_idx)
            cams.append(cam)
        return cams

# Common issues and solutions:

def troubleshoot_gradcam():
    """
    Common GradCAM issues and solutions:
    
    1. Gradient is None:
       - Make sure model is in eval() mode but still allows gradients
       - Check if target layer is correct
       - Ensure backward() is called on the correct tensor
    
    2. CAM shows random patterns:
       - Wrong target layer selected
       - Model not properly pretrained
       - Input preprocessing mismatch
    
    3. CAM too bright/dark:
       - Adjust normalization
       - Check alpha value in overlay
       - Verify image preprocessing
    
    4. CAM doesn't highlight expected regions:
       - Model might be using different features than expected
       - Try different target layers
       - Check if model is actually confident in prediction
    """
    pass

# For custom models, modify target layer selection:
def get_target_layer_examples():
    """
    Examples of target layer selection for different architectures:
    
    ResNet: model.layer4[-1].conv3
    VGG: model.features[-1]
    DenseNet: model.features[-1]
    EfficientNet: model.features[-1]
    
    For custom models, choose the last convolutional layer before
    global average pooling or fully connected layers.
    """
    pass

if __name__ == "__main__":
    demo_gradcam()