In [None]:
import torch
import torchvision.transforms as transforms
from torchvision import models
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import random
import os
import cv2

# Grad-CAM implementation
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self.hook_handles = []

        self.hook_handles.append(target_layer.register_forward_hook(self.save_activations))
        self.hook_handles.append(target_layer.register_backward_hook(self.save_gradients))

    def save_activations(self, module, input, output):
        self.activations = output

    def save_gradients(self, module, grad_input, grad_output):
        self.gradients = grad_output[0]

    def generate_heatmap(self, class_idx):
        weights = torch.mean(self.gradients, dim=(2, 3))
        heatmap = torch.sum(weights[:, :, None, None] * self.activations, dim=1).squeeze()
        heatmap = np.maximum(heatmap.detach().cpu().numpy(), 0)
        heatmap /= heatmap.max()
        return heatmap

    def remove_hooks(self):
        for handle in self.hook_handles:
            handle.remove()

model = models.resnet50(pretrained=False)
num_classes = 2 
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
state_dict = torch.load("bee_wasp.pth", map_location=torch.device('cpu'))
model.eval()

preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Random 5 images
folder = "kaggle_bee_vs_wasp/bee1"
images = random.sample(os.listdir(folder), 5)

target_layer = model.layer4[2].conv3 
grad_cam = GradCAM(model, target_layer)

for img_name in images:
    img_path = os.path.join(folder, img_name)
    img = Image.open(img_path).convert("RGB")
    input_tensor = preprocess(img).unsqueeze(0)

    output = model(input_tensor)
    class_idx = output.argmax(dim=1).item()
    class_name = "bee" if class_idx == 0 else "wasp"

    model.zero_grad()
    target = output[0, class_idx]
    target.backward()

    heatmap = grad_cam.generate_heatmap(class_idx)
    heatmap_resized = cv2.resize(heatmap, (224, 224))

    heatmap_normalized = heatmap_resized - heatmap_resized.min()
    heatmap_normalized /= heatmap_normalized.max()

    heatmap_colored = cv2.applyColorMap((heatmap_normalized * 255).astype(np.uint8), cv2.COLORMAP_JET)
    heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)

    img_np = np.array(img.resize((224, 224)))
    overlay = (0.6 * img_np / 255.0 + 0.4 * heatmap_colored / 255.0)
    overlay = (overlay * 255).astype(np.uint8)

    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.title(f"Prediction: {class_name}")
    plt.imshow(img_np)
    plt.axis("off")
    
    plt.subplot(1, 2, 2)
    plt.imshow(overlay)
    plt.axis("off")
    plt.show()

grad_cam.remove_hooks()