In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
SHAP analysis (best-of-100) with combined overlay figure
--------------------------------------------------------
- Reads: data/训练模型数据(最终).xlsx and data/XGB(ADF)SHAP分析.xlsx
- Picks the single best run by Test_R2 (best-of-100; matches the manuscript)
- Re-trains the model with the same random_state on a 70/30 split
- Computes SHAP (TreeExplainer)
- Exports three figures:
  1) figures/shap_summary_dot.png  (beeswarm)
  2) figures/shap_summary_bar.png  (mean |SHAP| bar)
  3) figures/SHAP标准.svg          (overlay: beeswarm on bottom axis + mean bar on top axis)

Notes:
- The overlay uses twiny() for a top x-axis. Draw the dot plot first, then draw the
  bar plot on the shared y-ordering, lowering bar alpha so dots are not obscured.
"""

import os
import numpy as np
import pandas as pd
from xgboost import XGBRegressor
from sklearn.model_selection import train_test_split
import shap
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import rcParams

# -----------------------------
# Paths & files
# -----------------------------
DATA_XLSX = os.path.join("data", "Train_data.xlsx")
RES_XLSX  = os.path.join("data", "HHO_XGB_results.xlsx")
FIG_DIR   = "figures"
os.makedirs(FIG_DIR, exist_ok=True)

# -----------------------------
# Features & target
# -----------------------------
FEATURES = ["SSSI", "SSP", "DEM", "MRRTF", "NDVI", "MRVBF", "RD"]
TARGET   = "thickness"

# -----------------------------
# Figure style (journal-like)
# -----------------------------
def set_figure_style(layout="single"):
    if layout == "single":
        font_size = 8
        tick_size = 7
        legend_size = 7
        linewidth = 0.6
    elif layout == "double":
        font_size = 9
        tick_size = 8
        legend_size = 8
        linewidth = 0.8
    elif layout == "three-quarter":
        font_size = 9
        tick_size = 8
        legend_size = 8
        linewidth = 0.7
    else:
        raise ValueError("layout must be 'single', 'double', or 'three-quarter'")

    mpl.rcParams.update({
        'font.family': 'Times New Roman',
        'font.size': font_size,
        'axes.labelweight': 'bold',
        'axes.titlesize': font_size,
        'axes.titleweight': 'bold',
        'axes.linewidth': linewidth,
        'xtick.labelsize': tick_size,
        'ytick.labelsize': tick_size,
        'xtick.color': 'black',
        'ytick.color': 'black',
        'xtick.major.size': 3,
        'ytick.major.size': 3,
        'xtick.major.width': linewidth,
        'ytick.major.width': linewidth,
        'legend.fontsize': legend_size,
        'axes.formatter.use_mathtext': True,
        'axes.formatter.limits': (-3, 4),
    })

set_figure_style("single")

# -----------------------------
# 1) Load data & pick the best run
# -----------------------------
df  = pd.read_excel(DATA_XLSX).dropna()
res = pd.read_excel(RES_XLSX)
res_sssi = res if "Group" not in res.columns else res[res["Group"] == "SSSI"].copy()

if res_sssi.empty:
    raise ValueError("No SSSI rows found in .xlsx.")

best = res_sssi.sort_values("Test_R2", ascending=False).iloc[0]
iter_seed = int(best["Iteration"])
params = dict(
    n_estimators=int(best["n_estimators"]),
    max_depth=int(best["max_depth"]),
    learning_rate=float(best["learning_rate"]),
    objective="reg:squarederror",
    verbosity=0,
    random_state=iter_seed,
)

# -----------------------------
# 2) Train/test split (70/30, as in the manuscript)
# -----------------------------
X = df[FEATURES]
y = df[TARGET]
X_tr, X_te, y_tr, y_te = train_test_split(X, y, train_size=0.7, random_state=iter_seed)

# Fit the best run’s model
model = XGBRegressor(**params).fit(X_tr, y_tr)

# -----------------------------
# 3) SHAP (TreeExplainer)
#    To keep the ordering consistent across plots, compute SHAP on full X
# -----------------------------
explainer = shap.TreeExplainer(model)
# Compatibility with different SHAP versions:
try:
    shap_values = explainer.shap_values(X)
except Exception:
    shap_values = explainer(X)

# -----------------------------
# 4) Export separate dot & bar plots (for clean journal layouts)
# -----------------------------
plt.figure()
shap.summary_plot(shap_values, X, feature_names=X.columns, plot_type="dot", show=False, color_bar=True)
plt.tight_layout()
plt.savefig(os.path.join(FIG_DIR, "shap_summary_dot.png"), dpi=600, bbox_inches="tight")
plt.close()

plt.figure()
shap.summary_plot(shap_values, X, feature_names=X.columns, plot_type="bar", show=False)
plt.tight_layout()
plt.savefig(os.path.join(FIG_DIR, "shap_summary_bar.png"), dpi=600, bbox_inches="tight")
plt.close()

# -----------------------------
# 5) Overlay figure (beeswarm + top-axis mean bar): SHAP.svg
# -----------------------------
# Key points:
# - Draw the dot plot (with colorbar) first, then call twiny() and draw the bar plot
#   using the same y-ordering. Make bars semi-transparent so dots remain visible.
# -----------------------------
fig, ax1 = plt.subplots(figsize=(3.67, 2.56), dpi=1200)

# Beeswarm on the main axis
shap.summary_plot(
    shap_values, X,
    feature_names=X.columns,
    plot_type="dot",
    show=False,
    color_bar=True
)
# Adjust the main drawing area (left, bottom, width, height)
plt.gca().set_position([0.15, 0.2, 0.7, 0.7])
ax1 = plt.gca()

# Top axis + bar plot
ax2 = ax1.twiny()
# Reuse SHAP’s bar plot to keep the exact same ordering & ticks
shap.summary_plot(
    shap_values, X,
    feature_names=X.columns,
    plot_type="bar",
    show=False
)
plt.gca().set_position([0.15, 0.2, 0.7, 0.7])  # keep the same drawing area

# Lower bar opacity
for bar in ax2.patches:
    bar.set_alpha(0.25)

# Labels & ticks
ax1.set_xlabel('Shapley Value Contribution (Bee Swarm)', fontsize=12)
ax1.set_ylabel('Features', fontsize=12)
ax2.set_xlabel('Mean |SHAP| (Feature Importance)', fontsize=12)
ax2.xaxis.set_visible(True)
ax2.xaxis.set_ticks_position('top')
ax2.xaxis.set_label_position('top')
ax2.tick_params(axis='x', which='both', labelsize=10)

# Fixed x-range example (uncomment if you want it)
# ax1.set_xticks(np.arange(-1, 4, 0.5))
# ax1.set_xlim(-1, 3)

plt.tight_layout()
plt.savefig(os.path.join(FIG_DIR, "SHAP.svg"), dpi=600, bbox_inches='tight')
plt.close()

print("[DONE] SHAP: dot/bar/overlay saved to the figures/ directory.")
