In [None]:
import os
import random
import warnings
import numpy as np
import pandas as pd
import matplotlib.gridspec as gridspec
import torch
import matplotlib.pyplot as plt
import sys
from sklearn.metrics import balanced_accuracy_score

random.seed(123)
torch.manual_seed(123)
np.random.seed(123)

In [None]:
sys.path.insert(1, '../../models/')
from CoreSNN import *
sys.path.insert(1, '../')
from ExplanationCreation import *
from ExplanationEvaluation import *

In [None]:
warnings.filterwarnings('ignore', 'y_pred contains classes not in y_true')

In [None]:
dataset = load_obj('../../data/synthetic/syn_data.pkl')
testset_t = load_obj('../../data/synthetic/expl_syn_testset.pkl')
y_true = dataset['y_test'][:, testset_t]

# Output-completeness

## Read result data

In [None]:
os.getcwd()

In [None]:
epsilons = {'s': [], 'ns':[], 'sam':[]} # Fill tested epsilon values here

def get_oc_scores(nb_layer, expl_type, epsilons):
    epsilons = get_epsilons(dataset, expl_type)
    _, y_pred_0, y_pred_p_0 = load_obj(os.getcwd() + '\\output_completeness\\{}\\{}L_oc_epsilon{}.pkl'.format(expl_type, nb_layer, epsilons[0]))
    _, y_pred_5, y_pred_p_5 = load_obj(os.getcwd() + '\\output_completeness\\{}\\{}L_oc_epsilon{}.pkl'.format(expl_type, nb_layer, epsilons[1]))
    _, y_pred_10, y_pred_p_10 = load_obj(os.getcwd() + '\\output_completeness\\{}\\{}L_oc_epsilon{}.pkl'.format(expl_type, nb_layer, epsilons[2]))
    _, y_pred_15, y_pred_p_15 = load_obj(os.getcwd() + '\\output_completeness\\{}\\{}L_oc_epsilon{}.pkl'.format(expl_type, nb_layer, epsilons[3]))
    _, y_pred_20, y_pred_p_20 = load_obj(os.getcwd() + '\\output_completeness\{}\\{}L_oc_epsilon{}.pkl'.format(expl_type, nb_layer, epsilons[4]))
    _, y_pred_25, y_pred_p_25 = load_obj(os.getcwd() + '\\output_completeness\\{}\\{}L_oc_epsilon{}.pkl'.format(expl_type, nb_layer, epsilons[5]))
    _, y_pred_50, y_pred_p_50 = load_obj(os.getcwd() + '\\output_completeness\\{}\\{}L_oc_epsilon{}.pkl'.format(expl_type, nb_layer, epsilons[6]))
    _, y_pred_75, y_pred_p_75 = load_obj(os.getcwd() + '\\output_completeness\\{}\\{}L_oc_epsilon{}.pkl'.format(expl_type, nb_layer, epsilons[7]))

    scores = [balanced_accuracy_score(y_pred_0, y_pred_p_0),
                  balanced_accuracy_score(y_pred_5, y_pred_p_5),
                  balanced_accuracy_score(y_pred_10, y_pred_p_10),
                  balanced_accuracy_score(y_pred_15, y_pred_p_15),
                  balanced_accuracy_score(y_pred_20, y_pred_p_20),
                  balanced_accuracy_score(y_pred_25, y_pred_p_25),
                  balanced_accuracy_score(y_pred_50, y_pred_p_50),
                  balanced_accuracy_score(y_pred_75, y_pred_p_75)]
    return scores

In [None]:
oc_s = []
oc_ns = []
oc_sam = []

for nb_layer in range(3):
    oc_s.append(get_oc_scores(nb_layer, 's', epsilons['s']))
    oc_ns.append(get_oc_scores(nb_layer, 'ns', epsilons['ns']))
    oc_sam.append(get_oc_scores(nb_layer, 'sam', epsilons['sam']))


## Visualize results

In [None]:
fig = plt.figure(tight_layout=True, frameon=False, figsize=(15,5),dpi=200)
gs = gridspec.GridSpec(1,3)

ax1 = fig.add_subplot(gs[0,0])
ax1.plot(oc_s[0])
ax1.plot(oc_ns[0])
ax1.plot(oc_sam[0], linestyle='dotted')
ax1.set_ylim(ymin = 0, ymax=1)
ax1.set_ylabel('Output-completeness score', fontdict={'fontsize': 16})
ax1.set_xticks(range(8))
ax1.set_xticklabels([0, '5%', '10%', '15%', '20%', '25%', '50%', '75%'])
ax1.set_xlabel('Epsilon', fontdict={'fontsize': 16})
ax1.set_title('Output-completeness for explanations of SNN-1L', pad=10)
ax1.legend(['TSA-S','TSA-NS', 'SAM'], prop={'size':13})

ax2 = fig.add_subplot(gs[0,1])
ax2.plot(oc_s[1])
ax2.plot(oc_ns[1])
ax2.plot(oc_sam[1], linestyle='dotted')
ax2.set_ylim(ymin = 0, ymax=1)
ax2.set_ylabel('Output-completeness  score', fontdict={'fontsize': 16})
ax2.set_xticks(range(8))
ax2.set_xticklabels([0, '5%', '10%', '15%', '20%', '25%', '50%', '75%'])
ax2.set_xlabel('Epsilon', fontdict={'fontsize': 16})
ax2.set_title('Output-completeness for explanations of SNN-2L', pad=10)
ax2.legend(['TSA-S','TSA-NS', 'SAM'], prop={'size':13})

ax3 = fig.add_subplot(gs[0,2])
ax3.plot(oc_s[2])
ax3.plot(oc_ns[2])
ax3.plot(oc_sam[2], linestyle='dotted')
ax3.set_ylim(ymin = 0, ymax=1)
ax3.set_ylabel('Output-completeness score', fontdict={'fontsize': 16})
ax3.set_xticks(range(8))
ax3.set_xticklabels([0, '5%', '10%', '15%', '20%', '25%', '50%', '75%'])
ax3.set_xlabel('Epsilon', fontdict={'fontsize': 16})
ax3.set_title('Output-completeness for explanations of SNN-3L', pad=10)
ax3.legend(['TSA-S', 'TSA-NS', 'SAM'], prop={'size':13})


# Correctness

## Reading the data

In [None]:
y_preds_p_one_s = load_obj(os.getcwd()+'\\correctness\\{}\\{}\\y_preds_perturbed_{}.pkl'.format('syn', 's', 'onelayer'))
y_preds_p_two_s = load_obj(os.getcwd()+'\\correctness\\{}\\{}\\y_preds_perturbed_{}.pkl'.format('syn', 's', 'twolayer'))
y_preds_p_three_s = load_obj(os.getcwd()+'\\correctness\\{}\\{}\\y_preds_perturbed_{}.pkl'.format('syn', 's', 'threelayer'))

y_preds_p_one_ns2 = load_obj(os.getcwd()+'\\correctness\\{}\\{}\\y_preds_perturbed_{}.pkl'.format('syn', 'ns2', 'onelayer'))
y_preds_p_two_ns2 = load_obj(os.getcwd()+'\\correctness\\{}\\{}\\y_preds_perturbed_{}.pkl'.format('syn', 'ns2', 'twolayer'))
y_preds_p_three_ns2 = load_obj(os.getcwd()+'\\correctness\\{}\\{}\\y_preds_perturbed_{}.pkl'.format('syn', 'ns2', 'threelayer'))


In [None]:
# some postprocessing since some explanations are empty, so fill up the label arrays to make the ground truth to compare against
for i, y_pred in enumerate(y_preds_p_one_s):
    if len(y_pred) == 0:
        y_preds_p_one_s[i].append(y_true[:, i][0])
for i, y_pred in enumerate(y_preds_p_two_s):
    if len(y_pred) == 0:
        y_preds_p_two_s[i].append(y_true[:, i][0])
for i, y_pred in enumerate(y_preds_p_three_s):
    if len(y_pred) == 0:
        y_preds_p_three_s[i].append(y_true[:, i][0])
        
for i, y_pred in enumerate(y_preds_p_one_ns2):
    if len(y_pred) == 0:
        y_preds_p_one_ns2[i].append(y_true[:, i][0])
for i, y_pred in enumerate(y_preds_p_two_ns2):
    if len(y_pred) == 0:
        y_preds_p_two_ns2[i].append(y_true[:, i][0])
for i, y_pred in enumerate(y_preds_p_three_ns2):
    if len(y_pred) == 0:
        y_preds_p_three_ns2[i].append(y_true[:, i][0])

## Normalization and combination of the results per model

In [None]:
def get_perf_curve_yhat(y_preds_p):
    y_hat = [pred[0] for pred in y_preds_p] # get the first prediction of the model without perturbations yet
    perf = []
    for i in range(max([len(y_pred) for y_pred in y_preds_p])):
        y_pred_p = [pred[i] if i<len(pred) else pred[-1] for pred in y_preds_p]
        perf.append(balanced_accuracy_score(y_hat, y_pred_p))
    return perf

def get_perf_curve_ytrue(y_preds_p):
    perf = []
    for i in range(max([len(y_pred) for y_pred in y_preds_p])):
        y_pred_p = [pred[i] if i<len(pred) else pred[-1] for pred in y_preds_p]
        perf.append(balanced_accuracy_score(y_true[0], y_pred_p))
    return perf

In [None]:
perf_one_yhat_s = get_perf_curve_yhat(y_preds_p_one_s)
perf_one_ytrue_s = get_perf_curve_ytrue(y_preds_p_one_s)
perf_one_yhat_ns2 = get_perf_curve_yhat(y_preds_p_one_ns2)
perf_one_ytrue_ns2 = get_perf_curve_ytrue(y_preds_p_one_ns2)

perf_two_yhat_s = get_perf_curve_yhat(y_preds_p_two_s)
perf_two_ytrue_s = get_perf_curve_ytrue(y_preds_p_two_s)
perf_two_yhat_ns2 = get_perf_curve_yhat(y_preds_p_two_ns2)
perf_two_ytrue_ns2 = get_perf_curve_ytrue(y_preds_p_two_ns2)

perf_three_yhat_s = get_perf_curve_yhat(y_preds_p_three_s)
perf_three_ytrue_s = get_perf_curve_ytrue(y_preds_p_three_s)
perf_three_yhat_ns2 = get_perf_curve_yhat(y_preds_p_three_ns2)
perf_three_ytrue_ns2 = get_perf_curve_ytrue(y_preds_p_three_ns2)


In [None]:
ess_yhat_one_s = metrics.auc(range(len(perf_one_yhat_s)), perf_one_yhat_s)
norm_ess_yhat_one_s = ess_yhat_one_s/len(perf_one_yhat_s)
ess_yhat_one_ns2 = metrics.auc(range(len(perf_one_yhat_ns2)), perf_one_yhat_ns2)
norm_ess_yhat_one_ns2 = ess_yhat_one_ns2/len(perf_one_yhat_ns2)

ess_yhat_two_s = metrics.auc(range(len(perf_two_yhat_s)), perf_two_yhat_s)
norm_ess_yhat_two_s = ess_yhat_two_s/len(perf_two_yhat_s)
ess_yhat_two_ns2 = metrics.auc(range(len(perf_two_yhat_ns2)), perf_two_yhat_ns2)
norm_ess_yhat_two_ns2 = ess_yhat_two_ns2/len(perf_two_yhat_ns2)

ess_yhat_three_s = metrics.auc(range(len(perf_three_yhat_s)), perf_three_yhat_s)
norm_ess_yhat_three_s = ess_yhat_three_s/len(perf_three_yhat_s)
ess_yhat_three_ns2 = metrics.auc(range(len(perf_three_yhat_ns2)), perf_three_yhat_ns2)
norm_ess_yhat_three_ns2 = ess_yhat_three_ns2/len(perf_three_yhat_ns2)

print('ESS for TSA-S Explanations from OneLayerSNN: ', norm_ess_yhat_one_s, ' +- ', conf_interval(norm_ess_yhat_one_s, 100))
print('ESS for TSA-NS2 Explanations from OneLayerSNN: ', norm_ess_yhat_one_ns2, ' +- ', conf_interval(norm_ess_yhat_one_ns2, 100))

print('ESS for TSA-S Explanations from TwoLayerSNN: ', norm_ess_yhat_two_s, ' +- ', conf_interval(norm_ess_yhat_two_s, 100))
print('ESS for TSA-NS2 Explanations from TwoLayerSNN: ', norm_ess_yhat_two_ns2, ' +- ', conf_interval(norm_ess_yhat_two_ns2, 100))

print('ESS for TSA-S Explanations from ThreeLayerSNN: ', norm_ess_yhat_three_s, ' +- ', conf_interval(norm_ess_yhat_three_s, 100))
print('ESS for TSA-NS2 Explanations from ThreeLayerSNN: ', norm_ess_yhat_three_ns2, ' +- ', conf_interval(norm_ess_yhat_three_ns2, 100))


In [None]:
# conf_interval((norm_ess_yhat_one_ns2 + norm_ess_yhat_two_ns2 + norm_ess_yhat_three_ns2)/3, 180)

In [None]:
fig = plt.figure(tight_layout=True, dpi=150, frameon=False, figsize=(15,5))
gs = gridspec.GridSpec(1,3)

ax1 = fig.add_subplot(gs[0,0])
ax1.plot(perf_one_yhat_s, color=(194/256, 154/256, 177/256))
ax1.plot(perf_one_yhat_ns2, color=(219/256, 154/256, 143/256))
ax1.set_ylim(0,1)
ax1.set_ylabel('Balanced accuracy', fontdict={'size': 16})
ax1.set_xlabel('Number of flipped segments', fontdict={'size': 16})
ax1.legend(['TSA-S', 'TSA-NS2'], prop={'size':13})
ax1.set_title('Balanced accuracy of OneLayerSNN\n with flipped feature segments')

ax2 = fig.add_subplot(gs[0,1])
ax2.plot(perf_two_yhat_s, color=(194/256, 154/256, 177/256))
ax2.plot(perf_two_yhat_ns2, color=(219/256, 154/256, 143/256))
ax2.set_ylim(0,1)
ax2.set_ylabel('Balanced accuracy', fontdict={'size': 16})
ax2.set_xlabel('Number of flipped segments', fontdict={'size': 16})
ax2.legend(['TSA-S', 'TSA-NS2'], prop={'size':13})
ax2.set_title('Balanced accuracy of TwoLayerSNN \n with flipped feature segments')

ax3 = fig.add_subplot(gs[0,2])
ax3.plot(perf_three_yhat_s, color=(194/256, 154/256, 177/256))
ax3.plot(perf_three_yhat_ns2, color=(219/256, 154/256, 143/256))
ax3.set_ylim(0,1)
ax3.set_ylabel('Balanced accuracy', fontdict={'size': 16})
ax3.set_xlabel('Number of flipped segments', fontdict={'size': 16})
ax3.legend(['TSA-S', 'TSA-NS2'], prop={'size':13})
ax3.set_title('Balanced accuracy of ThreeLayerSNN \n with flipped feature segments')



In [None]:
fig = plt.figure(tight_layout=True, dpi=150, frameon=False, figsize=(15,5))
gs = gridspec.GridSpec(1,3)

ax1 = fig.add_subplot(gs[0,0])
ax1.plot(perf_one_ytrue_s, color=(194/256, 154/256, 177/256))
ax1.plot(perf_one_ytrue_ns2, color=(219/256, 154/256, 143/256))
ax1.set_ylim(0,1)
ax1.set_ylabel('Balanced accuracy', fontdict={'size': 16})
ax1.set_xlabel('Number of flipped segments', fontdict={'size': 16})
ax1.legend(['NCS', 'TSA-NS'], prop={'size':13})

ax2 = fig.add_subplot(gs[0,1])
ax2.plot(perf_two_ytrue_s, color=(194/256, 154/256, 177/256))
ax2.plot(perf_two_ytrue_ns2, color=(219/256, 154/256, 143/256))
ax2.set_ylim(0,1)
ax2.set_ylabel('Balanced accuracy', fontdict={'size': 16})
ax2.set_xlabel('Number of flipped segments', fontdict={'size': 16})
ax2.legend(['NCS', 'TSA-NS'], prop={'size':13})

ax3 = fig.add_subplot(gs[0,2])
ax3.plot(perf_three_ytrue_s, color=(194/256, 154/256, 177/256))
ax3.plot(perf_three_ytrue_ns2, color=(219/256, 154/256, 143/256))
ax3.set_ylim(0,1)
ax3.set_ylabel('Balanced accuracy', fontdict={'size': 16})
ax3.set_xlabel('Number of flipped segments', fontdict={'size': 16})
ax3.legend(['NCS', 'TSA-NS'], prop={'size':13})


# Sensitivity

### Read results

In [None]:
def get_sens_score(modelname, tsa_variant):
    score = load_obj(os.getcwd()+'\\continuity\\{}\\{}\\max_sensitivity_{}.pkl'.format('syn', tsa_variant, modelname))
    return score

In [None]:
max_sensitivity_one_s = get_sens_score('one', 's')
max_sensitivity_two_s = get_sens_score('two', 's')
max_sensitivity_three_s = get_sens_score('three', 's')

max_sensitivity_one_ns2 = get_sens_score('one', 'ns2')
max_sensitivity_two_ns2 = get_sens_score('two', 'ns2')
max_sensitivity_three_ns2 = get_sens_score('three', 'ns2')

### Analyze

In [None]:
sensitivities_s = [max_sensitivity_one_s, max_sensitivity_two_s, max_sensitivity_three_s]
sensitivities_ns2 = [max_sensitivity_one_ns2, max_sensitivity_two_ns2, max_sensitivity_three_ns2]
df_sensitivity = pd.DataFrame([sensitivities_s, sensitivities_ns2]).transpose()

In [None]:
df_sensitivity

# Compactness

### Load explanations. Compactness is then the sum of absolute attribution values.

In [None]:
def load_explanations(explanation_type, model):
    return load_obj(os.getcwd() + '\\expl_{}_syn_nocw_{}.pkl'.format(model, explanation_type))

In [None]:
def compute_compactness(model, explanation_type):
    explanations = {**load_explanations(explanation_type, model)}
    sum_absolute_attribution = 0
    for key in explanations.keys():
        sum_absolute_attribution += torch.sum(torch.abs(explanations[key][0]))
    return sum_absolute_attribution/100

In [None]:
models = ['one', 'two', 'three']
explanation_types = ['ns2', 's']

compactness_ns2 = []
compactness_s = []
for model in models:
    compactness_s.append(compute_compactness(model, 's'))
    compactness_ns2.append(compute_compactness(model, 'ns2'))


In [None]:
compactness_ns2

In [None]:
compactness_s

In [None]:
def compute_sample_std(model, explanation_type):
    explanations = {**load_explanations(explanation_type, model)}
    x_bar = compute_compactness(model, explanation_type)
    s = 0
    for key in explanations.keys():
        s += (torch.sum(torch.abs(explanations[key][0])) - x_bar)**2
    s = s/179
    return s

def compute_95_ci(s, n):
# t statistic is 1.97 for 180 dof and 95% ci
    return 1.97*(s/(n**0.5))

In [None]:
for model in models:
    for explanation_type in explanation_types:
        s = compute_sample_std(model, explanation_type)
        ci = compute_95_ci(s, 100)
        print('CI of {}, {}:{}'.format(model, explanation_type, ci))