In [None]:
import numpy as np
import pandas as pd
import pickle  
from sklearn import metrics
import shap
import tensorflow as tf
from tensorflow.keras.models import load_model

import logging, os
logging.disable(logging.WARNING)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

Results have been obtained by running SHAPForPrediction_tf1.py script

SHAP documentation available at [SHAP](https://github.com/slundberg/shap) GitHub repository.

Datasets and SHAP results available at [ExplainedDecisions](https://osf.io/wgk8e/?view_only=8aec18499ed8457cb296032545963542) public repository.

- Check TensorFlow and SHAP version for compatibility 

In [None]:
tf.__version__ , shap.__version__

- Load the dataset file ('file'), the training and test indexes file ('file_to_open') and the model file ('file_model')
- Select the decision horizon Thor ('look_forward')
- Select the sampling time ('step') linked to sequence length Tseq

In [None]:
# Example for the ANN trained on Novice data, in the default case of Tseq = 1s (step = 1) and a prediction horizon Thor = 1 (look_forward = 1)

file = open("./Datasets/DatasetFileMultiClassPred_BothHerders_WrtGoal_Extended_step2","rb")
 
look_forward = 16

# step = ""

model_id = '26022022'
file_id = '001'

directory = './checkpoint/FinalModels/'
file_to_open = open(directory + model_id + '/TrainTestSets_Expert_step2_thor16,"rb")
file_model = directory + model_id  + '/'+ model_id + file_id

- Load the dataset and select the columns referred by 'Labels'

In [None]:
Labels = ['h_t0 rel dist', 'h_t1 rel dist', 'h_t2 rel dist', 'h_t3 rel dist', 'h_t0 rel angle', 'h_t1 rel angle', 
          'h_t2 rel angle', 'h_t3 rel angle', 
          'h_goal rel dist', 't0_goal rel dist', 't1_goal rel dist', 't2_goal rel dist', 't3_goal rel dist',
          'h vel_r' , 't0 vel_r' , 't1 vel_r' , 't2 vel_r' ,  't3 vel_r' , 
          'h acc_r', 't0 acc_r', 't1 acc_r', 't2 acc_r', 't3 acc_r', 
          'h_goal_th', 't0_goal_th', 't1_goal_th', 't2_goal_th', 't3_goal_th', 
          'h_dir_motion', 't0_dir_motion', 't1_dir_motion', 't2_dir_motion', 't3_dir_motion',
          'h_h1 rel dist', 'h_h1 rel angle', 'h1_goal rel dist', 'h1 vel_r', 'h1 acc_r',
          'h1_goal_th', 'h1_dir_motion', 'h1_t0 rel dist', 'h1_t1 rel dist', 'h1_t2 rel dist', 'h1_t3 rel dist', 
          'h1_t0 rel angle', 'h1_t1 rel angle', 'h1_t2 rel angle', 'h1_t3 rel angle','Label']

Labels.insert(0,"Herder_id")
Labels.insert(1,"Trial_id")

In [None]:
Dataset_full_df = pickle.load(file)
file.close()
    
Dataset_df = Dataset_full_df[Labels]

n_features = len(Dataset_df.columns) - 3
print("there are ", n_features," features!")

Dataset = Dataset_df.values

- Create the sequences of features and target outputs from the dataset

In [None]:
sequences, sequences_labels, targets = [],[],[]

herders_tot = int(max(Dataset[:,0])) + 1
trial_tot = int(max(Dataset[:,1])) + 1

for herder_id in range(herders_tot):
    for trial_id in range(trial_tot):
        Dtst = Dataset_df[(Dataset_df["Herder_id"]==herder_id) & (Dataset_df["Trial_id"]==trial_id)].values[:,2:]
        seq, tar, seq_lbl = uf.create_dataset(Dtst, look_back, look_forward)
        sequences = sequences + seq
        targets = targets + tar
        sequences_labels = sequences_labels + seq_lbl

sequences_array = np.array(sequences)
targets_array = np.array(targets)
sequences_labels_array = np.array(sequences_labels)

- Select from the total available samples the ones used for training and test 

In [None]:
type_index = indexes_data[0]
train_index = indexes_data[1]
test_index = indexes_data[2]

X_senior, y_senior, Z_senior = sequences_array[type_index], targets_array[type_index], sequences_labels_array[type_index]
X_test = X_senior[test_index]
y_test = y_senior[test_index]
Z_test = Z_senior[test_index]
        
dummies_test = pd.get_dummies(y_test)

targets_labels_array = uf.checkSamplesType(Z_test)
dummies_train = pd.get_dummies(y_train)

# Compute performance metrics of the ANN

- Select the test set

In [None]:
# test_set = X_test
# test_set_target = dummies_test.values

- Load the trained ANN

In [None]:
model = load_model(file_model)

- Use the trained ANN on the test set and Compute metrics for the trained ANN

In [None]:
test_preds = model.predict(X_test)

predicted_classes = np.argmax(test_preds,axis=1)
expected_classes = np.argmax(dummies_test.values,axis=1)
correct = metrics.accuracy_score(expected_classes,predicted_classes)

print("------ Accuracy: %.2f%%" % (correct*100))

precision_recall_f1 = metrics.precision_recall_fscore_support(expected_classes,predicted_classes)

precision, recall, f1 = 0, 0, 0

for i in range (5): 
    precision = precision + precision_recall_f1[0][i]
    recall = recall + precision_recall_f1[1][i]
    f1 = f1 + precision_recall_f1[2][i]

print("Macro-Precision: %.2f%%" % (precision*100 / 5))
print("-- Macro-Recall: %.2f%%" % (recall*100 / 5))
print("------ Macro-F1: %.2f%%" % (f1*100 /5))

kappascore = metrics.cohen_kappa_score(expected_classes, predicted_classes)
print("---- KappaScore: %.2f%%" % (kappascore*100))

confusionMatrix = metrics.confusion_matrix(expected_classes, predicted_classes, normalize='true')
metrics.ConfusionMatrixDisplay(confusionMatrix).plot()

# Load and print the SHAP values' file 

In [None]:
file_name_shap = file_model+'_ShapVal'

with open(file_name_shap,'rb') as file:
    shap_values = pickle.load(file)
    
shap_values = shap_values[0]

In [None]:
# shap.initjs()   # uncomment this line to display SHAP plots 

- For each class, compute the mean of shap values associated to each input features and display the top 10 [https://github.com/slundberg/shap/issues/632]

In [None]:
feat_list = []

for class_id in range(5):
    vals = np.abs(shap_values[class_id]).mean(0)
    feature_importance = pd.DataFrame(list(zip(Labels[2:-1], sum(vals))), columns=['var','feature_importance_vals_class'+str(class_id)])
    feature_importance.sort_values(by=['feature_importance_vals_class'+str(class_id)], ascending=False,inplace=True)
    feat_list.append(feature_importance.values[:10,0])
    print('\n class ', class_id)
    print(feature_importance.values[:10,0],"\n")
