In [2]:
import os
import torch
from torchvision.datasets import ImageFolder
from torchvision import transforms, models
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import torch.nn.functional as F

DATA_DIR = '/root/Aerial_Landscapes'
MODEL_PATH = './efficientnet_b0.pth'
SAVE_DIR = './error_gradcam_outputs'
BATCH_SIZE = 32
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_CLASSES = 15
TOP_N_ERRORS = 20

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

dataset = ImageFolder(DATA_DIR, transform=transform)
class_names = dataset.classes
val_size = int(0.2 * len(dataset))
_, val_dataset = torch.utils.data.random_split(dataset, [len(dataset) - val_size, val_size])
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

model = models.efficientnet_b0(weights=None)
model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, NUM_CLASSES)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE)
model.eval()

gradients = []

def save_gradient(module, grad_in, grad_out):
    gradients.append(grad_out[0])

final_conv = model.features[-1]
final_conv.register_full_backward_hook(save_gradient)

def generate_gradcam(image_tensor, class_idx):
    model.zero_grad()
    output = model(image_tensor.unsqueeze(0).to(DEVICE))
    loss = output[0, class_idx]
    loss.backward()

    grad = gradients[-1].cpu().data.numpy()[0]
    fmap = activations[-1].cpu().data.numpy()[0]

    weights = np.mean(grad, axis=(1, 2))
    cam = np.zeros(fmap.shape[1:], dtype=np.float32)

    for i, w in enumerate(weights):
        cam += w * fmap[i]

    cam = np.maximum(cam, 0)
    cam = cv2.resize(cam, (224, 224))
    cam = cam - np.min(cam)
    cam = cam / (np.max(cam) + 1e-8)
    return cam

def superimpose_cam(img_tensor, cam):
    img = img_tensor.permute(1, 2, 0).cpu().numpy()
    img = (img - img.min()) / (img.max() - img.min())
    heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
    heatmap = heatmap / 255.0
    superimposed = heatmap * 0.4 + img
    return superimposed

os.makedirs(SAVE_DIR, exist_ok=True)

error_images = []
predictions = []
true_labels = []

activations = []
def forward_hook(module, input, output):
    activations.append(output)

final_conv.register_forward_hook(forward_hook)

with torch.no_grad():
    for images, labels in val_loader:
        outputs = model(images.to(DEVICE))
        _, preds = outputs.max(1)
        for i in range(len(labels)):
            if preds[i].item() != labels[i].item():
                error_images.append(images[i])
                predictions.append(preds[i].item())
                true_labels.append(labels[i].item())

for idx in range(min(TOP_N_ERRORS, len(error_images))):
    img = error_images[idx]
    pred_class = predictions[idx]
    true_class = true_labels[idx]

    gradients.clear()
    activations.clear()
    _ = model(img.unsqueeze(0).to(DEVICE))
    cam = generate_gradcam(img.to(DEVICE), pred_class)
    vis = superimpose_cam(img, cam)
    vis = np.clip(vis, 0, 1)
    out_path = os.path.join(SAVE_DIR, f"error_{idx}_pred-{class_names[pred_class]}_true-{class_names[true_class]}.png")
    plt.imsave(out_path, vis)

print(f"Done. Error samples with Grad-CAM saved to {SAVE_DIR}")

  model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))


Done. Error samples with Grad-CAM saved to ./error_gradcam_outputs
