In [9]:
import torch
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import math
import torch.nn as nn
from transformers import ViTForImageClassification
from torchvision import transforms

class GradCAM:
    def __init__(self, model):
        self.model = model.model
        self.target_layer = self.model.vit.encoder.layer[-1].output  # Last encoder block
        self.model.eval()
        
        self.gradients = []
        self.activations = []
        
        def save_gradient(grad):
            """Hook to capture gradients during backpropagation"""
            self.gradients.append(grad.detach())

        def save_activation(module, input, output):
            """Hook to capture activations during forward pass"""
            self.activations.append(output.detach())
            output.register_hook(save_gradient)  # Register hook on the output tensor
        
        # Register hooks
        self.target_layer.register_forward_hook(save_activation)

    def generate_cam(self, input_image, target_class=None):
        self.gradients = []
        self.activations = []
        
        image_tensor = input_image.unsqueeze(0).to(self.model.device)

        self.model.train()  # ✅ Enable train mode for correct backprop behavior

        # Forward pass
        output = self.model(image_tensor).logits
        pred_class = output.argmax(dim=1).item() if target_class is None else target_class

        # Clear gradients
        self.model.zero_grad()

        # Backward pass
        class_loss = output[0, pred_class]
        class_loss.backward()

        self.model.eval()  # ✅ Restore eval mode

        # Ensure gradients are non-zero
        gradients = self.gradients[0] if self.gradients else None
        if gradients is None or torch.all(gradients == 0):
            print("🚨 Gradients are still zero! Check target layer selection or backprop.")
            return None, pred_class
        else:
            print(f"✅ Non-zero gradients found! {gradients.shape}")

        # Get activations and gradients
        gradients = self.gradients[0]  # [1, 197, 768]
        activations = self.activations[0]  # [1, 197, 768]

        # Remove CLS token
        gradients = gradients[:, 1:, :]  # [1, 196, 768]
        activations = activations[:, 1:, :]  # [1, 196, 768]

        # Pool the gradients
        pooled_gradients = torch.mean(gradients, dim=[0, 1])  # [768]

        # Weight activations by importance scores
        for i in range(768):
            activations[:, :, i] *= pooled_gradients[i]

        # Average over the feature dimension
        cam = torch.mean(activations, dim=2).squeeze(0)  # [196]

        # Reshape to square (14x14)
        cam = cam.view(14, 14)

        # Normalize CAM
        cam = torch.maximum(cam, torch.tensor(0, device=self.model.device))
        if cam.max() != 0:
            cam = (cam - cam.min()) / (cam.max() - cam.min())

        # Resize to match input image
        cam = cam.cpu().numpy()
        cam = cv2.resize(cam, (input_image.shape[2], input_image.shape[1]))

        print(f"✅ Grad-CAM Successfully Generated for Class: {pred_class}")

        return cam, pred_class


class RetinalClassifier:
    def __init__(self, num_classes=5):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Load pre-trained model while ignoring classifier size mismatch
        self.model = ViTForImageClassification.from_pretrained(
            'google/vit-base-patch16-224',
            num_labels=num_classes,
            ignore_mismatched_sizes=True  # Prevents classifier size errors
        ).to(self.device)

        # Replace classifier layer
        self.model.classifier = torch.nn.Linear(self.model.config.hidden_size, num_classes).to(self.device)

        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def predict(self, image_path):
        self.model.eval()
        image = Image.open(image_path).convert('RGB')
        image = self.transform(image).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            outputs = self.model(image).logits
            _, predicted = torch.max(outputs, 1)
            
        return predicted.item()


class RetinalExplainer:
    def __init__(self, classifier_model):
        self.grad_cam = GradCAM(classifier_model)
        self.model = classifier_model
        
    def explain(self, image_path, save_path=None):
        """Generate and optionally save explanation visualization"""
        image = Image.open(image_path).convert('RGB')
        preprocess = self.model.transform
        input_tensor = preprocess(image).to(self.model.model.device)
        
        # Generate CAM
        cam, pred_class = self.grad_cam.generate_cam(input_tensor)
        if cam is None:
            return None, pred_class
        
        # Convert CAM to heatmap
        heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
        
        # Load original image and resize to match heatmap
        original_image = cv2.imread(image_path)
        original_image = cv2.resize(original_image, (cam.shape[1], cam.shape[0]))
        
        # Superimpose heatmap on original image
        superimposed_img = cv2.addWeighted(original_image, 0.7, heatmap, 0.3, 0)
        
        if save_path:
            cv2.imwrite(save_path, superimposed_img)
            
        return superimposed_img, pred_class

    def plot_explanation(self, image_path):
        """Plot original image, heatmap, and superimposed visualization"""
        superimposed_img, pred_class = self.explain(image_path)
        if superimposed_img is None:
            print("❌ Grad-CAM generation failed due to zero gradients.")
            return None

        original_image = cv2.imread(image_path)
        original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
        
        # Create figure
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        
        # Plot original image
        ax1.imshow(original_image)
        ax1.set_title('Original Image')
        ax1.axis('off')
        
        # Plot Grad-CAM visualization
        ax2.imshow(cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB))
        ax2.set_title(f'Grad-CAM Visualization\nPredicted Class: {pred_class}')
        ax2.axis('off')
        
        plt.tight_layout()
        return fig


# ✅ Usage example
if __name__ == "__main__":
    model = RetinalClassifier()

    # Load model checkpoint
    checkpoint = torch.load('best_model.pth', map_location=torch.device('cpu'))

    # Remove classifier from checkpoint
    filtered_checkpoint = {k: v for k, v in checkpoint['model_state_dict'].items() if not k.startswith("classifier")}
    model.model.load_state_dict(filtered_checkpoint, strict=False)

    print("Checkpoint loaded successfully, classifier initialized randomly.")

    explainer = RetinalExplainer(model)
    
    # Generate explanation for an image
    fig = explainer.plot_explanation('/Users/tanishroy/Desktop/School/Third year/Retinal-Disease-Detection-main/APTOS 2019 Blindness Detection Segmented/train_images/0a4e1a29ffff.png')
    plt.show()

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([5]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([5, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


NameError: name 'math' is not defined