In [1]:
from PIL import Image
import torch
from torchvision.models import resnet50, ResNet50_Weights

import requests

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
from xai_inference_engine import XAIInferenceEngine

In [4]:
# Model
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2).to(device)
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()

# Model config
# Set model to eval mode
model.eval()
last_conv_layer = model.layer4[2].conv3
class_count = 5
class_list = weights.meta["categories"]
img_h = 224

# Image Preprocessing
url = "https://raw.githubusercontent.com/utkuozbulak/pytorch-cnn-visualizations/master/input_images/cat_dog.png"
r = requests.get(url, allow_redirects=True)
open("dog-and-cat-cover.jpg", "wb").write(r.content)
img = Image.open("dog-and-cat-cover.jpg")
img = img.resize((img_h, img_h), resample=Image.BICUBIC)
img_tensor = preprocess(img).to(device)

In [5]:
xai_inferencer = XAIInferenceEngine(
    model=model,
    last_conv_layer=last_conv_layer,
    device=device,
)

In [6]:
preds, sorted_pred_indices, super_imp_img, saliency_maps = xai_inferencer.predict(
    img=img,
    img_tensor=img_tensor,
)

: 

In [None]:
print("[INFO]: Displaying Results...")
print("Predictions: {}".format(preds.shape))
print("Sorted Prediction Indices: {}".format(sorted_pred_indices.cpu().numpy()[:10]))
print("Heatmaps shape: {}".format(saliency_maps))
print("Super Imposed Image: {}".format(super_imp_img))