# Feature Importance (SHAP)

## Set Up

Packages/Libraries

In [None]:
import sys

sys.path.append("../")
from src.data_utils import get_data, get_models, get_feature_lists
from src.config import BASE_PATH, SEED
from src.feat_importance import get_shap_single_model
from joblib import delayed, Parallel
from src.nn_model import load_nn_clf
from shutil import rmtree
from sklearn.model_selection import train_test_split
import shap

shap.utils._general._show_progress = False

Set Globals

In [None]:
# Data
file_dir = BASE_PATH / "data" / "processed"
OUTCOME_DICT = {
    "med": get_data("med_outcome", file_dir),
    "surg": get_data("surg_outcome", file_dir),
    "mort": get_data("mort_outcome", file_dir),
    "reop": get_data("reop_outcome", file_dir),
    "vte": get_data("vte_outcome", file_dir),
}

# Models
model_dir = BASE_PATH / "models" / "trained"
model_prefix_list = ["lgbm", "lr", "xgb", "stack"]
MODEL_DICT = {}
X_shape = OUTCOME_DICT["surg"]["X_train"].shape[1]  # same for all outcomes
for outcome in OUTCOME_DICT.keys():
    ## Base models
    MODEL_DICT[outcome] = get_models(model_prefix_list, outcome, model_dir)
    ## Neural network
    nn_import = load_nn_clf(
        data_path=BASE_PATH / "models" / "trained" / outcome / "nn.pt",
        in_dim=X_shape,
        device="cpu",
    )
    MODEL_DICT[outcome]["nn"] = nn_import


FEAT_ORDER = [
    ##Demographics + Comorbidities
    "SURGINDICD",
    "AGE",
    "BMI",
    "SEX",
    "ETHNICITY_HISPANIC",
    "RACE",
    "DIABETES",
    "HXCOPD",
    "HXCHF",
    "ASCITES",
    "BLEEDDIS",
    "TRANSFUS",
    "DIALYSIS",
    "HYPERMED",
    "VENTILAT",
    "SMOKE",
    "DISCANCR",
    "RENAFAIL",
    "STEROID",
    "ASACLAS",
    "DYSPNEA",
    "WNDINF",
    "WTLOSS",
    ## Blood Labs
    "PRALBUM",
    "PRWBC",
    "PRHCT",
    "PRPLATE",
    ## Intra-Op
    "OPTIME",
    "URGENCY",
    "ANESTHES",
    "SURGSPEC",
    "INOUT",
    "OPERYR",
    ## Mastectomy
    "SNLBCPT",
    "ALNDCPT",
    "PARTIALCPT",
    "SUBSIMPLECPT",
    "RADICALCPT",
    "MODIFIEDRADICALCPT",
    ## Reconstruction
    "IMMEDIATECPT",
    "DELAYEDCPT",
    "TEINSERTIONCPT",
    "TEEXPANDERCPT",
    "FREECPT",
    "LATCPT",
    "SINTRAMCPT",
    "SINTRAMSUPERCPT",
    "BITRAMCPT",
    "AUGPROSIMPCPT",
    "OTHERRECONTECHCPT",
    "REVRECBREASTCPT",
    "FATGRAFTCPT",
    "ADJTISTRANSCPT",
    "MASTOCPT",
    "BREASTREDCPT",
    "NPWTCPT",
]


FEAT_ORDER = [str(col.upper()) for col in FEAT_ORDER]

Ensure we got all bases covered with feat_order

In [None]:
dummy_df = OUTCOME_DICT["med"]["X_test"][:5]
all_cols = set()
for col in dummy_df.columns:
    col_split = col.split("_")
    if len(col_split) == 1 or col_split[0] == "ETHNICITY":
        all_cols.add(col)
    else:
        col_name = col_split[0]
        all_cols.add(col_name)
assert set(FEAT_ORDER) == set(all_cols)

## RUN SHAP

Run in parallel

In [None]:
jobs = []
for outcome_name, outcome_data in OUTCOME_DICT.items():
    save_dir = BASE_PATH / "data" / "SHAP" / outcome_name
    X_train = outcome_data["X_train"]
    y_train = outcome_data["y_train"]
    X_test = outcome_data["X_test"]
    y_test = outcome_data["y_test"]
    #########################################################################################################
    ################################Subset of data for testing workflow######################################
    #########################################################################################################
    # X_train, _, y_train, _ = train_test_split(
    #     outcome_data["X_train"],
    #     outcome_data["y_train"],
    #     stratify=outcome_data["y_train"],
    #     random_state=SEED,
    #     train_size=0.25,
    # )
    # _, X_test, _, y_test = train_test_split(
    #     outcome_data["X_test"],
    #     outcome_data["y_test"],
    #     stratify=outcome_data["y_test"],
    #     random_state=SEED,
    #     test_size=0.25,
    # )
    #########################################################################################################
    #########################################################################################################
    #########################################################################################################
    # ================================> Save a subset for explanation
    # X_test has ~110k, we want ~10k so use 10%
    # LR, XGB, LGBM models are ~fast, can use more entries
    _, X_explain_base, _, _ = train_test_split(
        X_test, y_test, stratify=y_test, test_size=0.1, random_state=SEED, shuffle=True
    )
    # X_test has ~110k, we want ~5k so use 5%
    # NN KernelExplainer() slow, use less patients
    _, X_explain_nn, _, _ = train_test_split(
        X_test, y_test, stratify=y_test, test_size=0.05, random_state=SEED, shuffle=True
    )
    # X_test has ~110k, we want ~1k so use 1%
    # Stack KernelExplainer() extremely slow, use way less patients
    _, X_explain_stack, _, _ = train_test_split(
        X_test, y_test, stratify=y_test, test_size=0.01, random_state=SEED, shuffle=True
    )
    print(f"Base size: {len(X_explain_base)}")
    print(f"Kernel Size for NN: {len(X_explain_nn)}")
    print(f"Kernel Size for stack: {len(X_explain_stack)}")
    # Save subset X_test
    explain_save_path = save_dir / "explain"
    if explain_save_path.exists():
        rmtree(explain_save_path)
    explain_save_path.mkdir(exist_ok=True, parents=True)
    base_explain_path = explain_save_path / "base.parquet"
    X_explain_base.to_parquet(base_explain_path)
    nn_kernel_explain_path = explain_save_path / "kernel.parquet"
    X_explain_nn.to_parquet(nn_kernel_explain_path)
    stack_kernel_explain_path = explain_save_path / "kernel_stack.parquet"
    X_explain_stack.to_parquet(stack_kernel_explain_path)
    cur_model_dict = MODEL_DICT[outcome_name]
    for model_name, model in cur_model_dict.items():
        # ================================> Split X_train for background (per model)
        if model_name != "stack":
            continue
        # NOTE: Train set ~515k patients
        if model_name in ["lr", "xgb", "lgbm"]:
            # Explainers are ~fast, so can have larger background set (~5000)
            train_size = 0.01
            explain_path = base_explain_path
        elif model_name == "nn":
            # Explainer slower, so less background (~50)
            train_size = 0.0001
            explain_path = nn_kernel_explain_path
        elif model_name == "stack":
            # Explainer slower, so less background (~50)
            train_size = 0.0001
            explain_path = stack_kernel_explain_path
        else:
            raise ValueError("Model not recognized")
        X_background, _, _, _ = train_test_split(
            X_train,
            y_train,
            stratify=y_train,
            train_size=train_size,
            random_state=SEED,
            shuffle=True,
        )
        ##Write to memory
        background_save_path = save_dir / f"{model_name}_background.parquet"
        if background_save_path.exists():
            background_save_path.unlink()
        background_save_path.parent.mkdir(exist_ok=True, parents=True)
        X_background.to_parquet(background_save_path)
        print(f"{outcome_name} background size: {len(X_background)}")
        # ================================> CALL SHAP
        log_path = BASE_PATH / "shap_logs" / model_name / f"{outcome_name}.log"
        if log_path.exists():
            log_path.unlink()
        result_path = BASE_PATH / "results" / "tables" / "SHAP"
        if result_path.exists():
            rmtree(result_path)
        jobs.append(
            delayed(get_shap_single_model)(
                model=model,
                model_name=model_name,
                feat_order=FEAT_ORDER,
                outcome_name=outcome_name,
                explanation_path=explain_path,
                background_path=background_save_path,
                log_path=log_path,
                result_path=result_path,
            )
        )

In [None]:
# Run jobs with 25 parallel workers
print("=== Starting jobs on CPU ===")
Parallel(n_jobs=min(25, len(jobs)), backend="loky")(jobs)