In [2]:
# Add script to include project root in sys.path

import sys
from pathlib import Path

# Add project root (one level up from notebooks/)
project_root = Path(__file__).resolve().parent.parent if "__file__" in globals() else Path.cwd().parent
sys.path.append(str(project_root))

In [3]:
# Import necessary libraries

from src.gradcam import GradCAM
from torchvision import models, transforms
from PIL import Image
import torch
import matplotlib.pyplot as plt
from pathlib import Path

In [None]:
# Load pretrained model

model = models.resnet50(weights="IMAGENET1K_V1")
gradcam = GradCAM(model, target_layer_name="layer4")

In [5]:
# Load and preprocess a few sample images

image_dir = Path("data/rsna_subset/test_images")
sample_paths = list(image_dir.glob("*.png"))[:3]
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

for img_path in sample_paths:
    img = Image.open(img_path).convert("RGB")
    tensor = transform(img).unsqueeze(0)

    # 3. Generate heatmap
    heatmap = gradcam.generate(tensor)

    # 4. Overlay heatmap and save
    output_path = Path("reports/week2_gradcam_samples") / f"{img_path.stem}_overlay.png"
    gradcam.save_overlay(img, heatmap, str(output_path))

    # 5. Display results
    plt.imshow(Image.open(output_path))
    plt.title(img_path.name)
    plt.axis("off")
    plt.show()