In [None]:
# -*- coding: utf-8 -*-
"""
XGBoost + RFE (backward) + forward selection, with SHAP reporting.
Steps (English):
1) Load dependent raster and explanatory rasters, align, mask to valid pixels.
2) Robustly clip each explanatory raster to [1st, 99th] percentiles.
3) Train an initial XGBoost model; report baseline R^2/RMSE.
4) RFE (backward elimination): iteratively remove the least-important feature
   until FINAL_FEATURE_COUNT; record a FULL importance ranking.
5) Forward selection: start from top-k, add features only if ΔR^2 >= R2_THRESHOLD.
6) Train final model on the selected set, save metrics & SHAP-based importance.
"""

import os
import numpy as np
import pandas as pd
import rioxarray as rxr
import xgboost as xgb
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
from joblib import Parallel, delayed
import shap
import pickle
import matplotlib.pyplot as plt

# ---------------------------
# Config
# ---------------------------
NUM_CPUS = 32                 # parallel raster loading
R2_THRESHOLD = 0.01          # min ΔR^2 to keep a feature in forward selection
FINAL_FEATURE_COUNT = 3       # target size for RFE and initial seed for forward selection
TEST_SIZE = 0.30
RANDOM_STATE = 42
PERC_LOW, PERC_HIGH = 1, 99   # robust clipping
MIN_VALID_SAMPLES = 100       # avoid training on tiny samples

# ---------------------------
# Utilities
# ---------------------------
def robust_percentiles(arr, low=1, high=99):
    """Return (lo, hi) on finite values only; None if degenerate."""
    vals = np.asarray(arr)
    finite = np.isfinite(vals)
    if not finite.any():
        return None
    lo = np.nanpercentile(vals[finite], low)
    hi = np.nanpercentile(vals[finite], high)
    if not (np.isfinite(lo) and np.isfinite(hi) and lo < hi):
        return None
    return float(lo), float(hi)

def load_and_process_data(file, dep_data):
    """Read, align, mask to dep valid pixels, robustly clip to [1,99] percentiles."""
    da = rxr.open_rasterio(file).squeeze().astype("float32")
    da = da.rio.reproject_match(dep_data)
    da = da.where(~dep_data.isnull())
    bounds = robust_percentiles(da.values, PERC_LOW, PERC_HIGH)
    if bounds is not None:
        lo, hi = bounds
        da = da.clip(min=lo, max=hi)
    return da

def prepare_dataset(explanatory_files, dep_data, n_jobs=NUM_CPUS):
    """Return y (dep) and X (stacked features) as 2D arrays with valid rows only."""
    exps = Parallel(n_jobs=n_jobs)(
        delayed(load_and_process_data)(f, dep_data) for f in explanatory_files
    )
    y = dep_data.values.flatten().astype("float32")
    X = np.vstack([e.values.flatten().astype("float32") for e in exps]).T

    mask = (
        np.isfinite(y) & np.isfinite(X).all(axis=1)
    )
    y, X = y[mask], X[mask]
    return y, X

def train_and_evaluate(X_train, X_test, y_train, y_test):
    """Fit XGBoost and return (model, R2, RMSE)."""
    model = xgb.XGBRegressor(
        n_estimators=500,
        max_depth=None,
        learning_rate=0.1,
        subsample=0.8,
        colsample_bytree=0.8,
        random_state=RANDOM_STATE,
        n_jobs=1,   # avoid conflict with joblib
        tree_method="hist"  # fast & stable default
    )
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    r2 = r2_score(y_test, y_pred)
    rmse = float(np.sqrt(mean_squared_error(y_test, y_pred)))
    return model, float(r2), rmse

# ---------------------------
# Feature selection
# ---------------------------
def recursive_feature_elimination(explanatory_files, dep_data):
    """
    Backward elimination down to FINAL_FEATURE_COUNT.
    Returns:
      survivors: list[str]  (size == FINAL_FEATURE_COUNT)
      full_ranking: list[str]  (most->least important overall)
    """
    remaining = explanatory_files.copy()
    removed_order = []  # tuples of (feature, importance) in removal order (least important at each step)

    print("\n[RFE] Start backward elimination …")
    while len(remaining) > FINAL_FEATURE_COUNT:
        y, X = prepare_dataset(remaining, dep_data)
        if y.size < MIN_VALID_SAMPLES:
            print("[RFE] Too few valid samples; aborting RFE.")
            break

        X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=TEST_SIZE, random_state=RANDOM_STATE)
        model, r2, rmse = train_and_evaluate(X_tr, X_te, y_tr, y_te)
        importances = model.feature_importances_

        # Identify least important feature in the current set
        least_idx = int(np.argmin(importances))
        least_feat = remaining[least_idx]
        least_imp = float(importances[least_idx])
        print(f"[RFE] Remove: {os.path.basename(least_feat)} (importance={least_imp:.6f})  | R²={r2:.4f}, RMSE={rmse:.4f}")

        removed_order.append((least_feat, least_imp))
        remaining.pop(least_idx)

    # Build a full ranking:
    # - Features removed earlier had the lowest importance → put them at the bottom.
    # - Survivors are considered the most important; sort survivors by a final fit to get an order.
    if len(remaining) > 1:
        y, X = prepare_dataset(remaining, dep_data)
        X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=TEST_SIZE, random_state=RANDOM_STATE)
        model, _, _ = train_and_evaluate(X_tr, X_te, y_tr, y_te)
        surv_imps = model.feature_importances_
        survivors_sorted = [f for _, f in sorted(zip(-surv_imps, remaining))]
    else:
        survivors_sorted = remaining.copy()

    removed_features_in_order = [f for f, _ in removed_order]  # from first removed → last removed
    # Most->least important: survivors (sorted by importance) followed by removed (reverse removal order)
    full_ranking = survivors_sorted + removed_features_in_order[::-1]
    print(f"[RFE] Done. Survivors: {[os.path.basename(f) for f in survivors_sorted]}")
    return survivors_sorted, full_ranking

def forward_feature_selection(feature_ranking, dep_data):
    """
    Forward selection starting from the top-k (k=FINAL_FEATURE_COUNT) seed.
    Adds features only if ΔR² >= R2_THRESHOLD.
    Returns (selected_features, best_r2).
    """
    if len(feature_ranking) <= FINAL_FEATURE_COUNT:
        # nothing to add
        seed = feature_ranking.copy()
        y, X = prepare_dataset(seed, dep_data)
        X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=TEST_SIZE, random_state=RANDOM_STATE)
        _, r2, _ = train_and_evaluate(X_tr, X_te, y_tr, y_te)
        return seed, r2

    selected = feature_ranking[:FINAL_FEATURE_COUNT].copy()
    remaining = feature_ranking[FINAL_FEATURE_COUNT:].copy()

    # Evaluate the seed set first
    y, X = prepare_dataset(selected, dep_data)
    X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=TEST_SIZE, random_state=RANDOM_STATE)
    _, best_r2, _ = train_and_evaluate(X_tr, X_te, y_tr, y_te)
    print(f"\n[Forward] Seed (top-{FINAL_FEATURE_COUNT}) R² = {best_r2:.4f}")

    while remaining:
        cand = remaining.pop(0)
        trial = selected + [cand]

        y, X = prepare_dataset(trial, dep_data)
        X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=TEST_SIZE, random_state=RANDOM_STATE)
        _, r2, _ = train_and_evaluate(X_tr, X_te, y_tr, y_te)
        delta = r2 - best_r2
        print(f"[Forward] Try +{os.path.basename(cand)}: R²={r2:.4f} (Δ={delta:.4f})")

        if delta >= R2_THRESHOLD:
            selected.append(cand)
            best_r2 = r2
            print(f"[Forward] Kept: {os.path.basename(cand)}  | New best R²={best_r2:.4f}")
        else:
            print(f"[Forward] Discarded: {os.path.basename(cand)} (Δ<{R2_THRESHOLD})")

    return selected, best_r2

# ---------------------------
# Main
# ---------------------------
def main():
    # --- Load dependent variable (edit path) ---
    dep_file = r"I:\path\to\dependent.tif"
    dep_data = rxr.open_rasterio(dep_file).squeeze()
    dep_data = dep_data.where((dep_data > -10) & (dep_data <= 0))  # example mask for [-10, 0]

    # --- Initial explanatory rasters (edit list) ---
    explanatory_files = [
        r"I:\path\to\ai_v3_yra_.tif",
        r"I:\path\to\KNDVI_alpha.tif",
        r"I:\path\to\Srad.tif",
        r"I:\path\to\VPD.tif",
        r"I:\path\to\Wind.tif",
        r"I:\path\to\Annual_ppt_1982_2021.tif",
        r"I:\path\to\Vegetation_species.tif",
        r"I:\path\to\wc2.1_2.5m_elev.tif",
        r"I:\path\to\mean_NDVI_.tif",
        r"I:\path\to\Tmax_Pfreq.tif",
        r"I:\path\to\SPEI_1982_2021.tif",
        r"I:\path\to\NDVI_EOS_POS_Difference_LAI.tif",
        r"I:\path\to\rplant_proxy.tif",
    ]

    # --- Initial model ---
    y, X = prepare_dataset(explanatory_files, dep_data)
    if y.size < MIN_VALID_SAMPLES:
        raise RuntimeError("Too few valid samples after masking/clipping. Check rasters & alignment.")
    X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=TEST_SIZE, random_state=RANDOM_STATE)
    init_model, init_r2, init_rmse = train_and_evaluate(X_tr, X_te, y_tr, y_te)
    print(f"\n[Init] R²={init_r2:.4f}, RMSE={init_rmse:.4f}  (n={y.size})")

    # Save initial predictions
    with open("initial_xgb_results.pkl", "wb") as f:
        pickle.dump(
            {"y_test": y_te, "y_pred": init_model.predict(X_te), "r2": init_r2, "rmse": init_rmse},
            f
        )

    # --- RFE to FINAL_FEATURE_COUNT, plus full ranking ---
    survivors, full_ranking = recursive_feature_elimination(explanatory_files, dep_data)

    # If you prefer to forward-select using full_ranking (most->least important):
    selected_features, forward_r2 = forward_feature_selection(full_ranking, dep_data)

    # --- Final model on selected features ---
    yF, XF = prepare_dataset(selected_features, dep_data)
    X_trF, X_teF, y_trF, y_teF = train_test_split(XF, yF, test_size=TEST_SIZE, random_state=RANDOM_STATE)
    final_model, final_r2, final_rmse = train_and_evaluate(X_trF, X_teF, y_trF, y_teF)
    print(f"\n[Final] Features={len(selected_features)}  R²={final_r2:.4f}, RMSE={final_rmse:.4f}")

    with open("final_xgb_results.pkl", "wb") as f:
        pickle.dump(
            {"y_test": y_teF, "y_pred": final_model.predict(X_teF), "r2": final_r2, "rmse": final_rmse},
            f
        )

    # --- SHAP on final model (test set) ---
    # For tree models in SHAP v0.41+, you can use: shap.Explainer(final_model)
    explainer = shap.TreeExplainer(final_model)
    shap_values = explainer.shap_values(X_teF)  # array (n_test, n_features)

    shap_mean_abs = np.abs(shap_values).mean(axis=0)
    shap_df = pd.DataFrame({
        "Feature": [os.path.basename(f) for f in selected_features],
        "Mean_ABS_SHAP": shap_mean_abs
    }).sort_values(by="Mean_ABS_SHAP", ascending=False)

    shap_df.to_csv("final_shap_feature_importances.csv", index=False, encoding="utf-8-sig")

    plt.figure(figsize=(10, 6))
    plt.barh(shap_df["Feature"], shap_df["Mean_ABS_SHAP"])
    plt.xlabel("Mean |SHAP| (importance)")
    plt.ylabel("Feature")
    plt.title("SHAP Feature Importance (Final Model)")
    plt.gca().invert_yaxis()
    plt.tight_layout()
    plt.savefig("final_shap_feature_importances.png", dpi=200)
    plt.show()

if __name__ == "__main__":
    main()
