In [1]:
# NUMPY IMPORTS
import numpy as np
import numpy.ma as ma

# PYTORCH IMPORTS
import torch
import torch.nn as nn
from torchvision import models, datasets, transforms

# SYS IMPORTS
import sys
import re
import os
import time
from PIL import Image
import hshap
from hshap.utils import Net
import matplotlib.pyplot as plt

os.environ["CUDA_VISIBLE_DEVICES"] = "9"

# DEFINE DEVICE
_device = "cuda:0"
device = torch.device(_device)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.cuda.empty_cache()
print("Current device is {}".format(device))

# LOAD PRE-TRAINED INCEPTION-V3 MODEL
torch.manual_seed(0)
model = Net()
weight_path = "model2.pth"
model.load_state_dict(torch.load(weight_path, map_location=device)) 
model.to(device)
model.eval()
print("Loaded pretrained model")

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        # transforms.Normalize([0.7206, 0.7204, 0.7651], [0.2305, 0.2384, 0.1706])
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)

data_dir = "/export/gaon1/data/jteneggi/data/synthetic/datasets/"
train_dataset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform)
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True, num_workers=0)
_iter = iter(dataloader)
X, _ = next(_iter)
ref = X.detach().mean(0).to(device)
# ref = torch.ones(3, 100, 120).to(device)
ref_output = model(ref.unsqueeze(0))
ref_logits = torch.nn.Softmax(dim=1)(ref_output)
print(ref_logits)
print("Loaded reference")

exp_mapper = ["hexp/absolute_0", "hexp/relative_50", "hexp/relative_60", "hexp/relative_70", "hexp/relative_80", "hexp/relative_90", "gradexp", "deepexp", "partexp", "gradcam", "gradcampp", "naive"]

A = 100*120
exp_x = np.linspace(np.log10(1/A), 0, 5)
relative_perturbation_sizes = np.sort(10 ** (exp_x))
perturbation_sizes = np.round(A * relative_perturbation_sizes)
perturbation_sizes = np.array(perturbation_sizes, dtype="int")
# print(perturbation_sizes)
perturbations_L = len(perturbation_sizes)

true_positives = np.load("true_positives.npz", allow_pickle=True)
images = true_positives.item()["1"]
L = len(images)

for j, exp_name in enumerate(exp_mapper):
    exp_logits = torch.zeros((L, perturbations_L)).to(device)
    explanation_dir = os.path.join("true_positive_explanations", exp_name)
    for i, image_path in enumerate(images[:15]):
        fig = plt.figure(figsize=(15, 4))
        fig.suptitle(exp_name)
        image_name = os.path.basename(image_path)
        image = transform(Image.open(image_path)).to(device).detach()
        if exp_name == "naive":
            explanation = torch.rand(image.size(1), image.size(2), device=torch.device("cpu")) + .5
        else:
            explanation = np.load(os.path.join(explanation_dir, "%s.npy" % image_name))
        activation_threshold = 0
        salient_points = np.where(explanation > activation_threshold)
        salient_rows = salient_points[0]
        salient_columns = salient_points[1]
        scores = explanation[salient_points]
        ranks = np.argsort(scores)
        L = len(scores)
        # print(L)

        masked_perturbations = ma.masked_greater(perturbation_sizes, L)
        valid_perturbations = masked_perturbations.compressed()
        m = len(valid_perturbations)

        axes = fig.subplots(1, m)
    
        _input = image.unsqueeze(0)    
        perturbed_batch = _input.repeat(m, 1, 1, 1)
        for k, perturbation_size in enumerate(valid_perturbations):
            # perturbation_L = round(perturbation_size * L)
            # print(perturbation_size)
            perturbed_ids = ranks[-perturbation_size:]
            perturbed_rows = salient_rows[perturbed_ids]
            perturbed_columns = salient_columns[perturbed_ids]
            perturbed_batch[k, :, perturbed_rows, perturbed_columns] = ref[:, perturbed_rows, perturbed_columns]        
            axes[k].imshow(perturbed_batch[k].permute(1, 2, 0).cpu())
        
        with torch.no_grad():
            outputs = model(perturbed_batch)
            logits = torch.nn.Softmax(dim=1)(outputs)[:, 1]
            logits = torch.log10(logits).cpu().numpy()
            torch.cuda.empty_cache()
            # print(logits)
    
        for l in np.arange(m):
            axes[l].set_title(r"$\log p_1 = %.4f$" % logits[l])
        # plt.plot(valid_perturbations, logits, '-.')
        plt.savefig(os.path.join("LOR" , "example_figures", "%s_%d.eps" % (exp_name.replace("/", "_"), i)))
        plt.close()

Current device is cuda:0
Loaded pretrained model


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


tensor([[9.9933e-01, 6.6733e-04]], device='cuda:0', grad_fn=<SoftmaxBackward>)
Loaded reference


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i