In [3]:
import torch
import torch.nn.functional as F
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
from torchvision.utils import save_image
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os

DATA_DIR = '/root/Aerial_Landscapes'
MODEL_PATH = './efficientnet_b0.pth'
SAVE_DIR = './gradcam_outputs'
os.makedirs(SAVE_DIR, exist_ok=True)
NUM_CLASSES = 15
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = models.efficientnet_b0(weights='IMAGENET1K_V1')
model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, NUM_CLASSES)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model = model.to(DEVICE)
model.eval()

gradients = None
activations = None

def save_gradients_hook(module, grad_input, grad_output):
    global gradients
    gradients = grad_output[0]

def save_activations_hook(module, input, output):
    global activations
    activations = output

target_layer = model.features[-1]
target_layer.register_forward_hook(save_activations_hook)
target_layer.register_backward_hook(save_gradients_hook)

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

def generate_gradcam(image_tensor, class_idx):
    image_tensor = image_tensor.unsqueeze(0).to(DEVICE)
    output = model(image_tensor)
    score = output[0, class_idx]
    model.zero_grad()
    score.backward()

    pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
    activation = activations[0]
    for i in range(activation.shape[0]):
        activation[i, :, :] *= pooled_gradients[i]

    heatmap = torch.mean(activation, dim=0).cpu().detach().numpy()
    heatmap = np.maximum(heatmap, 0)
    heatmap /= np.max(heatmap)
    return heatmap

dataset = ImageFolder(DATA_DIR, transform=transform)
class_names = dataset.classes
loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)

count = 0
for img, label in loader:
    if count >= 15:
        break
    img = img.to(DEVICE)
    output = model(img)
    pred_class = output.argmax(dim=1).item()
    heatmap = generate_gradcam(img[0], pred_class)

    # Convert to numpy
    img_np = img[0].permute(1, 2, 0).cpu().numpy()
    img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())

    heatmap_resized = cv2.resize(heatmap, (224, 224))
    heatmap_rgb = cv2.applyColorMap(np.uint8(255 * heatmap_resized), cv2.COLORMAP_JET)
    superimposed_img = heatmap_rgb / 255 * 0.4 + img_np

    # Save image
    out_path = os.path.join(SAVE_DIR, f"gradcam_{count}_pred-{class_names[pred_class]}_true-{class_names[label.item()]}.png")
    superimposed_img = np.clip(superimposed_img, 0, 1)
    plt.imsave(out_path, superimposed_img)
    count += 1

print("✅ Grad-CAM visualization complete. Saved to:", SAVE_DIR)

  model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))


✅ Grad-CAM visualization complete. Saved to: ./gradcam_outputs
