In [None]:
# -*- coding: utf-8 -*-
"""
SHAP dependence line plots (1x3) with robust filtering and bin-wise export.

Pipeline (English steps):
1) Load npz containing: exp_values (n_samples x n_features),
   shap_values_all (n_samples x n_features), feature_names (list/array).
2) Select target features by basename (e.g., 'Srad.tif'), map to column index.
3) Clean values: remove huge fill values, keep finite values only.
4) Robust outlier trimming: keep [1st, 99th] percentiles of X, apply same mask to SHAP.
5) Quantile binning (default 50 bins). For each valid bin, compute:
   - x statistic (here we keep BOTH center for plotting, and LOWER EDGE for CSV)
   - SHAP median and SHAP std
   - bin sample count
6) Plot median line with a shaded band (choose 'fixed' ±0.25 or 'std' = ±std).
7) Save CSV per feature. **First column is the LOWER BIN EDGE**, per your preference.
8) Save a single 1x3 figure.

Notes:
- Font fallback is added in case 'SimSun' or 'Times New Roman' are not present.
- Colors are kept (since you set them explicitly). If you want strictly default
  matplotlib colors, remove the color arguments to ax.plot/fill_between.
"""

import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.font_manager as fm
from matplotlib.offsetbox import AnchoredOffsetbox, TextArea, HPacker

# -----------------------------
# 0) Configuration (edit here)
# -----------------------------
SAVE_DIR   = r"/path/to/SPEI"                 # directory containing the .npz
NPZ_NAME   = "shap_dependence_data.npz"
OUT_DIR    = os.path.join(SAVE_DIR, "dependence_figs")
os.makedirs(OUT_DIR, exist_ok=True)

# The three features to plot (by basename)
TARGET_FILES = [
    "Srad.tif",          # shortwave radiation
    "mean_NDVI_.tif",    # NDVI (example; adjust scale if needed)
    "SPEI_STImin.tif",   # compound heat-drought index (example)
]

# Labels for x-axes (Chinese + unit, as you used)
LABEL_MAP = {
    "Srad.tif"        : ("太阳辐射", "(W/m$^{2}$)"),
    "mean_NDVI_.tif"  : ("叶面指数", "(m$^{-2}$)"),
    "SPEI_STImin.tif" : ("复合高温热浪事件", "(%)"),
}

# Optional re-scaling of X before binning/plotting (edit as needed)
SCALE_FACTORS = {
    "Srad.tif"        : 1.0,
    "mean_NDVI_.tif"  : 10.0,  # example; set to 1/10000 if stored as ×10000
    "SPEI_STImin.tif" : 1.0,
}

# Binning & filtering
NUM_BINS        = 50
MIN_SAMPLES     = 20   # minimum samples per bin
MIN_SHAP_COUNT  = 10   # minimum finite SHAP samples per bin
FILL_ABS_THRESH = 1e10 # remove absolute values beyond this as invalid fills
TRIM_LOW_Q      = 1.0  # percentile
TRIM_HIGH_Q     = 99.0 # percentile

# Shaded band mode: 'std' (±1 * std)
SHADE_MODE      = "fixed"
FIXED_HALF_WIDTH = 1

# -----------------------------
# 1) Matplotlib global styling
# -----------------------------
mpl.rcParams["axes.linewidth"]   = 3
mpl.rcParams["axes.unicode_minus"] = False
# Keep your serif preference; add safe fallbacks
mpl.rcParams["font.family"]      = ["serif"]
mpl.rcParams["font.serif"]       = ["Times New Roman", "DejaVu Serif", "STIXGeneral"]
mpl.rcParams["mathtext.fontset"] = "stix"

# Try to resolve specific fonts; fallback gracefully
def _safe_find_font(name, fallback=None):
    try:
        path = fm.findfont(name, fallback_to_default=True)
        return fm.FontProperties(fname=path) if os.path.exists(path) else fallback
    except Exception:
        return fallback

font_simsun = _safe_find_font("SimSun", fallback=fm.FontProperties())
font_times  = _safe_find_font("Times New Roman", fallback=fm.FontProperties())

# -----------------------------
# 2) Load npz
# -----------------------------
npz_path = os.path.join(SAVE_DIR, NPZ_NAME)
data = np.load(npz_path, allow_pickle=True)

# Expected keys (with basic checks)
for key in ["exp_values", "shap_values_all", "feature_names"]:
    if key not in data:
        raise KeyError(f"Missing key in npz: '{key}'")

exp_values      = np.asarray(data["exp_values"])          # (n_samples, n_features)
shap_values_all = np.asarray(data["shap_values_all"])     # (n_samples, n_features)
feature_names   = [str(x) for x in data["feature_names"]]

# Map basename -> column index
basename_to_idx = {os.path.basename(fn): i for i, fn in enumerate(feature_names)}

# -----------------------------
# 3) Plot (1 x 3)
# -----------------------------
fig, axes = plt.subplots(1, 3, figsize=(18, 5.2))
axes_list = list(axes) if hasattr(axes, "__len__") else [axes]

for ax, fname in zip(axes_list, TARGET_FILES):
    if fname not in basename_to_idx:
        ax.set_visible(False)
        print(f"[WARN] Target not found in feature_names: {fname}")
        continue

    col_idx = basename_to_idx[fname]
    x_raw   = exp_values[:, col_idx].astype(float)
    shap_x  = shap_values_all[:, col_idx].astype(float)

    # Scale X
    x = x_raw * SCALE_FACTORS.get(fname, 1.0)

    # Remove huge fill values
    x[np.abs(x) > FILL_ABS_THRESH]       = np.nan
    shap_x[np.abs(shap_x) > FILL_ABS_THRESH] = np.nan

    # Keep finite only
    finite = np.isfinite(x) & np.isfinite(shap_x)
    if finite.sum() < 2:
        ax.set_visible(False)
        print(f"[INFO] Too few valid samples after fill removal: {fname}")
        continue

    x = x[finite]
    shap_x = shap_x[finite]

    # Robust trim on X only (1%–99%), then apply mask to SHAP
    try:
        lo = np.nanpercentile(x, TRIM_LOW_Q)
        hi = np.nanpercentile(x, TRIM_HIGH_Q)
    except Exception as e:
        ax.set_visible(False)
        print(f"[ERROR] Percentile trimming failed for {fname}: {e}")
        continue

    if not np.isfinite(lo) or not np.isfinite(hi) or lo >= hi:
        ax.set_visible(False)
        print(f"[INFO] Degenerate trimming bounds for {fname}.")
        continue

    keep = (x >= lo) & (x <= hi) & np.isfinite(x) & np.isfinite(shap_x)
    x = x[keep]
    shap_x = shap_x[keep]

    if x.size < 2:
        ax.set_visible(False)
        print(f"[INFO] Too few samples after trimming: {fname}")
        continue

    # Quantile bins
    try:
        bins = np.nanpercentile(x, np.linspace(0, 100, NUM_BINS + 1))
        bins = np.unique(bins)  # guard against duplicates
    except Exception as e:
        ax.set_visible(False)
        print(f"[ERROR] Binning failed for {fname}: {e}")
        continue

    if bins.size < 2:
        ax.set_visible(False)
        print(f"[INFO] Invalid bins (nearly constant): {fname}")
        continue

    # Digitize: returns bin index in [0, len(bins)-2]
    inds = np.digitize(x, bins, right=False) - 1
    inds = np.clip(inds, 0, len(bins) - 2)

    # Collect stats per bin
    bin_lowers, bin_centers, shap_medians, shap_stds, bin_counts = [], [], [], [], []
    for k in range(len(bins) - 1):
        in_bin = (inds == k)
        if in_bin.sum() > MIN_SAMPLES and np.isfinite(shap_x[in_bin]).sum() >= MIN_SHAP_COUNT:
            # lower edge (for CSV first column, per your preference)
            bin_lowers.append(bins[k])
            bin_centers.append(np.nanmean(x[in_bin]))
            shap_medians.append(np.nanmedian(shap_x[in_bin]))
            shap_stds.append(np.nanstd(shap_x[in_bin]))
            bin_counts.append(int(in_bin.sum()))

    if len(bin_centers) == 0:
        ax.set_visible(False)
        print(f"[INFO] No valid bins after thresholds: {fname}")
        continue

    # Save per-feature CSV (first column = LOWER BIN EDGE)
    out_csv = os.path.join(OUT_DIR, f"{fname}_shap_dependence_line_p1-99.csv")
    arr_out = np.vstack([
        np.array(bin_lowers,  dtype=float),
        np.array(bin_centers, dtype=float),
        np.array(shap_medians, dtype=float),
        np.array(shap_stds,    dtype=float),
        np.array(bin_counts,   dtype=int),
    ]).T
    np.savetxt(
        out_csv,
        arr_out,
        delimiter=",",
        header="bin_lower,bin_center,shap_median,shap_std,bin_count",
        comments=""
    )

    # Plot median line + shaded band
    ax.plot(bin_centers, shap_medians, color="#035040", linewidth=4)

    if SHADE_MODE.lower() == "std":
        lower = np.array(shap_medians) - np.array(shap_stds)
        upper = np.array(shap_medians) + np.array(shap_stds)
    else:  # 'fixed'
        lower = np.array(shap_medians) - FIXED_HALF_WIDTH
        upper = np.array(shap_medians) + FIXED_HALF_WIDTH

    ax.fill_between(bin_centers, lower, upper, color="#58aca1", alpha=0.20)

    # Style
    for spine in ax.spines.values():
        spine.set_linewidth(3)
    ax.tick_params(axis="x", labelsize=16)
    ax.tick_params(axis="y", labelsize=16)

    # Y label (left subplot only)
    if ax is axes_list[0]:
        ax.set_ylabel("SHAP值", fontsize=18, fontproperties=font_simsun)
    else:
        ax.set_ylabel("")

    # X label using Chinese + unit (宋体 + Times)
    zh, unit = LABEL_MAP.get(fname, (fname, ""))
    if unit:
        txt_var  = TextArea(zh,   textprops={"fontsize":18, "fontproperties":font_simsun})
        txt_unit = TextArea(unit, textprops={"fontsize":14, "fontproperties":font_times})
        label_box = HPacker(children=[txt_var, txt_unit], align="center", pad=0, sep=3)
        anchored_box = AnchoredOffsetbox(
            loc="lower center", child=label_box, pad=0., frameon=False,
            bbox_to_anchor=(0.5, -0.28), bbox_transform=ax.transAxes, borderpad=0.
        )
        ax.add_artist(anchored_box)
    else:
        ax.set_xlabel(zh, fontsize=18, fontproperties=font_simsun, labelpad=12)

    ax.grid(False)
    ax.set_title("")

# Tight layout & save
plt.tight_layout()
plt.subplots_adjust(wspace=0.24, hspace=0.4)
out_fig = os.path.join(OUT_DIR, "dependence_1x3_Srad_LAI_STImin_p1-99_cleanfill.tif")
plt.savefig(out_fig, dpi=300, bbox_inches="tight")
plt.show()

print(f"\n✅ Finished 1×3 dependence plots with fill removal + 1–99% trimming.\nSaved figure: {out_fig}")
