# ------------------------------------------------------------
# Explains CATE predictions from a CausalForestDML using a fast tree surrogate (GradientBoostingRegressor) + SHAP.
# ------------------------------------------------------------

In [1]:
import joblib
import numpy as np
import pandas as pd
import shap
from sklearn.ensemble import GradientBoostingRegressor
import matplotlib.pyplot as plt
import os

In [3]:
# 1) Load model and test data
model = joblib.load("model/causal_forest_dml_model.pkl")
test = pd.read_csv("data/test_data.csv")

# 2) Split to features X (same columns used in training), outcome Y, treatment T
X_cols = [c for c in test.columns if c not in ("Y", "T")]
X_test = test[X_cols].copy()

# 3) Get CATE predictions from the fitted causal forest (target for surrogate)
cate_pred = model.effect(X_test.values).squeeze()

# 4) Fit a tree surrogate to learn f_surrogate(X) ~ cate_pred
surrogate = GradientBoostingRegressor(random_state=0)
surrogate.fit(X_test, cate_pred)

# 5) SHAP on the surrogate
explainer = shap.TreeExplainer(surrogate)
shap_values = explainer.shap_values(X_test)

# 6) Summary plots
os.makedirs("plots", exist_ok=True)

plt.figure()
shap.summary_plot(shap_values, X_test, feature_names=X_cols, show=False)
plt.tight_layout()
plt.savefig("plots/shap_summary_beeswarm_surrogate.png", dpi=160)
plt.close()

plt.figure()
shap.summary_plot(shap_values, X_test, feature_names=X_cols, plot_type="bar", show=False)
plt.tight_layout()
plt.savefig("plots/shap_summary_bar_surrogate.png", dpi=160)
plt.close()

# 7) dependence plot for a top feature
top_feature = X_cols[np.argsort(np.abs(shap_values).mean(0))[::-1][0]]
plt.figure()
shap.dependence_plot(top_feature, shap_values, X_test, show=False)
plt.tight_layout()
plt.savefig(f"plots/shap_dependence_{top_feature}.png", dpi=160)
plt.close()

print("Saved SHAP plots:\n - plots/shap_summary_beeswarm_surrogate.png\n - plots/shap_summary_bar_surrogate.png\n - plots/shap_dependence_<top_feature>.png")


Saved SHAP plots:
 - plots/shap_summary_beeswarm_surrogate.png
 - plots/shap_summary_bar_surrogate.png
 - plots/shap_dependence_<top_feature>.png


<Figure size 640x480 with 0 Axes>