# Computing The Aggregated Shap Value
In our paper we have grouped features, extracted from the eeg recordings based on different conditions, e.g. based on the frequency band a feature was extracted from. In this notebook we demonstrate how to compute the aggregated of shap value for a distinct grouping.
This allows us to create rank orders based on feature group importance witch is a crucial step for our proposed ShapAgreement metric.

In [None]:
import pickle
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from sklearn.preprocessing import LabelEncoder
from config.interpolation_maps import chan_map_R12

In [None]:
def load_object(fname):
    try:
        with open(fname + ".pickle", "rb") as f:
            return pickle.load(f)
    except Exception as ex:
        print("Error during unpickling object (Possibly unsupported):", ex)

### Loading The Training Set
First we load the provided sample training set (`12-all` used in the paper) with the provided names of the features. The names are constructed as:

`resting state` _ `frequency band` _ `extraction method` _ `channel name`

`resting state`: eyes open (EO), eyes closed (EC)
`frequency band`: the frequency band the feature was extracted from (delta, theta, alpha, beta, whole_spectrum)
`extraction method`: the name of the extraction method used to compute the feature
`channel name`: the name/label of the channel

In [None]:
data = load_object("data/example_training_set/training_set")
x = data['x']
x_names = data['x_names']

### Loading The Shap Values
For demonstration purposes we have provided the shap values for all our models trained on set `12-all`. The *shap_values.pickle* file is a dict object containing the shap values on key `shap_values` and the fold indexes on key `fold`. We use the fold indexes to load the samples for which we computed the corresponding shap values (`x_test`).

In [None]:
# Corresponds to the folder name in the models folder
MODEL = "XGBoost"

shap_data = load_object(f"models/{MODEL}/shap_values")
fold = shap_data['fold']
shap_values = shap_data['shap_values']

x_train = [x[i] for i in fold[0]]
x_test = [x[i] for i in fold[1]]
x_train_df = pd.DataFrame(x_train, columns=x_names)
x_test_df = pd.DataFrame(x_test, columns=x_names)

### Grouping Features
Now we define our function for grouping our features. The provided function takes in a list of all feature names (`x_names`) and groups them based on the frequency band.

In [None]:
freq_bands = ['delta', 'theta', 'alpha', 'beta', 'whole_spec']

def group_freq_bands_shap(label_arr):
    feature_groups_fb = []
    n_labels_fb = []
    
    for fb in freq_bands:
        feature_group_idx = [i for i, name in enumerate(label_arr) if fb in name and
                             (fb != 'whole_spec' or not any(ofb in name for ofb in freq_bands[:-1]))]

        if feature_group_idx:
            feature_groups_fb.append(feature_group_idx)
            n_labels_fb.append(fb)

    return n_labels_fb, feature_groups_fb

n_labels, feature_groups = group_freq_bands_shap(x_names)

### Shap Value Aggrigation
Now we compute the aggregated shap value for our feature groups by summing the shap values within a group for each sample. Then we compute the mean absolute shap value across our samples.

In [None]:
# Calculate aggregated SHAP values for each feature group
grouped_shap_values = np.zeros((len(x_test), len(n_labels)))
for i, group in enumerate(feature_groups):
    grouped_shap_values[:, i] = np.sum(shap_values[:, group], axis=1)

# Transposing the array so the dimensions are: [groups] x [shap values of samples]
grouped_shap_values = np.array(grouped_shap_values).transpose()

# Computing the mean absolute shap value
mean_abs_shap = []
for label_shap_values in grouped_shap_values:
    mean_abs_shap.append(np.mean(np.abs(label_shap_values)))

### Group Rank Order Based On Shap Values
Now we sort the aggregated shap values in order to get the group rank and plot the results.

In [None]:
n_labels = [r'$\mathbf{\delta}$', r'$\mathbf{\theta}$', r'$\mathbf{\alpha}$', r'$\mathbf{\beta}$', r'$\mathbf{\omega}$']
vals, labels = zip(* sorted(zip(mean_abs_shap, n_labels)))

plt.barh(labels, vals)
plt.title(r"$\mathbf{c)}$ Frequency Band Importance", weight='bold')
plt.xlabel('grouped SHAP value', weight='bold')
plt.tight_layout()
plt.show()