In [1]:
import torch
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader
import os
from os.path import isfile, isdir
import sys
import glob
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import random
from typing import Union
import gc
import copy
from matplotlib.patches import Circle, RegularPolygon
from matplotlib.path import Path
from matplotlib.projections.polar import PolarAxes
from matplotlib.projections import register_projection
from matplotlib.spines import Spine
from matplotlib.transforms import Affine2D
import seaborn as sns
import cv2
from PIL import Image
import configparser
import time

import quantus
from captum.attr import *

# import config_vae_inference_local as config
import vae_models, bayesian_models
import vae_utils

In [2]:
class Custom_Dataset(torch.utils.data.dataset.Dataset):
    def __init__(self, image_dir, device, normalize):
        self.dataset = glob.glob(f"{os.path.join(image_dir,'*')}")
        self.device = device
        self.normalize = normalize
        # print(f"Test instances: {len(self.dataset)}")
        # print(self.dataset[0])

    def __getitem__(self, index):
        image = Image.open(self.dataset[index])
        image_float_np = np.float32(image) / 255

        # input_tensor = preprocess_image(image_float_np, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        input_tensor = self.get_transform(self.normalize)(image_float_np)
        plot_tensor = self.get_transform(False)(image_float_np)
        input_tensor = input_tensor.to(self.device)
        plot_tensor = plot_tensor.to(self.device)
        return [input_tensor,plot_tensor], 0

    def __len__(self):
        return len(self.dataset)

    def get_transform(self, normalize):
        transforms = []
        transforms.append(T.ToTensor())
        # transforms.append(T.ConvertImageDtype(torch.float))
        if normalize:
            transforms.append(T.Normalize(mean=[0.4500, 0.4373, 0.4494], std=[0.2248, 0.2249, 0.2280]))
        return T.Compose(transforms)

    # def get_transform2(self):
    #     transforms = []
    #     transforms.append(T.ToTensor())
    #     return T.Compose(transforms)

In [3]:
def explainer_wrapper(**kwargs):
    """Wrapper for explainer functions."""
    if kwargs["method"] == "Saliency":
        return saliency_explainer(**kwargs)
    elif kwargs["method"] == "IntegratedGradients":
        return intgrad_explainer(**kwargs)
    elif kwargs["method"] == "FusionGrad":
        return fusiongrad_explainer(**kwargs)
    elif kwargs["method"] == "GradientShap":
        return gradshap_explainer(**kwargs)
    else:
        raise ValueError("Pick an explaination function that exists.")


def saliency_explainer(
    model, inputs, targets, abs=False, normalise=False, *args, **kwargs
) -> np.array:
    """Wrapper aorund captum's Saliency implementation."""

    gc.collect()
    torch.cuda.empty_cache()

    # Set model in evaluate mode.
    model.to(kwargs.get("device", None))
    model.eval()

    if not isinstance(inputs, torch.Tensor):
        inputs = (
            torch.Tensor(inputs)
            .reshape(
                -1,
                kwargs.get("nr_channels", 3),
                kwargs.get("img_size", 28),
                kwargs.get("img_size", 28),
            )
            .to(kwargs.get("device", None))
        )
    if not isinstance(targets, torch.Tensor):
        targets = (
            torch.as_tensor(targets).long().to(kwargs.get("device", None))
        )  # inputs = inputs.reshape(-1, 3, 28, 28)

    assert (
        len(np.shape(inputs)) == 4
    ), "Inputs should be shaped (nr_samples, nr_channels, img_size, img_size) e.g., (1, 3, 28, 28)."

    explanation = (
        Saliency(model)
        .attribute(inputs, targets, abs=abs)
        .sum(axis=1)
        .reshape(-1, kwargs.get("img_size", 28), kwargs.get("img_size", 28))
        .cpu()
        .data
    )

    gc.collect()
    torch.cuda.empty_cache()

    if normalise:
        explanation = quantus.normalise_func.normalise_by_negative(explanation)

    if isinstance(explanation, torch.Tensor):
        if explanation.requires_grad:
            return explanation.cpu().detach().numpy()
        return explanation.cpu().numpy()

    return explanation


def intgrad_explainer(
    model, inputs, targets, abs=False, normalise=False, *args, **kwargs
) -> np.array:
    """Wrapper aorund captum's Integrated Gradients implementation."""

    gc.collect()
    torch.cuda.empty_cache()

    # Set model in evaluate mode.
    model.to(kwargs.get("device", None))
    model.eval()

    if not isinstance(inputs, torch.Tensor):
        inputs = (
            torch.Tensor(inputs)
            .reshape(
                -1,
                kwargs.get("nr_channels", 3),
                kwargs.get("img_size", 28),
                kwargs.get("img_size", 28),
            )
            .to(kwargs.get("device", None))
        )
    if not isinstance(targets, torch.Tensor):
        targets = torch.as_tensor(targets).long().to(kwargs.get("device", None))

    assert (
        len(np.shape(inputs)) == 4
    ), "Inputs should be shaped (nr_samples, nr_channels, img_size, img_size) e.g., (1, 3, 28, 28)."

    explanation = (
        IntegratedGradients(model)
        .attribute(
            inputs=inputs,
            target=targets,
            # baselines=torch.zeros_like(inputs),
            baselines=torch.rand_like(inputs),
            # baselines = torch.ones_like(inputs),
            n_steps=10,
            method="riemann_trapezoid",
        )
        .sum(axis=1)
        .reshape(-1, kwargs.get("img_size", 28), kwargs.get("img_size", 28))
        .cpu()
        .data
    )

    gc.collect()
    torch.cuda.empty_cache()

    if normalise:
        explanation = quantus.normalise_func.normalise_by_negative(explanation)

    if isinstance(explanation, torch.Tensor):
        if explanation.requires_grad:
            return explanation.cpu().detach().numpy()
        return explanation.cpu().numpy()

    return explanation


def gradshap_explainer(
    model, inputs, targets, abs=False, normalise=False, *args, **kwargs
) -> np.array:
    """Wrapper aorund captum's GradShap implementation."""

    gc.collect()
    torch.cuda.empty_cache()

    # Set model in evaluate mode.
    model.to(kwargs.get("device", None))
    model.eval()

    if not isinstance(inputs, torch.Tensor):
        inputs = (
            torch.Tensor(inputs)
            .reshape(
                -1,
                kwargs.get("nr_channels", 3),
                kwargs.get("img_size", 28),
                kwargs.get("img_size", 28),
            )
            .to(kwargs.get("device", None))
        )

    if not isinstance(targets, torch.Tensor):
        targets = torch.as_tensor(targets).long().to(kwargs.get("device", None))

    assert (
        len(np.shape(inputs)) == 4
    ), "Inputs should be shaped (nr_samples, nr_channels, img_size, img_size) e.g., (1, 3, 28, 28)."

    baselines = torch.zeros_like(inputs).to(kwargs.get("device", None))
    explanation = (
        GradientShap(model)
        .attribute(inputs=inputs, target=targets, baselines=baselines)
        .sum(axis=1)
        .reshape(-1, kwargs.get("img_size", 28), kwargs.get("img_size", 28))
        .cpu()
        .data
    )

    gc.collect()
    torch.cuda.empty_cache()

    if normalise:
        explanation = quantus.normalise_func.normalise_by_negative(explanation)

    if isinstance(explanation, torch.Tensor):
        if explanation.requires_grad:
            return explanation.cpu().detach().numpy()
        return explanation.cpu().numpy()

    return explanation

def fusiongrad_explainer(
    model, inputs, targets, abs=False, normalise=False, *args, **kwargs
) -> np.array:
    """Wrapper aorund captum's FusionGrad implementation."""

    std = kwargs.get("std", 0.5)
    mean = kwargs.get("mean", 1.0)
    n = kwargs.get("n", 10)
    m = kwargs.get("m", 10)
    sg_std = kwargs.get("sg_std", 0.5)
    sg_mean = kwargs.get("sg_mean", 0.0)
    posterior_mean = kwargs.get("posterior_mean", None)
    noise_type = kwargs.get("noise_type", "multiplicative")
    clip = kwargs.get("clip", False)

    def _sample(model, posterior_mean, std, distribution=None, noise_type="multiplicative"):
        """Implmentation to sample a model."""

        # Load model params.
        model.load_state_dict(posterior_mean)

        # If std is not zero, loop over each layer and add Gaussian noise.
        if not std == 0.0:
            with torch.no_grad():
                for layer in model.parameters():
                    if noise_type == "additive":
                        layer.add_(distribution.sample(layer.size()).to(layer.device))
                    elif noise_type == "multiplicative":
                        layer.mul_(distribution.sample(layer.size()).to(layer.device))
                    else:
                        print(
                            "Set NoiseGrad attribute 'noise_type' to either 'additive' or 'multiplicative' (str)."
                        )

        return model


    # Creates a normal (also called Gaussian) distribution.
    distribution = torch.distributions.normal.Normal(
        loc=torch.as_tensor(mean, dtype=torch.float),
        scale=torch.as_tensor(std, dtype=torch.float),
    )

    # Set model in evaluate mode.
    model.to(kwargs.get("device", None))
    model.eval()

    if not isinstance(inputs, torch.Tensor):
        inputs = (
            torch.Tensor(inputs)
            .reshape(
                -1,
                kwargs.get("nr_channels", 3),
                kwargs.get("img_size", 28),
                kwargs.get("img_size", 28),
            )
            .to(kwargs.get("device", None))
        )
    if not isinstance(targets, torch.Tensor):
        targets = torch.as_tensor(targets).long().to(kwargs.get("device", None))

    assert (
        len(np.shape(inputs)) == 4
    ), "Inputs should be shaped (nr_samples, nr_channels, img_size, img_size) e.g., (1, 3, 28, 28)."

    if inputs.shape[0] > 1:
        explanation = torch.zeros(
            (
                n,
                m,
                inputs.shape[0],
                kwargs.get("img_size", 28),
                kwargs.get("img_size", 28),
            )
        )
    else:
        explanation = torch.zeros(
            (n, m, kwargs.get("img_size", 28), kwargs.get("img_size", 28))
        )

    for i in range(n):
        model = _sample(
            model=model,
            posterior_mean=posterior_mean,
            std=std,
            distribution=distribution,
            noise_type=noise_type,
        )
        for j in range(m):
            inputs_noisy = inputs + torch.randn_like(inputs) * sg_std + sg_mean
            if clip:
                inputs_noisy = torch.clip(inputs_noisy, min=0.0, max=1.0)

            explanation[i][j] = (
                Saliency(model)
                .attribute(inputs_noisy, targets, abs=abs)
                .sum(axis=1)
                .reshape(-1, kwargs.get("img_size", 28), kwargs.get("img_size", 28))
                .cpu()
                .data
            )

    explanation = explanation.mean(axis=(0, 1))

    gc.collect()
    torch.cuda.empty_cache()

    if normalise:
        explanation = quantus.normalise_func.normalise_by_negative(explanation)

    if isinstance(explanation, torch.Tensor):
        if explanation.requires_grad:
            return explanation.cpu().detach().numpy()
        return explanation.cpu().numpy()

    return explanation

In [4]:
# Produce explanations and empty cache to to survive memory-wise.
def get_xai_explanations(x_batch, y_batch, model, device, model_type, bnn_reps=10):
    # Saliency.
    gc.collect()
    torch.cuda.empty_cache()
    res = [saliency_explainer(model=model.to(device),
                                        inputs=x_batch,
                                        targets=y_batch,
                                        **{"device": device},
                                        ) for _ in range(bnn_reps)]
    a_batch_saliency = np.mean(res, axis=0)
    a_batch_saliency_std = np.std(res, axis=0)
    # GradShap.
    gc.collect()
    torch.cuda.empty_cache()
    res = [gradshap_explainer(model=model.to(device), 
                                        inputs=x_batch,
                                        targets=y_batch, 
                                        **{"device": device},
                                        ) for _ in range(bnn_reps)]
    a_batch_gradshap = np.mean(res, axis=0)
    a_batch_gradshap_std = np.std(res, axis=0)
    # Integrated Gradients.
    gc.collect()
    torch.cuda.empty_cache()
    res = [intgrad_explainer(model=model.to(device),
                                        inputs=x_batch,
                                        targets=y_batch,
                                        **{"device": device},
                                        ) for _ in range(bnn_reps)]
    a_batch_intgrad = np.mean(res, axis=0)
    a_batch_intgrad_std = np.std(res, axis=0)

    # FusionGrad
    gc.collect()
    torch.cuda.empty_cache()
    posterior_mean = copy.deepcopy(model.to(device).state_dict())
    res = [fusiongrad_explainer(model=model.to(device), 
                                            inputs=x_batch, 
                                            targets=y_batch, 
                                            **{"posterior_mean": posterior_mean, "mean": 1.0, "std": 0.5, 
                                                "sg_mean": 0.0, "sg_std": 0.5, "n": 25, "m": 25, 
                                                "noise_type": "multiplicative", "device": device}) for _ in range(bnn_reps)]

    a_batch_fusiongrad = np.mean(res, axis=0)
    a_batch_fusiongrad_std = np.std(res, axis=0)
    # Save explanations to file.
    if model_type == "BNN":
        explanations = {
            "mean(SAL)": a_batch_saliency,
            "std(SAL)": a_batch_saliency_std,
            "mean(GS)": a_batch_gradshap,
            "std(GS)": a_batch_gradshap_std,
            "mean(IG)": a_batch_intgrad,
            "std(IG)": a_batch_intgrad_std,
            "mean(FG)": a_batch_fusiongrad,
            "std(FG)": a_batch_fusiongrad_std
        }
    else:
        explanations = {
            "SAL": a_batch_saliency,
            "GS": a_batch_gradshap,
            "IG": a_batch_intgrad,
            "FG": a_batch_fusiongrad
        }

    # explanations_std = {
    #     "Saliency_std": a_batch_saliency_std,
    #     "GradientShap_std": a_batch_gradshap_std,
    #     "IntegratedGradients_std": a_batch_intgrad_std,
    #     # "FusionGrad": a_batch_fusiongrad
    # }

    return explanations


def get_xai_methods_and_metrics(explanations, num_classes, subset_size=14, perturb_baseline="black", avg_sensitivity_samples=10, disable_warnings=True):
    xai_methods = list(explanations.keys())

    metrics = {
    "Robustness": quantus.AvgSensitivity(
        nr_samples=avg_sensitivity_samples,
        lower_bound=0.2,
        norm_numerator=quantus.norm_func.fro_norm,
        norm_denominator=quantus.norm_func.fro_norm,
        perturb_func=quantus.perturb_func.uniform_noise,
        similarity_func=quantus.similarity_func.difference,
        abs=False,
        normalise=False,
        aggregate_func=np.mean,
        return_aggregate=True,
        disable_warnings=disable_warnings,
    ),
    "Faithfulness": quantus.FaithfulnessCorrelation(
        nr_runs=10,
        subset_size=subset_size,
        perturb_baseline=perturb_baseline,
        perturb_func=quantus.perturb_func.baseline_replacement_by_indices,
        similarity_func=quantus.similarity_func.correlation_pearson,
        abs=False,
        normalise=False,
        aggregate_func=np.mean,
        return_aggregate=True,
        disable_warnings=disable_warnings,
    ),
    "Complexity": quantus.Sparseness(
        abs=True,
        normalise=False,
        aggregate_func=np.mean,
        return_aggregate=True,
        disable_warnings=disable_warnings,
    ),
    "Randomisation": quantus.RandomLogit(
        num_classes=num_classes,
        similarity_func=quantus.similarity_func.ssim,
        abs=True,
        normalise=False,
        aggregate_func=np.mean,
        return_aggregate=True,
        disable_warnings=disable_warnings,
    )}
    # )},
    # "Sufficiency": quantus.Sufficiency(
    #     threshold=0.6,
    #     return_aggregate=False,
    #     disable_warnings=disable_warnings
    # ),
    # "Consistency": quantus.Consistency(
    #     discretise_func=quantus.discretise_func.top_n_sign,
    #     return_aggregate=False,
    #     disable_warnings=disable_warnings)
    # }
    

    return xai_methods, metrics
    


In [5]:
def plot_explanations(x_batch, y_batch, explanations, probabilities, num_classes, output_dir, epoch):
    # i_max = 1
    # for idx,x in enumerate(x_batch.cpu().numpy()):
    for idx,x in enumerate(x_batch):
        fig, axes = plt.subplots(nrows=1, ncols=2+len(explanations), figsize=(15, 4))
        # axes[0].imshow(np.moveaxis(quantus.normalise_func.denormalise(x_batch[index].cpu().numpy(), mean=np.array([0.485, 0.456, 0.406]), std=np.array([0.229, 0.224, 0.225])), 0, -1), vmin=0.0, vmax=1.0)
        # img
        axes[0].imshow(np.moveaxis(x, 0, -1), vmin=0.0, vmax=1.0)
        axes[0].title.set_text(f"class {y_batch[idx].item()}")
        axes[0].axis("off")
        # probability bar 
        axes[1].bar(np.arange(num_classes), np.array([probabilities[idx][0],probabilities[idx][1]]), color='red')
        axes[1].set_xticks(np.arange(num_classes))
        axes[1].set_ylim([0, 1])
        # explanations
        for i, (k, v) in enumerate(explanations.items()):
            axes[i+2].imshow(quantus.normalise_func.normalise_by_negative(explanations[k][idx].reshape(28, 28)), cmap="seismic", vmin=-1.0, vmax=1.0)
            axes[i+2].title.set_text(f"{k}")
            axes[i+2].axis("off")

        plot_loc = os.path.join(output_dir,f"prob_{epoch}_{idx}_explanation.jpg")
        plt.savefig(plot_loc)
        plt.close('all')


def plot_explanations_with_sums(x_batch, y_batch, explanations, probabilities, num_classes, output_dir, epoch, num_forward_passes, model_type="BNN", scale_factor=7):
    y_indices_rowsums = np.linspace(0,x_batch[0].shape[1]-1,x_batch[0].shape[1]).astype(int)
    plt.rcParams["figure.autolayout"] = True
    for idx,x in enumerate(x_batch):
        if model_type == "BNN":
            fig, axes = plt.subplots(nrows=1, ncols=2+len(explanations), figsize=(16, 2), constrained_layout = True) #, gridspec_kw={'height_ratios': [12 if i != 1 else 4 for i in range(2+len(explanations))]})
        else:
            fig, axes = plt.subplots(nrows=1, ncols=1+int(len(explanations)), figsize=(6, 2), constrained_layout = True)

        # if model_type == "BNN":
        #     fig.suptitle("Explanations for BNN", fontsize=16)
        # else:
        #     fig.suptitle("Explanations for LeNet", fontsize=16)

        # axes[0].imshow(np.moveaxis(quantus.normalise_func.denormalise(x_batch[index].cpu().numpy(), mean=np.array([0.485, 0.456, 0.406]), std=np.array([0.229, 0.224, 0.225])), 0, -1), vmin=0.0, vmax=1.0)
        # img
        #axes[0].imshow(cv2.resize(np.moveaxis(x, 0, -1), (12,12)).astype(np.uint8), vmin=0.0, vmax=1.0)
        #  cv2.resize(np.moveaxis(x_batch[0], 0, -1), (20,20))
        fig.tight_layout()
        axes[0].imshow(np.moveaxis(x, 0, -1), vmin=0.0, vmax=1.0)
        axes[0].title.set_text(f"class {y_batch[idx].item()}")
        axes[0].axis("off")
        # probability bar
        if model_type == "BNN":
            axes[1].bar(np.arange(num_classes), np.array([probabilities[idx][0],probabilities[idx][1]]), color='red')
            axes[1].set_xticks(np.arange(num_classes))
            axes[1].set_ylim([0, 1])
            axes[1].set_title(f"{num_forward_passes} draws")
            # sns.barplot(np.arange(num_classes),np.array([probabilities[idx][0],probabilities[idx][1]]), ax=axes[1])
            i = 2
        else:
            i = 1

        for k in explanations.keys():
            # skip std plots for lenet (fixed weights)
            # if model_type != "BNN" and j % 2 != 0:
            #     continue
            # print(f"explanations[k][idx]: {k}")
            explanation = explanations[k][idx].reshape(28, 28)
            
            csum = np.sum(np.clip(explanation,a_min=0,a_max=None), 0)
            rsum = np.sum(np.clip(explanation,a_min=0,a_max=None), 1)
            csum_scaled = (csum-min(csum))/(max(csum)-min(csum))*scale_factor
            rsum_scaled = (rsum-min(rsum))/(max(rsum)-min(rsum))*scale_factor

            axes[i].imshow(quantus.normalise_func.normalise_by_negative(explanation), cmap="seismic", vmin=-1.0, vmax=1.0)
            axes[i].plot(-abs(csum_scaled), color='blue') #, marker='o') #, mfc='orange')
            axes[i].plot(-abs(rsum_scaled), y_indices_rowsums, color='blue') #, marker='o')#, mfc='orange')
            # axes[i].plot(csums, color='blue') #, marker='o') #, mfc='orange')
            # axes[i].plot(rsums, y_indices_rowsums, color='blue') #, marker='o')#, mfc='orange')
            axes[i].title.set_text(f"{k}")
            axes[i].axis("off")
            i += 1

        plot_loc = os.path.join(output_dir,f"prob_{epoch}_{idx}_explanation.jpg")
        plt.savefig(plot_loc)
        plt.close('all')

In [6]:
def get_bnn_probabilities(image_tensor, class_model, device, num_forward_passes=10):
    # print(image_tensor.size()[0])
    probs = []
    plot_list = []
    for tensor in image_tensor:
        tensor_unsqueezed = torch.unsqueeze(tensor,0)
        with torch.no_grad():
            preds = [np.argmax(class_model(tensor_unsqueezed).cpu().numpy()) for _ in range(num_forward_passes)]

        mxt_prob = np.mean(preds)
        dot_prob = 1 - mxt_prob

        probs.append((dot_prob,mxt_prob))

        if dot_prob > 0.05 and dot_prob < 0.95:
            plot_list.append(True)
        else:
            plot_list.append(False)
    return probs, plot_list

# def plot_bnn_probabilities(x_batch, probabilities, output_dir, num_classes, num_forward_passes, epoch, plot_list):
#         transform = T.Compose([T.ToPILImage()])
        
#         for i,tensor in enumerate(x_batch):
#             if not plot_list[i]:
#                 # print("continue")
#                 continue
#             # print("plottable")
            
#             # print(probabilities[i][0])
#             plot_loc = os.path.join(output_dir,f"prob_{epoch}_{i}.jpg")
#             # image = np.array(transform(tensor))
#             image = transform(tensor)
#             # print(type(image))
#             plt.figure()
#             fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(num_classes, 2),
#                                         gridspec_kw={'width_ratios': [3, 3]})
            
#             # Show the image and the true label
#             # ax1.imshow(image[..., 0], cmap='gray')
#             ax1.imshow(image) # , cmap='gray')
#             ax1.axis('off')
        
#             bar = ax2.bar(np.arange(num_classes), np.array([probabilities[i][0],probabilities[i][1]]), color='red')
#             # bar[int(true_label)].set_color('green')
#             ax2.set_xticks(np.arange(num_classes))
#             ax2.set_ylim([0, 1])
#             ax2.set_title(f'Probabilities after {num_forward_passes} draws')
#             # plt.savefig(os.path.join(plot_loc,"test_img_{}.png".format(i)))
#             plt.savefig(plot_loc)
#             plt.close('all')

# def plot_xai_eval(x_batch, y_batch, a_batch_saliency, a_batch_intgrad, output_dir, epoch, plot_list):
#     # fig, axes = plt.subplots(nrows=nr_images, ncols=3, figsize=(7.5, int(nr_images*3)))
#     for i in range(x_batch.size()[0]):
#         if not plot_list[i]:
#             # print("continue")
#             continue
#         # print("plottable")
#         plot_loc = os.path.join(output_dir,f"prob_{epoch}_{i}_xai.jpg")

#         fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(7.5, int(3)))
#         axes[0].imshow(rgb2gray(np.reshape(x_batch[i].cpu().numpy(),(28, 28, 3))), cmap="gray") # .astype(np.uint8))# , vmin=0.0, vmax=1.0)#, cmap="gray")
#         axes[0].title.set_text(f"Marker type {y_batch[i].item()}")
#         axes[0].axis("off")
#         axes[1].imshow(a_batch_saliency[i], cmap="seismic")
#         axes[1].title.set_text(f"Saliency")
#         axes[1].axis("off")
#         axes[2].imshow(a_batch_intgrad[i], cmap="seismic")
#         axes[2].title.set_text(f"Integrated Gradients")
#         axes[2].axis("off")
#         plt.tight_layout()
#         plt.savefig(plot_loc)
#         plt.close('all')

# def rgb2gray(rgb):
#     return np.dot(rgb[...,:3], [0.2989, 0.5870, 0.1140])

In [7]:
# explain = explanations["Saliency"][0]

# # csum = np.sum(explanation, 0)
# # rsum = np.sum(explanation, 1)
# # csum_scaled = (csum-min(csum))/(max(csum)-min(csum))*scale_factor
# # rsum_scaled = (rsum-min(rsum))/(max(rsum)-min(rsum))*scale_factor
# a = np.clip(explain,a_min=0,a_max=None)
# a.shape

In [8]:
def get_xai_quantification_results(x_batch, y_batch, class_model, device, xai_methods, xai_metrics, bnn_reps=10):
    results = {method : {} for method in xai_methods}

    for i,method in enumerate(xai_methods):
        # print(f"i: {i}   method: {method}")
        if (i % 2 != 0):
            print(f"skipping method: {method}")
            continue
        for metric, metric_func in xai_metrics.items():

            print(f"Evaluating {metric} of {method} method.")
            gc.collect()
            torch.cuda.empty_cache()

            # Get scores and append results.
            scores = [metric_func(
                model=class_model,
                x_batch=x_batch,
                y_batch=y_batch,
                a_batch=None,
                device=device,
                explain_func=explainer_wrapper,
                explain_func_kwargs={
                    "method": method,
                    "posterior_mean": copy.deepcopy(
                        class_model.to(device).state_dict()
                    ),
                    "mean": 1.0,
                    "std": 0.5,
                    "sg_mean": 0.0,
                    "sg_std": 0.5,
                    "n": 25,
                    "m": 25,
                    "noise_type": "multiplicative",
                    "device": device,
                },
            ) for _ in range(1)]
            # print(f"scores: {scores}")
            # print(f"np.mean(scores,axis=0): {np.mean(scores,axis=0)}")
            results[method][metric] = np.mean(scores,axis=0)

            # Empty cache.
            gc.collect()
            torch.cuda.empty_cache()
    
    return results

def xai_results_postprocess(xai_methods, xai_metrics, results, output_dir, epoch):
    results_agg = {}
    for i,method in enumerate(xai_methods):
        if i % 2 != 0:
            continue
        results_agg[method] = {}
        for metric, metric_func in xai_metrics.items():
            results_agg[method][metric] = np.mean(results[method][metric])

    df = pd.DataFrame.from_dict(results_agg)
    df = df.T.abs()
    df.to_csv(os.path.join(output_dir,f"df_{epoch}.csv"))
    
    # Take inverse ranking for Robustness, since lower is better.
    df_normalised = df.loc[:, df.columns != 'Robustness'].apply(lambda x: x / x.max())
    df_normalised["Robustness"] = df["Robustness"].min()/df["Robustness"].values
    df_normalised_rank = df_normalised.rank()
    df_normalised_rank.to_csv(os.path.join(output_dir,f"df_normalised_rank_{epoch}.csv"))
    return df, df_normalised_rank

In [9]:
# import quantus

# quantus.AVAILABLE_PERTURBATION_FUNCTIONS
# quantus.AVAILABLE_SIMILARITY_FUNCTIONS

In [10]:
"""
    metrics = {
    "Robustness": quantus.AvgSensitivity(
        nr_samples=avg_sensitivity_samples,
        lower_bound=0.2,
        norm_numerator=quantus.norm_func.fro_norm,
        norm_denominator=quantus.norm_func.fro_norm,
        perturb_func=quantus.perturb_func.uniform_noise,
        similarity_func=quantus.similarity_func.difference,
        abs=False,
        normalise=False,
        aggregate_func=np.mean,
        return_aggregate=True,
        disable_warnings=disable_warnings,
    ),
    "Faithfulness": quantus.FaithfulnessCorrelation(
        nr_runs=10,
        subset_size=subset_size,
        perturb_baseline=perturb_baseline,
        perturb_func=quantus.perturb_func.baseline_replacement_by_indices,
        similarity_func=quantus.similarity_func.correlation_pearson,
        abs=False,
        normalise=False,
        aggregate_func=np.mean,
        return_aggregate=True,
        disable_warnings=disable_warnings,
    ),
    "Complexity": quantus.Sparseness(
        abs=True,
        normalise=False,
        aggregate_func=np.mean,
        return_aggregate=True,
        disable_warnings=disable_warnings,
    ),
    "Randomisation": quantus.RandomLogit(
        num_classes=num_classes,
        similarity_func=quantus.similarity_func.ssim,
        abs=True,
        normalise=False,
        aggregate_func=np.mean,
        return_aggregate=True,
        disable_warnings=disable_warnings,
    )}
    # )},
    # "Sufficiency": quantus.Sufficiency(
    #     threshold=0.6,
    #     return_aggregate=False,
    #     disable_warnings=disable_warnings
    # ),
    # "Consistency": quantus.Consistency(
    #     discretise_func=quantus.discretise_func.top_n_sign,
    #     return_aggregate=False,
    #     disable_warnings=disable_warnings)
    # }
"""

# for i,method in enumerate(xai_methods):
#     if (i % 2 != 0):
#         print(f"skipping method: {method}")
#         continue
#     # for metric, metric_func in xai_metrics.items():
#     print(f"method: {method}")
#     print(f"metric_func: {metric_func}")        
#     scores = [metric_func(
#                     model=class_model,
#                     x_batch=x_batch,
#                     y_batch=y_batch,
#                     a_batch=None,
#                     device=device,
#                     explain_func=explainer_wrapper,
#                     explain_func_kwargs={
#                         "method": method,
#                         "posterior_mean": copy.deepcopy(
#                             class_model.to(device).state_dict()
#                         ),
#                         "mean": 1.0,
#                         "std": 0.5,
#                         "sg_mean": 0.0,
#                         "sg_std": 0.5,
#                         "n": 25,
#                         "m": 25,
#                         "noise_type": "multiplicative",
#                         "device": device,
#                     },
#                 ) for _ in range(1)]

# "Complexity": quantus.Sparseness(
#         abs=True,
#         normalise=False,
#         aggregate_func=np.mean,
#         return_aggregate=True,
#         disable_warnings=disable_warnings,
#     )


# range: 0-ln(2)? https://arxiv.org/pdf/2005.00631.pdf section 5

# def get_complexity_results(model, x_batch, y_batch, explanations, device, output_dir, epoch, complex_score_result, bnn_reps=1,
#                             save_results_to_dataframe=True, disable_warnings=True):
#     print("Generating complexity results")
    
#     result = complex_score_result

#     for i, (method, attr) in enumerate(explanations.items()):
#         if i % 2 != 0:
#             # print("skipping method: {method}")
#             continue
#         metric_func = quantus.Sparseness(abs=True,
#                                             normalise=False,
#                                             aggregate_func=np.mean,
#                                             return_aggregate=True,
#                                             disable_warnings=disable_warnings,
#                                         )
#         # print(f"metric: {metric}")
#         # score = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=attr, device=device)
#         score = np.mean([metric_func(
#             model=model,
#             x_batch=x_batch,
#             y_batch=y_batch,
#             a_batch=None,
#             device=device,
#             explain_func=explainer_wrapper,
#             explain_func_kwargs={
#                 "method": method,
#                 "posterior_mean": copy.deepcopy(
#                     model.to(device).state_dict()
#                 ),
#                 "mean": 1.0,
#                 "std": 0.5,
#                 "sg_mean": 0.0,
#                 "sg_std": 0.5,
#                 "n": 25,
#                 "m": 25,
#                 "noise_type": "multiplicative",
#                 "device": device,
#             },
#         ) for _ in range(bnn_reps)], axis=0)
#         # print(f"score: {score}")
#         result["Complexity score"].append(score)
#         result["Method"].append(method)
#         result["Epoch"].append(epoch)

#     return result

def get_complexity_results(model, x_batch, y_batch, explanations, device, output_dir, epoch, bnn_reps=1,
                            save_results_to_dataframe=True, disable_warnings=True):
    print("Generating complexity results")
    
    result = {
    "Score": [],
    "Method": [],
    "Index": []
    }
    
    # print(f"type(x_batch): {type(x_batch)}   x_batch.shape: {x_batch.shape}")
    for image_idx,x in enumerate(x_batch):   
        x_batch_tmp = np.expand_dims(x, axis=0)
        y_batch_tmp = np.expand_dims(y_batch[image_idx], axis=0)
        
        for i, (method, attr) in enumerate(explanations.items()):
            if method.split("_")[-1]=="std":
                    # print(f"skipping method: {method}")
                    continue
            a_batch_tmp = np.expand_dims(attr[image_idx], axis=0)
            metric_func = quantus.Sparseness(abs=True,
                                                normalise=False,
                                                aggregate_func=np.mean,
                                                return_aggregate=True,
                                                disable_warnings=disable_warnings,
                                            )
            score = np.mean([metric_func(
                model=model,
                x_batch=x_batch_tmp,
                y_batch=y_batch_tmp,
                a_batch=a_batch_tmp,
                device=device,
                explain_func=explainer_wrapper,
                explain_func_kwargs={
                    "method": method,
                    "posterior_mean": copy.deepcopy(
                        model.to(device).state_dict()
                    ),
                    "mean": 1.0,
                    "std": 0.5,
                    "sg_mean": 0.0,
                    "sg_std": 0.5,
                    "n": 25,
                    "m": 25,
                    "noise_type": "multiplicative",
                    "device": device,
                },
            ) for _ in range(1)], axis=0)
            print(f"score: {score}")
            result["Score"].append(np.mean(score))
            result["Method"].append(method)
            result["Index"].append(image_idx)
            
    df = pd.DataFrame(result)
    df_grouped = df.copy()

    # Group by the ranking.
    df_grouped["Rank"] = df_grouped.groupby(['Index'])["Score"].rank(ascending=True)

    # Smaller adjustments.
    df_grouped = df_grouped.loc[:, ~df_grouped.columns.str.contains('^Unnamed')]
    df_grouped.columns = map(lambda x: str(x).capitalize(), df_grouped.columns)

    df_view = df_grouped.groupby(["Method"])["Rank"].value_counts(normalize=True).mul(100).reset_index(name='Percentage').round(2)
    series_avg = df_grouped.groupby(['Method'])["Score"].mean()
    
    # Reorder the methods for plotting purporses.
    df_view_ordered = pd.DataFrame(columns=["Method", "Rank", "Percentage"])
    df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'Saliency']], ignore_index=True)
    df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'GradientShap']], ignore_index=True)
    df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'IntegratedGradients']], ignore_index=True)
    df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'FusionGrad']], ignore_index=True)

    if save_results_to_dataframe:
        df.to_csv(os.path.join(output_dir,f"df_complexity_{epoch}.csv"))
        df_grouped.to_csv(os.path.join(output_dir,f"df_complexity__grouped_{epoch}.csv"))
        df_view_ordered.to_csv(os.path.join(output_dir,f"df_complexity__ordered_{epoch}.csv"))
        series_avg.to_frame().to_csv(os.path.join(output_dir,f"df_complexity__avgs_{epoch}.csv"))

    return df,df_grouped,df_view_ordered

def get_randomisation_results(model, x_batch, y_batch, explanations, device, output_dir, epoch, bnn_reps=1, num_classes=2, seeds=[21,42,100],#,1000,5000],
                            # sim_funcs={"difference": quantus.similarity_func.difference, "abs_difference": quantus.similarity_func.abs_difference},
                            sim_funcs={"ssim": quantus.similarity_func.ssim,},
                            save_results_to_dataframe=True, disable_warnings=True, metric="RandomLogit"):
    print("Generating randomisation results")
    
    seeds = seeds
    sim_funcs = sim_funcs

    result = {
        "Score": [],
        "Method": [],
        "Similarity function": [],
        "Seed": []
    }

    for seed in seeds:
        for i, (method, attr) in enumerate(explanations.items()):
            if method.split("_")[-1]=="std":
                    # print(f"skipping method: {method}")
                    continue
            for sim, sim_func in sim_funcs.items():
                if metric == "RandomLogit":
                    metric_func = quantus.RandomLogit(
                                            num_classes=num_classes,
                                            similarity_func=sim_func,
                                            seed = seed,
                                            abs=True,
                                            normalise=True,
                                            aggregate_func=np.mean,
                                            return_aggregate=True,
                                            disable_warnings=disable_warnings,
                                        )
                # metric_func = quantus.metrics.randomisation.model_parameter_randomisation.ModelParameterRandomisation
                 # metric_func = quantus.metrics.randomisation.model_parameter_randomisation.ModelParameterRandomisation
                else:
                    metric_func = quantus.ModelParameterRandomisation(
                        similarity_func = sim_func,
                        seed = seed,
                        abs = True,
                        normalise = True,
                        aggregate_func= np.mean,
                        return_aggregate= True,
                        return_sample_correlation = True,
                        disable_warnings=disable_warnings,
                    )
                # print(f"metric: {metric}")
                # score = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=attr, device=device)
                score = np.mean(np.mean([metric_func(
                    model=model,
                    x_batch=x_batch,
                    y_batch=y_batch,
                    a_batch=attr,
                    device=device,
                    explain_func=explainer_wrapper,
                    explain_func_kwargs={
                        "method": method,
                        "posterior_mean": copy.deepcopy(
                           model.to(device).state_dict()
                        ),
                        "mean": 1.0,
                        "std": 0.5,
                        "sg_mean": 0.0,
                        "sg_std": 0.5,
                        "n": 25,
                        "m": 25,
                        "noise_type": "multiplicative",
                        "device": device,
                    },
                ) for _ in range(1)], axis=0))
                print(f"score: {score}")
                result["Score"].append(score)
                result["Method"].append(method)
                result["Similarity function"].append(sim)
                result["Seed"].append(seed)
                
    df = pd.DataFrame(result)
    df_grouped = df.copy()

    # Group by the ranking.
    df_grouped["Rank"] = df_grouped.groupby(['Seed', 'Similarity function'])["Score"].rank(ascending=False)

    # Smaller adjustments.
    df_grouped = df_grouped.loc[:, ~df_grouped.columns.str.contains('^Unnamed')]
    df_grouped.columns = map(lambda x: str(x).capitalize(), df_grouped.columns)

    df_view = df_grouped.groupby(["Method"])["Rank"].value_counts(normalize=True).mul(100).reset_index(name='Percentage').round(2)
    series_avg = df_grouped.groupby(['Method'])["Score"].mean()
    
    # Reorder the methods for plotting purporses.
    df_view_ordered = pd.DataFrame(columns=["Method", "Rank", "Percentage"])
    df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'Saliency']], ignore_index=True)
    df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'GradientShap']], ignore_index=True)
    df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'IntegratedGradients']], ignore_index=True)
    df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'FusionGrad']], ignore_index=True)

    if save_results_to_dataframe:
        df.to_csv(os.path.join(output_dir,f"df_randomisation_{epoch}.csv"))
        df_grouped.to_csv(os.path.join(output_dir,f"df_randomisation__grouped_{epoch}.csv"))
        df_view_ordered.to_csv(os.path.join(output_dir,f"df_randomisation__ordered_{epoch}.csv"))
        series_avg.to_frame().to_csv(os.path.join(output_dir,f"df_randomisation__avgs_{epoch}.csv"))
    return df,df_grouped,df_view_ordered




def get_robustness_results(model, x_batch, y_batch, explanations, device, output_dir, epoch, bnn_reps=1, avg_sensitivity_samples=10, lower_bounds=np.linspace(0.1,0.2,2),
                            # sim_funcs={"difference": quantus.similarity_func.difference, "abs_difference": quantus.similarity_func.abs_difference},
                            sim_funcs={"difference": quantus.similarity_func.difference},
                            save_results_to_dataframe=True, disable_warnings=True, metric="AvgSensitivity"):
    print("Generating robustness results")
    
    lower_bounds = lower_bounds
    sim_funcs = sim_funcs

    result = {
        "Score": [],
        "Method": [],
        "Similarity function": [],
        "Lower bound": []
    }
    scores = []
    
    for lb in lower_bounds:
        for i, (method, attr) in enumerate(explanations.items()):
            if method.split("_")[-1]=="std":
                    #  print(f"skipping method: {method}")
                    continue
            for sim, sim_func in sim_funcs.items():
                if metric == "AvgSensitivity":
                    metric_func = quantus.AvgSensitivity(nr_samples=avg_sensitivity_samples,
                                                    lower_bound=lb,
                                                    norm_numerator=quantus.norm_func.fro_norm,
                                                    norm_denominator=quantus.norm_func.fro_norm,
                                                    perturb_func=quantus.perturb_func.uniform_noise,
                                                    similarity_func=sim_func,
                                                    abs=False,
                                                    normalise=False,
                                                    aggregate_func=np.mean,
                                                    return_aggregate=True,
                                                    disable_warnings=disable_warnings,
                                                )
                else:
                    metric_func = quantus.LocalLipschitzEstimate(nr_samples=5,
                                        perturb_std=0.1,
                                        perturb_mean=0.0,
                                        norm_numerator=quantus.norm_func.fro_norm,
                                        norm_denominator=quantus.norm_func.fro_norm,
                                        perturb_func=quantus.perturb_func.uniform_noise,
                                        # similarity_func=quantus.similarity_func.difference,
                                        abs=False,
                                        normalise=False,
                                        aggregate_func=np.mean,
                                        return_aggregate=True,
                                        disable_warnings=True
                                        )
                # print(f"metric: {metric}")
                # score = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=attr, device=device)
                score = np.mean([metric_func(
                    model=model,
                    x_batch=x_batch,
                    y_batch=y_batch,
                    a_batch=attr,
                    device=device,
                    explain_func=explainer_wrapper,
                    explain_func_kwargs={
                        "method": method,
                        "posterior_mean": copy.deepcopy(
                            model.to(device).state_dict()
                        ),
                        "mean": 1.0,
                        "std": 0.5,
                        "sg_mean": 0.0,
                        "sg_std": 0.5,
                        "n": 25,
                        "m": 25,
                        "noise_type": "multiplicative",
                        "device": device,
                    },
                ) for _ in range(1)], axis=0)
                print(f"score: {score}")
                score = np.mean(score)
                print(f"score2: {score}")
                result["Score"].append(score)
                result["Method"].append(method)
                result["Similarity function"].append(sim)
                result["Lower bound"].append(lb)
                
    df = pd.DataFrame(result)
    df_grouped = df.copy()

    # Group by the ranking.
    df_grouped["Rank"] = df_grouped.groupby(['Lower bound', 'Similarity function'])["Score"].rank()

    # Smaller adjustments.
    df_grouped = df_grouped.loc[:, ~df_grouped.columns.str.contains('^Unnamed')]
    df_grouped.columns = map(lambda x: str(x).capitalize(), df_grouped.columns)

    df_view = df_grouped.groupby(["Method"])["Rank"].value_counts(normalize=True).mul(100).reset_index(name='Percentage').round(2)
    series_avg = df_grouped.groupby(['Method'])["Score"].mean()

    # Reorder the methods for plotting purporses.
    df_view_ordered = pd.DataFrame(columns=["Method", "Rank", "Percentage"])
    df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'Saliency']], ignore_index=True)
    df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'GradientShap']], ignore_index=True)
    df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'IntegratedGradients']], ignore_index=True)
    df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'FusionGrad']], ignore_index=True)

    if save_results_to_dataframe:
        df.to_csv(os.path.join(output_dir,f"df_robustness_{epoch}.csv"))
        df_grouped.to_csv(os.path.join(output_dir,f"df_robustness__grouped_{epoch}.csv"))
        df_view_ordered.to_csv(os.path.join(output_dir,f"df_robustness__ordered_{epoch}.csv"))
        series_avg.to_frame().to_csv(os.path.join(output_dir,f"df_robustness__avgs_{epoch}.csv"))

    return df,df_grouped,df_view_ordered
    
    
# def get_sensitivity_results(model, x_batch, y_batch, explanations, device, output_dir, epoch, baseline_strategies=["mean", "uniform"], 
#                             subset_sizes=np.array([7,14,28]), 
#                             sim_funcs={"pearson": quantus.similarity_func.correlation_pearson, "spearman": quantus.similarity_func.correlation_spearman},
#                             save_results_to_dataframe=True):
    
#     baseline_strategies = baseline_strategies
#     subset_sizes = subset_sizes
#     sim_funcs = sim_funcs

#     result = {
#         "Faithfulness score": [],
#         "Method": [],
#         "Similarity function": [],
#         "Baseline strategy": [],
#         "Subset size": [],
#     }


#     for i, (method, attr) in enumerate(explanations.items()):
#         if i % 2 != 0:
#             # print("skipping method: {method}")
#             continue
#         metric = quantus.FaithfulnessCorrelation(abs=True,
#                                                 normalise=True,
#                                                 return_aggregate=True,
#                                                 disable_warnings=True,
#                                                 aggregate_func=np.mean,
#                                                 normalise_func=quantus.normalise_func.normalise_by_negative,
#                                                 nr_runs=10,
#                                                 perturb_baseline="mean",
#                                                 perturb_func=quantus.perturb_func.baseline_replacement_by_indices,
#                                                 similarity_func=quantus.similarity_func.correlation_pearson,
#                                                 subset_size=7)
        
#         # print(f"metric: {metric}")
#         score = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=attr, device=device)
#         result["Method"].append(method)
#         result["Faithfulness score"].append(score[0])
#         result["Similarity function"].append("pearson")

    # df = pd.DataFrame(result)
    # df_grouped = df.copy()

    # # Group by the ranking.
    # df_grouped["Rank"] = df_grouped.groupby(['Baseline strategy', 'Subset size', 'Similarity function'])["Faithfulness score"].rank()

    # # Smaller adjustments.
    # df_grouped = df_grouped.loc[:, ~df_grouped.columns.str.contains('^Unnamed')]
    # df_grouped.columns = map(lambda x: str(x).capitalize(), df_grouped.columns)

    # df_view = df_grouped.groupby(["Method"])["Rank"].value_counts(normalize=True).mul(100).reset_index(name='Percentage').round(2)

    # # Reorder the methods for plotting purporses.
    # df_view_ordered = pd.DataFrame(columns=["Method", "Rank", "Percentage"])
    # df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'Saliency']], ignore_index=True)
    # df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'GradientShap']], ignore_index=True)
    # df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'IntegratedGradients']], ignore_index=True)
    # # df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'FusionGrad']], ignore_index=True)

    # if save_results_to_dataframe:
    #     df.to_csv(os.path.join(output_dir,f"df_sensitivity_{epoch}.csv"))
    #     df_grouped.to_csv(os.path.join(output_dir,f"df_sensitivity_grouped_{epoch}.csv"))
    #     df_view_ordered.to_csv(os.path.join(output_dir,f"df_sensitivity_ordered_{epoch}.csv"))

    # return df,df_grouped,df_view_ordered



def get_faithfulness_results(model, x_batch, y_batch, explanations, device, output_dir, epoch, bnn_reps=1, baseline_strategies=["mean","uniform"], 
                            subset_sizes=np.array([3,7,12,18,22]), 
                            sim_funcs={"pearson": quantus.similarity_func.correlation_pearson, "spearman": quantus.similarity_func.correlation_spearman},
                            save_results_to_dataframe=True, metric="fc"):
    print("Generating faithfulness results")
    
    baseline_strategies = baseline_strategies
    subset_sizes = subset_sizes
    sim_funcs = sim_funcs

    result = {
        "Score": [],
        "Method": [],
        "Similarity function": [],
        "Baseline strategy": [],
        "Subset size": [],
    }
    
    
        
        
        

    for b in baseline_strategies:
        for s in subset_sizes:
            for i, (method, attr) in enumerate(explanations.items()):
                # if i % 2 != 0:
                if method.split("_")[-1]=="std":
                    # print(f"skipping method: {method}")
                    continue
                for sim, sim_func in sim_funcs.items():
                    print(f"bl: {b}   s: {s}")
                    if metric.lower() == "road":
                        metric_func = quantus.ROAD(abs=True,
                                            normalise=True,
                                            return_aggregate=True,
                                            disable_warnings=True,
                                            aggregate_func=np.mean,
                                            normalise_func=quantus.normalise_func.normalise_by_negative,
                                            perturb_baseline=b,)
                        
                        # print(f"metric: {metric}")
                        score = metric_func(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=attr, device=device)
                        # print(f"score: {score}")
                        result["Method"].append(method)
                        result["Baseline strategy"].append(b.capitalize())
                        result["Subset size"].append(s)
                        result["Score"].append(score)
                        result["Similarity function"].append(sim)
                    else:
                        metric_func = quantus.FaithfulnessCorrelation(abs=True,
                                                                normalise=True,
                                                                return_aggregate=True,
                                                                disable_warnings=True,
                                                                aggregate_func=np.mean,
                                                                normalise_func=quantus.normalise_func.normalise_by_negative,
                                                                nr_runs=10,
                                                                perturb_baseline=b,
                                                                perturb_func=quantus.perturb_func.baseline_replacement_by_indices,
                                                                similarity_func=sim_func,
                                                                subset_size=s)
                    
                        # print(f"metric: {metric}")
                        score = metric_func(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=attr, device=device)
                        # print(f"score: {score}")
                        result["Method"].append(method)
                        result["Baseline strategy"].append(b.capitalize())
                        result["Subset size"].append(s)
                        result["Score"].append(score[0])
                        result["Similarity function"].append(sim)

    return result

    df = pd.DataFrame(result)
    df_grouped = df.copy()

    # Group by the ranking.
    df_grouped["Rank"] = df_grouped.groupby(['Baseline strategy', 'Subset size', 'Similarity function'])["Score"].rank()

    # Smaller adjustments.
    df_grouped = df_grouped.loc[:, ~df_grouped.columns.str.contains('^Unnamed')]
    df_grouped.columns = map(lambda x: str(x).capitalize(), df_grouped.columns)

    df_view = df_grouped.groupby(["Method"])["Rank"].value_counts(normalize=True).mul(100).reset_index(name='Percentage').round(2)
    series_avg = df_grouped.groupby(['Method'])["Score"].mean()

    # Reorder the methods for plotting purporses.
    df_view_ordered = pd.DataFrame(columns=["Method", "Rank", "Percentage"])
    df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'Saliency']], ignore_index=True)
    df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'GradientShap']], ignore_index=True)
    df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'IntegratedGradients']], ignore_index=True)
    df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'FusionGrad']], ignore_index=True)

    if save_results_to_dataframe:
        df.to_csv(os.path.join(output_dir,f"df_faithfulness_{epoch}.csv"))
        df_grouped.to_csv(os.path.join(output_dir,f"df_faithfulness_grouped_{epoch}.csv"))
        df_view_ordered.to_csv(os.path.join(output_dir,f"df_faithfulness_ordered_{epoch}.csv"))
        series_avg.to_frame().to_csv(os.path.join(output_dir,f"df_faithfulness_avgs_{epoch}.csv"))

    return df,df_grouped,df_view_ordered


# def create_complexity_dfs(result,output_dir,save_results_to_dataframe=True):
#     df = pd.DataFrame(result)
#     df_grouped = df.copy()

#     # Group by the ranking.
#     df_grouped["Rank"] = df_grouped.groupby(['Epoch'])["Complexity score"].rank(ascending=False)

#     # Smaller adjustments.
#     df_grouped = df_grouped.loc[:, ~df_grouped.columns.str.contains('^Unnamed')]
#     df_grouped.columns = map(lambda x: str(x).capitalize(), df_grouped.columns)

#     df_view = df_grouped.groupby(["Method"])["Rank"].value_counts(normalize=True).mul(100).reset_index(name='Percentage').round(2)

#     # Reorder the methods for plotting purporses.
#     df_view_ordered = pd.DataFrame(columns=["Method", "Rank", "Percentage"])
#     df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'Saliency']], ignore_index=True)
#     df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'GradientShap']], ignore_index=True)
#     df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'IntegratedGradients']], ignore_index=True)
#     # df_view_ordered = df_view_ordered.append([df_view.loc[df_view["Method"] == 'FusionGrad']], ignore_index=True)

#     if save_results_to_dataframe:
#         df.to_csv(os.path.join(output_dir,f"df_complexity.csv"))
#         df_grouped.to_csv(os.path.join(output_dir,f"df_complexity__grouped.csv"))
#         df_view_ordered.to_csv(os.path.join(output_dir,f"df_complexity__ordered.csv"))
    
#     return df, df_grouped, df_view_ordered


def plot_ranking_results(df_view_ordered,output_dir,epoch,metric_name):
        
    plot_loc = os.path.join(output_dir,f"prob_{epoch}_{metric_name}.jpg")
    
    fig, ax = plt.subplots(figsize=(6.5,5))
    ax = sns.histplot(x='Method', hue='Rank', weights='Percentage', multiple='stack', data=df_view_ordered, shrink=0.6, palette="colorblind", legend=False)
    ax.spines["right"].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.tick_params(axis='both', which='major', labelsize=16)
    ax.set_ylabel('Frequency of rank', fontsize=15)
    ax.set_xlabel('')
    ax.set_xticklabels(["SAL", "GS", "IG", "FG"])
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=2, fancybox=True, shadow=False, labels=['1st', "2nd", "3rd", "4th"][::-1])
    # plt.axvline(x=3.5, ymax=0.95, color='black', linestyle='-')
    plt.tight_layout()
    plt.savefig(plot_loc)
    plt.close("all")
        
def multi_plot_ranking_results(df_views_ordered,output_dir,epoch,metric_names):
    # df_views_ordered_loc = df_views_ordered.copy()
    # print(f"before len(df_views_ordered_loc): {len(df_views_ordered_loc)}")
    # if not complexity_result is None:
    #     print(f"not none")
    #     df_views_ordered_loc.append(create_complexity_dfs(complexity_result, output_dir)[2])
        
    plot_loc = os.path.join(output_dir,f"prob_{epoch}_all_metrics.jpg")
    
    # print(f"after len(df_views_ordered_loc): {len(df_views_ordered_loc)}")
    
    for m in metric_names:
        print(f"m: {m}")
    # fig, ax = plt.subplots(figsize=(6.5,5))
    fig, axes = plt.subplots(ncols=int(len(df_views_ordered)), figsize=(14, 6))
    # fig.legend(loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=3, fancybox=True, shadow=False, labels=['1st', "2nd", "3rd"][::-1])
    fig.suptitle("Ranking of XAI functions on different metrics")
    for i,df_view_ordered in enumerate(df_views_ordered):
        # print(f"{metric_names[i]}")
        sns.histplot(ax=axes[i], x='Method', hue='Rank', weights='Percentage', multiple='stack', data=df_view_ordered, shrink=0.6, palette="colorblind", legend=False)
        axes[i].spines["right"].set_visible(False)
        axes[i].spines['top'].set_visible(False)
        axes[i].tick_params(axis='both', which='major', labelsize=16)
        axes[i].set_ylabel('Frequency of rank', fontsize=15)
        axes[i].set_xlabel('')
        axes[i].set_xticklabels(["SAL", "GS", "IG", "FG"])
        axes[i].legend(loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=3, fancybox=True, shadow=False, labels=['1st', "2nd", "3rd", "4th"][::-1])
        axes[i].set_title(f"{metric_names[i]}")
        # axes[i].legend(loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=3, fancybox=True, shadow=False, labels=['1st', "2nd", "3rd"][::-1])
        # axes[i].axvline(x=3, ymax=0.95, color='black', linestyle='-')
        
    plt.tight_layout()
    plt.savefig(plot_loc)
    plt.close("all")        
        
        
    # for idx,x in enumerate(x_batch):
    #     if model_type == "BNN":
    #         fig, axes = plt.subplots(nrows=1, ncols=2+len(explanations), figsize=(12, 4)) #, gridspec_kw={'height_ratios': [12 if i != 1 else 4 for i in range(2+len(explanations))]})
    #     else:
    #         fig, axes = plt.subplots(nrows=1, ncols=1+int(len(explanations)/2), figsize=(6, 4))

    #     if model_type == "BNN":
    #         fig.suptitle("Explanations for BNN", fontsize=16)
    #     else:
    #         fig.suptitle("Explanations for LeNet", fontsize=16)

    #     # axes[0].imshow(np.moveaxis(quantus.normalise_func.denormalise(x_batch[index].cpu().numpy(), mean=np.array([0.485, 0.456, 0.406]), std=np.array([0.229, 0.224, 0.225])), 0, -1), vmin=0.0, vmax=1.0)
    #     # img
    #     #axes[0].imshow(cv2.resize(np.moveaxis(x, 0, -1), (12,12)).astype(np.uint8), vmin=0.0, vmax=1.0)
    #     #  cv2.resize(np.moveaxis(x_batch[0], 0, -1), (20,20))
    #     axes[0].imshow(np.moveaxis(x, 0, -1), vmin=0.0, vmax=1.0)
    #     axes[0].title.set_text(f"class {y_batch[idx].item()}")
    #     axes[0].axis("off")
    #     # probability bar
    #     if model_type == "BNN":
    #         axes[1].bar(np.arange(num_classes), np.array([probabilities[idx][0],probabilities[idx][1]]), color='red')
    #         axes[1].set_xticks(np.arange(num_classes))
    #         axes[1].set_ylim([0, 1])
    #         axes[1].set_title(f"{num_forward_passes} draws")
    #         # sns.barplot(np.arange(num_classes),np.array([probabilities[idx][0],probabilities[idx][1]]), ax=axes[1])
    #         i = 2
    #     else:
    #         i = 1

    #     for j, (k, v) in enumerate(explanations.items()):
    #         # skip std plots for lenet (fixed weights)
    #         if model_type != "BNN" and j % 2 != 0:
    #             continue
    #         explanation = explanations[k][idx].reshape(28, 28)
            
    #         csums = np.sum(explanation, 0)
    #         rsums = np.sum(explanation, 1)

    #         axes[i].imshow(quantus.normalise_func.normalise_by_negative(explanation), cmap="seismic", vmin=-1.0, vmax=1.0)
    #         axes[i].plot(-abs(csums), color='blue') #, marker='o') #, mfc='orange')
    #         axes[i].plot(-abs(rsums), y_indices_rowsums, color='blue') #, marker='o')#, mfc='orange')
    #         # axes[i].plot(csums, color='blue') #, marker='o') #, mfc='orange')
    #         # axes[i].plot(rsums, y_indices_rowsums, color='blue') #, marker='o')#, mfc='orange')
    #         axes[i].title.set_text(f"{k}")
    #         axes[i].axis("off")
    #         i += 1

    #     plot_loc = os.path.join(output_dir,f"prob_{epoch}_{idx}_explanation.jpg")
    #     plt.savefig(plot_loc)
    #     plt.close('all')
        

# def plot_sensitivity_results(df_view_ordered,output_dir,epoch):
#     plot_loc = os.path.join(output_dir,f"prob_{epoch}_sensitivity.jpg")
    
#     fig, ax = plt.subplots(figsize=(6.5,5))
#     ax = sns.histplot(x='Method', hue='Rank', weights='Percentage', multiple='stack', data=df_view_ordered, shrink=0.6, palette="colorblind", legend=False)
#     ax.spines["right"].set_visible(False)
#     ax.spines['top'].set_visible(False)
#     ax.tick_params(axis='both', which='major', labelsize=16)
#     ax.set_ylabel('Frequency of rank', fontsize=15)
#     ax.set_xlabel('')
#     ax.set_xticklabels(["SAL", "GS", "IG"])
#     plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=3, fancybox=True, shadow=False, labels=['1st', "2nd", "3rd"][::-1])
#     plt.axvline(x=3.5, ymax=0.95, color='black', linestyle='-')
#     plt.tight_layout()
#     plt.savefig(plot_loc)
#     plt.close("all")

In [11]:
# multi_plot_ranking_results(df_views_ordered,bnn_probs_out_path,i,metric_names,complexity_result)

In [12]:
def parse_config(filename:str = "config_vae_inference_local.cfg"):
    global NUM_CLASSES, TEST_IMGS_ROOT_DIR, OUTPUT_DIR, MODEL_TYPE 
    global BATCH_SIZE, CHECKPOINT_LOC_CLASSIFICIATION_MODEL, NUM_BNN_FORWARD_PASSES, CLASSIFIER_INPUT_DIMS
    global CUDA_GPU_INDEX, NORMALIZE

    config = configparser.ConfigParser()
    config.read(filename)

    NUM_CLASSES = int(config["DEFAULT"]["NUM_CLASSES"])

    TEST_IMGS_ROOT_DIR = os.path.join(config["DEFAULT"]["TEST_DIR_IMGS"])
    OUTPUT_DIR = os.path.join(config["DEFAULT"]["OUTPUT_DIR"])

    CHECKPOINT_LOC_CLASSIFICIATION_MODEL = os.path.join(config["DEFAULT"]["CHECKPOINT_LOC_CLASSIFICIATION_MODEL"])
    
    l = list(map(str.lower, CHECKPOINT_LOC_CLASSIFICIATION_MODEL.split("/")))
    
    if "bnn" in l:
        MODEL_TYPE = "BNN"
    elif "lenet" in l:
        MODEL_TYPE  = "LeNet"
    else:
        print("model type has to be set via correct path (path has either to contain 'lenet' or 'bnn'). Exiting...")
        sys.exit()

    NUM_BNN_FORWARD_PASSES = int(config["DEFAULT"]["NUM_BNN_FORWARD_PASSES"])
    CLASSIFIER_INPUT_DIMS = int(config["DEFAULT"]["CLASSIFIER_INPUT_DIMS"])
    BATCH_SIZE = int(config["DEFAULT"]["BATCH_SIZE"])
    
    CUDA_GPU_INDEX = int(config["CUDA"]["CUDA_GPU_INDEX"])
    NORMALIZE = bool(int(config["DEFAULT"]["NORMALIZE"]))

In [13]:
parse_config("config_vae_inference_local.cfg")

if not os.path.exists(os.path.join(CHECKPOINT_LOC_CLASSIFICIATION_MODEL)):
   sys.exit("no valid CHECKPOINT_LOC_CLASSIFICIATION_MODEL path") 

device = torch.device(f"cuda:{CUDA_GPU_INDEX}" if torch.cuda.is_available() else "cpu")
device = "cpu"

if MODEL_TYPE == "BNN":
    class_model = bayesian_models.BayesianModel3BatchNormActivation(CLASSIFIER_INPUT_DIMS,NUM_CLASSES)
else:
    class_model = bayesian_models.LeNet(CLASSIFIER_INPUT_DIMS,NUM_CLASSES)
    NUM_BNN_FORWARD_PASSES = 1
    # print(f"NUM_BNN_FPASSES: {NUM_BNN_FORWARD_PASSES}")

class_model.load_state_dict(torch.load(CHECKPOINT_LOC_CLASSIFICIATION_MODEL,map_location=torch.device('cpu')))
class_model.eval()

# bnn_reps = config.NUM_BNN_FORWARD_PASSES

ds = Custom_Dataset(TEST_IMGS_ROOT_DIR, device, NORMALIZE)
dl = DataLoader(ds,batch_size=BATCH_SIZE, shuffle=False)

vae_out_path = os.path.join(os.path.join(OUTPUT_DIR),"vae")
bnn_probs_out_path = os.path.join(os.path.join(OUTPUT_DIR),"bnn")
df_out_path = os.path.join(bnn_probs_out_path,"dfs")
# df_out_path = os.path.join(OUTPUT_DIR)

vae_utils.make_folders(os.path.join(OUTPUT_DIR), "bnn", "vae", os.path.join("bnn/dfs"))

# complexity_result = {
#     "Complexity score": [],
#     "Method": [],
#     "Epoch": []
# }


# metric_names = ["Faithfulness", "Robustness", "Randomisation", "Complexity"]
# metric_names = ["Faithfulness", "Randomisation", "Complexity"]

for i, ([x_batch,x_batch_plot],l) in enumerate(dl):
    # if i == 0:
    #     continue
    if i > 0:
        break
    
    df_views_ordered = []
    time_dict = {}
    start_time = time.time()

    sec_time = time.time()
    print("Predicting results")
    probs, plot_list = get_bnn_probabilities(x_batch, class_model, device, NUM_BNN_FORWARD_PASSES)
    # print(f"probs: {probs}")
    time_dict["bnn_probs"] = time.time() - sec_time
    print(f"time for predictions: {time.time()-sec_time}")
    
    x_batch = x_batch.to(device)
    x_batch.requires_grad = True
    y_batch = torch.from_numpy(np.array(np.argmax(probs,axis=1),dtype=np.int64)).to(device)

    # vae_utils.save_tensor_to_image_known_path(x_batch,NUM_SAMPLES,i,vae_out_path)
    sec_time = time.time()
    print("Generating XAI explanations")
    explanations = get_xai_explanations(x_batch, y_batch, class_model, device, MODEL_TYPE, NUM_BNN_FORWARD_PASSES)
    time_dict["explanations"] = time.time() - sec_time
    print(f"time for explanations generation: {time.time()-sec_time}")

    x_batch, x_batch_plot, y_batch = x_batch.detach().cpu().numpy(), x_batch_plot.detach().cpu().numpy(), y_batch.detach().cpu().numpy()
    # # plot_explanations(x_batch.detach(), y_batch.detach(), explanations, probs, config.NUM_CLASSES, bnn_probs_out_path, i)
    plot_explanations_with_sums(x_batch_plot, y_batch, explanations, probs, NUM_CLASSES, bnn_probs_out_path, i, NUM_BNN_FORWARD_PASSES, MODEL_TYPE)
    
    # xai_methods, xai_metrics = get_xai_methods_and_metrics(explanations, NUM_CLASSES)
    # # # res = get_xai_quantification_results(x_batch, y_batch, class_model, device, xai_methods, xai_metrics, NUM_BNN_FORWARD_PASSES)
    # # # df, df_normalised_rank = xai_results_postprocess(xai_methods, xai_metrics, res, bnn_probs_out_path, i)
    # # df, df_grouped, df_view_ordered = get_sensitivity_results(class_model, x_batch, y_batch, explanations, device, bnn_probs_out_path, i, baseline_strategies=["mean","uniform","black","white"])
    # # # plot_ranking_results(df_view_ordered,bnn_probs_out_path,i,"faithfulness")
    # # df, df_grouped, df_view_ordered = get_robustness_results(class_model, x_batch, y_batch, explanations, device, bnn_probs_out_path, i)
    # # # # df, df_grouped, df_view_ordered = get_randomisation_results(class_model, x_batch, y_batch, explanations, device, bnn_probs_out_path, i)
        # # plot_ranking_results(df_view_ordered,bnn_probs_out_path,i,"faithfulness")

    # df_views_ordered.append(get_sensitivity_results(class_model, x_batch, y_batch, explanations, device, df_out_path, i, bnn_reps = NUM_BNN_FORWARD_PASSES, baseline_strategies=["mean","uniform","black","white"])[2])

    # df_views_ordered.append(get_robustness_results(class_model, x_batch, y_batch, explanations, device, df_out_path, i, bnn_reps = NUM_BNN_FORWARD_PASSES)[2])
    # df_views_ordered.append(get_randomisation_results(class_model, x_batch, y_batch, explanations, device, df_out_path, i, bnn_reps = NUM_BNN_FORWARD_PASSES)[2])
    # # complexity_result = get_complexity_results(class_model, x_batch, y_batch, explanations, device, bnn_probs_out_path, i, complexity_result)
    # df_views_ordered.append(get_complexity_results(class_model, x_batch, y_batch, explanations, device, df_out_path, i, bnn_reps = NUM_BNN_FORWARD_PASSES)[2])
    # plot_ranking_results(df_views_ordered[-1],bnn_probs_out_path,i,"randomisation")#,complexity_result=complexity_result)
    
    # methods = {"Faithfulness": get_faithfulness_results(class_model, x_batch, y_batch, explanations, device, df_out_path, i, bnn_reps = NUM_BNN_FORWARD_PASSES, metric="road"),
    #            "Robustness": get_robustness_results(class_model, x_batch, y_batch, explanations, device, df_out_path, i, bnn_reps = NUM_BNN_FORWARD_PASSES)}#,#, metric="bla"), not for both (fro_norm)
    #         #    "Randomisation": get_randomisation_results(class_model, x_batch, y_batch, explanations, device, df_out_path, i, bnn_reps = NUM_BNN_FORWARD_PASSES),#, metric="bla"), not for bnn (deepcopy)
    #         #    "Complexity": get_complexity_results(class_model, x_batch, y_batch, explanations, device, df_out_path, i, bnn_reps = NUM_BNN_FORWARD_PASSES)}
    
    # print("Starting evaluation process...")
    
    # res = get_faithfulness_results(class_model, x_batch, y_batch, explanations, device, df_out_path, i, bnn_reps = NUM_BNN_FORWARD_PASSES, metric="road")
    
    # for method, fct in methods.items():
    #     method_time = time.time()
    #     df, df_grouped, df_view_ordered = fct
    #     df_views_ordered.append(df_view_ordered)
    #     plot_ranking_results(df_views_ordered[-1],bnn_probs_out_path,i,method)
    #     print(f"time for {method.lower()} result generation: {time.time()-method_time}")
    #     time_dict[method] = time.time()-method_time
    
    # time_dict["all"] = time.time()-start_time
    # print(f"time for all metrics: {time_dict['all']}")
    # pd.DataFrame(pd.Series(time_dict,name="time")).to_csv(os.path.join(df_out_path,f"df_times_{i}.csv"))
    
    # multi_plot_ranking_results(df_views_ordered, bnn_probs_out_path, i, list(methods.keys()))
    
    # # multi_plot_ranking_results(df_views_ordered, bnn_probs_out_path, i, [k.title() for k in time_dict.keys() if k != "all"])
    
    # print(f"batch {i} done")

Predicting results


KeyboardInterrupt: 

In [None]:
# time for randomisation result generation: 0.1786808967590332
# time for complexity result generation: 0.16939043998718262
# time for all metrics: 0.3481442928314209

In [None]:
explanations["FG"]

array([[[ 0.17624553, -0.9960372 ,  0.0209351 , ...,  0.24069881,
          0.32969385, -0.02676467],
        [-1.5382316 ,  0.75300705, -2.7950165 , ..., -0.1263105 ,
         -0.10576378, -1.0180641 ],
        [ 0.18729898, -1.7753942 ,  0.443183  , ...,  0.7980205 ,
         -0.23002842, -0.44936904],
        ...,
        [-0.38557163,  0.30142114,  0.0569355 , ..., -1.6081734 ,
          0.2721547 , -0.16181341],
        [-0.22954102,  0.54201096,  0.09387596, ...,  0.02369696,
         -0.26130134,  0.0616801 ],
        [-0.14371544, -0.17132768, -0.04146149, ..., -0.17579845,
         -0.32051405,  0.22982499]],

       [[ 0.2774817 ,  0.39340746, -0.7173558 , ..., -0.01491425,
          0.19197494, -0.28781974],
        [-0.03366408, -0.5304768 ,  1.0151659 , ...,  1.0315105 ,
          1.101684  ,  0.6832551 ],
        [ 0.5846871 ,  0.20212722, -0.4523396 , ...,  0.26995775,
          1.262933  ,  0.66842467],
        ...,
        [ 0.5710393 , -0.27555895,  1.5714474 , ..., -

In [None]:
res["Score"][0]

NameError: name 'res' is not defined