In [2]:
import torch
from torch.autograd import Function
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
import cv2
import numpy as np



# Load and preprocess an image
def load_image(image_path):
    img = Image.open(image_path).convert('RGB')
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    img_tensor = preprocess(img)
    img_tensor = img_tensor.unsqueeze(0)
    return img_tensor

# Function to perform Grad-CAM
class GradCam:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None

        # Register a hook to capture gradients during backward pass
        self.hook = self.register_hooks()

    def register_hooks(self):
        def hook_fn(module, grad_in, grad_out):
            self.gradients = grad_out[0]

        target_layer = self.model._modules.get(self.target_layer)
        hook = target_layer.register_forward_hook(hook_fn)
        return hook

    def remove_hooks(self):
        self.hook.remove()

    def forward(self, x):
        return self.model(x)

    def backward(self, output):
        self.model.zero_grad()
        output.backward()

    def generate_heatmap(self, input_tensor, class_idx):
        self.model.zero_grad()

        # Perform forward and backward pass
        output = self.forward(input_tensor)
        target = output[0][class_idx]
        self.backward(target)

        # Calculate the importance weights (gradients)
        weights = torch.mean(self.gradients, dim=(2, 3), keepdim=True)

        # Get the activations from the target layer
        target_layer_output = self.hook.output[0]

        # Weighted sum of activations to get the Grad-CAM heatmap
        grad_cam = torch.sum(weights * target_layer_output, dim=1, keepdim=True)
        grad_cam = F.relu(grad_cam)

        # Resize the heatmap to match the input image size
        grad_cam = F.interpolate(grad_cam, size=(input_tensor.shape[2], input_tensor.shape[3]), mode='bilinear', align_corners=False)

    




In [3]:
# Load a pre-trained ResNet model
model = models.resnet18(pretrained=False)
model.eval()

# get the image from the dataloader
img, _ = load_image()

# get the most likely prediction of the model
pred = model(img).argmax(dim=1)

NameError: name 'dataloader' is not defined