In [None]:
import os
import sys
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
from sklearn.metrics import precision_recall_fscore_support
from tqdm import tqdm

root_dir = "../../../"
sys.path.append(root_dir)

from dataset import BBBC041Dataset
from utils import HRT

os.environ["CUDA_VISIBLE_DEVICES"] = "8"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

experiment_dir = os.path.join(root_dir, "experiments", "BBBC041")
data_dir = os.path.join(experiment_dir, "data")
trophozoite_dir = os.path.join(data_dir, "trophozoite")
explanation_dir = os.path.join(experiment_dir, "explanations")

In [None]:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

model = torch.hub.load("pytorch/vision:v0.10.0", "resnet18", pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
model = model.to(device)
model.load_state_dict(
    torch.load(
        os.path.join(experiment_dir, "pretrained_model", "model.pt"),
        map_location=device,
    )
)
model.eval()
torch.set_grad_enabled(False)

In [None]:
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ]
)
unnorm = transforms.Normalize(-mean / std, 1 / std)
ref = torch.load(os.path.join(trophozoite_dir, "reference.pt"), map_location=device)
dataset = BBBC041Dataset(os.path.join(trophozoite_dir, "val"), ref, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, num_workers=4, shuffle=False)
image_names = [os.path.basename(x[0]).split(".")[0] for x in dataset.samples]

In [None]:
correct = 0
for i, data in enumerate(tqdm(dataloader)):
    input, label = data

    input = input.to(device)
    label = label.to(device)

    output = model(input)
    prediction = output.argmax(dim=1)

    correct += torch.sum(prediction == label)
t = 1 - correct / len(dataset)
print(f"Test statistic (1 - accuracy): {100*t:.2f}%")

importance = [int(not HRT(model, dataset, j, 4, t)) for j in range(4)]
print(f"Importance of features: {importance}")

In [None]:
quadrant_df = pd.read_pickle(os.path.join(data_dir, "quadrant_labels"))
flipped_images = np.load(os.path.join(data_dir, "flipped_images.npy"))

dataloader = DataLoader(dataset, batch_size=1, num_workers=4, shuffle=False)

PP = 0
TP = 0
quadrant_importance = []
quadrant_target = []
for i, data in enumerate(tqdm(dataloader)):
    input, label = data

    input = input.to(device)
    label = label.to(device)

    image_name = image_names[i]
    if len(flipped_images) > 0 and image_name in flipped_images:
        continue

    output = model(input)
    prediction = output.argmax(dim=1)

    PP += prediction
    TP += label

    quadrant_labels = quadrant_df.at[image_name, "quadrant_labels"]
    quadrant_target.extend(quadrant_labels)
    quadrant_importance.extend(importance if prediction == 1 else 4 * [0])

precision, recall, f1, _ = precision_recall_fscore_support(
    quadrant_target,
    quadrant_importance,
    pos_label=1,
    average="binary",
)
print(
    f"TP: {TP}",
    f"PP: {PP}",
    f"precision: {precision:.2f}",
    f"recall: {recall:.2f}",
    f"f1: {f1:.2f}",
)