In [None]:
import torch
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt

class GradCAM:
    def __init__(self, model):
        self.model = model
        self.feature_extractor = model.vit
        self.model.eval()
        
        # Get the attention blocks
        self.target_layers = [self.feature_extractor.encoder.layer[-1].attention]
        self.gradients = []
        self.activations = []
        
        def save_gradient(grad):
            self.gradients.append(grad)
            
        def save_activation(module, input, output):
            self.activations.append(output)
            
        # Register hooks
        for layer in self.target_layers:
            layer.register_forward_hook(save_activation)
            layer.register_backward_hook(save_gradient)

    def generate_cam(self, input_image, target_class=None):
        """
        Generate Grad-CAM visualization for the input image
        
        Args:
            input_image: preprocessed image tensor
            target_class: class index for which to generate CAM (if None, uses predicted class)
        
        Returns:
            cam: numpy array of the same size as input image
            pred_class: predicted class index
        """
        image_tensor = input_image.unsqueeze(0)
        
        # Forward pass
        output = self.model(image_tensor).logits
        pred_class = output.argmax(dim=1).item() if target_class is None else target_class
        
        # Clear gradients
        self.model.zero_grad()
        
        # Backward pass
        class_loss = output[0, pred_class]
        class_loss.backward()
        
        # Get gradients and activations
        gradients = self.gradients[0]
        activations = self.activations[0]
        
        # Global average pooling of gradients
        weights = torch.mean(gradients, dim=(2, 3))
        
        # Weight the activations by the gradients
        cam = torch.zeros(activations.shape[1:], dtype=torch.float32)
        for i, w in enumerate(weights[0]):
            cam += w * activations[0, i, :, :]
            
        # ReLU and normalization
        cam = torch.maximum(cam, torch.tensor(0))
        cam = cam - cam.min()
        cam = cam / cam.max()
        
        # Resize to original image size
        cam = cv2.resize(cam.detach().cpu().numpy(), 
                        (input_image.shape[2], input_image.shape[3]))
        
        return cam, pred_class

class RetinalExplainer:
    def __init__(self, classifier_model):
        self.grad_cam = GradCAM(classifier_model)
        
    def explain(self, image_path, save_path=None):
        """
        Generate and optionally save explanation visualization
        
        Args:
            image_path: path to the input image
            save_path: path to save the visualization (optional)
        
        Returns:
            superimposed_img: numpy array of the heatmap superimposed on original image
        """
        # Load and preprocess image
        image = Image.open(image_path).convert('RGB')
        preprocess = self.grad_cam.model.transform
        input_tensor = preprocess(image).unsqueeze(0)
        
        # Generate CAM
        cam, pred_class = self.grad_cam.generate_cam(input_tensor[0])
        
        # Convert CAM to heatmap
        heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
        
        # Load original image and resize to match heatmap
        original_image = cv2.imread(image_path)
        original_image = cv2.resize(original_image, (cam.shape[1], cam.shape[0]))
        
        # Superimpose heatmap on original image
        superimposed_img = cv2.addWeighted(original_image, 0.7, heatmap, 0.3, 0)
        
        if save_path:
            cv2.imwrite(save_path, superimposed_img)
            
        return superimposed_img, pred_class

    def plot_explanation(self, image_path):
        """
        Plot the original image, heatmap, and superimposed visualization
        """
        # Generate explanation
        superimposed_img, pred_class = self.explain(image_path)
        
        # Load original image
        original_image = cv2.imread(image_path)
        original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
        
        # Create figure with subplots
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        
        # Plot original image
        ax1.imshow(original_image)
        ax1.set_title('Original Image')
        ax1.axis('off')
        
        # Plot superimposed visualization
        ax2.imshow(cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB))
        ax2.set_title(f'Grad-CAM Visualization\nPredicted Class: {pred_class}')
        ax2.axis('off')
        
        plt.tight_layout()
        return fig

# Usage example
if __name__ == "__main__":
    # Assuming we have a trained classifier
    classifier = RetinalClassifier()
    explainer = RetinalExplainer(classifier.model)
    
    # Generate and display explanation for a sample image
    fig = explainer.plot_explanation('path/to/sample/image.png')
    plt.show()