In [None]:
# Imports

# Environment Configuration
import os
import random
import logging
import warnings
# Limit threads for reproducibility
os.environ["OMP_NUM_THREADS"]       = "1"
os.environ["MKL_NUM_THREADS"]       = "1"
os.environ["NUMEXPR_NUM_THREADS"]   = "1"
os.environ["RAY_LOG_TO_STDERR"]     = "1"
os.environ["RAY_DISABLE_LOG_MONITOR"] = "1"
os.environ["RAY_DASHBOARD_DISABLE"]   = "1"

# Suppress warnings
warnings.filterwarnings("ignore", category=UserWarning, module="ray")
warnings.filterwarnings("ignore", category=FutureWarning, module=".*SAMME.R.*")
warnings.filterwarnings("ignore")

# Silence Ray logs
logging.getLogger("ray").setLevel(logging.ERROR)

# Core Libraries
import time
import pickle
from pathlib import Path
from datetime import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import shap

# Statistics
from scipy.stats import chi2_contingency

# Scikit-learn
from sklearn.preprocessing import OneHotEncoder, RobustScaler
from sklearn.model_selection import (
    LeaveOneGroupOut,
    StratifiedGroupKFold,
    cross_val_score,
)
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import AdaBoostClassifier
from sklearn.metrics import make_scorer, f1_score


# Imbalanced-learn
from imblearn.ensemble import BalancedRandomForestClassifier as BRFC

# Ray / Hyperparameter Search
import ray
from ray import tune
from ray.tune import CheckpointConfig
from ray.tune.search import ConcurrencyLimiter
from ray.tune.search.optuna import OptunaSearch
from optuna.samplers import TPESampler
from ray.tune import with_parameters
from ray.train import report


print("Ray version:", ray.__version__)

# Reproducibility
SEED = 123
np.random.seed(SEED)
random.seed(SEED)


In [None]:
# Hyperparameter Tuning + Feature Selection (Balanced Random Forest)
def BRFC_Tune(X_train, y_train, seed_state, splitter, 
              num_samples=100, n_startup_trials=20, 
              subject_id=None):

    base_storage_dir = r"C:\Users\akishta03\Documents\SS_ML\Ray_Temp"
    storage_dir = os.path.join(base_storage_dir, f"losocv_subject_{subject_id}") if subject_id else base_storage_dir

    if not ray.is_initialized():
        ray.init(
            num_cpus=20,
            ignore_reinit_error=True,
            log_to_driver=False,
            include_dashboard=False,
            local_mode=False
        )

    def trainable(config, X, y, cv):
        model = BRFC(
            n_estimators=config["n_estimators"],
            max_depth=config["max_depth"],
            min_samples_split=config["min_samples_split"],
            min_samples_leaf=config["min_samples_leaf"],
            max_features='log2',
            sampling_strategy="auto",
            #class_weight="balanced",
            replacement=True,
            random_state=seed_state,
            n_jobs=-1
        )

        scores = cross_val_score(model, X, y, cv=cv, scoring="roc_auc", n_jobs=-1)
        k = scores.size
        mean_auc = float(np.mean(scores))
        std_auc = float(np.std(scores, ddof=1))
        se_auc = float(std_auc / np.sqrt(k))
        score = float(mean_auc - se_auc)

        report({"mean_AUC": mean_auc, "score": score})

    #Parameter Grid
    param_grid = {
        "n_estimators": tune.choice([50, 100, 150, 200, 300]),
        "max_depth": tune.choice([None, 5, 10, 15]),
        "min_samples_split": tune.choice([2, 4, 6, 8]),
        "min_samples_leaf": tune.choice([1, 2, 3, 4])
    }



    # Optuna hyperparameter search
    optuna_sampler = TPESampler(seed=seed_state, n_startup_trials=n_startup_trials)
    search_algo = ConcurrencyLimiter(
        OptunaSearch(sampler=optuna_sampler, metric="score", mode="max"),
        max_concurrent=1
    )

    trainable_fn = with_parameters(trainable, X=X_train, y=y_train, cv=splitter)

    analysis = tune.run(
        trainable_fn,
        config=param_grid,
        num_samples=num_samples,
        search_alg=search_algo,
        resources_per_trial={"cpu": 20},
        reuse_actors=True,
        storage_path=storage_dir,
        metric="score",
        mode="max",
        name="brfc_tune",
        trial_dirname_creator=lambda trial: f"{trial.trial_id[:8]}",
        checkpoint_config=CheckpointConfig(checkpoint_frequency=0)
    )

    best_config = analysis.get_best_config(metric="score", mode="max")

    # Final model fit with best hyperparameters
    best_model = BRFC(
        n_estimators=best_config["n_estimators"],
        max_depth=best_config["max_depth"],
        min_samples_split=best_config["min_samples_split"],
        min_samples_leaf=best_config["min_samples_leaf"],
        max_features='log2',
        sampling_strategy="auto",
        #class_weight="balanced",
        replacement=True,
        random_state=seed_state,
        n_jobs=-1
    )

    best_model.fit(X_train, y_train)
    selected_features = X_train.columns.tolist()

    return best_model, selected_features


In [None]:
# Rapid correlation reduction of feature space, uses Cramer's V stats for categorical data and Pearson r for numerical data. Threshold is 0.9. Coreelations derived from training data and deployed on test data
def corr_feat_elim(cat_cols, dataset, dataset2):

    # Separate categorical and numeric variables
    # categorical columns are a vector passed in, and used to separate out numeric cols. One-hot encoding performed as a data processing step
    numeric_cols = [col for col in dataset.columns if col not in cat_cols and col != 'SubjectID' and col != 'model_set' and col != 'test_set']
    threshold = 0.9

    # CATEGORICAL ANALYSIS

    # Function to calculate Cramér's V for categorical data
    def cramers_v(x, y):
        # Generate the contingency table
        confusion_matrix = pd.crosstab(x, y)
        
        # Check for degenerate or empty tables
        if confusion_matrix.empty or confusion_matrix.shape[0] < 2 or confusion_matrix.shape[1] < 2:
            return 0

        # Compute chi-squared statistic
        chi2 = chi2_contingency(confusion_matrix)[0]
        n = confusion_matrix.sum().sum()
        if n <= 1:  # Avoid invalid sample size
            return 0

        # Compute Cramér's V components
        phi2 = chi2 / n
        r, k = confusion_matrix.shape
        phi2corr = max(0, phi2 - ((k-1)*(r-1))/(n-1) if n > 1 else 0)
        rcorr = r - ((r-1)**2)/(n-1) if n > 1 else r
        kcorr = k - ((k-1)**2)/(n-1) if n > 1 else k

        # Compute denominator and Cramér's V
        denominator = min((kcorr-1), (rcorr-1))
        if denominator <= 0:  # Avoid invalid denominator
            return 0
        return np.sqrt(phi2corr / denominator)

    # Calculate Cramér's V for each pair of categorical features and generate a matrix
    cramers_v_matrix = pd.DataFrame(index=cat_cols, columns=cat_cols)
    for col1 in cat_cols:
        for col2 in cat_cols:
            if col1 == col2:
                cramers_v_matrix.loc[col1, col2] = 1.0
            else:
                cramers_v_matrix.loc[col1, col2] = cramers_v(dataset[col1], dataset[col2])


    # Iterate through matrix and eliminate correlated features
    
    to_drop_cat = set()

    for col1 in cat_cols:
        for col2 in cat_cols:
            if col1 != col2 and cramers_v_matrix.loc[col1, col2] > threshold:
                to_drop_cat.add(col2)

    # Create reconciled datset of categorical features
    dataset_cat = dataset[cat_cols].drop(columns=to_drop_cat)
    dataset2_cat = dataset2[cat_cols].drop(columns=to_drop_cat)

    # NUMERIC ANALYSIS

    # Calculate Pearson correlation for numeric features
    dataset_num = dataset[numeric_cols].apply(pd.to_numeric, errors = 'coerce') # make sure types are standardized
    dataset2_num = dataset2[numeric_cols].apply(pd.to_numeric, errors = 'coerce') # make sure types are standardized
    corr_matrix = dataset_num.corr() # pandas implementation on CPU as this calculates pairwise and does not consider missing data


    # GPU acceleration of pairwise feature elimination
    corr_matrix_gpu = np.asarray(corr_matrix.values)
    num_features = corr_matrix_gpu.shape[0]
    mask = np.abs(corr_matrix_gpu) > threshold
    mask = mask * ~np.eye(mask.shape[0], dtype=bool)
    indices = np.nonzero(mask)
    indices_np = indices[0], indices[1]

    highly_correlated_pairs = set()
    for i, j in zip(indices_np[0], indices_np[1]):
        if i < j:
            highly_correlated_pairs.add(corr_matrix.columns[j])


    # Combine the reduced categorical and numeric data

    dataset_num = dataset_num.drop(columns = highly_correlated_pairs)
    dataset2_num = dataset2_num.drop(columns = highly_correlated_pairs)
    full_dataset = pd.concat([dataset_cat, dataset_num], axis=1).reset_index(drop=True)
    full_dataset2 = pd.concat([dataset2_cat, dataset2_num], axis=1).reset_index(drop=True)
    full_dataset.astype(np.float32) # conversion for GPU accelerations to work downstream
    full_dataset2.astype(np.float32) # conversion for GPU accelerations to work downstream

    columns_to_drop = full_dataset.filter(like="Unnamed").columns
    #print(columns_to_drop)
    columns_to_drop2 = full_dataset2.filter(like="Unnamed").columns
    #print(columns_to_drop2)

    full_dataset = full_dataset.drop(columns = columns_to_drop)
    full_dataset2 = full_dataset2.drop(columns = columns_to_drop2)

    return full_dataset, full_dataset2

In [None]:
"""
Data Set Configuration:
--------------------------------------------------------------------
"all"       → Demographics, Pre + (PrePostDiff) Gait Symmetry, and MEPs  
"demo"      → Demographics only  
"demoMEPs"  → Demographics and MEPs  
"gait"      → Pre + (PrePostDiff) Gait Symmetry  
"preGait"   → Pre only Gait Symmetry  
"preOnly"   → Demographics, Pre only Gait Symmetry, and MEPs
"""

# ---------------------------------------------------------------------
# Data Input Configuration dictionary
# ---------------------------------------------------------------------
DATA_CONFIG = {
    "all": {
        "path": r"Y:\LabMembers\Ameen\VScode Scripts\Spinal Stim Aim 1 Codes\ML Labels and Inputs\Model_Input_AllFeat_AllData.csv",
        "cat_cols": [
            "Sex", "Laterality", "Type of Stroke", "Community Orthotic Type", "Orthotic During Session",
            "Lowest RMT Location", "Lowest RMT Paretic", "Lowest RMT Flexor",
            "Last RMT Location", "Last RMT Paretic", "Last RMT Flexor"
        ],
        "drop_cols": ['Lowest RMT Muscle','Lowest RMT Muscles', 'Last RMT Muscle']
    },
    "demo": {
        "path": r"Y:\LabMembers\Ameen\VScode Scripts\Spinal Stim Aim 1 Codes\ML Labels and Inputs\Model_Input_AllFeat_Demo.csv",
        "cat_cols": ["Sex", "Laterality", "Type of Stroke", "Community Orthotic Type", "Orthotic During Session"],
        "drop_cols": []
    },
    "demoMEPs": {
        "path": r"Y:\LabMembers\Ameen\VScode Scripts\Spinal Stim Aim 1 Codes\ML Labels and Inputs\Model_Input_AllFeat_Demo_MEPs.csv",
        "cat_cols": [
            "Sex", "Laterality", "Type of Stroke", "Community Orthotic Type", "Orthotic During Session",
            "Lowest RMT Location", "Lowest RMT Paretic", "Lowest RMT Flexor",
            "Last RMT Location", "Last RMT Paretic", "Last RMT Flexor"
        ],
        "drop_cols": ['Lowest RMT Muscle','Lowest RMT Muscles', 'Last RMT Muscle']
    },
    "MEPs": {
        "path": r"Y:\LabMembers\Ameen\VScode Scripts\Spinal Stim Aim 1 Codes\ML Labels and Inputs\Model_Input_AllFeat_MEPs.csv",
        "cat_cols": [
            "Lowest RMT Location", "Lowest RMT Paretic", "Lowest RMT Flexor",
            "Last RMT Location", "Last RMT Paretic", "Last RMT Flexor"
        ],
        "drop_cols": ['Lowest RMT Muscle','Lowest RMT Muscles', 'Last RMT Muscle']
    },
    "gait": {
        "path": r"Y:\LabMembers\Ameen\VScode Scripts\Spinal Stim Aim 1 Codes\ML Labels and Inputs\Model_Input_AllFeat_Gait.csv",
        "cat_cols": [],
        "drop_cols": []
    },
    "preGait": {
        "path": r"Y:\LabMembers\Ameen\VScode Scripts\Spinal Stim Aim 1 Codes\ML Labels and Inputs\Model_Input_AllFeat_Pre_Gait.csv",
        "cat_cols": [],
        "drop_cols": []
    },
    "preOnly": {
        "path": r"Y:\LabMembers\Ameen\VScode Scripts\Spinal Stim Aim 1 Codes\ML Labels and Inputs\Model_Input_AllFeat_PreOnlyData.csv",
        "cat_cols": [
            "Sex", "Laterality", "Type of Stroke", "Community Orthotic Type", "Orthotic During Session",
            "Lowest RMT Location", "Lowest RMT Paretic", "Lowest RMT Flexor",
            "Last RMT Location", "Last RMT Paretic", "Last RMT Flexor"
        ],
        "drop_cols": ['Lowest RMT Muscle','Lowest RMT Muscles', 'Last RMT Muscle']
    },
}


# Data Loader function
def load_dataset(name: str):
    """
    Load a dataset by nickname and return:
        df           : pandas.DataFrame
        cat_cols     : list[str] # Categorical columns
        drop_cols    : list[str] # Columns to drop
    """
    dataSet = name
    cfg = DATA_CONFIG[name]           
    csv_path = Path(cfg["path"])
    if not csv_path.exists():
        raise FileNotFoundError(csv_path)

    df = pd.read_csv(csv_path)
    return df, cfg["cat_cols"], cfg["drop_cols"], name



In [None]:
"""
Data Set Configuration:
--------------------------------------------------------------------
"all"       → Demographics, Pre + (PrePostDiff) Gait Symmetry, and MEPs  
"demo"      → Demographics only  
"demoMEPs"  → Demographics and MEPs  
"gait"      → Pre + (PrePostDiff) Gait Symmetry  
"preGait"   → Pre only Gait Symmetry  
"preOnly"   → Demographics, Pre only Gait Symmetry, and MEPs
"""

# FINAL MODEL USES "preOnly" DATASET
# Load input data
input_df, cat_cols, drop_cols, dataSet = load_dataset("preOnly")
print(f"Loaded {input_df.shape[0]:,} rows, {input_df.shape[1]} columns.")
print("Categorical cols :", cat_cols)
print("Drop cols        :", drop_cols)


# Numeric features
input_df_num = input_df.drop(columns=cat_cols + drop_cols)

# Load outcome labels
outcome_df = pd.read_csv(r"Y:\LabMembers\Ameen\VScode Scripts\Spinal Stim Aim 1 Codes\ML Labels and Inputs\ML_Labels.csv")


# Separate numeric columns
input_df_num = input_df.drop(columns=cat_cols + drop_cols)

# One-Hot Encoding
ohe = OneHotEncoder(drop=None, sparse_output=False, handle_unknown='ignore')
ohe.fit(input_df[cat_cols])

# Transform categorical columns
cat_encoded = ohe.transform(input_df[cat_cols])
cat_encoded_df = pd.DataFrame(cat_encoded, columns=ohe.get_feature_names_out(cat_cols), index=input_df.index)

# Combine with numeric and outcome data
merged_df_ohe = pd.concat([input_df_num, cat_encoded_df], axis=1)
merged_df_ohe = merged_df_ohe.merge(outcome_df, on="Subject", how="inner")

# Relace NaNs with -1 for all columns
merged_df_ohe = merged_df_ohe.fillna(-1)

# Check for NaNs
nan_summary = merged_df_ohe.isna().sum()
print("NaN count per column:\n", nan_summary[nan_summary > 0])

# Output OHE dataframe
merged_df_ohe


In [None]:
# Initialzing the key parameters for the pipeline
seed = 123 # model and pipeline reproducibility
loso = LeaveOneGroupOut() # defining the splitter for LOSCOCV
results_df = pd.DataFrame()  # storing outputs for later analysis / plotting. Need to reinitialize before every runthrough
outcome_vec = ['Frequency Label','Intensity Label'] # defining what to iterate through in the for-loop
cat_features = [c for c in merged_df_ohe.columns if any(cat in c for cat in cat_cols)] # defining cat_features for feature selection

In [None]:
# Initialize status log path
status_log_path = r"C:\Users\akishta03\Documents\SS_ML\status_log.txt"

# Optional: clear old logs if rerunning whole script
with open(status_log_path, "w") as f:
    f.write("Run Status Log\n")
    f.write("====================\n")

# EasyEnsemble + SelectFromModel tuning LOSOCV Model Pipeline
# Model results file name, updated with date, and data set type
model_name   = "brfc"  # "easyensemble" or "brfc"
output_label = f"{model_name}_{dataSet}_{datetime.now().strftime('%Y%m%d')}"

# Initialize a dataframe to store selected feature names for each LOSOCV run
losocv_features = pd.DataFrame()


for outcome in outcome_vec:
    print(outcome)
    iter_df = merged_df_ohe.copy()

    for train_idx, test_idx in loso.split(iter_df, groups=iter_df["Subject"]):
        start_time = time.time()

        train_df = iter_df.iloc[train_idx].reset_index(drop=True)
        test_df  = iter_df.iloc[test_idx].reset_index(drop=True)

        subject_test  = test_df["Subject"]
        subject_trial = test_df["Trial"]
        subject_trial_diff = test_df["Trial_Diff"] if "Trial_Diff" in test_df.columns else None
        
        kfold = StratifiedGroupKFold(n_splits=3, shuffle=True, random_state=seed) # defining the splitter for k-fold CV
        kfold_splits = list(
            kfold.split(train_df, train_df[outcome], groups=train_df["Subject"])
        )

        train_df = train_df.drop(columns=["Subject"])
        test_df  = test_df.drop(columns=["Subject"])

        X_train = train_df.drop(columns=[c for c in train_df.columns if "Label" in c])
        X_test  = test_df.drop(columns=[c for c in test_df.columns if "Label" in c])
        y_train = train_df[outcome]
        y_test  = test_df[outcome]

        # Correlation-based elimination
        X_train, X_test = corr_feat_elim(cat_features, X_train, X_test)

        # Scale numeric data only
        numeric_cols = [c for c in X_train.columns if c not in cat_features]
        scaler = RobustScaler()
        X_train[numeric_cols] = scaler.fit_transform(X_train[numeric_cols])
        X_test[numeric_cols]  = scaler.transform(X_test[numeric_cols])

        # Train model using EasyEnsemble with hyperparameter search in Optuna
        print("Starting tune for BRFC...")
        final_model, selected_features = BRFC_Tune(X_train, y_train, seed, kfold_splits)
        
        # --- NEW: Save selected features for this LOSOCV iteration ---
        feat_row = pd.Series(selected_features, name=f"{outcome}_Subj{subject_test.unique()[0]}")
        losocv_features = pd.concat([losocv_features, feat_row], axis=1)

        # Evaluate on held-out subject
        X_test_sel = X_test[selected_features]
        y_pred     = final_model.predict(X_test_sel)
        y_prob     = final_model.predict_proba(X_test_sel)[:, 1]

        elapsed = time.time() - start_time
        status_msg = f"{output_label} | Outcome: {outcome} | Subject {subject_test.unique()[0]} | Time {elapsed:.2f}s"
        print(status_msg)
        
        # Write status message to log
        with open(status_log_path, "a") as f:
            f.write(status_msg + "\n")

        # Record test results
        out_df = pd.DataFrame({
            "Subject": subject_test,
            "Trial":   subject_trial,
            "y_true":  y_test.values.ravel(),
            "y_pred":  y_pred,
            "y_prob":  y_prob,
        })
        if subject_trial_diff is not None:
            out_df["Trial_Diff"] = subject_trial_diff
        out_df["Outcome"] = outcome
        out_df["Model"]   = model_name.upper()


        results_df = pd.concat([results_df, out_df], ignore_index=True)

    # save results with version increase if file exists
base_path   = r'C:\Users\akishta03\Documents\SS_ML\ML Results'
output_file = f"{base_path}\\{output_label}_Final_100125_Frequency_dropFirst.csv"
ver = 1
while os.path.exists(output_file):
    output_file = f"{base_path}\\{output_label}_v{ver}.csv"
    ver += 1

results_df.to_csv(output_file, index=False)
# --- Save LOSOCV feature names to CSV ---
feat_output_file = f"{base_path}\\{output_label}_LOSOCV_Features.csv"
losocv_features.to_csv(feat_output_file, index=False)
print(f"Saved LOSOCV feature names to {feat_output_file}")


In [None]:
# FULL SUBJECT TRAINING + SHAP ANALYSIS
# ============================================================
# Train on ALL subjects + Tree SHAP (Top-10)
# Save SHAP arrays + global CSV
# Rename features for plots using FeatureNameLabelMapping.xlsx
# ============================================================

# Paths
model_name    = "BRFC"
fig_base_path = r"Y:\LabMembers\Ameen\VScode Scripts\Spinal Stim Aim 1 Codes\ML Results\SHAP"
os.makedirs(fig_base_path, exist_ok=True)

# no date tag in cache filenames (so we can reload)
shap_label = f"{model_name}_{dataSet}_ALLSUBJ"

# Plotting configuration
plt.rcParams.update({
    "figure.dpi": 300,
    "savefig.dpi": 300,
    "figure.figsize": (7.0, 5.0),
    "font.size": 12,
    "axes.labelsize": 12,
    "axes.titlesize": 13,
    "xtick.labelsize": 11,
    "ytick.labelsize": 11,
    "legend.fontsize": 11,
    "axes.grid": False,
})

RANDOM_STATE = seed
top_k = 10

# x-axis labels
x_axis_labels = {
    "Frequency Label": "SHAP value (impact on model output)\n← 50 Hz                               30 Hz →",
    "Intensity Label": "SHAP value (impact on model output)\n← RMT Intensity               TOL Intensity →"
}
default_xlabel = "SHAP value (impact on model output)"

# Load feature mapping 
mapping_path = os.path.join(fig_base_path, "FeatureNameLabelMapping.xlsx")
feature_map = {}
if os.path.exists(mapping_path):
    mdf = pd.read_excel(mapping_path)
    feature_map = dict(zip(mdf["Feature Name"].astype(str), mdf["Rename"].astype(str)))

def rename_list(names):
    return [feature_map.get(str(n), str(n)) for n in names]


# Loop over outcomes
for outcome in outcome_vec:
    print(f"\n=== Train on ALL subjects + SHAP :: Outcome = {outcome} ===")
    xlabel_text = x_axis_labels.get(outcome, default_xlabel)

    # Load dataset
    all_df = merged_df_ohe.copy()
    groups = all_df["Subject"].copy()
    kfold = StratifiedGroupKFold(n_splits=3, shuffle=True, random_state=RANDOM_STATE)
    kfold_splits_all = list(kfold.split(all_df, all_df[outcome], groups=groups))

    all_df = all_df.drop(columns=["Subject"])
    X_all = all_df.drop(columns=[c for c in all_df.columns if "Label" in c])
    y_all = all_df[outcome].astype(int).values

    X_all, _ = corr_feat_elim(cat_features, X_all, X_all)
    numeric_cols = [c for c in X_all.columns if c not in cat_features]
    scaler = RobustScaler()
    X_all.loc[:, numeric_cols] = scaler.fit_transform(X_all[numeric_cols])

    # Load cached model or train
    model_cache_path = os.path.join(
        fig_base_path, f"{model_name}_{dataSet}_ALLSUBJ_{outcome}_final_model.pkl"
    )
    if os.path.exists(model_cache_path):
        with open(model_cache_path, "rb") as f:
            data = pickle.load(f)
        final_model = data["model"]
        selected_features = data["features"]
        print(f"Loaded cached model for {outcome}")
    else:
        final_model, selected_features = BRFC_Tune(
            X_all, y_all, RANDOM_STATE, kfold_splits_all
        )
        with open(model_cache_path, "wb") as f:
            pickle.dump({"model": final_model, "features": selected_features}, f)
        print(f"Trained and cached model for {outcome}")

    # Slice to selected features
    X_sel_all = X_all[selected_features].copy()
    if not hasattr(final_model, "feature_names_in_"):
        final_model.feature_names_in_ = np.array(selected_features, dtype=object)

    # Compute or Load SHAP
    shap_vals_path = os.path.join(fig_base_path, f"{shap_label}_{outcome}_SHAP_values.npz")

    if os.path.exists(shap_vals_path):
        print(f"Loading cached SHAP values for {outcome}")
        data = np.load(shap_vals_path, allow_pickle=True)
        shap_vals     = np.array(data["shap_vals"])
        feature_names = data["feature_names"]
        X_values      = data["X_values"]

        # Handle binary classification case
        if shap_vals.ndim == 3 and shap_vals.shape[-1] == 2:
            shap_vals = shap_vals[:, :, 1]  # keep positive class
        if isinstance(shap_vals, list) and len(shap_vals) == 2:
            shap_vals = shap_vals[1]

        # restore into DataFrame
        X_sel_all = pd.DataFrame(X_values, columns=feature_names)
        selected_features = list(feature_names)
    else:
        print(f"Computing SHAP values for {outcome} (TreeExplainer)")
        explainer = shap.TreeExplainer(
            model=final_model,
            data=X_sel_all,
            model_output="probability",         
            feature_perturbation="interventional",
            feature_names=X_sel_all.columns
        )

        shap_vals = explainer.shap_values(X_sel_all)

        # Binary classification
        if isinstance(shap_vals, list) and len(shap_vals) == 2:
            shap_vals = shap_vals[1]
        elif shap_vals.ndim == 3 and shap_vals.shape[-1] == 2:
            shap_vals = shap_vals[:, :, 1]

        shap_vals = np.array(shap_vals)
        print("Final shap_vals shape:", shap_vals.shape)

        # Save SHAP arrays
        np.savez_compressed(
            shap_vals_path,
            shap_vals=shap_vals,
            feature_names=np.array(X_sel_all.columns, dtype=object),
            X_values=X_sel_all.values
        )
        print(f"Saved SHAP values to {shap_vals_path}")

    # Global importance
    mean_abs = np.abs(shap_vals).mean(axis=0)
    nonzero_idx = np.where(mean_abs > 0)[0]

    if len(nonzero_idx) > 10:
        order = np.argsort(mean_abs[nonzero_idx])[::-1][:10]
        top_idx = nonzero_idx[order]
    else:
        order = np.argsort(mean_abs[nonzero_idx])[::-1]
        top_idx = nonzero_idx[order]

    top_idx = np.ravel(top_idx)  # flatten to 1D
    top_feats = X_sel_all.columns[top_idx].tolist()
    display_top_feats = rename_list(top_feats)

    shap_global = pd.DataFrame({
        "feature": X_sel_all.columns,
        "mean_abs_shap": mean_abs.ravel()
    }).sort_values("mean_abs_shap", ascending=False).reset_index(drop=True)

    shap_global["display_name"] = rename_list(shap_global["feature"])
    table_path = os.path.join(fig_base_path, f"{shap_label}_{outcome}_SHAP_global.csv")
    shap_global.to_csv(table_path, index=False)

    # Plots
    n_display = len(top_idx)

    # Bar
    plt.figure()
    plt.barh(range(n_display), mean_abs[top_idx], align="center")
    plt.yticks(range(n_display), display_top_feats)
    plt.gca().invert_yaxis()
    plt.xlabel(default_xlabel)
    plt.title(f"{model_name} — {outcome} — Mean |SHAP| (Top {n_display})")
    plt.tight_layout()
    bar_path = os.path.join(fig_base_path, f"{shap_label}_{outcome}_BAR_top{n_display}.png")
    plt.savefig(bar_path, bbox_inches="tight", facecolor="white")
    plt.close()

    # Beeswarm
    X_eval_top = X_sel_all[top_feats].copy()
    X_eval_top.columns = display_top_feats
    shap_vals_top = shap_vals[:, top_idx]

    plt.figure()
    shap.summary_plot(
        shap_vals_top,
        X_eval_top,
        show=False,
        max_display=n_display
    )
    # Get the colorbar axis 
    cbar = plt.gcf().axes[-1]

    # Replace labels
    cbar.set_yticklabels(["Low / False", "High / True"])

    # Force symmetric x-axis with 0 centered
    ax = plt.gca()
    xlim = ax.get_xlim()
    xmax = max(abs(xlim[0]), abs(xlim[1]))
    ax.set_xlim(-xmax, xmax)

    plt.xlabel(xlabel_text)
    plt.title(f"{outcome} — SHAP Beeswarm (Top {n_display})")
    plt.tight_layout()
    beeswarm_path = os.path.join(fig_base_path, f"{shap_label}_{outcome}_BEESWARM_top{n_display}.png")
    plt.savefig(beeswarm_path, bbox_inches="tight", facecolor="white")
    plt.close()

    print(f"Saved: {bar_path}\n       {beeswarm_path}\n       {table_path}")
