In [None]:
# If SHAP isn't installed in the kernel, just pip install shap

import shap
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

plt.rcParams["figure.figsize"] = (9, 6)
plt.rcParams["axes.grid"] = True

#Helpers
def _pick_estimator(m, idx=0):
    """Return the estimator for target idx if model is MultiOutput; else return m."""
    return m.estimators_[idx] if hasattr(m, "estimators_") else m

def _feat_names(X):
    return list(X.columns) if isinstance(X, pd.DataFrame) else feature_columns

# Choose which target to explain
TARGET_IDX = 0
TARGET_NAME = target_columns[TARGET_IDX] if ("target_columns" in globals() and len(target_columns) > TARGET_IDX) else f"target_{TARGET_IDX}"

# Subsample test set for quicker SHAP (keeps your column order)
if isinstance(X_test, pd.DataFrame):
    X_slice = X_test.sample(min(300, len(X_test)), random_state=42)
else:
    # If numpy, wrap as DataFrame for nicer labels
    X_slice = pd.DataFrame(X_test[:min(300, len(X_test))], columns=feature_columns)

# Build explainer for the selected target's RF and compute SHAP values
rf_est = _pick_estimator(model, TARGET_IDX)
explainer = shap.TreeExplainer(rf_est)
shap_values = explainer.shap_values(X_slice)

#Figure 1: Global importance (bar) 
shap.summary_plot(
    shap_values,
    X_slice,
    feature_names=_feat_names(X_slice),
    plot_type="bar",
    show=False
)
plt.title(f"Global Feature Importance (SHAP) – {TARGET_NAME}")
plt.tight_layout()
plt.savefig(f"shap_global_bar__{TARGET_NAME}.png", dpi=160, bbox_inches="tight")
plt.show()

#Figure 2: Impact + direction (beeswarm) ---
shap.summary_plot(
    shap_values,
    X_slice,
    feature_names=_feat_names(X_slice),
    show=False
)
plt.title(f"SHAP Summary (Impact & Direction) – {TARGET_NAME}")
plt.tight_layout()
plt.savefig(f"shap_beeswarm__{TARGET_NAME}.png", dpi=160, bbox_inches="tight")
plt.show()

# --- Optional small table for your write-up ---
abs_mean = np.abs(np.array(shap_values)).mean(axis=0)
top10 = (
    pd.DataFrame({"feature": _feat_names(X_slice), "mean_abs_SHAP": abs_mean})
    .sort_values("mean_abs_SHAP", ascending=False)
    .head(10)
    .reset_index(drop=True)
)
display(top10)
top10.to_csv(f"shap_top10__{TARGET_NAME}.csv", index=False)
