## 0）Setup and Styling

In [1]:
import os, json, numpy as np, pandas as pd, matplotlib.pyplot as plt, matplotlib as mpl
from pathlib import Path

# Directories
SAVE_DIR = "./runs_demo"
FIG_DIR  = os.path.join(SAVE_DIR, "stage6_explain")
DATA_PATH = "./data/global_clean.csv"   # For rebuilding X_df/feat_cols
Path(FIG_DIR).mkdir(parents=True, exist_ok=True)

# Paper-ready style
mpl.rcParams.update({
    "figure.dpi": 160, "savefig.dpi": 320,
    "font.size": 11, "axes.titlesize": 13,
    "axes.labelsize": 11, "xtick.labelsize": 10, "ytick.labelsize": 10,
    "axes.spines.top": False, "axes.spines.right": False,
    "axes.grid": False, "legend.frameon": False,
    "pdf.fonttype": 42, "ps.fonttype": 42,   
})

# Colors
COL_XGB, COL_LGB = "#4E79A7", "#F28E2B"
COL_LINE = "#9aa0a6"

def safe_savefig(fig, out_base):
    for ext in ("png","pdf","svg"):
        fig.savefig(f"{out_base}.{ext}", bbox_inches="tight", dpi=320)

print("[Init] SAVE_DIR:", SAVE_DIR, "| FIG_DIR:", FIG_DIR)

[Init] SAVE_DIR: ./runs_demo | FIG_DIR: ./runs_demo/stage6_explain1


## 1) Rebuild X_df / feat_cols / df_meta (from CSV)

In [2]:
drop_cols_hard = [
    "Unnamed: 0","ij_grid","i_grid","j_grid","x_proj","y_proj","survey_id","date","rgi_id",
    "consensus_ice_thickness","millan_ice_thickness","itslive_v","hugonnet_dhdt",
    "glacier_length","glacier_area_km2","glacier_oggm_volume",
    "glacier_min_elev","glacier_max_elev","glacier_median_elev","glacier_outline_year",
    "lin_mb_above_z","oggm_mb_above_z","thickness_uncertainty",
]
TARGET_COL = "thickness"

df_raw = pd.read_csv(DATA_PATH, low_memory=False)
df = df_raw.drop(columns=[c for c in drop_cols_hard if c in df_raw.columns], errors="ignore").copy()
assert TARGET_COL in df.columns, f"Target '{TARGET_COL}' missing."

RGI_COL   = "RGI" if "RGI" in df.columns else ("region" if "region" in df.columns else None)
ID_COL    = "glacier_id" if "glacier_id" in df.columns else None
NAME_COL  = "glacier_name" if "glacier_name" in df.columns else None

exclude_cols = {TARGET_COL, "lat", "lon", "year", "date"}
for c in (RGI_COL, ID_COL, NAME_COL):
    if c: exclude_cols.add(c)

num_df = df.select_dtypes(include=[np.number])

std0_cols = [c for c in num_df.columns if num_df[c].nunique(dropna=True) <= 1]
num_df = num_df.drop(columns=std0_cols, errors="ignore")

feat_cols = [c for c in num_df.columns if c not in exclude_cols]
X_df = df[feat_cols].astype("float32").replace([np.inf, -np.inf], np.nan)
y_sr = df[TARGET_COL].astype("float32")

row_mask = ~(X_df.isna().any(axis=1) | y_sr.isna())
X_df, y_sr, df = X_df.loc[row_mask], y_sr.loc[row_mask], df.loc[row_mask].reset_index(drop=True)

meta_path = os.path.join(SAVE_DIR, "df_meta.csv")
if os.path.exists(meta_path):
    df_meta = pd.read_csv(meta_path)
    if len(df_meta) != len(df):
        df_meta = df[[c for c in ["glacier_name","RGI","glacier_id"] if c in df.columns]].copy()
else:
    df_meta = df[[c for c in ["glacier_name","RGI","glacier_id"] if c in df.columns]].copy()

print("[Data] rows =", len(df), "| n_features =", len(feat_cols))
print("[Data] feat preview:", feat_cols[:12])
print("[Meta] cols:", list(df_meta.columns))

[Data] rows = 284558 | n_features = 18
[Data] feat preview: ['latitude', 'longitude', 'topo', 'topo_smoothed', 'slope', 'slope_factor', 'aspect', 'dis_from_border', 'catchment_area', 'millan_v', 'glacier_cenlon', 'glacier_cenlat']
[Meta] cols: ['RGI']


## 2) Utilities for importance

In [3]:
def load_imp_pct(csv_path):
    """Read (feature, importance) → Return (feature, pct), where pct sums to 100."""
    df = pd.read_csv(csv_path)
    df.columns = [c.lower().strip() for c in df.columns]
    s = df["importance"].sum()
    df["pct"] = (df["importance"] / s * 100.0) if s > 0 else 0.0
    return df[["feature","pct"]]

## 3) Importance

In [4]:
TOPK = 15
paths = {
    "builtin_xgb": os.path.join(FIG_DIR, "imp_builtin_xgboost.csv"),
    "builtin_lgb": os.path.join(FIG_DIR, "imp_builtin_lightgbm.csv"),
    "perm_xgb":    os.path.join(FIG_DIR, "imp_perm_xgboost.csv"),
    "perm_lgb":    os.path.join(FIG_DIR, "imp_perm_lightgbm.csv"),
}

# Read and convert to percentage
bx = load_imp_pct(paths["builtin_xgb"])
bl = load_imp_pct(paths["builtin_lgb"])
px = load_imp_pct(paths["perm_xgb"])
pl = load_imp_pct(paths["perm_lgb"])

def _wrap(s, width=20):
    s = str(s).replace("_", "·")
    if len(s) <= width: return s
    out, line = [], ""
    for ch in s:
        line += ch
        if len(line) >= width:
            out.append(line); line = ""
    if line: out.append(line)
    return "\n".join(out)

def plot_panel(ax, df, color, title, tag, topk=15):
    """Each panel independently: sorted descending, white background, no grid, independent y-axis categories & limits."""
    d = (df.sort_values("pct", ascending=False)
           .head(topk)
           .reset_index(drop=True))
    feats = d["feature"].tolist()
    vals  = d["pct"].to_numpy()

    # White background + clean axes
    ax.set_facecolor("white")
    for side in ("top","right"):
        ax.spines[side].set_visible(False)
    ax.grid(False)

    # Independent y-axis: each subplot sets yticks/labels and ylim individually
    y = np.arange(len(feats))
    bars = ax.barh(y, vals, color=color, height=0.64, align="center")
    ax.set_yticks(y)
    ax.set_yticklabels([_wrap(f) for f in feats], fontsize=10)
    ax.set_ylim(-0.5, len(feats)-0.5)
    ax.invert_yaxis()

    # x-axis (adjust to each subplot’s range)
    xmax = 1.06 * vals.max() if len(vals) else 1.0
    ax.set_xlim(0, xmax)
    ax.set_xlabel("Importance (%)")
    ax.set_title(f"{tag}. {title}", loc="center")

    # Add value annotations at the end of bars
    pad = 0.6 if xmax <= 10 else 0.8
    for b, v in zip(bars, vals):
        x = min(v + pad, xmax - pad*0.4)
        ax.text(x, b.get_y() + b.get_height()/2, f"{v:.2f}",
                ha="left", va="center", fontsize=9)

# 2x2 subplot without shared axes
fig_h = max(7.0, 0.45 * TOPK * 2)
fig, axes = plt.subplots(2, 2, figsize=(14.2, fig_h))
fig.set_facecolor("white")

plot_panel(axes[0,0], bx, COL_XGB, "Built-in Importance (XGB)", "a", TOPK)
plot_panel(axes[0,1], bl, COL_LGB, "Built-in Importance (LGB)", "b", TOPK)
plot_panel(axes[1,0], px, COL_XGB, "Permutation Importance (XGB)", "c", TOPK)
plot_panel(axes[1,1], pl, COL_LGB, "Permutation Importance (LGB)", "d", TOPK)

plt.tight_layout()
out_base = os.path.join(FIG_DIR, f"fig_importance_2x2_top{TOPK}_no_share_axes")
safe_savefig(fig, out_base)
plt.close(fig)
print("[Saved]", out_base)

[Saved] ./runs_demo/stage6_explain1/fig_importance_2x2_top15_no_share_axes


## 4)Importance Comparison

In [5]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

# Colors
COL_XGB  = globals().get("COL_XGB",  "#4E79A7")
COL_LGB  = globals().get("COL_LGB",  "#F28E2B")
COL_LINE = globals().get("COL_LINE", "#bfc4cc")

# Canvas and axis length settings
FIG_W         = 16.0
TITLE_FSIZE   = 14
X_PAD_RATIO   = 0.12

def prep_diff(xgb_csv, lgb_csv, topk=15):
    a = load_imp_pct(xgb_csv).rename(columns={"pct": "XGB"})
    b = load_imp_pct(lgb_csv).rename(columns={"pct": "LGB"})
    m = a.merge(b, on="feature", how="outer").fillna(0.0)
    m["diff"] = (m["XGB"] - m["LGB"]).abs()
    return (m.sort_values("diff", ascending=False)
             .head(topk)
             .reset_index(drop=True))

TOPK = 15
built_df = prep_diff(paths["builtin_xgb"], paths["builtin_lgb"], TOPK)
perm_df  = prep_diff(paths["perm_xgb"],    paths["perm_lgb"],    TOPK)

def _format_axis(ax):
    ax.set_facecolor("white")
    for s in ("top","right"):
        ax.spines[s].set_visible(False)
    ax.grid(False)
    ax.xaxis.set_major_locator(mpl.ticker.MaxNLocator(nbins=7))
    ax.xaxis.set_major_formatter(mpl.ticker.FormatStrFormatter("%.0f"))

def draw_simple_dumbbell(ax, df, title, show_legend=False):
    _format_axis(ax)
    n = len(df); y = np.arange(n)

    lgb = df["LGB"].to_numpy()
    xgb = df["XGB"].to_numpy()

    # Gray line connections
    left  = np.minimum(lgb, xgb)
    right = np.maximum(lgb, xgb)
    ax.hlines(y, left, right, color=COL_LINE, linewidth=2.0, zorder=1)

    # Dots
    p_lgb = ax.scatter(lgb, y, s=42, color=COL_LGB, marker="o", zorder=2, label="LGB")
    p_xgb = ax.scatter(xgb, y, s=42, color=COL_XGB, marker="o", zorder=3, label="XGB")

    # y-axis
    ax.set_yticks(y)
    ax.set_yticklabels(df["feature"], fontsize=10)
    ax.invert_yaxis()

    # Extended x-axis (add X_PAD_RATIO padding to max value)
    xmax = float(np.max(np.r_[lgb, xgb])) if n else 1.0
    ax.set_xlim(0, xmax * (1.0 + X_PAD_RATIO))
    ax.set_xlabel("Importance (%)")

    # Centered title
    ax.set_title(title, loc="center", fontsize=TITLE_FSIZE, pad=6)

    if show_legend:
        ax.legend(loc="lower right", frameon=False, ncols=2, fontsize=9)

# 2×1 layout (vertical), no shared axes
fig_h = max(6.0, 0.45 * TOPK) * 2 + 0.8
fig, axes = plt.subplots(2, 1, figsize=(FIG_W, fig_h), sharex=False, sharey=False)
fig.set_facecolor("white")

draw_simple_dumbbell(axes[0], built_df, "a. Built-in Importance (sorted by abs diff)", show_legend=False)
draw_simple_dumbbell(axes[1], perm_df,  "b. Permutation Importance (sorted by abs diff)", show_legend=True)

plt.tight_layout()
out_base = os.path.join(FIG_DIR, "fig_dumbbell_delta_2x1_centered_wide")
safe_savefig(fig, out_base)
plt.close(fig)
print("[Saved]", out_base + ".png / .pdf / .svg")

[Saved] ./runs_demo/stage6_explain1/fig_dumbbell_delta_2x1_centered_wide.png / .pdf / .svg


## 5) SHAP Beeswarm

In [6]:
# Requires stage6_explain/xgb_shap_subset.npy & lgb_shap_subset.npy
sv_xgb = np.load(os.path.join(FIG_DIR, "xgb_shap_subset.npy"))
sv_lgb = np.load(os.path.join(FIG_DIR, "lgb_shap_subset.npy"))

# Ensure feature names align with column order
assert sv_xgb.shape[1] == sv_lgb.shape[1] == len(feat_cols), "SHAP shape != feature count"

# Sorting rule: average of mean(|SHAP|) from both models
imp_x = np.abs(sv_xgb).mean(0)
imp_l = np.abs(sv_lgb).mean(0)
order = np.argsort(-(imp_x + imp_l)/2)

X_sub = pd.DataFrame(X_df.to_numpy()[:sv_xgb.shape[0], :], columns=feat_cols)

import shap
fig, axes = plt.subplots(1, 2, figsize=(16, 6), sharey=True)

plt.sca(axes[0])
shap.summary_plot(sv_xgb[:, order], X_sub.iloc[:, order], show=False, plot_size=None)
axes[0].set_title("a. SHAP Beeswarm (XGB; shared order & x-range)")
xlim_shared = axes[0].get_xlim()

plt.sca(axes[1])
shap.summary_plot(sv_lgb[:, order], X_sub.iloc[:, order], show=False, plot_size=None)
axes[1].set_title("b. SHAP Beeswarm (LGB; shared order & x-range)")
axes[1].set_xlim(xlim_shared)

plt.tight_layout()
out_base = os.path.join(FIG_DIR, "fig_shap_beeswarm_shared")
safe_savefig(fig, out_base)
plt.close(fig)
print("[Saved]", out_base)

[Saved] ./runs_demo/stage6_explain1/fig_shap_beeswarm_shared


## 6) SHAP Dependence

In [7]:
# Select Top-3 (same sorting as in Cell 5)
top3_idx = order[:3]
top3_names = [feat_cols[i] for i in top3_idx]

fig, axes = plt.subplots(2, 3, figsize=(16, 8))
cmap = plt.get_cmap("viridis")
norm = plt.Normalize(X_sub[top3_names[2]].min(), X_sub[top3_names[2]].max())  # Use 3rd feature for shared colorbar

for j, fid in enumerate(top3_idx):
    # XGB
    ax = axes[0, j]
    sc = ax.scatter(X_sub.iloc[:, fid], sv_xgb[:, fid], c=X_sub.iloc[:, top3_idx[2]], s=8, cmap=cmap, norm=norm)
    ax.set_xlabel(feat_cols[fid]); ax.set_ylabel(f"SHAP({feat_cols[fid]})")
    ax.set_title(f"a{j+1}. XGB — {feat_cols[fid]}")
    # LGB
    ax = axes[1, j]
    ax.scatter(X_sub.iloc[:, fid], sv_lgb[:, fid], c=X_sub.iloc[:, top3_idx[2]], s=8, cmap=cmap, norm=norm)
    ax.set_xlabel(feat_cols[fid]); ax.set_ylabel(f"SHAP({feat_cols[fid]})")
    ax.set_title(f"b{j+1}. LGB — {feat_cols[fid]}")

# Shared colorbar
cax = fig.add_axes([0.92, 0.15, 0.015, 0.7])
cb = plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), cax=cax)
cb.set_label(f"{feat_cols[top3_idx[2]]} (shared)")

plt.tight_layout(rect=[0,0,0.9,1])
out_base = os.path.join(FIG_DIR, "fig_shap_dependence_top3_sharedcbar")
safe_savefig(fig, out_base)
plt.close(fig)
print("[Saved]", out_base)

  plt.tight_layout(rect=[0,0,0.9,1])


[Saved] ./runs_demo/stage6_explain1/fig_shap_dependence_top3_sharedcbar


## 7) Local SHAP — Force-only

In [8]:
import os, re, json, warnings
import numpy as np
import pandas as pd
import shap
import matplotlib as mpl
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore")

# ---------- Paths ----------
SAVE_DIR  = "./runs_demo"
FIG_DIR   = os.path.join(SAVE_DIR, "stage6_explain1")   # Directory for current cell's outputs
BLEND_DIR = os.path.join(SAVE_DIR, "stage4_best_blend") # Load pre-trained models if available
os.makedirs(FIG_DIR, exist_ok=True)

# ---------- Expect S1 in memory ----------
# Relies on variables df / X_df / y_sr already set in S1
assert "df" in globals() and "X_df" in globals() and "y_sr" in globals(), \
    "Please run S1 first to ensure df / X_df / y_sr are loaded into memory."

# ---------- Read Metadata and Select Target Sample ----------
meta_path = os.path.join(SAVE_DIR, "df_meta.csv")
if os.path.exists(meta_path):
    df_meta = pd.read_csv(meta_path)
    if 'index' in df_meta.columns:
        df_meta = df_meta.drop(columns=['index'])
    if len(df_meta) != len(X_df):
        df_meta = df[[c for c in ["glacier_name","RGI","glacier_id"] if c in df.columns]].reset_index(drop=True)
else:
    df_meta = df[[c for c in ["glacier_name","RGI","glacier_id"] if c in df.columns]].reset_index(drop=True)

ID_COL = "glacier_name" if "glacier_name" in df_meta.columns else ("RGI" if "RGI" in df_meta.columns else "glacier_id")

TARGET_KEY_OVERRIDE = None  # You can also specify a specific glacier name or region here
target_key = df_meta[ID_COL].iloc[0] if TARGET_KEY_OVERRIDE is None else TARGET_KEY_OVERRIDE
row_idx = int(df_meta.index[df_meta[ID_COL] == target_key][0])

row_df  = X_df.iloc[row_idx:row_idx+1]
row_np  = row_df.values[0]
feat_names = row_df.columns.tolist()
baseline_value = float(np.nanmean(y_sr))

# ---------- Read/Calculate SHAP (Robust) ----------
def _load_model_try(path):
    try:
        import joblib
        return joblib.load(path) if os.path.exists(path) else None
    except Exception as e:
        print("[WARN] Model loading failed:", e); return None

def _explain_tree(model, row_df):
    try:
        exp = shap.TreeExplainer(model)(row_df)
        base = float(np.asarray(exp.base_values).reshape(-1)[0])
        vals = np.asarray(exp.values).reshape(-1)
        dat  = np.asarray(exp.data).reshape(-1)
        return shap.Explanation(values=vals, base_values=base, data=dat,
                                feature_names=row_df.columns.tolist())
    except Exception as e:
        print("[WARN] SHAP calculation failed:", e); return None

sv_xgb = None; sv_lgb = None
mx = _load_model_try(os.path.join(BLEND_DIR, "xgb_final.pkl"))
ml = _load_model_try(os.path.join(BLEND_DIR, "lgb_final.pkl"))
if mx is not None: sv_xgb = _explain_tree(mx, row_df)
if ml is not None: sv_lgb = _explain_tree(ml, row_df)

# Fallback to saved .npy (first search default directory, then in current directory)
def _try_load_npy():
    for d in (os.path.join(SAVE_DIR, "stage6_explain"), FIG_DIR):
        npy_x = os.path.join(d, "xgb_shap_subset.npy")
        npy_l = os.path.join(d, "lgb_shap_subset.npy")
        if os.path.exists(npy_x) and os.path.exists(npy_l):
            arr_x = np.load(npy_x); arr_l = np.load(npy_l)
            if row_idx < arr_x.shape[0] and row_idx < arr_l.shape[0]:
                return arr_x[row_idx], arr_l[row_idx]
    return None, None

if sv_xgb is None or sv_lgb is None:
    axv, alv = _try_load_npy()
    if (sv_xgb is None) and (axv is not None):
        sv_xgb = shap.Explanation(values=axv, base_values=baseline_value,
                                  data=row_np, feature_names=feat_names)
    if (sv_lgb is None) and (alv is not None):
        sv_lgb = shap.Explanation(values=alv, base_values=baseline_value,
                                  data=row_np, feature_names=feat_names)

# Last fallback: zero vector
if sv_xgb is None: sv_xgb = shap.Explanation(np.zeros_like(row_np), baseline_value, row_np, feat_names)
if sv_lgb is None: sv_lgb = shap.Explanation(np.zeros_like(row_np), baseline_value, row_np, feat_names)

# ---------- Rename for Simplicity (Avoid Long Names) ----------
RENAME = {
    "dis_from_border":"dist_border", "catchment_area":"catch_area",
    "topo_smoothed":"topo_smooth", "slope_factor":"slope_f",
    "glacier_cenlon":"g_cenlon", "glacier_cenlat":"g_cenlat",
    "glacier_aar":"g_aar", "glacier_reference_mb":"ref_mb"
}
def _short(n): return RENAME.get(n, n).replace("__","_")
def _ren(sv): return shap.Explanation(values=sv.values, base_values=sv.base_values,
                                      data=sv.data, feature_names=[_short(n) for n in sv.feature_names])
sv_xgb, sv_lgb = _ren(sv_xgb), _ren(sv_lgb)

# ---------- Plot: No Overlap Force (Ultra Wide Canvas + Move Text Out of Way) ----------
mpl.rcParams.update({"font.family":"DejaVu Sans","font.size":11, "figure.dpi":150})

def _safe_save(fig, base):
    for ext in ("png","pdf","svg"):
        fig.savefig(f"{base}.{ext}", bbox_inches="tight", dpi=320)

def save_force_clean(sv, out_base, top_k=12, figsize=(19.0, 3.1)):
    """
    Force plot (ultimate version to prevent overlap):
      1) Keep only Top-K feature names
      2) Remove SHAP native 'higher'/'lower'/'f(x)'/'base value'
      3) Replot these texts above the figure to prevent overlap with colorbar
    """
    order = np.argsort(-np.abs(sv.values))[:top_k]
    vals  = np.asarray(sv.values)[order]
    names = [sv.feature_names[i] for i in order]
    base  = float(np.asarray(sv.base_values).reshape(-1)[0])
    fx    = base + float(vals.sum())

    fig = shap.plots.force(base, vals, features=None, feature_names=names,
                           matplotlib=True, show=False)
    fig.set_size_inches(*figsize)
    ax = fig.axes[0]

    # Top padding: move content down and leave enough space for the text at the top
    fig.subplots_adjust(left=0.05, right=0.995, bottom=0.32, top=0.72)

    # Remove overlapping native SHAP text
    for t in list(ax.texts):
        s = t.get_text().strip()
        if s in ("higher", "lower", "f(x)", "base value"):
            t.remove(); continue
        ss = s.replace(",", "")
        if re.fullmatch(r"-?\d+(\.\d+)?", ss) and t.get_fontsize() >= 14:
            t.remove()

    # Re-draw these texts outside the figure to avoid overlap with colorbar
    fig.text(0.20, 0.93, "higher", color="#ff2c6d", ha="center", va="bottom")
    fig.text(0.50, 0.91, "f(x)",    color="#787878", ha="center", va="bottom")
    fig.text(0.80, 0.93, "lower",  color="#1f77b4", ha="center", va="bottom")
    fig.text(0.50, 0.865, f"{fx:.2f}", ha="center", va="bottom",
             fontsize=18, fontweight="bold", color="black")

    fig.canvas.draw()
    _safe_save(fig, out_base)
    plt.close(fig)

# Save individual images
fx_path = os.path.join(FIG_DIR, "force_xgb_ultra")
fl_path = os.path.join(FIG_DIR, "force_lgb_ultra")
save_force_clean(sv_xgb, fx_path, top_k=12, figsize=(19.0, 3.1))
save_force_clean(sv_lgb, fl_path, top_k=12, figsize=(19.0, 3.1))

# ---------- Stack the Two Force Plots (Vertical Arrangement) ----------
import matplotlib.image as mpimg
safe_name = re.sub(r"[^A-Za-z0-9_.-]+","_", str(target_key))
stack_base = os.path.join(FIG_DIR, f"fig_local_force_stack_{ID_COL}_{safe_name}")

fig, axs = plt.subplots(2, 1, figsize=(19, 7.4))
axs[0].imshow(mpimg.imread(fx_path + ".png")); axs[0].set_axis_off(); axs[0].set_title("a. XGB Force (wide)")
axs[1].imshow(mpimg.imread(fl_path + ".png")); axs[1].set_axis_off(); axs[1].set_title("c. LGB Force (wide)")
fig.suptitle(f"Local explanation — {ID_COL}: {target_key}", y=0.98, fontsize=14)
plt.tight_layout(rect=[0,0,1,0.965])
_safe_save(fig, stack_base)
plt.close(fig)

print("[Saved] XGB force  ->", fx_path + ".png")
print("[Saved] LGB force  ->", fl_path + ".png")
print("[Saved] Force-wide stack ->", stack_base + ".png")

[Saved] XGB force  -> ./runs_demo/stage6_explain1/force_xgb_ultra.png
[Saved] LGB force  -> ./runs_demo/stage6_explain1/force_lgb_ultra.png
[Saved] Force-wide stack -> ./runs_demo/stage6_explain1/fig_local_force_stack_RGI_RGI01.png


## 8) Local SHAP — Waterfall

In [9]:
import os, re, numpy as np, shap, matplotlib as mpl, matplotlib.pyplot as plt, matplotlib.image as mpimg

SAVE_DIR = "./runs_demo"
FIG_DIR  = os.path.join(SAVE_DIR, "stage6_explain1")
os.makedirs(FIG_DIR, exist_ok=True)

# Reuse sv_xgb, sv_lgb, ID_COL, and target_key from Cell 7a (Run Cell 7a if these variables are not in memory)

# Uncomment the following block if you want this cell to run independently:
# try:
#     sv_xgb, sv_lgb
# except NameError:
#     raise RuntimeError("Please run Cell 7a first to generate sv_xgb / sv_lgb / ID_COL / target_key")

# ---------- Shared x-axis Waterfall Plot ----------
mpl.rcParams.update({"font.family":"DejaVu Sans","font.size":11})

WATERFALL_FIGSIZE = (7.6, 5.2)
WATERFALL_LEFT, WATERFALL_RIGHT, WATERFALL_TOP, WATERFALL_BOTTOM = 0.56, 0.98, 0.90, 0.18
TOPK_WATERFALL = 10

def _safe_save(fig, base):
    for ext in ("png","pdf","svg"):
        fig.savefig(f"{base}.{ext}", bbox_inches="tight", dpi=320)

def save_waterfall_shared(sv, out_base, shared_xlim=None, max_display=10):
    shap.plots.waterfall(sv, max_display=max_display, show=False)
    fig = plt.gcf(); ax = plt.gca()
    fig.set_size_inches(*WATERFALL_FIGSIZE)
    fig.subplots_adjust(left=WATERFALL_LEFT, right=WATERFALL_RIGHT,
                        top=WATERFALL_TOP, bottom=WATERFALL_BOTTOM)
    ax.xaxis.set_major_locator(mpl.ticker.MaxNLocator(nbins=6))
    ax.xaxis.set_major_formatter(mpl.ticker.FuncFormatter(lambda x, pos: f"{x:.0f}"))
    if shared_xlim is not None:
        ax.set_xlim(shared_xlim)
    fig.canvas.draw(); _safe_save(fig, out_base)
    xlim = ax.get_xlim(); plt.close(fig); return xlim

# Generate temporary plots to get xlim, then use the common range for both models
tmp_x = os.path.join(FIG_DIR, "_wf_tmp_xgb")
tmp_l = os.path.join(FIG_DIR, "_wf_tmp_lgb")
xlim_x = save_waterfall_shared(sv_xgb, tmp_x, max_display=TOPK_WATERFALL)
xlim_l = save_waterfall_shared(sv_lgb, tmp_l, max_display=TOPK_WATERFALL)
shared_xlim = (min(xlim_x[0], xlim_l[0]), max(xlim_x[1], xlim_l[1]))

wx_path = os.path.join(FIG_DIR, "waterfall_xgb_shared")
wl_path = os.path.join(FIG_DIR, "waterfall_lgb_shared")
save_waterfall_shared(sv_xgb, wx_path, shared_xlim=shared_xlim, max_display=TOPK_WATERFALL)
save_waterfall_shared(sv_lgb, wl_path, shared_xlim=shared_xlim, max_display=TOPK_WATERFALL)

# ---------- Compose: Side-by-Side (1x2) ----------
safe_name = re.sub(r"[^A-Za-z0-9_.-]+","_", str(target_key))
row_base = os.path.join(FIG_DIR, f"fig_local_waterfall_row_{ID_COL}_{safe_name}")

fig, axs = plt.subplots(1, 2, figsize=(15.6, 5.6))
axs[0].imshow(mpimg.imread(wx_path + ".png")); axs[0].set_axis_off(); axs[0].set_title("b. XGB Waterfall")
axs[1].imshow(mpimg.imread(wl_path + ".png")); axs[1].set_axis_off(); axs[1].set_title("d. LGB Waterfall")
fig.suptitle(f"Local explanation — {ID_COL}: {target_key}", y=0.98, fontsize=14)
plt.tight_layout(rect=[0,0,1,0.965]); _safe_save(fig, row_base); plt.close(fig)

print("[Saved] Waterfall-row ->", row_base + ".png")

[Saved] Waterfall-row -> ./runs_demo/stage6_explain1/fig_local_waterfall_row_RGI_RGI01.png
