In [None]:
import os
import random

import numpy as np
import pandas as pd
import matplotlib.gridspec as gridspec
import torch
import matplotlib.pyplot as plt
import re

random.seed(123)
torch.manual_seed(123)
np.random.seed(123)

os.chdir('../models/')
from CoreSNN import *
os.chdir('../tsa/')
from ExplanationCreation import *
from ExplanationEvaluation import *

In [None]:
dataset = load_obj('data/dataset_max.pkl')
A_testset_t = load_obj('data/quantitative_test_t_A_final.pkl')
B_testset_t = load_obj('data/quantitative_test_t_B_final.pkl')
A_y_true = dataset['y_test_A'][:, A_testset_t]
B_y_true = dataset['y_test_B'][:, B_testset_t]

# Output-completeness

## Read result data

In [None]:
os.chdir('') #Add path to results
epsilons = {'s': [], 'ns':[], 'sam':[]} # Fill tested epsilon values here

def get_oc_scores(nb_layer, expl_type, epsilons):
    _, y_pred_A_0, y_pred_p_A_0 = load_obj('..\\evaluation\\output_completeness\\{}\\{}L_oc_A_epsilon{}.pkl'.format(expl_type, nb_layer, epsilons[0]))
    _, y_pred_A_25, y_pred_p_A_25 = load_obj('..\\evaluation\\output_completeness\\{}\\{}L_oc_A_epsilon{}.pkl'.format(expl_type, nb_layer, epsilons[1]))
    _, y_pred_A_50, y_pred_p_A_50 = load_obj('..\\evaluation\\output_completeness\\{}\\{}L_oc_A_epsilon{}.pkl'.format(expl_type, nb_layer, epsilons[2]))
    _, y_pred_A_75, y_pred_p_A_75 = load_obj('..\\evaluation\\output_completeness\\{}\\{}L_oc_A_epsilon{}.pkl'.format(expl_type, nb_layer, epsilons[3]))
    _, y_pred_B_0, y_pred_p_B_0 = load_obj('..\\evaluation\\output_completeness\\{}\\{}L_oc_B_epsilon{}.pkl'.format(expl_type, nb_layer, epsilons[0]))
    _, y_pred_B_25, y_pred_p_B_25 = load_obj('..\\evaluation\\output_completeness\\{}\\{}L_oc_B_epsilon{}.pkl'.format(expl_type, nb_layer, epsilons[1]))
    _, y_pred_B_50, y_pred_p_B_50 = load_obj('..\\evaluation\\output_completeness\\{}\\{}L_oc_B_epsilon{}.pkl'.format(expl_type, nb_layer, epsilons[2]))
    _, y_pred_B_75, y_pred_p_B_75 = load_obj('..\\evaluation\\output_completeness\\{}\\{}L_oc_B_epsilon{}.pkl'.format(expl_type, nb_layer, epsilons[3]))

    scores = [balanced_accuracy_score([*y_pred_A_0, *y_pred_B_0], [*y_pred_p_A_0, *y_pred_p_B_0]),
              balanced_accuracy_score([*y_pred_A_25, *y_pred_B_25], [*y_pred_p_A_25, *y_pred_p_B_25]),
              balanced_accuracy_score([*y_pred_A_50, *y_pred_B_50], [*y_pred_p_A_50, *y_pred_p_B_50]),
              balanced_accuracy_score([*y_pred_A_75, *y_pred_B_75], [*y_pred_p_A_75, *y_pred_p_B_75])]
    return scores

In [None]:
oc_s = []
oc_ns = []
oc_sam = []

for nb_layer in range(3):
    oc_s.append(get_oc_scores(nb_layer, 's', epsilons['s']))
    oc_ns.append(get_oc_scores(nb_layer, 'ns', epsilons['ns']))
    oc_sam.append(get_oc_scores(nb_layer, 'sam', epsilons['sam']))

## Visualize results

In [None]:
fig = plt.figure(tight_layout=True, frameon=False, figsize=(15,5),dpi=200)
gs = gridspec.GridSpec(1,3)

ax1 = fig.add_subplot(gs[0,0])
ax1.plot(oc_s[0])
ax1.plot(oc_ns[0])
ax1.plot(oc_sam[0], linestyle='dotted')
ax1.set_ylim(ymin = 0, ymax=1)
ax1.set_ylabel('Output-completeness score', fontdict={'fontsize': 16})
ax1.set_xticks(range(4))
ax1.set_xticklabels([0, '25%', '50%', '75%'])
ax1.set_xlabel('Epsilon', fontdict={'fontsize': 16})
ax1.set_title('Output-completeness for explanations of SNN-1L', pad=10)
ax1.legend(['TSA-S','TSA-NS', 'SAM'], prop={'size':13})

ax2 = fig.add_subplot(gs[0,1])
ax2.plot(oc_s[1])
ax2.plot(oc_ns[1])
ax2.plot(oc_sam[1], linestyle='dotted')
ax2.set_ylim(ymin = 0, ymax=1)
ax2.set_ylabel('Output-completeness score', fontdict={'fontsize': 16})
ax2.set_xticks(range(4))
ax2.set_xticklabels([0, '25%', '50%', '75%'])
ax2.set_xlabel('Epsilon', fontdict={'fontsize': 16})
ax2.set_title('Output-completeness for explanations of SNN-2L', pad=10)
ax2.legend(['TSA-S','TSA-NS', 'SAM'], prop={'size':13})

ax3 = fig.add_subplot(gs[0,2])
ax3.plot(oc_s[2])
ax3.plot(oc_ns[2])
ax3.plot(oc_sam[2], linestyle='dotted')
ax3.set_ylim(ymin = 0, ymax=1)
ax3.set_ylabel('Output-completeness score', fontdict={'fontsize': 16})
ax3.set_xticks(range(4))
ax3.set_xticklabels([0, '25%', '50%', '75%'])
ax3.set_xlabel('Epsilon', fontdict={'fontsize': 16})
ax3.set_title('Output-completeness for explanations of SNN-3L', pad=10)
ax3.legend(['TSA-S','TSA-NS', 'SAM'], prop={'size':13})

# Correctness

## Reading the data

In [None]:
y_preds_p_1A_s = load_obj('..\\correctness\\{}\\y_preds_perturbed_{}_{}.pkl'.format('s', '1L', 'A'))
y_preds_p_1B_s = load_obj('..\\correctness\\{}\\y_preds_perturbed_{}_{}.pkl'.format('s', '1L', 'B'))
y_preds_p_2A_s = load_obj('..\\correctness\\{}\\y_preds_perturbed_{}_{}.pkl'.format('s', '2L', 'A'))
y_preds_p_2B_s = load_obj('..\\correctness\\{}\\y_preds_perturbed_{}_{}.pkl'.format('s', '2L', 'B'))
y_preds_p_3A_s = load_obj('..\\correctness\\{}\\y_preds_perturbed_{}_{}.pkl'.format('s', '3L', 'A'))
y_preds_p_3B_s = load_obj('..\\correctness\\{}\\y_preds_perturbed_{}_{}.pkl'.format('s', '3L', 'B'))

y_preds_p_1A_ns = load_obj('..\\correctness\\{}\\y_preds_perturbed_{}_{}.pkl'.format('ns', '1L', 'A'))
y_preds_p_1B_ns = load_obj('..\\correctness\\{}\\y_preds_perturbed_{}_{}.pkl'.format('ns', '1L', 'B'))
y_preds_p_2A_ns = load_obj('..\\correctness\\{}\\y_preds_perturbed_{}_{}.pkl'.format('ns', '2L', 'A'))
y_preds_p_2B_ns = load_obj('..\\correctness\\{}\\y_preds_perturbed_{}_{}.pkl'.format('ns', '2L', 'B'))
y_preds_p_3A_ns = load_obj('..\\correctness\\{}\\y_preds_perturbed_{}_{}.pkl'.format('ns', '3L', 'A'))
y_preds_p_3B_ns = load_obj('..\\correctness\\{}\\y_preds_perturbed_{}_{}.pkl'.format('ns', '3L', 'B'))

y_preds_p_1A_sam = load_obj('..\\correctness\\{}\\y_preds_perturbed_{}_{}.pkl'.format('sam', '1L', 'A'))
y_preds_p_1B_sam = load_obj('..\\correctness\\{}\\y_preds_perturbed_{}_{}.pkl'.format('sam', '1L', 'B'))
y_preds_p_2A_sam = load_obj('..\\correctness\\{}\\y_preds_perturbed_{}_{}.pkl'.format('sam', '2L', 'A'))
y_preds_p_2B_sam = load_obj('..\\correctness\\{}\\y_preds_perturbed_{}_{}.pkl'.format('sam', '2L', 'B'))
y_preds_p_3A_sam = load_obj('..\\correctness\\{}\\y_preds_perturbed_{}_{}.pkl'.format('sam', '3L', 'A'))
y_preds_p_3B_sam = load_obj('..\\correctness\\{}\\y_preds_perturbed_{}_{}.pkl'.format('sam', '3L', 'B'))

In [None]:
# no feature segments found for A t=9 in the explanations extracted from ThreeLayerSNN, so only consider the original prediction 
y_preds_p_3A_s[37] = [5]

## Normalization and combination of the results per model

In [None]:
def get_perf_curve_yhat(y_preds_p_A, y_preds_p_B):
    y_hat = [pred[0] for pred in y_preds_p_A]
    y_hat_B = [pred[0] for pred in y_preds_p_B]
    y_hat.extend(y_hat_B)

    perf = []
    for i in range(max([len(y_pred) for y_pred in y_preds_p_A])):
        y_pred_p = [pred[i] if i<len(pred) else pred[-1] for pred in y_preds_p_A]
        y_pred_p_B = [pred[i] if i<len(pred) else pred[-1] for pred in y_preds_p_B]
        y_pred_p.extend(y_pred_p_B)
        perf.append(balanced_accuracy_score(y_hat, y_pred_p))
    return perf

def get_perf_curve_ytrue(y_preds_p_A, y_preds_p_B):
    perf = []
    for i in range(max([len(y_pred) for y_pred in y_preds_p_A])):
        y_pred_p = [pred[i] if i<len(pred) else pred[-1] for pred in y_preds_p_A]
        y_pred_p_B = [pred[i] if i<len(pred) else pred[-1] for pred in y_preds_p_B]
        y_pred_p.extend(y_pred_p_B)
        perf.append(balanced_accuracy_score([*A_y_true[0], *B_y_true[0]], y_pred_p))
    return perf

In [None]:
perf_1L_yhat_s = get_perf_curve_yhat(y_preds_p_1A_s, y_preds_p_1B_s)
perf_1L_ytrue_s = get_perf_curve_ytrue(y_preds_p_1A_s, y_preds_p_1B_s)
perf_1L_yhat_ns = get_perf_curve_yhat(y_preds_p_1A_ns, y_preds_p_1B_ns)
perf_1L_ytrue_ns = get_perf_curve_ytrue(y_preds_p_1A_ns, y_preds_p_1B_ns)
perf_1L_yhat_sam = get_perf_curve_yhat(y_preds_p_1A_sam, y_preds_p_1B_sam)
perf_1L_ytrue_sam = get_perf_curve_ytrue(y_preds_p_1A_sam, y_preds_p_1B_sam)

perf_2L_yhat_s = get_perf_curve_yhat(y_preds_p_2A_s, y_preds_p_2B_s)
perf_2L_ytrue_s = get_perf_curve_ytrue(y_preds_p_2A_s, y_preds_p_2B_s)
perf_2L_yhat_ns = get_perf_curve_yhat(y_preds_p_2A_ns, y_preds_p_2B_ns)
perf_2L_ytrue_ns = get_perf_curve_ytrue(y_preds_p_2A_ns, y_preds_p_2B_ns)
perf_2L_yhat_sam = get_perf_curve_yhat(y_preds_p_2A_sam, y_preds_p_2B_sam)
perf_2L_ytrue_sam = get_perf_curve_ytrue(y_preds_p_2A_sam, y_preds_p_2B_sam)

perf_3L_yhat_s = get_perf_curve_yhat(y_preds_p_3A_s, y_preds_p_3B_s)
perf_3L_ytrue_s = get_perf_curve_ytrue(y_preds_p_3A_s, y_preds_p_3B_s)
perf_3L_yhat_ns = get_perf_curve_yhat(y_preds_p_3A_ns, y_preds_p_3B_ns)
perf_3L_ytrue_ns = get_perf_curve_ytrue(y_preds_p_3A_ns, y_preds_p_3B_ns)
perf_3L_yhat_sam = get_perf_curve_yhat(y_preds_p_3A_sam, y_preds_p_3B_sam)
perf_3L_ytrue_sam = get_perf_curve_ytrue(y_preds_p_3A_sam, y_preds_p_3B_sam)


In [None]:
ess_yhat_1L_s = metrics.auc(range(len(perf_1L_yhat_s)), perf_1L_yhat_s)
norm_ess_yhat_1L_s = ess_yhat_1L_s/len(perf_1L_yhat_s)
ess_yhat_1L_ns = metrics.auc(range(len(perf_1L_yhat_ns)), perf_1L_yhat_ns)
norm_ess_yhat_1L_ns = ess_yhat_1L_ns/len(perf_1L_yhat_ns)
ess_yhat_1L_sam = metrics.auc(range(len(perf_1L_yhat_sam)), perf_1L_yhat_sam)
norm_ess_yhat_1L_sam = ess_yhat_1L_sam/len(perf_1L_yhat_sam)

ess_yhat_2L_s = metrics.auc(range(len(perf_2L_yhat_s)), perf_2L_yhat_s)
norm_ess_yhat_2L_s = ess_yhat_2L_s/len(perf_2L_yhat_s)
ess_yhat_2L_ns = metrics.auc(range(len(perf_2L_yhat_ns)), perf_2L_yhat_ns)
norm_ess_yhat_2L_ns = ess_yhat_2L_ns/len(perf_2L_yhat_ns)
ess_yhat_2L_sam = metrics.auc(range(len(perf_2L_yhat_sam)), perf_2L_yhat_sam)
norm_ess_yhat_2L_sam = ess_yhat_2L_sam/len(perf_2L_yhat_sam)

ess_yhat_3L_s = metrics.auc(range(len(perf_3L_yhat_s)), perf_3L_yhat_s)
norm_ess_yhat_3L_s = ess_yhat_3L_s/len(perf_3L_yhat_s)
ess_yhat_3L_ns = metrics.auc(range(len(perf_3L_yhat_ns)), perf_3L_yhat_ns)
norm_ess_yhat_3L_ns = ess_yhat_3L_ns/len(perf_3L_yhat_ns)
ess_yhat_3L_sam = metrics.auc(range(len(perf_3L_yhat_sam)), perf_3L_yhat_sam)
norm_ess_yhat_3L_sam = ess_yhat_3L_sam/len(perf_3L_yhat_sam)

print('ESS for TSA-S Explanations from SNN-1L: ', norm_ess_yhat_1L_s, ' +- ', conf_interval(norm_ess_yhat_1L_s, 180))
print('ESS for TSA-NS Explanations from SNN-1L: ', norm_ess_yhat_1L_ns, ' +- ', conf_interval(norm_ess_yhat_1L_ns, 180))
print('ESS for SAM Explanations from SNN-1L: ', norm_ess_yhat_1L_sam, ' +- ', conf_interval(norm_ess_yhat_1L_sam, 180))
print('ESS for TSA-S Explanations from SNN-2L: ', norm_ess_yhat_2L_s, ' +- ', conf_interval(norm_ess_yhat_2L_s, 180))
print('ESS for TSA-NS Explanations from SNN-2L: ', norm_ess_yhat_2L_ns, ' +- ', conf_interval(norm_ess_yhat_2L_ns, 180))
print('ESS for SAM Explanations from SNN-2L: ', norm_ess_yhat_2L_sam, ' +- ', conf_interval(norm_ess_yhat_2L_sam, 180))
print('ESS for TSA-S Explanations from SNN-3L: ', norm_ess_yhat_3L_s, ' +- ', conf_interval(norm_ess_yhat_3L_s, 180))
print('ESS for TSA-NS Explanations from SNN-3L: ', norm_ess_yhat_3L_ns, ' +- ', conf_interval(norm_ess_yhat_3L_ns, 180))
print('ESS for SAM Explanations from SNN-3L: ', norm_ess_yhat_3L_sam, ' +- ', conf_interval(norm_ess_yhat_3L_sam, 180))

In [None]:
fig = plt.figure(tight_layout=True, dpi=150, frameon=False, figsize=(15,5))
gs = gridspec.GridSpec(1,3)

ax1 = fig.add_subplot(gs[0,0])
ax1.plot(perf_1L_yhat_s)
ax1.plot(perf_1L_yhat_ns)
ax1.plot(perf_1L_yhat_sam, linestyle='dotted')
ax1.set_ylim(0,1)
ax1.set_ylabel('Balanced accuracy', fontdict={'size': 16})
ax1.set_xlabel('Number of flipped segments', fontdict={'size': 16})
ax1.legend(['TSA-S', 'TSA-NS', 'SAM'], prop={'size':13})
ax1.set_title('Balanced accuracy of SNN-1L\n with flipped feature segments\n with regard to original model predictions')

ax2 = fig.add_subplot(gs[0,1])
ax2.plot(perf_2L_yhat_s)
ax2.plot(perf_2L_yhat_ns)
ax2.plot(perf_2L_yhat_sam, linestyle='dotted')
ax2.set_ylim(0,1)
ax2.set_ylabel('Balanced accuracy', fontdict={'size': 16})
ax2.set_xlabel('Number of flipped segments', fontdict={'size': 16})
ax2.legend(['TSA-S', 'TSA-NS', 'SAM'], prop={'size':13})
ax2.set_title('Balanced accuracy of SNN-2L \n with flipped feature segments\n with regard to original model predictions')

ax3 = fig.add_subplot(gs[0,2])
ax3.plot(perf_3L_yhat_s)
ax3.plot(perf_3L_yhat_ns)
ax3.plot(perf_3L_yhat_sam, linestyle='dotted')
ax3.set_ylim(0,1)
ax3.set_ylabel('Balanced accuracy', fontdict={'size': 16})
ax3.set_xlabel('Number of flipped segments', fontdict={'size': 16})
ax3.legend(['TSA-S', 'TSA-NS', 'SAM'], prop={'size':13})
ax3.set_title('Balanced accuracy of SNN-3L \n with flipped feature segments\n with regard to original model predictions')



In [None]:
fig = plt.figure(tight_layout=True, dpi=150, frameon=False, figsize=(15,5))
gs = gridspec.GridSpec(1,3)

ax1 = fig.add_subplot(gs[0,0])
ax1.plot(perf_1L_ytrue_s)
ax1.plot(perf_1L_ytrue_ns)
ax1.plot(perf_1L_ytrue_sam, linestyle='dotted')
ax1.set_ylim(0,1)
ax1.set_ylabel('Balanced accuracy', fontdict={'size': 16})
ax1.set_xlabel('Number of flipped segments', fontdict={'size': 16})
ax1.legend(['TSA-S', 'TSA-NS', 'SAM'], prop={'size':13})
ax1.set_title('Balanced accuracy of SNN-1L\n with flipped feature segments\n with regard to ground truth')

ax2 = fig.add_subplot(gs[0,1])
ax2.plot(perf_2L_ytrue_s)
ax2.plot(perf_2L_ytrue_ns)
ax2.plot(perf_2L_ytrue_sam, linestyle='dotted')
ax2.set_ylim(0,1)
ax2.set_ylabel('Balanced accuracy', fontdict={'size': 16})
ax2.set_xlabel('Number of flipped segments', fontdict={'size': 16})
ax2.legend(['TSA-S', 'TSA-NS', 'SAM'], prop={'size':13})
ax2.set_title('Balanced accuracy of SNN-2L \n with flipped feature segments\n with regard to ground truth')

ax3 = fig.add_subplot(gs[0,2])
ax3.plot(perf_3L_ytrue_s)
ax3.plot(perf_3L_ytrue_ns)
ax3.plot(perf_3L_ytrue_sam, linestyle='dotted')
ax3.set_ylim(0,1)
ax3.set_ylabel('Balanced accuracy', fontdict={'size': 16})
ax3.set_xlabel('Number of flipped segments', fontdict={'size': 16})
ax3.legend(['TSA-S', 'TSA-NS', 'SAM'], prop={'size':13})
ax3.set_title('Balanced accuracy of SNN-3L \n with flipped feature segments\n with regard to ground truth')




# Sensitivity

### Read results

In [None]:
max_sensitivity_s_oneA = load_obj('/sensitivity/tsa-s/max_sensitivity_oneA.pkl')
max_sensitivity_s_oneB = load_obj('/sensitivity/tsa-s/max_sensitivity_oneB.pkl')
max_sensitivity_s_one = max(max_sensitivity_s_oneA, max_sensitivity_s_oneB)

max_sensitivity_s_twoA = load_obj('/sensitivity/tsa-s/max_sensitivity_twoA.pkl')
max_sensitivity_s_twoB = load_obj('/sensitivity/tsa-s/max_sensitivity_twoB.pkl')
max_sensitivity_s_two = max(max_sensitivity_s_twoA, max_sensitivity_s_twoB)

max_sensitivity_s_threeA = load_obj('/sensitivity/tsa-s/max_sensitivity_threeA.pkl')
max_sensitivity_s_threeB = load_obj('/sensitivity/tsa-s/max_sensitivity_threeB.pkl')
max_sensitivity_s_three = max(max_sensitivity_s_threeA, max_sensitivity_s_threeB)

max_sensitivity_ns_oneA = load_obj('/sensitivity/tsa-ns/max_sensitivity_oneA.pkl')
max_sensitivity_ns_oneB = load_obj('/sensitivity/tsa-ns/max_sensitivity_oneB.pkl')
max_sensitivity_ns_one = max(max_sensitivity_ns_oneA, max_sensitivity_ns_oneB)

max_sensitivity_ns_twoA = load_obj('/sensitivity/tsa-ns/max_sensitivity_twoA.pkl')
max_sensitivity_ns_twoB = load_obj('/sensitivity/tsa-ns/max_sensitivity_twoB.pkl')
max_sensitivity_ns_two = max(max_sensitivity_ns_twoA, max_sensitivity_ns_twoB)

max_sensitivity_ns_threeA = load_obj('/sensitivity/tsa-ns/max_sensitivity_threeA.pkl')
max_sensitivity_ns_threeB = load_obj('/sensitivity/tsa-ns/max_sensitivity_threeB.pkl')
max_sensitivity_ns_three = max(max_sensitivity_ns_threeA, max_sensitivity_ns_threeB)

baseline_sensitivity_A = load_obj('/sensitivity/max_sensitivity_baseline_A.pkl')
baseline_sensitivity_B = load_obj('/sensitivity/max_sensitivity_baseline_B.pkl')
max_sensitivity_baseline = max(baseline_sensitivity_A, baseline_sensitivity_B)

### Analyze

In [None]:
sensitivities_s = [max_sensitivity_s_one, max_sensitivity_s_two, max_sensitivity_s_three]
sensitivities_ns = [max_sensitivity_ns_one, max_sensitivity_ns_two, max_sensitivity_ns_three]
df_sensitivity = pd.DataFrame([sensitivities_s, sensitivities_ns]).transpose()

In [None]:
df_sensitivity

In [None]:
max_sensitivity_baseline