In [2]:
import shap
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
import joblib

# load the saved pipeline
full_pipeline = joblib.load("model/grad_boost_model_0923_1534.joblib")  # adjust to your filename


timestamp = datetime.now().strftime('%m%d_%H%M')

# 1) Extract fitted preprocessor and model, transform test set
pre = full_pipeline.named_steps['preprocessor']
gbc = full_pipeline.named_steps['classifier']

# Transform X_test with the fitted preprocessor
X_test_trans = pre.transform(X_test)

# Get feature names after preprocessing
try:
    feature_names = pre.get_feature_names_out()
except Exception:
    # Fallback if sklearn version is older
    feature_names = []
    # Try to build names manually
    try:
        num_names = pre.transformers_[0][2]
        ohe = pre.transformers_[1][1].named_steps['onehot']
        cat_cols = pre.transformers_[1][2]
        ohe_names = ohe.get_feature_names_out(cat_cols)
        feature_names = np.array(list(num_names) + list(ohe_names))
    except Exception:
        feature_names = np.array([f"f{i}" for i in range(X_test_trans.shape[1])])

print(f"[SHAP] Using {len(feature_names)} transformed features.")

# 2) Build a SHAP explainer
# For sklearn GradientBoostingClassifier, TreeExplainer usually works.
# If it fails, we fallback to model-agnostic Explainer on predict_proba.
explainer = None
shap_values = None

try:
    explainer = shap.TreeExplainer(gbc)
    shap_values = explainer.shap_values(X_test_trans)
    # For binary classification, shap_values may be:
    #  - array (n_samples, n_features) → log-odds contributions
    #  - list of arrays per class → pick the positive class (index 1)
    if isinstance(shap_values, list):
        shap_values_pos = shap_values[1]
    else:
        shap_values_pos = shap_values
    expected_value = explainer.expected_value[1] if isinstance(explainer.expected_value, list) else explainer.expected_value
except Exception as e:
    print("[SHAP] TreeExplainer failed, falling back to model-agnostic Explainer. Reason:", e)
    # Use the pipeline’s predict_proba on transformed data
    f = lambda M: gbc.predict_proba(M)[:, 1]
    background = shap.sample(X_test_trans, 100, random_state=42)  # small background for speed
    explainer = shap.Explainer(f, background)
    shap_values_pos = explainer(X_test_trans).values
    expected_value = explainer(X_test_trans[:1]).base_values.mean()

# 3) GLOBAL: summary (beeswarm) and bar plots
print("[SHAP] Rendering global plots...")
plt.figure()
shap.summary_plot(shap_values_pos, X_test_trans, feature_names=feature_names, show=False)
beeswarm_path = os.path.join(GRAPHS_DIR, f"shap_beeswarm_{timestamp}.png")
plt.tight_layout()
plt.savefig(beeswarm_path, dpi=150)
plt.close()
print(f"[SHAP] Saved beeswarm to: {beeswarm_path}")

plt.figure()
shap.summary_plot(shap_values_pos, X_test_trans, feature_names=feature_names, plot_type="bar", show=False)
bar_path = os.path.join(GRAPHS_DIR, f"shap_importance_bar_{timestamp}.png")
plt.tight_layout()
plt.savefig(bar_path, dpi=150)
plt.close()
print(f"[SHAP] Saved bar plot to: {bar_path}")

# 4) LOCAL: force plots for one positive and one negative example
print("[SHAP] Rendering local force plots...")
# Build predictions on test to pick examples
y_prob_test = full_pipeline.predict_proba(X_test)[:, 1]
y_pred_test = (y_prob_test >= 0.5).astype(int)

# Pick one default (1) and one non-default (0)
pos_idx = int(np.where(y_test.values == 1)[0][0]) if (y_test == 1).any() else 0
neg_idx = int(np.where(y_test.values == 0)[0][0]) if (y_test == 0).any() else 0

# SHAP expects a 1D vector for a single instance
pos_vals = shap_values_pos[pos_idx]
neg_vals = shap_values_pos[neg_idx]

# New SHAP API: use save_html for force plot
try:
    shap.initjs()
except Exception:
    pass

pos_html = os.path.join(GRAPHS_DIR, f"shap_force_pos_{timestamp}.html")
neg_html = os.path.join(GRAPHS_DIR, f"shap_force_neg_{timestamp}.html")

# Two ways depending on SHAP version:
try:
    # Old API
    pos_plot = shap.force_plot(expected_value, pos_vals, X_test_trans[pos_idx,:], feature_names=feature_names)
    neg_plot = shap.force_plot(expected_value, neg_vals, X_test_trans[neg_idx,:], feature_names=feature_names)
    shap.save_html(pos_html, pos_plot)
    shap.save_html(neg_html, neg_plot)
except Exception as e:
    # New API (shap.plots.force expects Explanation object)
    print("[SHAP] Old force_plot API failed, trying new API:", e)
    try:
        expl_pos = shap.Explanation(values=pos_vals,
                                    base_values=expected_value,
                                    data=X_test_trans[pos_idx,:],
                                    feature_names=feature_names)
        expl_neg = shap.Explanation(values=neg_vals,
                                    base_values=expected_value,
                                    data=X_test_trans[neg_idx,:],
                                    feature_names=feature_names)
        # shap.plots.force returns a matplotlib/JS object; use save_html if available
        shap.save_html(pos_html, shap.plots.force(expl_pos))
        shap.save_html(neg_html, shap.plots.force(expl_neg))
    except Exception as e2:
        print("[SHAP] New API also failed; as a fallback, we’ll save top-K local contributions textually.", e2)
        # Fallback: print top contributors
        def top_k_local(values, names, k=10):
            order = np.argsort(np.abs(values))[::-1][:k]
            return [(names[i], float(values[i])) for i in order]
        print("Top local contributors (positive case):", top_k_local(pos_vals, feature_names))
        print("Top local contributors (negative case):", top_k_local(neg_vals, feature_names))
        pos_html = None
        neg_html = None

if pos_html:
    print(f"[SHAP] Saved positive local force plot to: {pos_html}")
if neg_html:
    print(f"[SHAP] Saved negative local force plot to: {neg_html}")

# 5) QUICK TEXT SUMMARY FOR SLIDES
# Rank global features by mean(|SHAP|)
mean_abs = np.abs(shap_values_pos).mean(axis=0)
order = np.argsort(mean_abs)[::-1]
top10 = [(feature_names[i], float(mean_abs[i])) for i in order[:10]]
print("\n[SHAP] Top-10 global drivers by mean(|SHAP|):")
for name, val in top10:
    print(f" - {name}: {val:.6f}")


NameError: name 'train_xgboost' is not defined