# SHAP and Permutation Plot Construction

Import Data + Models + Libraries/Packages

In [1]:
import sys
import os
sys.path.append(os.path.abspath('../'))
from src.build_dnn_model import build_nn_model
import shap
import pandas as pd
import joblib
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from sklearn.inspection import permutation_importance
from src.config import SEED
## Data
y_train, y_test = (pd.read_excel('../data/raw/split/Raw_y_train.xlsx'))['ORN'], (pd.read_excel('../data/raw/split/Raw_y_test.xlsx'))['ORN']
ml_X_train = pd.read_parquet('../data/processed/ml_train_transformed.parquet')
ml_X_test = pd.read_parquet('../data/processed/ml_test_transformed.parquet')
nomo_X_train = pd.read_parquet('../data/processed/nomo_train_transformed.parquet')
nomo_X_test = pd.read_parquet('../data/processed/nomo_test_transformed.parquet')

## Models
lightgbm_clf = joblib.load('../models/LightGBM.joblib')
svc_clf = joblib.load('../models/SVC.joblib')
knn_clf = joblib.load('../models/KNN.joblib')
dnn_clf = joblib.load('../models/DNN.joblib')
stack_clf = joblib.load('../models/stack.joblib')
nomo_clf = joblib.load('../models/NLR.joblib')

base_models = {
    'LightGBM': lightgbm_clf,
    'SVC': svc_clf,
    'KNN': knn_clf,
    'DNN': dnn_clf,
    'Stack': stack_clf
}

# Generate SHAP Beeswarm and MAV Plots

In [None]:
##Ordered list of features for shap plots
shap_feats = [
    'TXEXPOSURE',
    'WOUNDINF',
    'SITE',
    'PRERT',
    'EXPOSURE',
    'POSTCT',
    'JEWER',
    'N',
    'PRECT',
    'REOPEN',
    'OSTEOTOMY',
    'DEFECT_TYPE',
    'DM',
    'RECUR',
    'POSTRT',
    'PLATE',
    'ALB',
    'ADMISSION',
    'MEDEXPOSURE',
    'OPTIME',
    'ISCHEMICTIME',
    'LENGTH',
    'BT',
    'BMI',
    'FLAP',
    'AGE',
    'T',
    'HGB',
    'EXPOSUREFU',
    'SP',
    'STAGE',
    'PRIOREX',
    'ASA',
    'GENDER'
    ]


In [None]:
excel_path = '../results/tables/SHAP_MAV_Tables.xlsx'
if os.path.exists(excel_path):
    os.remove(excel_path)

mean_abs_shap_df_all_models = pd.DataFrame(shap_feats, columns = ['Feature'])

######## Combine one-hot-encoded #######
def combine_encoded(shap_values, name, mask, return_original=True):
# NOTE: combine_encoded() function adapted from following repository:
# https://gist.github.com/peterdhansen/ca87cc1bfbc4c092f0872a3bfe3204b2#
    mask = np.array(mask)
    mask_col_names = np.array(shap_values.feature_names, dtype='object')[mask]
    sv_name = shap.Explanation(shap_values.values[:, mask],
                               feature_names=list(mask_col_names),
                               data=shap_values.data[:, mask],
                               base_values=shap_values.base_values,
                               display_data=shap_values.display_data,
                               instance_names=shap_values.instance_names,
                               output_names=shap_values.output_names,
                               output_indexes=shap_values.output_indexes,
                               lower_bounds=shap_values.lower_bounds,
                               upper_bounds=shap_values.upper_bounds,
                               main_effects=shap_values.main_effects,
                               hierarchical_values=shap_values.hierarchical_values,
                               clustering=shap_values.clustering,
                               )
    new_data = (sv_name.data * np.arange(sum(mask))).sum(axis=1).astype(int)
    svdata = np.concatenate([
        shap_values.data[:, ~mask],
        new_data.reshape(-1, 1)
    ], axis=1)

    if shap_values.display_data is None:
        svdd = shap_values.data[:, ~mask]
    else:
        svdd = shap_values.display_data[:, ~mask]

    svdisplay_data = np.concatenate([
        svdd,
        mask_col_names[new_data].reshape(-1, 1)
    ], axis=1)

    # Handle multi-class (3D) vs binary/regression (2D) SHAP arrays
    if len(shap_values.values.shape) == 3:  # Multi-class case
        # Sum encoded features while preserving class dimension
        new_values = sv_name.values.sum(axis=1, keepdims=True)
        svvalues = np.concatenate([
            shap_values.values[:, ~mask, :],
            new_values
        ], axis=1)
    else:  # Binary/regression case
        new_values = sv_name.values.sum(axis=1)
        svvalues = np.concatenate([
            shap_values.values[:, ~mask],
            new_values.reshape(-1, 1)
        ], axis=1)


    svfeature_names = list(np.array(shap_values.feature_names)[~mask]) + [name]

    sv = shap.Explanation(svvalues,
                          base_values=shap_values.base_values,
                          data=svdata,
                          display_data=svdisplay_data,
                          instance_names=shap_values.instance_names,
                          feature_names=svfeature_names,
                          output_names=shap_values.output_names,
                          output_indexes=shap_values.output_indexes,
                          lower_bounds=shap_values.lower_bounds,
                          upper_bounds=shap_values.upper_bounds,
                          main_effects=shap_values.main_effects,
                          hierarchical_values=shap_values.hierarchical_values,
                          clustering=shap_values.clustering,
                          )
    if return_original:
        return sv, sv_name
    else:
        return sv
    

def get_vals_to_plot(shap_vals):
    if len(shap_vals.values.shape) == 3:  # 3D array
        if shap_vals.values.shape[2] == 1:  # Binary classification with single output
            # DNN
            shap_vals_to_plot = shap_vals[:, :, 0]
        elif shap_vals.values.shape[2] >= 2:  # Binary with two outputs or multi-class
            shap_vals_to_plot = shap_vals[:, :, 1]  # Use positive class
        else:
            shap_vals_to_plot = shap_vals.mean(axis=2)  # Fallback
    else:  # 2D array
        # LightGBM, SVC, KNN, Stack, LR-Nomogram
        shap_vals_to_plot = shap_vals
    return shap_vals_to_plot

def get_shap(model, model_name, X, shap_feats, show_initial = False, show_sub_plots = False, mav_df = False, 
             export_results = False, constant_scale = False, label_MAV = False):

    ohe_feat_names = X.columns.tolist()
    ### Choose correct explainer
    if model_name == 'Stacked Generalization':
        explainer = shap.PermutationExplainer(model.predict, X)
    elif model_name == 'KNN' or model_name == 'SVC':
        explainer = shap.KernelExplainer(model.predict, X)
    elif 'Nomogram' in model_name:
        explainer = shap.LinearExplainer(model, X)
    elif model_name == 'DNN':
        #NOTE: Bc using scikeras wrapper, need to extract underlying model with .model_
        X = X.values
        explainer = shap.DeepExplainer(model.model_, X)
    #Tree
    elif model_name == 'LightGBM': # Different method for neural network
        explainer = shap.TreeExplainer(model, X)
    else:
        print(f'{model_name} does not have an allocated explainer')
        return
    
    ### Choose correct explainer params
    if type(explainer).__name__ == 'TreeExplainer':
        # Slightly different explainer() call for TreeExplainer
        shap_values = explainer(X, check_additivity=False)
    else:
        shap_values = explainer(X)  

    ### Put feature names back
    shap_values.feature_names = ohe_feat_names

    ##Show initial plot without combining one-hot-encoded
    if show_initial == True:
        shap.plots.beeswarm(shap_values, max_display=45, show=False)
        plt.title(f'{model_name} Raw SHAP Plot')
        plt.show()

    #### Combine JEWER ######
    shap_vals_jewer, sv_occ_jewer = combine_encoded(shap_values, 'JEWER', 
                                              ['JEWER' in n for n in shap_values.feature_names])
    #### COMBINE SITE #######
    shap_vals_site, sv_occ_site = combine_encoded(shap_vals_jewer, 'SITE', 
                                             ['SITE' in n for n in shap_vals_jewer.feature_names])
    ###### COMBINE PLATE #######
    shap_vals_plate, sv_occ_plate = combine_encoded(shap_vals_site, 'PLATE', 
                                              ['PLATE' in n for n in shap_vals_site.feature_names])
    #####COMBINE TXEXPOSURE #####
    shap_vals_tx, sv_occ_tx = combine_encoded(shap_vals_plate, 'TXEXPOSURE', 
                                              ['TXEXPOSURE' in n for n in shap_vals_plate.feature_names])
    ######SHOW SUB_PLOTS########
    if show_sub_plots == True:
        ##JEWER
        jewer_plot_vals = get_vals_to_plot(sv_occ_jewer) 
        shap.plots.beeswarm(jewer_plot_vals, max_display=20, show=False)
        plt.title(f'{model_name} Jewer SHAP Plot')
        plt.show()
        ##SITE
        site_plot_vals = get_vals_to_plot(sv_occ_site)
        shap.plots.beeswarm(site_plot_vals, max_display=20, show=False)
        plt.title(f'{model_name} Site SHAP Plot')
        plt.show()
        ##PLATE 
        plate_plot_vals = get_vals_to_plot(sv_occ_plate)
        shap.plots.beeswarm(plate_plot_vals, max_display=20, show=False)
        plt.title(f'{model_name} Plate SHAP Plot')
        plt.show()
        ##PLATE 
        plate_plot_vals = get_vals_to_plot(sv_occ_plate)
        shap.plots.beeswarm(plate_plot_vals, max_display=20, show=False)
        plt.title(f'{model_name} Plate SHAP Plot')
        plt.show()
        ##TXEXPOSURE
        tx_vals_to_plot = get_vals_to_plot(sv_occ_tx)
        shap.plots.beeswarm(tx_vals_to_plot, max_display=20, show=False)
        plt.title(f'{model_name} TXEXPOSURE SHAP Plot')
        plt.show()
    

    ############Get shap values, mean shap, sum shap########
    #OHE combined feature names
    feature_names = shap_vals_tx.feature_names

    # Determine SHAP values to plot
    shap_vals_to_plot = get_vals_to_plot(shap_vals_tx)
    
    shap_df = pd.DataFrame(shap_vals_to_plot.values, columns=feature_names)
    ## Reorder shap df to summary order
    shap_df = shap_df[shap_feats]
    indices = [feature_names.index(f) for f in shap_feats]
    shap_vals_to_plot.values = shap_vals_to_plot.values[:, indices]
    shap_vals_to_plot.feature_names = shap_feats


    #Absolute AVG shap vals for each feature
    absolute_mean_shap = shap_df.abs().mean().reset_index()
    absolute_mean_shap.columns = ['Feature', 'Mean Absolute SHAP Value']
    mean_abs_shap_df_all_models[model_name] = absolute_mean_shap['Mean Absolute SHAP Value'].to_list()
    if mav_df:
        ##Display and export MAV tables
        display(absolute_mean_shap)
        with pd.ExcelWriter(excel_path, engine='openpyxl', mode='a' if os.path.exists(excel_path) else 'w') as writer:
            absolute_mean_shap.to_excel(writer, sheet_name=model_name, index=True)

    
    ########## Plot beeswarm plot ############
    shap.plots.beeswarm(shap_vals_to_plot ,max_display=len(shap_feats), show=False)
    plt.title(f'{model_name}',  fontweight='semibold', fontsize = 25)
    #NOTE: Can change limits
    if constant_scale:
        x_min = -0.5 
        x_max = 0.5
        plt.xlim(x_min, x_max)
    if export_results:
        plt.savefig(f'../results/figures/SHAP/beeswarm/{model_name}_Beeswarm.pdf', bbox_inches='tight')
    plt.show()
    ########### Plot bar chart #############
    plt.figure(figsize=(8, 10))
    ## Summary order
    bars = plt.barh(absolute_mean_shap['Feature'], absolute_mean_shap['Mean Absolute SHAP Value'], color='skyblue')
    plt.title(f'{model_name}', fontweight='semibold', fontsize=25)
    plt.xlabel('Mean Absolute SHAP Value', fontsize=16)
    plt.ylabel('Feature', fontsize=16)
    plt.tight_layout()
    if label_MAV:
        ##Label bars with MAV values
        for bar in bars:
            width = bar.get_width()
            plt.text(width + 0.001, bar.get_y() + bar.get_height()/2,
                f'{width:.3f}', va='center', fontsize=10)
        plt.subplots_adjust(right=3)
    plt.margins(x=0.2)
    if export_results:
        plt.savefig(f'../results/figures/SHAP/MAV/{model_name}_MAV.pdf', bbox_inches='tight')
    plt.show()


Generate plots

In [None]:
for model_name, model in base_models.items():
    get_shap(model, model_name,
             ml_X_test, shap_feats,
             export_results=True)
    
get_shap(nomo_clf, 'Nomogram',
             nomo_X_test, shap_feats,
             export_results=True)

Make relative + export

In [21]:
##Normalized (used to make heat maps)
for col in mean_abs_shap_df_all_models.columns:
    if col == 'Feature':
        continue
    mean_abs_shap_df_all_models[col] = mean_abs_shap_df_all_models[col] / mean_abs_shap_df_all_models.sum()[col]
mean_abs_shap_df_all_models.to_excel('../results/tables/all_models_MAV_relative.xlsx', index = False)

# Permutation Plots

In [19]:
excel_path = '../results/tables/all_models_perm_means_stdevs.xlsx'
if os.path.exists(excel_path):
    os.remove(excel_path)
    
perm_df_all_models = pd.DataFrame(shap_feats, columns = ['Feature'])

def get_perm_ordered(model, model_name, X_test, y_test, ordered_feats, display_df = True, save_plot = True):
    result = permutation_importance(
        model, X_test, y_test,
        n_repeats=50,
        random_state=SEED,
        scoring='roc_auc'
    )
    onehot_to_original = {
        col: col.split('_')[0] if col != 'DEFECT_TYPE' else col
        for col in X_test.columns
    }

    # Aggregate importances for each group
    grouped_means = defaultdict(list)
    grouped_stds = defaultdict(list)

    for idx, col in enumerate(X_test.columns):
        orig = onehot_to_original[col]
        grouped_means[orig].append(result.importances_mean[idx])
        grouped_stds[orig].append(result.importances[idx])  # shape: (n_repeats,)

    # Sum means and stack stds for each group
    agg_means = {k: np.sum(v) for k, v in grouped_means.items()}
    agg_stds = {k: np.std(np.sum(np.stack(v, axis=0), axis=0)) for k, v in grouped_stds.items()}

    ##Sort in summary order
    features = [f for f in ordered_feats if f in agg_means]
    left_out_cols = [f for f in agg_means if f not in features]
    ##Check if emptu
    if left_out_cols:
        print(left_out_cols)
        features += [f for f in agg_means if f not in features]

    
    importances = [agg_means[f] for f in features]
    stds = [agg_stds[f] for f in features]

    ##Create DF
    importance_df = pd.DataFrame({
        'Feature': features,
        'Permutation Importance': importances,
        'STDEV': stds
    })
    perm_df_all_models[model_name] = importances
    if display_df:
        display(importance_df)

    with pd.ExcelWriter(excel_path, engine='openpyxl', mode='a' if os.path.exists(excel_path) else 'w') as writer:
        importance_df.to_excel(writer, sheet_name=model_name, index=True)
    
    # Plot
    plt.figure(figsize=(8, 10))
    plt.barh(features, importances, xerr=stds, color='skyblue')
    plt.xlabel("Aggregated Decrease in AUROC score")
    plt.title(f"{model_name} Permutation Importance (AUROC)")
    plt.tight_layout()
    if save_plot:
        plt.savefig(f'../results/figures/permutation/{model_name}_Perm.pdf', bbox_inches='tight')
    plt.show()


Generate Plots

In [None]:
for model_name, model in base_models.items():
    get_perm_ordered(model, model_name, 
                     ml_X_test, y_test, 
                     shap_feats, save_plot = True)
    
get_perm_ordered(nomo_clf, 'Nomogram', 
                 nomo_X_test, y_test, 
                 shap_feats,save_plot = True)
##Just importances (means but no stdev) used to make heat map
perm_df_all_models.to_excel('../results/tables/all_models_perm_means.xlsx', index = False)