In [1]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, models, transforms
import numpy as np
import matplotlib.pyplot as plt
import time
import os
from PIL import Image
import hshap

In [2]:
# select device
# uncomment the following line to specify which GPU you want to use
# os.environ["CUDA_VISIBLE_DEVICES"] = **YOUR_GPU_NUMBER**
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# load pretrained model
torch.manual_seed(0)
model = models.resnet18(pretrained=False)
# change last layer for binary classification
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
# load saved weights
model.load_state_dict(torch.load("model.pth", map_location=device))
# move model to device and set inference mode
model = model.to(device)
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [3]:
# compose transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# uncomment this code to compute the baseline
# data_dir = **YOUR_DATA_DIRECTORY_ABSOLUTE_PATH**
# train_dataset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform)
# dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True, num_workers=0)
# _iter = iter(dataloader)
# X, _ = next(_iter)
# ref = X.detach().mean(0)
# or simply load the packaged reference
ref = torch.load("reference.pth")
ref = ref.to(device)

# initialize h-Shap explainer
hexp = hshap.src.Explainer(model, ref, min_size=80)

In [4]:
# define thresholding modes
thresholds = ["absolute_0", "relative_70"]

# for each example image
for (dirpath, _, filenames) in os.walk("images"):
    for filename in filenames:
        # load and transform image
        image_path = os.path.join(dirpath, filename)
        image = Image.open(image_path)
        image_t = transform(image)
        # for each threshold mode
        for threshold in thresholds:
            # set thresholding params
            threshold_params = threshold.split("_")
            threshold_mode = threshold_params[0]
            threshold = float(threshold_params[1])
            # explain image
            print(f"Explaining image {filename} with {threshold_mode} thresholding and t={threshold}")
            t0 = time.time()
            explanation, leafs = hexp.explain(image_t, label=1, threshold_mode=threshold_mode, threshold=threshold)
            t = time.time()
            runtime = np.around(t - t0, decimals=6)
            print(f"Done in {runtime}s")

Explaining image 38d1b930-dd97-4248-9160-e9389a1a8dd7.png with absolute thresholding and t=0.0


KeyboardInterrupt: 