In [1]:
import torch
from torchvision import models
from torchvision import transforms
from PIL import Image
import requests
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np
import cv2
import json
import urllib.request

from pytorch_grad_cam import GradCAM, ScoreCAM, AblationCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget


In [2]:
if torch.backends.mps.is_available():
    device = torch.device("mps")  
elif torch.cuda.is_available():
    device = torch.device("cuda")  
else:
    device = torch.device("cpu")

print(f"Using device: {device}")


Using device: mps


In [3]:

model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2).to(device)
model.eval()

# This is done because of issue with MPS backend for GradCAM:
model_cpu = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2).to('cpu')
model_cpu.eval()


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [4]:

images_url = [
    "https://github.com/EliSchwartz/imagenet-sample-images/blob/master/n02098286_West_Highland_white_terrier.JPEG?raw=true",
    "https://github.com/EliSchwartz/imagenet-sample-images/blob/master/n02018207_American_coot.JPEG?raw=true",
    "https://github.com/EliSchwartz/imagenet-sample-images/blob/master/n04037443_racer.JPEG?raw=true",
    "https://github.com/EliSchwartz/imagenet-sample-images/blob/master/n02007558_flamingo.JPEG?raw=true",
    "https://github.com/EliSchwartz/imagenet-sample-images/blob/master/n01608432_kite.JPEG?raw=true",
    "https://github.com/EliSchwartz/imagenet-sample-images/blob/master/n01443537_goldfish.JPEG?raw=true",
    "https://github.com/EliSchwartz/imagenet-sample-images/blob/master/n01491361_tiger_shark.JPEG?raw=true",
    "https://github.com/EliSchwartz/imagenet-sample-images/blob/master/n01616318_vulture.JPEG?raw=true",
    "https://github.com/EliSchwartz/imagenet-sample-images/blob/master/n01677366_common_iguana.JPEG?raw=true",
    "https://github.com/EliSchwartz/imagenet-sample-images/blob/master/n07747607_orange.JPEG?raw=true"
]

output_path = "output_gradcam/"

In [5]:
def get_image_from_url(url):
    response = requests.get(url)
    image = Image.open(BytesIO(response.content)).convert("RGB")
    return image

def load_imagenet_classes():
    url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
    response = urllib.request.urlopen(url)
    class_names = [line.decode("utf-8").strip() for line in response.readlines()]
    return class_names

imagenet_classes = load_imagenet_classes()

In [6]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])


def preprocess_image(image):
    return transform(image).unsqueeze(0)



In [7]:


def normalize_image_for_cam(image, target_size=(224, 224)):
    """Normalize image for CAM visualization"""
    # Resize image to match model input size
    image_resized = image.resize(target_size, Image.LANCZOS)
    image_np = np.array(image_resized)
    image_np = image_np.astype(np.float32) / 255.0
    return image_np

def apply_cam(cam, input_tensor, pred_class, original_image):
    """Apply CAM and return visualization"""
    try:
        targets = [ClassifierOutputTarget(pred_class)]
        

        if not input_tensor.requires_grad:
            input_tensor.requires_grad_(True)
        
        grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
        
        # Check if CAM was computed successfully
        if grayscale_cam is None or len(grayscale_cam) == 0:
            raise ValueError("CAM computation failed - no gradients computed")
        
        grayscale_cam = grayscale_cam[0, :]
        normalized_image = normalize_image_for_cam(original_image.copy())
        
        cam_image = show_cam_on_image(normalized_image, grayscale_cam, use_rgb=True)
        return cam_image
        
    except Exception as e:
        print(f"Error in apply_cam: {e}")
        # Return original image if CAM fails
        return np.array(original_image)

# Initialize all CAMs
def get_cam_objects(model, target_layer):
    return {
        'GradCAM': GradCAM(model=model, target_layers=[target_layer]),
        'ScoreCAM': ScoreCAM(model=model, target_layers=[target_layer]),
        'AblationCAM': AblationCAM(model=model, target_layers=[target_layer])
    }



In [8]:
target_layer = model.layer4[-1].conv3
cams = get_cam_objects(model, target_layer)

idx = 5
for i, url in enumerate(images_url[idx:], start=idx):
    print(f"Processing image {i+1}/{len(images_url)}")
    image = get_image_from_url(url)
    
    # Use MPS model for fast inference
    input_tensor = preprocess_image(image).to(device)
    with torch.no_grad():
        output = model(input_tensor)
        pred_class = output.argmax().item()
    
    print(f"Image {i}: Predicted class index: {pred_class} - {imagenet_classes[pred_class]}")
    
    # Use CPU tensor for CAM computation
    input_tensor_cpu = preprocess_image(image).to('cpu')
    input_tensor_cpu.requires_grad_()
    
    for name, cam in cams.items():
        print(f"  Generating {name}...")
        cam_image = apply_cam(cam, input_tensor_cpu, pred_class, image)
        plt.imshow(cam_image)
        plt.title(f'{name} - {i}. Predicted class: {imagenet_classes[pred_class]}')
        plt.axis('off')
        plt.savefig(f'{output_path}{i}_{name}.png', bbox_inches='tight', dpi=150)
        plt.close()

Processing image 6/10
Image 5: Predicted class index: 1 - goldfish
  Generating GradCAM...
  Generating ScoreCAM...


100%|██████████| 128/128 [00:13<00:00,  9.61it/s]


  Generating AblationCAM...


100%|██████████| 64/64 [00:14<00:00,  4.50it/s]


Processing image 7/10
Image 6: Predicted class index: 3 - tiger shark
  Generating GradCAM...
  Generating ScoreCAM...


100%|██████████| 128/128 [00:13<00:00,  9.58it/s]


  Generating AblationCAM...


100%|██████████| 64/64 [00:13<00:00,  4.62it/s]


Processing image 8/10
Image 7: Predicted class index: 23 - vulture
  Generating GradCAM...
  Generating ScoreCAM...


100%|██████████| 128/128 [00:13<00:00,  9.62it/s]


  Generating AblationCAM...


100%|██████████| 64/64 [00:13<00:00,  4.61it/s]


Processing image 9/10
Image 8: Predicted class index: 39 - common iguana
  Generating GradCAM...
  Generating ScoreCAM...


100%|██████████| 128/128 [00:13<00:00,  9.52it/s]


  Generating AblationCAM...


100%|██████████| 64/64 [00:14<00:00,  4.54it/s]


Processing image 10/10
Image 9: Predicted class index: 950 - orange
  Generating GradCAM...
  Generating ScoreCAM...


100%|██████████| 128/128 [00:13<00:00,  9.57it/s]


  Generating AblationCAM...


100%|██████████| 64/64 [00:14<00:00,  4.37it/s]
