In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import pickle
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader
from sklearn.metrics import precision_recall_fscore_support
from tqdm import tqdm

os.environ["CUDA_VISIBLE_DEVICES"] = "8"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

root_dir = "../../../"
experiment_dir = os.path.join(root_dir, "experiments", "BBBC041")
data_dir = os.path.join(experiment_dir, "data")
trophozoite_dir = os.path.join(data_dir, "trophozoite")
explanation_dir = os.path.join(experiment_dir, "explanations")
explainer_dir = os.path.join(explanation_dir, "hexp", "800", "absolute_0")

In [None]:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

model = torch.hub.load("pytorch/vision:v0.10.0", "resnet18", pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
model = model.to(device)
model.load_state_dict(
    torch.load(
        os.path.join(experiment_dir, "pretrained_model", "model.pt"),
        map_location=device,
    )
)
model.eval()
torch.set_grad_enabled(False)

In [None]:
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ]
)
unnorm = transforms.Normalize(-mean / std, 1 / std)
dataset = ImageFolder(os.path.join(trophozoite_dir, "val"), transform)
dataloader = DataLoader(dataset, batch_size=1, num_workers=4, shuffle=False)
image_names = [os.path.basename(x[0]).split(".")[0] for x in dataset.samples]

In [None]:
quadrant_df = pd.read_pickle(os.path.join(data_dir, "quadrant_labels"))
flipped_images = np.load(os.path.join(data_dir, "flipped_images.npy"))

logit_threshold = 0.5
w = torch.tensor([1 / 4, 1 / 12, 1 / 12, 1 / 12, 1 / 12, 1 / 12, 1 / 12, 1 / 4])

PP = 0
TP = 0
perfect_retrieval = 0
quadrant_target = []
quadrant_importance = []
quadrant_importance_cauchy = []
df = []
for i, data in enumerate(tqdm(dataloader)):
    input, label = data

    image_name = image_names[i]
    quadrant_labels = quadrant_df.at[image_name, "quadrant_labels"]

    if len(flipped_images) > 0 and image_name in flipped_images:
        continue

    input = input.to(device)
    output = model(input)
    prediction = output.argmax(dim=1)

    PP += prediction
    TP += label

    if prediction == 1:
        with open(
            os.path.join(explainer_dir, f"{image_name}_{logit_threshold}.pkl"), "rb"
        ) as f:
            explanation = pickle.load(f)

        _, _, shaplit_map, p = explanation
        P = torch.ones((4, 8))
        P_cauchy = torch.ones((4, 8))
        for j, jp in p:
            for k, (C, pp) in enumerate(jp):
                P[j, k] = pp
                eps = 1e-04
                P_cauchy[j, k] = np.tan((1 / 2 - np.clip(pp, eps, 1 - eps)) * np.pi)

        P = torch.matmul(P, 2 * w)
        threshold = torch.quantile(P, 0.7)
        if len(P.unique()) == 1:
            importance = 4 * [1]
        else:
            importance = (P < threshold_cauchy).long().tolist()

        P_cauchy = 1 / 2 - torch.arctan(torch.matmul(P_cauchy, w)) / torch.pi
        threshold_cauchy = torch.quantile(P_cauchy, 0.7)
        if len(P_cauchy.unique()) == 1:
            importance_cauchy = 4 * [1]
        else:
            importance_cauchy = (P_cauchy < 0.05).long().tolist()
        print(
            quadrant_labels,
            P,
            P_cauchy,
            threshold,
            threshold_cauchy,
            importance,
            importance_cauchy,
        )

    else:
        importance = 4 * [0]
        importance_cauchy = 4 * [0]
    df.append(
        {
            "image_name": image_name,
            "importance": importance,
            "importance_cauchy": importance_cauchy,
        }
    )

    quadrant_target.extend(quadrant_labels)
    quadrant_importance.extend(importance)
    quadrant_importance_cauchy.extend(importance_cauchy)

df = pd.DataFrame(df)
df.to_pickle(os.path.join(explainer_dir, "quadrant_importance.pkl"))

precision, recall, f1, _ = precision_recall_fscore_support(
    quadrant_target,
    quadrant_importance,
    pos_label=1,
    average="binary",
)
precision_cauchy, recall_cauchy, f1_cauchy, _ = precision_recall_fscore_support(
    quadrant_target,
    quadrant_importance_cauchy,
    pos_label=1,
    average="binary",
)
print(
    f"TP: {TP}",
    f"PP: {PP}",
    f"precision: {precision:.2f}",
    f"recall: {recall:.2f}",
    f"f1: {f1:.2f}",
    f"precision_cauchy: {precision_cauchy:.2f}",
    f"recall_cauchy: {recall_cauchy:.2f}",
    f"f1_cauchy: {f1_cauchy:.2f}",
)