In [1]:
import os
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import io
from torchvision.models import resnet18, ResNet18_Weights, resnet50, resnet101, ResNet101_Weights, ResNet50_Weights
from tqdm import tqdm  # For progress bars
import time
import os
import random
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
import cv2  # For resizing and processing the Grad-CAM heatmap

In [2]:


# User-specified range of images to process
# start_index = 0  # Change as per user input
# end_index = 3   # Change as per user input 25229
# image_range = range(start_index, end_index)

# Paths
model_path = "models/ResNet50/model.pth"
dataset_path = "Streetview_Image_Dataset/processed/"
output_dir = "grad-cam"
csv_path = "models/ResNet50/test_predictions.csv"  #Test predictions csv


# Load pre-trained ResNet model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnet50(pretrained=False)
num_features = model.fc.in_features

# Temporarily set the fc layer to match the checkpoint dimensions
model.fc = torch.nn.Linear(num_features, 13)  # Match the checkpoint's number of output classes
model.load_state_dict(torch.load(model_path, map_location=device))

# Replace the fc layer for the current task (4 classes)
model.fc = torch.nn.Linear(num_features, 4)
model.to(device)
model.eval()

# Data transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Select images based on the specified range

df = pd.read_csv(csv_path)
selected_images = df["image_name"].tolist()

#Code used for debugging to make just some files. 
#image_files = [f for f in os.listdir(dataset_path) if f.endswith('.png')]
#sorted_image_files = sorted(image_files, key=lambda x: int(os.path.splitext(x)[0]))  # Sort by numeric order
#selected_images = [f"{i}.png" for i in image_range if f"{i}.png" in sorted_image_files]

# Grad-CAM setup
target_layers = [model.layer4[-1]]  # Last convolutional layer
os.makedirs(output_dir, exist_ok=True)

# Process each image and generate Grad-CAM visualizations
for idx, image_file in enumerate(selected_images):
    image_path = os.path.join(dataset_path, image_file)
    image = Image.open(image_path).convert("RGB")
    input_tensor = transform(image).unsqueeze(0).to(device)  # Prepare input tensor

    # Specify target class (optional; defaults to predicted class if None)
    targets = None  # Alternatively, specify class index with ClassifierOutputTarget(class_index)

    # Construct Grad-CAM object and generate heatmap
    with GradCAM(model=model, target_layers=target_layers) as cam:
        grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
        grayscale_cam = grayscale_cam[0, :]  # Extract heatmap for the first image in the batch

        # Resize Grad-CAM heatmap to match the original image dimensions
        grayscale_cam_resized = cv2.resize(grayscale_cam, (image.width, image.height))

        # Threshold the heatmap to isolate the ROI
        threshold_value = 0.5  # Adjust threshold as needed
        _, binary_heatmap = cv2.threshold(grayscale_cam_resized, threshold_value, 1, cv2.THRESH_BINARY)

        # Convert binary heatmap to uint8 for contour detection
        binary_heatmap_uint8 = (binary_heatmap * 255).astype(np.uint8)

        # Find contours in the thresholded heatmap
        contours, _ = cv2.findContours(binary_heatmap_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # Draw contours and extract bounding boxes for each detected ROI
        rois = []
        for contour in contours:
            x, y, w, h = cv2.boundingRect(contour)  # Bounding box around the ROI
            rois.append((x, y, w, h))  # Store ROI coordinates

        # Overlay heatmap on the original image
        image_np = np.array(image) / 255.0  # Normalize image to range [0, 1]
        visualization = show_cam_on_image(image_np, grayscale_cam_resized, use_rgb=True)

        # Save the Grad-CAM visualization
        output_path = os.path.join(output_dir, f"gradcam_{os.path.splitext(image_file)[0]}.jpeg")
        Image.fromarray(visualization).save(output_path)
        print(f"Saved Grad-CAM for {image_file} to {output_path}")

        # Save each ROI as a cropped image with non-red parts masked as black
        roi_dir = os.path.join(output_dir, "rois")
        os.makedirs(roi_dir, exist_ok=True)
        for roi_idx, (x, y, w, h) in enumerate(rois):
            # Crop the original image to the ROI
            roi_image = image.crop((x, y, x + w, y + h))
            roi_np = np.array(roi_image)  # Convert to numpy array

            # Crop the heatmap to match the ROI dimensions
            heatmap_roi = grayscale_cam_resized[y:y + h, x:x + w]
            heatmap_binary = (heatmap_roi >= threshold_value).astype(np.uint8)

            # Create a mask with the binary heatmap
            mask = np.repeat(heatmap_binary[:, :, np.newaxis], 3, axis=2)  # Repeat for 3 channels (RGB)

            # Apply the mask to the ROI (black out non-red regions)
            masked_roi = roi_np * mask

            # Convert the masked ROI back to an image
            masked_roi_image = Image.fromarray(masked_roi.astype(np.uint8))

            # Save the masked ROI image
            roi_output_path = os.path.join(roi_dir, f"roi_{os.path.splitext(image_file)[0]}_{roi_idx}.jpeg")
            masked_roi_image.save(roi_output_path)
            print(f"Saved masked ROI for {image_file} to {roi_output_path}")

print(f"Grad-CAM visualizations and ROIs saved in {output_dir}.")

  model.load_state_dict(torch.load(model_path, map_location=device))


Saved Grad-CAM for 14015.png to grad-cam\gradcam_14015.jpeg
Saved masked ROI for 14015.png to grad-cam\rois\roi_14015_0.jpeg
Saved Grad-CAM for 25144.png to grad-cam\gradcam_25144.jpeg
Saved masked ROI for 25144.png to grad-cam\rois\roi_25144_0.jpeg
Saved Grad-CAM for 5520.png to grad-cam\gradcam_5520.jpeg
Saved masked ROI for 5520.png to grad-cam\rois\roi_5520_0.jpeg
Saved masked ROI for 5520.png to grad-cam\rois\roi_5520_1.jpeg
Saved masked ROI for 5520.png to grad-cam\rois\roi_5520_2.jpeg
Saved masked ROI for 5520.png to grad-cam\rois\roi_5520_3.jpeg
Saved masked ROI for 5520.png to grad-cam\rois\roi_5520_4.jpeg
Saved Grad-CAM for 23265.png to grad-cam\gradcam_23265.jpeg
Saved masked ROI for 23265.png to grad-cam\rois\roi_23265_0.jpeg
Saved masked ROI for 23265.png to grad-cam\rois\roi_23265_1.jpeg
Saved masked ROI for 23265.png to grad-cam\rois\roi_23265_2.jpeg
Saved Grad-CAM for 15098.png to grad-cam\gradcam_15098.jpeg
Saved masked ROI for 15098.png to grad-cam\rois\roi_15098_0.jp

In [3]:
import os

# Specify the directory path
directory_path = "grad-cam/rois"

# Count the number of files in the directory
file_count = len([f for f in os.listdir(directory_path) if os.path.isfile(os.path.join(directory_path, f))])

print(f"There are {file_count} files in the directory '{directory_path}'.")


There are 13384 files in the directory 'grad-cam/rois'.
