## Preliminaries

In [None]:
# Load supporting functions.
import sys
sys.path.append('../')
from src import *

import quantus

In [None]:

# Import libraries.
import torch
import os
from tqdm import tqdm
import joblib

try:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('Using device:', torch.cuda.get_device_name(0))
except:
    pass

## Run Experiments

In [None]:
# Define all the hyperparameters here
# Set the paramters marked with TODO to reproduce paper results

savepath = "/some/path" #TODO Path to save results
dataset_name = "imagenet"
data_path = "/some/path" #TODO Path to imagenet dataset
labelmap_path = "../src/label_map_imagenet.json"
model_names = ["vgg16", "resnet18"]
xai_methods = [
   "SmoothGrad",
   "IntegratedGradients",
   "LRP-Eps",
   "LRP-Z+",
   "Guided-Backprop",
   "Gradient",
   "Saliency"
] #TODO
nr_test_samples = 1000
smprt_nr_samples = [1, 2, 5] #TODO: careful, this may take a while to run. Set to [1, 20, 50, 300] to reproduce original paper results
smprt_noise_magnitude = 0.1
layer_orders = ["bottom_up", "top_down"]
eval_normalise = True
batch_size = 32
shuffle = False

recompute_results = True # Set this to False to only plot already computed results

linewidth=2.5

In [None]:
os.makedirs(savepath, exist_ok=True)

# Get Dataset
# Prepare transforms
transform = get_transforms(dataset_name, mode="test")

# Prepare datasets
print("Preparing datasets...")
dataset = get_dataset(
    dataset_name,
    data_path,
    transform,
    mode="test",
    labelmap_path=labelmap_path
)

print(f"Number of Samples in Dataset: {len(dataset.samples)}")
dataset.samples = dataset.samples[:nr_test_samples]
print(f"Reduced of Samples in Dataset: {len(dataset.samples)}")

# Prepare dataloaders
print("Preparing dataloaders...")
loader = get_dataloader(
    dataset_name = dataset_name,
    dataset = dataset,
    batch_size = batch_size,
    shuffle = False,
)

In [None]:
# Generate SmoothMPRT Results
if recompute_results:

    for model_name in model_names:

        # Prepare model
        model = get_model(model_name, device)
        model.eval()

        XAI_METHOD_KWARGS = setup_xai_methods_zennit(xai_methods, model, device)

        for layer_order in layer_orders:

            print(f"Computing accuracy-scores with for model {model_name} with {layer_order} order")

            # Compute Accuracy
            accuracy_scores = {}
            n_layers = len(list(get_random_layer_generator(model, order=layer_order)))
            model_iterator = tqdm(
                get_random_layer_generator(model, order=layer_order),
                total=n_layers,
                disable=True,
            )

            for l_ix, (layer_name, random_layer_model) in enumerate(model_iterator):

                if l_ix == 0:
                    _, _, accuracy_scores["orig"] = eval_accuracy(model, loader, device)
                    accuracy_scores["orig"] = float(accuracy_scores["orig"])
                _, _, accuracy_scores[layer_name] = eval_accuracy(random_layer_model, loader, device)
                accuracy_scores[layer_name] = float(accuracy_scores[layer_name])

            filepath = os.path.join(savepath, f"accuracy--{model_name}--{layer_order}.joblib")
            joblib.dump(accuracy_scores, filepath)

            # Compute sMPRT Scores
            for nr_samples in smprt_nr_samples:

                for xai_method, xai_method_kwargs in XAI_METHOD_KWARGS.items():
                    print(f"Computing sMPRT-scores with N={nr_samples} for model {model_name} with {layer_order} order using {xai_method} explanations")
                    scores = {}

                    if xai_method in ["SmoothGrad", "Saliency"]:
                        metric_kwargs = {
                            "abs": True, #Set abs preprocessing for XAI methods where the sign has no meaning
                            "normalise": True,
                            "normalise_func": quantus.normalise_by_average_second_moment_estimate,
                            "similarity_func": quantus.ssim,
                            "layer_order": layer_order,
                            "nr_samples": nr_samples,
                            "noise_magnitude": smprt_noise_magnitude,
                        }
                    else:
                        metric_kwargs = {
                            "abs": False,
                            "normalise": True,
                            "normalise_func": quantus.normalise_by_average_second_moment_estimate,
                            "similarity_func": quantus.ssim,
                            "layer_order": layer_order,
                            "nr_samples": nr_samples,
                            "noise_magnitude": smprt_noise_magnitude,
                        }

                    metric = SmoothMPRT(
                        **metric_kwargs
                    )

                    for i, (batch, labels) in enumerate(loader):

                        batch_results = metric(
                            model=model,
                            x_batch=batch.numpy(),
                            y_batch=labels.numpy(),
                            a_batch=None,
                            device=device,
                            explain_func=quantus.explain,
                            explain_func_kwargs={**{"method": xai_method}, **xai_method_kwargs}
                        )

                        for k in batch_results.keys():
                            if k not in scores.keys():
                                scores[k] = batch_results[k]
                            else:
                                scores[k] += batch_results[k]


                    filepath = os.path.join(savepath, f"sMPRT--{model_name}--{layer_order}--{nr_samples}--{xai_method}.joblib")
                    joblib.dump(scores, filepath)


## Plots



In [None]:
# Loading Data
results = {}

for file in os.listdir(savepath):
    if file.endswith(".joblib"):
        if "sMPRT" in file:
            mod_name = file.split("--")[1]
            l_order = file.split("--")[2]
            n_samp = file.split("--")[3]
            meth_name = file.split("--")[4].split(".joblib")[0]
            res = joblib.load(os.path.join(savepath, file))

            if mod_name not in results.keys():
                results[mod_name] = {}
            if l_order not in results[mod_name].keys():
                results[mod_name][l_order] = {}
            if n_samp not in results[mod_name][l_order].keys():
                results[mod_name][l_order][n_samp] = {}
            if meth_name not in results[mod_name][l_order][n_samp].keys():
                results[mod_name][l_order][n_samp][meth_name] = {}

            data = []
            layer_names = []

            for l, d in res.items():
                data.append(d)
                layer_names.append(l)

            results[mod_name][l_order]["layer_names"] = layer_names
            results[mod_name][l_order][n_samp][meth_name] = data
        elif "accuracy" in file:
            mod_name = file.split("--")[1]
            l_order = file.split("--")[2].split(".joblib")[0]
            res = joblib.load(os.path.join(savepath, file))

            if mod_name not in results.keys():
                results[mod_name] = {}
            if l_order not in results[mod_name].keys():
                results[mod_name][l_order] = {}

            data = []
            layer_names = []

            for l, d in res.items():
                data.append(d)
                layer_names.append(l)

            results[mod_name][l_order]["accuracy"] = data

### sMPRT - Line Plots

In [None]:
for mod_name, mod_res in results.items():

    for l, l_order in enumerate(HATCH_MAP.keys()):

        if l_order not in mod_res.keys():
            continue
        l_res = mod_res[l_order]

        print(f"Plotting: Model {mod_name}, Randomisation Order {l_order}")

        layer_names = l_res["layer_names"]

        fig, ax = plt.subplots(figsize=(8, 4))

        ax.set_xlabel("Layers")
        ax.set_ylabel("SSIM")
        ax.set_xticks([])
        ax.grid(True)
        ax.set_ylim([0.0, 1.1])
        ax.set_yticks([0, 0.5, 1.0])
        ax.set_yticklabels([0, 0.5, 1])

        # xticklabels depend on model
        markevery=4 if "resnet" in mod_name else 2
        ax.set_xticks(list(range(len(layer_names)))[::markevery])
        xticklabels = layer_names[::markevery]
        for i in range(len(xticklabels)):
            xticklabels[i] = xticklabels[i].replace("downsample", "ds")
        xticklabels = xticklabels[:-1]+["final"]
        ax.set_xticklabels(xticklabels, rotation=45)

        for m, meth_name in enumerate(COLOR_MAP.keys()):

            for n, n_samp in enumerate(LINESTYLE_MAP.keys()):

                if n_samp not in l_res.keys():
                    continue
                n_res = l_res[n_samp]
                if meth_name not in n_res.keys():
                    continue
                res = n_res[meth_name]

                data = np.array(res)
                color = COLOR_MAP[meth_name]
                linestyle = LINESTYLE_MAP[n_samp]

                means = np.mean(data, axis=1)
                stds = np.std(data, axis=1)

                ax.plot(
                    list(range(len(means))), 
                    means, 
                    alpha=0.7, 
                    linewidth=linewidth, 
                    marker="o", 
                    markevery=markevery, 
                    linestyle=linestyle, 
                    color=color
                    )

        # Make Legend
        leg_lines = []
        LS_keys = [ls for ls in LINESTYLE_MAP.keys() if ls in l_res.keys()]
        for l in LS_keys:
            label = f"N={l}"
            leg_lines.append(ax.plot([], [], color="black", alpha=0.7, linewidth=linewidth, label=label, linestyle=LINESTYLE_MAP[l]))
        MK_KEYS = [mk for mk in COLOR_MAP.keys() if mk in n_res.keys()]
        for m in MK_KEYS:
            label=f"{m}"
            leg_lines.append(ax.plot([], [], color=COLOR_MAP[m], label=label, alpha=0.7, linewidth=linewidth))
        ax.legend()

        fig.savefig(os.path.join(savepath, f"sMPRT-lines-{mod_name}-{l_order}.svg"))

        plt.show()


### MPRT - randomisation order comparison lineplots (including accuracy)

In [None]:
for mod_name, mod_res in results.items():

    for l, l_order in enumerate(HATCH_MAP.keys()):

        print(f"Plotting: Model {mod_name}, Randomisation Order {l_order}")

        if l_order not in mod_res.keys():
            continue
        l_res = mod_res[l_order]

        layer_names = l_res["layer_names"]

        fig, ax = plt.subplots(figsize=(8, 4))

        ax.set_xlabel("Layers")
        ax.set_ylabel("SSIM / Accuracy")
        ax.set_xticks([])
        ax.grid(True)
        ax.set_ylim([0.0, 1.1])
        ax.set_yticks([0, 0.5, 1.0])
        ax.set_yticklabels([0, 0.5, 1])

        # xticklabels depend on model
        if "resnet" in mod_name:
            markevery=4
        else:
            markevery=2
        ax.set_xticks(list(range(len(layer_names)))[::markevery])
        xticklabels = layer_names[::markevery]
        ax.set_xticklabels(layer_names[::markevery])
        for i in range(len(xticklabels)):
            xticklabels[i] = xticklabels[i].replace("downsample", "ds")
        xticklabels = xticklabels[:-1]+["final"]
        ax.set_xticklabels(xticklabels, rotation=45)

        # Plot Methods
        for m, meth_name in enumerate(COLOR_MAP.keys()):

            for n, n_samp in enumerate(["1"]):

                if n_samp not in l_res.keys():
                    continue
                n_res = l_res[n_samp]
                if meth_name not in n_res.keys():
                    continue
                res = n_res[meth_name]

                data = np.array(res)
                color = COLOR_MAP[meth_name]
                linestyle = LINESTYLE_MAP[n_samp]

                means = np.mean(data, axis=1)
                stds = np.std(data, axis=1)

                ax.plot(
                    list(range(len(means))), 
                    means, 
                    alpha=0.7, 
                    linewidth=linewidth, 
                    marker="o", 
                    markevery=markevery, 
                    linestyle=linestyle, 
                    color=color
                    )
                ax.fill_between(
                    list(range(len(means))), 
                    means+stds, 
                    means-stds, 
                    facecolor=color, 
                    alpha=0.3
                    )

        # Plot Accuracy
        acc_res = l_res["accuracy"]

        data = np.array(acc_res)
        color = COLOR_MAP["Model"]
        linestyle = LINESTYLE_MAP["1"]

        ax.plot(
            list(range(len(data))), 
            data, 
            alpha=0.7, 
            linewidth=linewidth, 
            marker="o", 
            markevery=markevery, 
            linestyle=linestyle, 
            color=color
            )


        # Make Legend
        leg_lines = []
        MK_KEYS = [mk for mk in COLOR_MAP.keys() if mk in n_res.keys()]+["Model"]
        for m in MK_KEYS:
            label=f"{m}"
            leg_lines.append(ax.plot([], [], color=COLOR_MAP[m], label=label, alpha=0.7, linewidth=linewidth))
        ax.legend()

        fig.savefig(os.path.join(savepath, f"sMPRT-layerordercomparison-{mod_name}-{l_order}.svg"))

        plt.show()

### sMPRT - comparison of different N

In [None]:
for mod_name, mod_res in results.items():

    for l, l_order in enumerate(HATCH_MAP.keys()):

        if l_order not in mod_res.keys():
            continue
        l_res = mod_res[l_order]

        print(f"Plotting: Model {mod_name}, Randomisation Order {l_order}")

        layer_names = l_res["layer_names"]

        fig, ax = plt.subplots(figsize=(6, 4))

        ax.set_xlabel("N")
        ax.set_ylabel("AUC")
        ax.grid(True)
        ax.set_xlim((1, 300))

        for m, meth_name in enumerate(COLOR_MAP.keys()):

            means = []
            stds = []
            n_samps = []

            for n, n_samp in enumerate(l_res.keys()):

                if n_samp not in l_res.keys() or n_samp in ["layer_names", "accuracy"]:
                    continue
                n_res = l_res[n_samp]
                if meth_name not in n_res.keys():
                    continue
                res = n_res[meth_name]

                data = np.array(res)
                color = COLOR_MAP[meth_name]

                aucs = np.trapz(data, axis=0)
                auc_min = np.min(aucs)
                auc_25 = np.percentile(aucs, q=25)
                auc_50 = np.percentile(aucs, q=50)
                auc_75 = np.percentile(aucs, q=75)
                auc_max = np.max(aucs)
                color = COLOR_MAP[meth_name]

                means += [np.mean(aucs)]
                stds += [np.std(aucs)]
                n_samps += [int(n_samp)]

            ax.plot(
                n_samps, 
                means, 
                alpha=0.7, 
                linewidth=linewidth*2, 
                marker="o", 
                color=color
                )

        # Make Legend
        leg_lines = []
        MK_KEYS = [mk for mk in COLOR_MAP.keys() if mk in n_res.keys()]
        for m in MK_KEYS:
            label=f"{m}"
            leg_lines.append(ax.plot([], [], color=COLOR_MAP[m], label=label, alpha=0.7, linewidth=linewidth*2))
        ax.legend()

        fig.savefig(os.path.join(savepath, f"sMPRT-Ncomparison-{mod_name}-{l_order}.svg"))

        plt.show()