In [1]:
import os

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

import matplotlib
matplotlib.use("Agg")


In [2]:
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import joblib
import shap
import matplotlib.pyplot as plt

from catboost import CatBoostClassifier, Pool
from sklearn.model_selection import train_test_split

os.makedirs("../figures", exist_ok=True)
os.makedirs("../models", exist_ok=True)

print("Imports loaded")
print("shap:", shap.__version__)

Imports loaded
shap: 0.49.1


In [4]:
df = pd.read_csv("../data/stroke_clean.csv")
print("Loaded:", df.shape)

if "id" in df.columns:
    df = df.drop(columns=["id"])
    

df.head()

Loaded: (5110, 13)


Unnamed: 0,gender,age,hypertension,heart_disease,ever_married,work_type,Residence_type,avg_glucose_level,bmi,smoking_status,stroke,age_group
0,Male,67.0,0,1,Yes,Private,Urban,228.69,36.6,formerly smoked,1,Senior
1,Female,61.0,0,0,Yes,Self-employed,Rural,202.21,28.1,never smoked,1,Adult
2,Male,80.0,0,1,Yes,Private,Rural,105.92,32.5,never smoked,1,Senior
3,Female,49.0,0,0,Yes,Private,Urban,171.23,34.4,smokes,1,Adult
4,Female,79.0,1,0,Yes,Self-employed,Rural,174.12,24.0,never smoked,1,Senior


In [5]:
TARGET = "stroke"
X = df.drop(columns=[TARGET])
y = df[TARGET]

X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    test_size=0.2,
    stratify=y,
    random_state=42
)

print("Train:", X_train.shape, "Test:", X_test.shape)


Train: (4088, 11) Test: (1022, 11)


In [6]:
#load catboost

model_cb = CatBoostClassifier()
model_cb.load_model("../models/catboost_baseline.cbm")
print("CatBoost model loaded")

cb_feature_names = joblib.load("../models/catboost_feature_names.pkl")
cb_cat_cols = joblib.load("../models/catboost_categorical_cols.pkl")

# Align (important)
X_train_cb = X_train.reindex(columns=cb_feature_names, fill_value=np.nan)
X_test_cb = X_test.reindex(columns=cb_feature_names, fill_value=np.nan)

cb_cat_idx = [cb_feature_names.index(c) for c in cb_cat_cols if c in cb_feature_names]

# Fill categorical NaNs as strings for Pool safety
for c in cb_cat_cols:
    if c in X_train_cb.columns:
        X_train_cb[c] = X_train_cb[c].fillna("Unknown").astype(str)
    if c in X_test_cb.columns:
        X_test_cb[c] = X_test_cb[c].fillna("Unknown").astype(str)

# Fill numeric NaNs
for c in X_train_cb.columns:
    if c not in cb_cat_cols:
        X_train_cb[c] = pd.to_numeric(X_train_cb[c], errors="coerce").fillna(0)
        X_test_cb[c] = pd.to_numeric(X_test_cb[c], errors="coerce").fillna(0)

print("Data aligned for SHAP/CatBoost")


CatBoost model loaded
Data aligned for SHAP/CatBoost


In [7]:
#  Create SHAP explainer (CatBoost-safe)

explainer = shap.TreeExplainer(model_cb, feature_perturbation="tree_path_dependent")

# SHAP values for test set
shap_values = explainer.shap_values(X_test_cb)


if isinstance(shap_values, list):
    shap_values = shap_values[1]

print("SHAP values computed:", np.array(shap_values).shape)

SHAP values computed: (1022, 11)


In [8]:
# â€” Global SHAP Summary (save image)

plt.figure()
shap.summary_plot(shap_values, X_test_cb, show=False)
out_path = "../figures/shap_summary_catboost.png"
plt.savefig(out_path, dpi=300, bbox_inches="tight")
plt.close()
print("Saved:", out_path)

Saved: ../figures/shap_summary_catboost.png


In [9]:
#Global Feature Importance Bar (save image)

plt.figure()
shap.summary_plot(shap_values, X_test_cb, plot_type="bar", show=False)
out_path = "../figures/shap_feature_importance_bar_catboost.png"
plt.savefig(out_path, dpi=300, bbox_inches="tight")
plt.close()
print("Saved:", out_path)


Saved: ../figures/shap_feature_importance_bar_catboost.png


In [10]:
mean_abs = np.abs(shap_values).mean(axis=0)
shap_imp_df = pd.DataFrame({
    "feature": X_test_cb.columns,
    "mean_abs_shap": mean_abs
}).sort_values("mean_abs_shap", ascending=False)

csv_path = "../figures/shap_importance_catboost.csv"
shap_imp_df.to_csv(csv_path, index=False)
print("Saved:", csv_path)

shap_imp_df.head(15)

Saved: ../figures/shap_importance_catboost.csv


Unnamed: 0,feature,mean_abs_shap
1,age,1.541575
8,bmi,0.448491
10,age_group,0.343986
0,gender,0.162857
7,avg_glucose_level,0.155506
9,smoking_status,0.113021
5,work_type,0.101877
6,Residence_type,0.068945
4,ever_married,0.061587
2,hypertension,0.058777


In [11]:
idx = 10
row = X_test_cb.iloc[idx:idx+1]

# shap.Explanation object (newer shap versions)
base_value = explainer.expected_value
if isinstance(base_value, (list, np.ndarray)):
    # class-1 base value if list
    base_value = base_value[1] if len(np.array(base_value).shape) > 0 else base_value

exp = shap.Explanation(
    values=shap_values[idx],
    base_values=base_value,
    data=row.iloc[0].values,
    feature_names=X_test_cb.columns
)

plt.figure()
shap.plots.waterfall(exp, max_display=12, show=False)
out_path = f"../figures/shap_waterfall_catboost_{idx}.png"
plt.savefig(out_path, dpi=300, bbox_inches="tight")
plt.close()
print("Saved:", out_path)

Saved: ../figures/shap_waterfall_catboost_10.png


In [12]:
meta = {
    "local_index_used": idx,
    "note": "Index refers to X_test after stratified split random_state=42"
}
joblib.dump(meta, "../figures/shap_local_meta.pkl")
print("Saved: ../figures/shap_local_meta.pkl")
meta

Saved: ../figures/shap_local_meta.pkl


{'local_index_used': 10,
 'note': 'Index refers to X_test after stratified split random_state=42'}

In [13]:
print("Figures:", os.listdir("../figures"))

Figures: ['.ipynb_checkpoints', 'catboost_feature_importance.csv', 'confusion_matrix_catboost.csv', 'confusion_matrix_hybrid.csv', 'confusion_matrix_mlp.csv', 'hybrid_final_metrics.csv', 'hybrid_threshold_results.csv', 'hybrid_weight_results.csv', 'mlp_metrics.csv', 'shap_bar.png', 'shap_feature_importance_bar_catboost.png', 'shap_importance_catboost.csv', 'shap_local_meta.pkl', 'shap_summary.png', 'shap_summary_catboost.png', 'shap_waterfall_0.png', 'shap_waterfall_10.png', 'shap_waterfall_50.png', 'shap_waterfall_catboost_10.png']
