In [None]:
# -*- coding: utf-8 -*-
"""
Random Forest – aridity-zone-wise relative importance

Pipeline:
1) Load the dependent raster (masked to a valid range).
2) Load the aridity classification raster and align it to the dependent raster.
3) Load all explanatory rasters, reproject/match to the dependent raster, mask to valid pixels,
   and clip each to [1st, 99th] percentile to reduce outlier effects.
4) Train a RandomForestRegressor to compute normalized feature importances:
   - overall (all valid pixels)
   - per aridity zone (Arid, Semi-Arid, Sub-humid, Humid)
5) Save a CSV with overall and zone-wise importances.
"""

import os
import numpy as np
import pandas as pd
import rioxarray as rxr
from sklearn.ensemble import RandomForestRegressor
from joblib import Parallel, delayed

# --------------------------------
# 0) File inputs (placeholders)
# --------------------------------
dep_file = r"/path/to/dependent_variable.tif"              # e.g., SPEI-triggered SOS delay threshold map
classification_file = r"/path/to/aridity_zone_map.tif"     # e.g., aridity classification

# Explanatory variables (categories used in this study; replace with your files):
#   Climate: AI (aridity index), Srad (shortwave radiation), VPD, Wind, Annual PPT,
#            drought/temperature indices (e.g., SPEI, STImin)
#   Vegetation structure/condition: mean NDVI, NDVI/LAI composite metric (e.g., NDVI_EOS_POS_Difference_LAI)
#   Resilience/plant hydraulics: VOD-based resilience proxy, rplant (root/plant hydraulic proxy)
#   Soil & nutrients: SOC_N
#   Biodiversity: Vegetation species richness
explanatory_files = [
    r"/path/to/AI.tif",
    r"/path/to/VOD_resilience.tif",
    r"/path/to/Srad.tif",
    r"/path/to/VPD.tif",
    r"/path/to/Wind.tif",
    r"/path/to/Annual_PPT.tif",
    r"/path/to/Vegetation_species.tif",
    r"/path/to/NDVI_EOS_POS_Difference_LAI.tif",
    r"/path/to/rplant_proxy.tif",
    r"/path/to/SPEI_STImin.tif",
    r"/path/to/mean_NDVI.tif",
    r"/path/to/SOC_N.tif",
]

# Aridity-zone code -> label (example)
aridity_map = {
    2: "Arid",
    3: "Semi-Arid",
    4: "Sub-humid",
    5: "Humid",
}

# --------------------------------
# 1) Load dependent variable
# --------------------------------
dep = rxr.open_rasterio(dep_file).squeeze()

# Keep valid range only (example: (-10, 0]; adjust to your study)
dep = dep.where((dep > -10) & (dep <= 0))

# ------------------------------------------------
# 2) Load & align aridity classification layer
# ------------------------------------------------
aridity = rxr.open_rasterio(classification_file).squeeze()
aridity = aridity.rio.reproject_match(dep)

# ----------------------------------------------------------
# 3) Load, align, mask, and robustly clip explanatory layers
# ----------------------------------------------------------
def load_align_clip(fp, ref_da):
    x = rxr.open_rasterio(fp).squeeze()
    x = x.rio.reproject_match(ref_da)
    # Mask to valid dependent pixels
    x = x.where(~ref_da.isnull())
    # Robust percentile clip on finite values only
    vals = x.values
    finite = np.isfinite(vals)
    if finite.any():
        lo = np.nanpercentile(vals[finite], 1)
        hi = np.nanpercentile(vals[finite], 99)
        if np.isfinite(lo) and np.isfinite(hi) and lo < hi:
            x = x.clip(min=lo, max=hi)
    return x

exps = Parallel(n_jobs=-1)(
    delayed(load_align_clip)(fp, dep) for fp in explanatory_files
)
feature_names = [os.path.splitext(os.path.basename(fp))[0] for fp in explanatory_files]

# ------------------------------------------------
# 4) Utility: normalized RF feature importances
# ------------------------------------------------
def rf_relative_importance(dep_da, exp_list, rnd=42):
    dep_vals = dep_da.values.flatten()
    X = np.stack([e.values.flatten() for e in exp_list], axis=1)

    mask = np.isfinite(dep_vals) & np.all(np.isfinite(X), axis=1)
    Xm, ym = X[mask], dep_vals[mask]

    # Avoid unstable fits on very small samples
    if ym.size < 100:
        return np.full(len(exp_list), np.nan, dtype=float)

    rf = RandomForestRegressor(
        n_estimators=500,
        n_jobs=-1,
        random_state=rnd,
    )
    rf.fit(Xm, ym)
    imp = np.asarray(rf.feature_importances_, dtype=float)
    s = imp.sum()
    return imp / s if s > 0 else np.full_like(imp, np.nan, dtype=float)

# -------------------------------
# 5) Overall relative importance
# -------------------------------
overall = rf_relative_importance(dep, exps)

# -----------------------------------------
# 6) Zone-wise relative importance (aridity)
# -----------------------------------------
result = {"Feature": feature_names, "Overall": overall}
for code, name in aridity_map.items():
    zmask = aridity == code
    dep_z = dep.where(zmask)
    exps_z = [e.where(zmask) for e in exps]
    result[name] = rf_relative_importance(dep_z, exps_z)

# ---------------
# 7) Save to CSV
# ---------------
df = pd.DataFrame(result)

# (Optional) sort features by Overall importance for readability
try:
    df = df.sort_values(by="Overall", ascending=False)
except Exception:
    pass

out_csv = r"./RF_Importance_AridityZones.csv"
df.to_csv(out_csv, index=False, encoding="utf-8-sig")
print(f"✅ Saved overall and aridity-zone importances to: {out_csv}")
