# SHAP

In [None]:
import os
import shap
import pickle
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def get_results_data(pkl_path):
    """Load data from pickle file."""
    results = pickle.load(open(pkl_path, 'rb'))        
    return results

def get_plot_single_sample(sample, surv_type, exp_name, fold, max_display=15):
    """ Plot SHAP values for a single sample."""
    # Get shap values of this sample
    result_dir = f'../results/dss_survival_{surv_type}/{exp_name}/Fold_{fold}/shap/post_attn_shap.pkl'
    results_dict = get_results_data(result_dir)
    shap_values = results_dict['shap values']
    feat_names = results_dict['Feature names']
    print(results_dict['Samples'])
    sample_index = next((i for i, lst in enumerate(results_dict['Samples']) if lst[0] == sample), None)
    print("Sample ID:", sample)
    shap_of_one_sample = np.sum(np.expand_dims(shap_values[sample_index], axis=0), axis=2)
    sample_expl = shap.Explanation(values=shap_of_one_sample, feature_names=feat_names)

    # Plot multimodal features
    shap.plots.bar(sample_expl[0], max_display=max_display)
    plt.show()

def get_vals(values):
    """ Get mean absolute SHAP values. """
    sum_vals = np.sum(values, axis=(1,2))
    shap_val = np.mean(np.absolute(sum_vals), axis=(0))
    return shap_val

def print_res(name, results):
    """ Print mean and std of results. """
    results_np = np.array(results)
    mean = round(np.mean(results_np), 3)
    std = round(np.std(results_np), 3)
    print(f"{name}: {mean}±{std}")


def get_results_post_attn(surv_type, exp_name):
    """ Get and print results of post attention SHAP values. """
    result_dir = f'../results/dss_survival_{surv_type}/{exp_name}'

    spec, shar = [], []
    spec_hh, spec_gg, shar_hg, shar_gh = [], [], [], []

    # Loop over folds
    for i in range(5):
        # Get all SHAP values
        result_dir_fold = os.path.join(result_dir, f'Fold_{i}/shap/post_attn_shap.pkl')
        results_dict = get_results_data(result_dir_fold)
        shap_values = results_dict['shap values']

        # Get normalized mean absolute SHAP values of modality specific and modality shared representations
        spec_fold = get_vals(np.concatenate((shap_values[:, :50, :], shap_values[:, -16:, :]), axis=1))
        shar_fold = get_vals(shap_values[:, 50:116, :])

        spec_norm = spec_fold / (spec_fold + shar_fold)
        shar_norm = shar_fold / (spec_fold + shar_fold)

        spec.append(spec_norm)
        shar.append(shar_norm)

        # Get normalized mean absolute SHAP values of all 4 disentangled representations seperately
        spec_hh_fold = get_vals(shap_values[:, -16:, :])
        spec_gg_fold = get_vals(shap_values[:, :50, :])
        shar_hg_fold = get_vals(shap_values[:, 50:100, :])
        shar_gh_fold = get_vals(shap_values[:, 100:116, :])

        spec_hh_norm = spec_hh_fold / (spec_hh_fold + shar_hg_fold + shar_gh_fold + spec_gg_fold)
        spec_gg_norm = spec_gg_fold / (spec_hh_fold + shar_hg_fold + shar_gh_fold + spec_gg_fold)
        shar_hg_norm = shar_hg_fold / (spec_hh_fold + shar_hg_fold + shar_gh_fold + spec_gg_fold)
        shar_gh_norm = shar_gh_fold / (spec_hh_fold + shar_hg_fold + shar_gh_fold + spec_gg_fold)

        spec_hh.append(spec_hh_norm)
        spec_gg.append(spec_gg_norm)
        shar_hg.append(shar_hg_norm)
        shar_gh.append(shar_gh_norm)

    # Print results
    print_res("Specific", spec)
    print_res("Shared", shar)
    print_res("Specific Zp hh", spec_hh)
    print_res("Specific Zp gg", spec_gg)
    print_res("Shared Zp gh", shar_gh)
    print_res("Shared Zp hg", shar_hg)


# Get SHAP results BRCA

In [None]:
exp_name = 'DIMAFNEW'
get_results_post_attn('brca', exp_name)

In [None]:
# We can also plot the local multimodal feature importance plot of one sample
id = 'TCGA-A7-A0CJ'
fold = 2
get_plot_single_sample(id, 'brca', exp_name, fold)