Dataset

In [1]:
import os
from PIL import Image
import json
from torchvision.models import resnet50
import warnings
warnings.filterwarnings('ignore')
from codecarbon import track_emissions
from torchvision import transforms
from pytorch_grad_cam import (
    GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus,
    AblationCAM, XGradCAM, EigenCAM, EigenGradCAM,
    LayerCAM, FullGrad, GradCAMElementWise
)
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from typing import List, Callable, Optional
import numpy as np
import cv2
import torch
from tqdm import tqdm

# Function to load images
def load_images_from_directory(root_path: str):
    dataset = []
    for label in os.listdir(root_path):
        label_path = os.path.join(root_path, label)
        if os.path.isdir(label_path):
            for image_file in os.listdir(label_path):
                image_path = os.path.join(label_path, image_file)
                if image_path.lower().endswith(('.png', '.jpg', '.jpeg')):
                    img = Image.open(image_path)
                    dataset.append((img, label, image_file))
    return dataset

current_dir = "/home/workstation/code/XAImethods/hf_cam_dev"
dataset_path = f"{current_dir}/ImageNet-Mini/images"
dataset = load_images_from_directory(dataset_path)

# Load ImageNet class index
with open(f"{current_dir}/ImageNet-Mini/imagenet_class_index.json", "r") as f:
    imagenet_class_index = json.load(f)

# Determine device
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA!")
else:
    device = torch.device("cpu")
    print("Using CPU!")

# ResNet Model Wrapper
class ResNetWrapper(torch.nn.Module):
    def __init__(self, model):
        super(ResNetWrapper, self).__init__()
        self.model = model

    def forward(self, x):
        return self.model(x)
    
label_to_index_description = {v[0]: (k, v[1]) for k, v in imagenet_class_index.items()}


Using CUDA!


functions

In [2]:
# Initialize the model and target layer
model = resnet50(pretrained=True).to(device)
model_wrapper = ResNetWrapper(model).to(device)
target_layer_gradcam = model.layer4[-1].conv3

# Helper functions
def run_grad_cam_on_image(model: torch.nn.Module,
                          target_layer: torch.nn.Module,
                          targets_for_gradcam: List[Callable],
                          input_tensor: torch.nn.Module,
                          input_image: Image,
                          reshape_transform: Optional[Callable] = None,
                          method: Callable = GradCAM):
    with method(model=model,
                target_layers=[target_layer],
                reshape_transform=reshape_transform) as cam:
        repeated_tensor = input_tensor[None, :].repeat(len(targets_for_gradcam), 1, 1, 1)
        batch_results = cam(input_tensor=repeated_tensor,
                            targets=targets_for_gradcam)
        results = []
        grayscale_cams = []
        for grayscale_cam in batch_results:
            visualization = show_cam_on_image(np.float32(input_image) / 255,
                                              grayscale_cam,
                                              use_rgb=True)
            visualization = cv2.resize(visualization,
                                       (visualization.shape[1] // 2, visualization.shape[0] // 2))
            results.append(visualization)
            grayscale_cams.append(grayscale_cam)
        return np.hstack(results), grayscale_cams

def print_top_categories(model, img_tensor, top_k=5):
    logits = model(img_tensor.unsqueeze(0))
    indices = logits.cpu()[0, :].detach().numpy().argsort()[-top_k :][::-1]
    for i in indices:
        print(f"Predicted class {i}: {imagenet_class_index[str(i)][1]}")

def get_top_k_targets(model, input_tensor, k=5):
    logits = model(input_tensor.unsqueeze(0))
    top_k_indices = logits[0].argsort(descending=True)[:k].cpu().numpy()
    return [ClassifierOutputTarget(index) for index in top_k_indices]

# Prepare for the main loop
BATCH_SIZE = 100
num_batches = len(dataset) // BATCH_SIZE + (1 if len(dataset) % BATCH_SIZE != 0 else 0)
save_dir = f"{current_dir}/results/resnet50/{GradCAM.__name__}"

if not os.path.exists(save_dir):
    os.makedirs(save_dir)


In [3]:
def ensure_rgb(img):
    if img.mode != 'RGB':
        return img.convert('RGB')
    return img

for batch_num in tqdm(range(num_batches)):
    start_idx = batch_num * BATCH_SIZE
    end_idx = min((batch_num + 1) * BATCH_SIZE, len(dataset))

    # Initialize ResNet50
    model = resnet50(pretrained=True).to(device)
    target_layer_gradcam = model.layer4[-1].conv3  # Last convolutional layer of ResNet50
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

    for idx in range(start_idx, end_idx):
        img, label, filename = dataset[idx]
        try:
            torch.cuda.empty_cache()
            img = ensure_rgb(img)
            img_tensor = transform(img).to(device)

            # Map label to ImageNet index
            index_description = label_to_index_description.get(label)
            if index_description is None:
                print(f"Warning: Label '{label}' not found in the JSON file!")
                continue

            index_str, description = index_description
            index = int(index_str)
            dynamic_targets_for_gradcam = [ClassifierOutputTarget(index)]

            gradcam_result, grayscale_cams = run_grad_cam_on_image(
                model=model,
                target_layer=target_layer_gradcam,
                targets_for_gradcam=dynamic_targets_for_gradcam,
                input_tensor=img_tensor,
                input_image=img,
                reshape_transform=None  # No reshape required for ResNet50
            )

            logits = model(img_tensor.unsqueeze(0))
            top_indices = logits[0].argsort(descending=True)[:5].cpu().numpy()
            predictions = {index: {"score": logits[0][index].item(), "label": imagenet_class_index[str(index)][1]} for index in top_indices}

            img_dir = os.path.join(save_dir, filename.rsplit('.', 1)[0])
            if not os.path.exists(img_dir):
                os.makedirs(img_dir)

            true_label_file = os.path.join(img_dir, 'true_label.txt')
            with open(true_label_file, 'w') as f:
                f.write(str(label))

            img_name = os.path.join(img_dir, "original.jpg")
            gradcam_name = os.path.join(img_dir, "gradcam.jpg")
            grayscale_name = os.path.join(img_dir, "grayscale.jpg")
            grayscale_npy_name = os.path.join(img_dir, "grayscale.npy")
            scores_name = os.path.join(img_dir, "scores.npy")
            info_name = os.path.join(img_dir, "info.txt")

            img.save(img_name)
            Image.fromarray(gradcam_result).save(gradcam_name)
            Image.fromarray((grayscale_cams[0] * 255).astype(np.uint8)).save(grayscale_name)
            np.save(grayscale_npy_name, grayscale_cams[0])

            scores = [data["score"] for _, data in predictions.items()]
            np.save(scores_name, scores)

            with open(info_name, 'w') as f:
                for index, data in predictions.items():
                    label = data["label"]
                    score = data["score"]
                    f.write(f"Class {index} ({label}): {score:.2f}\n")
        except Exception as e:
            print(f"Error processing {filename}: {str(e)}")

print("Grad-CAM processing completed.")


100%|██████████| 39/39 [02:54<00:00,  4.48s/it]

Grad-CAM processing completed.



