In [None]:
import os
import argparse
import torch
import torch.nn as nn
import pickle
import time
import hshap
import matplotlib.pyplot as plt
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader

os.environ["CUDA_VISIBLE_DEVICES"] = "8"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

root_dir = "../../../"
experiment_dir = os.path.join(root_dir, "experiments", "BBBC041")
data_dir = os.path.join(experiment_dir, "data")
trophozoite_dir = os.path.join(data_dir, "trophozoite")
explanation_dir = os.path.join(experiment_dir, "explanations")

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.set_grad_enabled(False)

In [None]:
model = torch.hub.load("pytorch/vision:v0.10.0", "resnet18", pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
model.load_state_dict(
    torch.load(
        os.path.join(experiment_dir, "pretrained_model", "model.pt"),
        map_location=device,
    )
)
model = model.to(device)
model.eval()
x = torch.randn(1, 3, 1200, 1600, device=device)
model(x)
torch.cuda.empty_cache()

In [None]:
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
unnorm = transforms.Normalize(-mean / std, 1 / std)
dataset = ImageFolder(os.path.join(trophozoite_dir, "val"), transform)
image_names = [os.path.basename(x[0]).split(".")[0] for x in dataset.samples]
dataloader = DataLoader(dataset, batch_size=1, num_workers=4, shuffle=False)

In [None]:
ref = torch.load(os.path.join(explanation_dir, "reference.pt"), map_location=device)
hexp = hshap.src.Explainer(
    model=model,
    background=ref,
)
print("Initialized hshap")

In [None]:
for i, data in enumerate(dataloader):
    if i == 116:
        print("check")
        input, _ = data
        plt.imshow(torch.permute(unnorm(input[0]), (1, 2, 0)))
        plt.show()

        input = input.to(device)

        explanation = hexp.explain(
            input,
            label=1,
            s=800,
            threshold_mode="absolute",
            threshold=0,
            softmax_activation=True,
            logit_threshold=0.50,
            batch_size=2,
            binary_map=True,
            return_shaplit=True,
        )
        break