In [None]:
# shap_local_targets_xgb.py
from pathlib import Path
import joblib
import numpy as np
import pandas as pd
import xgboost as xgb
import shap
import matplotlib.pyplot as plt

import os

DIR = Path(os.getenv("BASE_DIR"))
BASE_DIR = Path(DIR / "Train_Final2")
OUT_DIR  = BASE_DIR / "ml_runs_pu_ensemble_xgb"
SETTING  = "pos1_bg2"

DATA_TABLE = BASE_DIR / "grid_500m_ml_table.parquet"
FEATURE_COLS_TXT = BASE_DIR / "feature_cols.txt"

XAI_DIR = OUT_DIR / "xai_outputs"
XAI_DIR.mkdir(exist_ok=True)

TOP_K = 10              
SAVE_FORCE_HTML = True  # force plot interactive (html)


df = pd.read_parquet(DATA_TABLE)

feature_cols = pd.read_csv(FEATURE_COLS_TXT, header=None)[0].tolist()

feature_cols = [c for c in feature_cols if c in df.columns and pd.api.types.is_numeric_dtype(df[c])]

X = df[feature_cols].astype(np.float32)


model = joblib.load(OUT_DIR / SETTING / "best_model.joblib")
booster = model.get_booster()

pred_mean = np.load(OUT_DIR / SETTING / "pred_mean.npy")  # shape (N,)


top_idx = np.argsort(pred_mean)[-TOP_K:][::-1]  # indices of top-K
X_top = X.iloc[top_idx]


cols_meta = [c for c in ["cell_id", "cell_x", "cell_y"] if c in df.columns]
meta = df.iloc[top_idx][cols_meta].copy() if cols_meta else pd.DataFrame(index=top_idx)
meta["pred_mean"] = pred_mean[top_idx]
meta.to_csv(XAI_DIR / f"local_targets_top{TOP_K}_meta.csv", index=False)


dmat_top = xgb.DMatrix(X_top, feature_names=feature_cols)
contrib = booster.predict(dmat_top, pred_contribs=True)

shap_vals = contrib[:, :-1]
base_vals = contrib[:, -1]


for i, ridx in enumerate(top_idx):
    exp_i = shap.Explanation(
        values=shap_vals[i],
        base_values=base_vals[i],
        data=X_top.iloc[i].to_numpy(),
        feature_names=feature_cols
    )

    # Waterfall (static PNG)
    shap.plots.waterfall(exp_i, max_display=20, show=False)
    plt.tight_layout()
    plt.savefig(XAI_DIR / f"target_{i:02d}_idx{ridx}_waterfall.png", dpi=300)
    plt.close()

    # Force plot (interactive HTML)
    if SAVE_FORCE_HTML:
        fp = shap.force_plot(
            base_vals[i],
            shap_vals[i],
            X_top.iloc[i],
            feature_names=feature_cols,
            matplotlib=False
        )
        shap.save_html(str(XAI_DIR / f"target_{i:02d}_idx{ridx}_force.html"), fp)


print(f"✅ DONE: Local SHAP saved to: {XAI_DIR}")
print(f"- meta CSV: local_targets_top{TOP_K}_meta.csv")
print("- waterfall PNGs + (optional) force HTMLs")


✅ DONE: Local SHAP saved to: C:\Users\Phong\Desktop\GIS\Project 2\Train_Final2\ml_runs_pu_ensemble_xgb\xai_outputs
- meta CSV: local_targets_top10_meta.csv
- waterfall PNGs + (optional) force HTMLs
