# ShapAgreement Metric
The notebook demonstrates how to compute the ShapAgreement between models. For demonstration purpose we provided the shap values for all models trained on set `12-all`.

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from scipy.cluster import hierarchy
from scipy.stats import spearmanr
from sklearn.preprocessing import LabelEncoder
import pickle

In [None]:
def load_object(fname):
    try:
        with open(fname + ".pickle", "rb") as f:
            return pickle.load(f)
    except Exception as ex:
        print("Error during unpickling object (Possibly unsupported):", ex)

### Selecting Models To Compare
In the list `MODEL_LIST` we can specify the models folder names to include in the ShapAgreement computation. We than load the training data to get the feature names as well as the training samples.

In [None]:
MODEL_LIST = [
    'MLP', 'KernelRidge', 'KNN', 'BaggedKNN', 'LassoRegression', 'EleasticNet', 
    'SVRegression', 'RandomForest', 'CatBoost', 'XGBoost'
]

# Load shared dataset
data = load_object('data/example_training_set/training_set')
x, groups, y, x_names = data['x'], data['group'], data['y'], data['x_names']


### Aggregation Function
Next we define our feature grouping function that takes in all feature names  constructed as:

`resting state` _ `frequency band` _ `extraction method` _ `channel name`

`resting state`: eyes open (EO), eyes closed (EC)
`frequency band`: the frequency band the feature was extracted from (delta, theta, alpha, beta, whole_spectrum)
`extraction method`: the name of the extraction method used to compute the feature
`channel name`: the name/label of the channel

and outputs a list of feature groups. For this demonstration we provided the function to group features based on the frequency band they were extracted from. 

In [None]:
freq_bands = ['delta', 'theta', 'alpha', 'beta', 'whole_spec']

def group_freq_bands_shap(label_arr):
    feature_groups_fb = []
    n_labels_fb = []
    
    for fb in freq_bands:
        feature_group_idx = [i for i, name in enumerate(label_arr) if fb in name and
                             (fb != 'whole_spec' or not any(ofb in name for ofb in freq_bands[:-1]))]

        if feature_group_idx:
            feature_groups_fb.append(feature_group_idx)
            n_labels_fb.append(fb)

    return n_labels_fb, feature_groups_fb

### Helper Functions
Following functions help to process the shap values of the models, group the features, compute the aggregated shap values for the feature groups in order to create group ranks and cluster the rank order correlation scores between all models.

In [None]:
def process_model(model, x, x_names):
    shap_dict = load_object(f'models/{model}/shap_values')
    fold = shap_dict['fold']
    shap_values = shap_dict['shap_values']

    # Select data based on fold
    x_test = [x[i] for i in fold[1]]
    x_test_df = pd.DataFrame(x_test, columns=x_names)

    # Aggregate SHAP values by feature groups
    n_labels, feature_groups = group_freq_bands_shap(x_names)
    grouped_shap_values = np.zeros((len(x_test), len(n_labels)))
    for i, group in enumerate(feature_groups):
        grouped_shap_values[:, i] = np.sum(shap_values[:, group], axis=1)

    # Compute mean absolute SHAP values and sort features
    mean_abs_shap = np.abs(grouped_shap_values).mean(axis=0)
    sorted_features = sorted(zip(mean_abs_shap, n_labels), reverse=True)
    return model, [label for _, label in sorted_features], [val for val in sorted_features]

def compute_correlation(ranks):
    return np.array([[spearmanr(x, l).correlation for l in ranks] for x in ranks])

def hierarchical_clustering(correlation_matrix):
    distance_matrix = np.sqrt(2 * (1 - correlation_matrix))
    linkage = hierarchy.linkage(distance_matrix, method='complete')
    dendrogram = hierarchy.dendrogram(linkage, no_plot=True)
    return dendrogram['leaves']

def plot_sorted_correlation_matrix(sorted_corr_matrix, labels):
    plt.figure(figsize=(10, 7))
    sns.heatmap(sorted_corr_matrix, annot=True, cmap='coolwarm', xticklabels=labels, yticklabels=labels, annot_kws={"fontsize": 8})
    plt.title(f'Ranked Ordered Correlation (FeaturesGroupedByFrequencyBand) fold {fold_n}')
    plt.tight_layout()
    plt.show()

### Computing The ShapAgreement
Next we compute the ShapAgreement among our models and plot the result in a clustered matrix.

In [None]:
results = [process_model(model, x, x_names) for model in MODEL_LIST]
feature_ranks = [result[1] for result in results]

# Encode feature ranks
le = LabelEncoder()
encoded_ranks = [le.fit_transform(rank) for rank in feature_ranks]

# Compute correlation matrix and perform hierarchical clustering
corr_matrix = compute_correlation(encoded_ranks)
ordered_indices = hierarchical_clustering(corr_matrix)
sorted_corr_matrix = corr_matrix[ordered_indices][:, ordered_indices]
sorted_labels = np.array([result[0] for result in results])[ordered_indices]

# Plot the correlation matrix
plot_sorted_correlation_matrix(sorted_corr_matrix, sorted_labels)
