In [None]:
import os
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import io
from torchvision.models import resnet18, ResNet18_Weights, resnet50, resnet101, ResNet101_Weights, ResNet50_Weights
from tqdm import tqdm  # For progress bars
import time
import os
import random
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision.models import resnet18
import cv2  # For resizing and processing the Grad-CAM heatmap

In [None]:
# Paths
model_path = "/Users/paolocadei/Desktop/GeolocationGuesserAI/model.pth"
dataset_path = "/Users/paolocadei/Desktop/GeolocationGuesserAI/Streetview_Image_Dataset"
output_dir = "/Users/paolocadei/Desktop/GeolocationGuesserAI/grad-cam"

# Load pre-trained ResNet model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnet18(pretrained=False)
num_features = model.fc.in_features

# Temporarily set the fc layer to match the checkpoint dimensions
model.fc = torch.nn.Linear(num_features, 16)  # Match the checkpoint's number of output classes
model.load_state_dict(torch.load(model_path, map_location=device))

# Replace the fc layer for the current task (4 classes)
model.fc = torch.nn.Linear(num_features, 4)
model.to(device)
model.eval()

# Data transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Randomly select 100 images from the dataset
image_files = [f for f in os.listdir(dataset_path) if os.path.isfile(os.path.join(dataset_path, f))]
random_images = random.sample(image_files, 100)  # Change to 100 images

# Grad-CAM setup
target_layers = [model.layer4[-1]]  # Last convolutional layer
os.makedirs(output_dir, exist_ok=True)

# Process each image and generate Grad-CAM visualizations
for idx, image_file in enumerate(random_images):
    image_path = os.path.join(dataset_path, image_file)
    image = Image.open(image_path).convert("RGB")
    input_tensor = transform(image).unsqueeze(0).to(device)  # Prepare input tensor

    # Specify target class (optional; defaults to predicted class if None)
    targets = None  # Alternatively, specify class index with ClassifierOutputTarget(class_index)

    # Construct Grad-CAM object and generate heatmap
    with GradCAM(model=model, target_layers=target_layers) as cam:
        grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
        grayscale_cam = grayscale_cam[0, :]  # Extract heatmap for the first image in the batch

        # Resize Grad-CAM heatmap to match the original image dimensions
        grayscale_cam_resized = cv2.resize(grayscale_cam, (image.width, image.height))

        # Threshold the heatmap to isolate the ROI
        threshold_value = 0.5  # Adjust threshold as needed
        _, binary_heatmap = cv2.threshold(grayscale_cam_resized, threshold_value, 1, cv2.THRESH_BINARY)

        # Convert binary heatmap to uint8 for contour detection
        binary_heatmap_uint8 = (binary_heatmap * 255).astype(np.uint8)

        # Find contours in the thresholded heatmap
        contours, _ = cv2.findContours(binary_heatmap_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # Draw contours and extract bounding boxes for each detected ROI
        rois = []
        for contour in contours:
            x, y, w, h = cv2.boundingRect(contour)  # Bounding box around the ROI
            rois.append((x, y, w, h))  # Store ROI coordinates

            # Optionally draw the bounding box on the original image (for visualization)
            cv2.rectangle(grayscale_cam_resized, (x, y), (x + w, y + h), (255, 0, 0), 2)

        # Overlay heatmap on the original image
        image_np = np.array(image) / 255.0  # Normalize image to range [0, 1]
        visualization = show_cam_on_image(image_np, grayscale_cam_resized, use_rgb=True)

        # Save the Grad-CAM visualization
        output_path = os.path.join(output_dir, f"gradcam_{idx}.jpg")
        Image.fromarray(visualization).save(output_path)
        print(f"Saved Grad-CAM for {image_file} to {output_path}")

        # Save each ROI as a cropped image
        roi_dir = os.path.join(output_dir, "rois")
        os.makedirs(roi_dir, exist_ok=True)
        for roi_idx, (x, y, w, h) in enumerate(rois):
            roi_image = image.crop((x, y, x + w, y + h))  # Crop the original image to the ROI
            roi_output_path = os.path.join(roi_dir, f"roi_{idx}_{roi_idx}.jpg")
            roi_image.save(roi_output_path)
            print(f"Saved ROI for {image_file} to {roi_output_path}")

print(f"Grad-CAM visualizations and ROIs saved in {output_dir}.")