In [None]:
#Fig.7#
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

import geopandas as gpd
from shapely.geometry import box, Point

import cartopy.crs as ccrs
from scipy.stats import wilcoxon

# ===================== Font: Times New Roman (Windows Fonts) =====================
import matplotlib as mpl
from matplotlib import font_manager as fm

FONT_DIR = Path("/mnt/c/Windows/Fonts")  

def setup_times_new_roman(font_dir: Path):
    if not font_dir.exists():
        raise FileNotFoundError(f"FONT_DIR not found: {font_dir}")

    candidates = [
        font_dir / "times.ttf",
        font_dir / "timesbd.ttf",
        font_dir / "timesi.ttf",
        font_dir / "timesbi.ttf",
    ]
    existing = [p for p in candidates if p.exists()]

    if not existing:
        fallback = []
        for ext in ("*.ttf", "*.otf", "*.ttc"):
            fallback.extend(list(font_dir.glob(ext)))
        fallback = [p for p in fallback if "times" in p.name.lower()]
        existing = fallback

    if not existing:
        raise RuntimeError(f"No Times New Roman font files found under: {font_dir}")

    for fp in existing:
        try:
            fm.fontManager.addfont(str(fp))
        except Exception as e:
            print(f"[WARN] failed to add font {fp}: {e}")

    regular = candidates[0] if candidates[0].exists() else existing[0]
    tnr_prop = fm.FontProperties(fname=str(regular))
    tnr_name = tnr_prop.get_name()

    mpl.rcParams["font.family"] = tnr_name
    mpl.rcParams["font.sans-serif"] = [tnr_name]
    mpl.rcParams["axes.unicode_minus"] = False

    mpl.rcParams["mathtext.fontset"] = "custom"
    mpl.rcParams["mathtext.rm"] = tnr_name
    mpl.rcParams["mathtext.it"] = f"{tnr_name}:italic"
    mpl.rcParams["mathtext.bf"] = f"{tnr_name}:bold"

    print(f"[INFO] Using font: {tnr_name}")
    print(f"[INFO] Registered font files: {[p.name for p in existing]}")
    return tnr_prop, tnr_name

TNR_PROP, TNR_NAME = setup_times_new_roman(FONT_DIR)


ARROW_PROP = fm.FontProperties(family="DejaVu Sans")

# ===================== 1) Config =====================
MODEL0_CSV = "/mnt/d/desktop/paper_data/01/model_data/case0/val_metrics_lead0.csv"
MODEL2_CSV = "/mnt/d/desktop/paper_data/01/model_data/case2/val_metrics_lead0.csv"
MODEL3_CSV = "/mnt/d/desktop/paper_data/01/model_data/case3/val_metrics_lead0.csv"

MODEL0_NAME = "Att-LSTM"
MODEL2_NAME = "MID-CNN-LSTM"
MODEL3_NAME = "HIGH-CNN-LSTM"

COMPARE_PAIRS = [
    (MODEL0_NAME, MODEL0_CSV, MODEL2_NAME, MODEL2_CSV),
    (MODEL0_NAME, MODEL0_CSV, MODEL3_NAME, MODEL3_CSV),
    (MODEL2_NAME, MODEL2_CSV, MODEL3_NAME, MODEL3_CSV),
]

METRICS_FOR_TEST = ["NSE", "KGE"]

CAMELS_TOPO = "/mnt/d/desktop/paper_data/01/static_attri/camels_topo.txt"
REGION_SHPS_DIR = "/mnt/d/desktop/paper_data/01/shp_outputs"

SUBREGION_NAMES = [
    "NorthAtlantic", "MidAtlantic", "Appalachia", "Southeast", "DeepSouth",
    "Prairie", "GreatLakes", "NPlains", "NRockies", "SRockies",
    "CPlains", "SPlains", "Southwest", "GreatBasin", "PacificSW", "PacificNW",
    "Mezquital"
]

USA_SHP = "/mnt/d/desktop/paper_data/01/gadm41_USA_shp/gadm41_USA_0.shp"
US_EXTENT = [-125, -66, 24, 50]   # lon_min, lon_max, lat_min, lat_max

BASIN_ZFILL = 8

OUT_DIR = "/mnt/d/desktop/paper_data/01/paper_output/region_wk_maps2"
os.makedirs(OUT_DIR, exist_ok=True)

ALPHA = 0.1

COLOR_RED = "#f26d6d"
COLOR_BLUE = "#7fb6d6"
COLOR_YELLOW = "#fff7cc"
COLOR_GRAY = "#9e9e9e"

# ===================== Front =====================
MAP_LABEL_FONTSIZE = 10      
MAP_ARROW_FONTSIZE = 10      
LEGEND_FONTSIZE = 10          
LEGEND_ARROW_FONTSIZE = 10  

plt.rcParams.update({
    "figure.dpi": 300,
    "savefig.dpi": 300,
})

LABEL_OFFSET = {
    # "RegionName": (dx, dy),
    "PacificNW":   (0,  0),
    "PacificSW":   (0,  0),
    "NorthAtlantic": (0,  0),
    "MidAtlantic":   (0,  -0),
    "GreatLakes":  (0,   0),
    "Prairie":     (0,  -0),
    "NPlains":     (0,   0),
    "NRockies":    (0,   0),
    "SRockies":    (0,  -0),
    "CPlains":     (0,   0),
    "SPlains":     (0,  -0),
    "Southeast":   (0, 10),
    "DeepSouth":   (0,   0),
    "Appalachia":  (0,  0),
    "GreatBasin":  (0,   0),
    "Southwest":   (0,  0),
    # "Mezquital": (0,0)  
}

# ===================== 2) topo =====================
def read_camels_topo(path: str) -> pd.DataFrame:
    topo = pd.read_csv(path, sep=r"\s+|,|;", engine="python")
    topo["basin_id_str"] = topo["gauge_id"].astype(str).str.zfill(BASIN_ZFILL)
    return topo[["basin_id_str", "gauge_lat", "gauge_lon"]].rename(
        columns={"gauge_lat": "lat", "gauge_lon": "lon"}
    )

# ===================== 3)  =====================
def load_two_models(a_csv: str, b_csv: str) -> pd.DataFrame:
    da = pd.read_csv(a_csv)
    db = pd.read_csv(b_csv)

    da["basin_id_str"] = da["basin_id"].astype(str).str.zfill(BASIN_ZFILL)
    db["basin_id_str"] = db["basin_id"].astype(str).str.zfill(BASIN_ZFILL)

    return da.merge(db, on="basin_id_str", suffixes=("_A", "_B"), how="inner")

# ===================== 4) CONUS geometry =====================
def load_conus_geometry(usa_shp: str, extent):
    usa = gpd.read_file(usa_shp)
    usa = usa.set_crs("EPSG:4326") if usa.crs is None else usa.to_crs("EPSG:4326")
    usa_geom = usa.geometry.unary_union
    bbox = box(extent[0], extent[2], extent[1], extent[3])
    return usa_geom.intersection(bbox)

# ===================== 5) Subregion =====================
def load_subregions(shp_dir: str, names: list, conus_geom):
    shp_dir = Path(shp_dir)
    rows = []
    for name in names:
        shp = shp_dir / f"{name}.shp"
        if not shp.exists():
            print(f"[WARN] Missing subregion shp: {shp}")
            continue

        gdf = gpd.read_file(shp)
        gdf = gdf.set_crs("EPSG:4326") if gdf.crs is None else gdf.to_crs("EPSG:4326")

        geom = gdf.geometry.unary_union.intersection(conus_geom)
        if geom.is_empty:
            print(f"[INFO] {name}: intersection with CONUS is empty, skip.")
            continue

        rows.append({"region": name, "geometry": geom})

    return gpd.GeoDataFrame(rows, crs="EPSG:4326")

# ===================== 6)  =====================
def assign_basins_to_regions(df_with_lonlat: pd.DataFrame, sub_gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
    gdf_pts = gpd.GeoDataFrame(
        df_with_lonlat.copy(),
        geometry=[Point(xy) for xy in zip(df_with_lonlat["lon"], df_with_lonlat["lat"])],
        crs="EPSG:4326"
    )
    return gpd.sjoin(gdf_pts, sub_gdf[["region", "geometry"]], how="left", predicate="within")

# ===================== 7) range statistics + Wilcoxon =====================
def region_stats_wilcoxon(joined: gpd.GeoDataFrame, metric: str, alpha: float) -> pd.DataFrame:
    colA = f"{metric}_A"
    colB = f"{metric}_B"
    if colA not in joined.columns or colB not in joined.columns:
        raise ValueError(f"Missing metric columns: {colA} / {colB}. Check your CSV columns.")

    out_rows = []
    for region, g in joined.groupby("region", dropna=True):
        xA = g[colA].astype(float)
        xB = g[colB].astype(float)
        mask = np.isfinite(xA) & np.isfinite(xB)
        xA = xA[mask]
        xB = xB[mask]

        n = len(xA)
        if n == 0:
            out_rows.append({
                "region": region, "n": 0,
                "prob_B_gt_A": np.nan, "prob_B_lt_A": np.nan,
                "median_delta": np.nan, "pvalue": np.nan,
                "sig": False, "direction": "NA"
            })
            continue

        delta = (xB.values - xA.values)
        prob_gt = np.mean(delta > 0)
        prob_lt = np.mean(delta < 0)
        med = np.nanmedian(delta)

        pval = np.nan
        sig = False
        direction = "NA"

        nz = delta[delta != 0]
        if len(nz) < 3:
            pval = np.nan
            sig = False
            direction = "NA"
        else:
            try:
                _, pval = wilcoxon(delta, zero_method="wilcox", alternative="two-sided", mode="auto")
                sig = (pval < alpha) and np.isfinite(med) and (med != 0)
                if sig:
                    direction = "B_gt_A" if med > 0 else "B_lt_A"
                else:
                    direction = "NS"
            except Exception:
                pval = np.nan
                sig = False
                direction = "NA"

        out_rows.append({
            "region": region,
            "n": int(n),
            "prob_B_gt_A": float(prob_gt),
            "prob_B_lt_A": float(prob_lt),
            "median_delta": float(med) if np.isfinite(med) else np.nan,
            "pvalue": float(pval) if np.isfinite(pval) else np.nan,
            "sig": bool(sig),
            "direction": direction
        })

    return pd.DataFrame(out_rows).sort_values("region").reset_index(drop=True)

# ===================== 8) Plot =====================
from matplotlib.patches import Patch
from matplotlib.text import Text
from matplotlib.legend_handler import HandlerBase

class HandlerText(HandlerBase):
    def create_artists(self, legend, orig_handle, xdescent, ydescent, width, height,
                       fontsize, trans):
        t = Text(
            x=xdescent + width / 2.0,
            y=ydescent + height / 2.0,
            text=orig_handle.get_text(),
            color=orig_handle.get_color(),
            fontweight=orig_handle.get_fontweight(),
            fontsize=orig_handle.get_fontsize() if orig_handle.get_fontsize() else fontsize,
            ha="center", va="center",
        )
        if orig_handle.get_fontproperties() is not None:
            t.set_fontproperties(orig_handle.get_fontproperties())
        t.set_transform(trans)
        return [t]

def plot_region_map(sub_gdf, conus_geom, stats_df, metric, nameA, nameB, alpha, out_dir):
    # --- Lambert Conformal for CONUS (recommended)
    proj = ccrs.LambertConformal(
        central_longitude=-96,
        central_latitude=39,
        standard_parallels=(33, 45)
    )

    m = sub_gdf.merge(stats_df, on="region", how="left")

    def pick_color(row):
        if pd.isna(row.get("n")) or int(row.get("n", 0)) == 0 or row.get("direction") == "NA":
            return COLOR_GRAY
        if row.get("direction") == "NS":
            return COLOR_YELLOW
        if row.get("direction") == "B_gt_A":
            return COLOR_RED
        if row.get("direction") == "B_lt_A":
            return COLOR_BLUE
        return COLOR_GRAY

    m["facecolor"] = m.apply(pick_color, axis=1)

    fig = plt.figure(figsize=(7, 5.4), dpi=300)
    ax = fig.add_axes([0.02, 0.04, 0.96, 0.92], projection=proj)

    # extent uses lon/lat CRS
    ax.set_extent(US_EXTENT, crs=ccrs.PlateCarree())

    # background conus boundary
    ax.add_geometries([conus_geom], crs=ccrs.PlateCarree(),
                      facecolor="white", edgecolor="black", linewidth=1.2, zorder=1)

    # subregions fill
    for _, r in m.iterrows():
        ax.add_geometries([r.geometry], crs=ccrs.PlateCarree(),
                          facecolor=r["facecolor"], edgecolor="black",
                          linewidth=1.0, zorder=2)

    # --- region label + arrow ---
    for _, r in m.iterrows():
        if str(r.get("region", "")).strip() == "Mezquital":
            continue

        pt = r.geometry.representative_point()   # still lon/lat geometry
        med = r.get("median_delta", np.nan)

        arrow = None
        arrow_color = None
        if np.isfinite(med):
            if med > 0:
                arrow = "⬆"
                arrow_color = "#F90105"
            elif med < 0:
                arrow = "⬇"
                arrow_color = "#0099FF"

        x, y = pt.x, pt.y
        
        reg = str(r["region"])
        dx, dy = LABEL_OFFSET.get(reg, (0, 0))

        ax.annotate(
            text=reg,
            xy=(x, y),
            xycoords=ccrs.PlateCarree(),
            xytext=(dx, dy),
            textcoords="offset points",
            ha="center", va="top",
            fontsize=MAP_LABEL_FONTSIZE,
            fontweight="bold",
            color="black",
            zorder=5,
            fontproperties=TNR_PROP,
            # bbox=dict(facecolor="white", edgecolor="none", alpha=0.75, pad=0.6)  
        )


        if arrow is not None:
            yoff = -10  
            dx = 6      

            
            ax.annotate(
                text="(",
                xy=(x, y),
                xycoords=ccrs.PlateCarree(),
                xytext=(-dx, yoff),
                textcoords="offset points",
                ha="center", va="top",
                fontsize=MAP_ARROW_FONTSIZE,
                fontweight="bold",
                color=arrow_color,
                zorder=4,
                fontproperties=TNR_PROP,
            )


            ax.annotate(
                text=arrow,
                xy=(x, y),
                xycoords=ccrs.PlateCarree(),
                xytext=(0, yoff),
                textcoords="offset points",
                ha="center", va="top",
                fontsize=MAP_ARROW_FONTSIZE,
                fontweight="bold",
                color=arrow_color,
                zorder=4,
                fontproperties=ARROW_PROP,
            )

            ax.annotate(
                text=")",
                xy=(x, y),
                xycoords=ccrs.PlateCarree(),
                xytext=(dx, yoff),
                textcoords="offset points",
                ha="center", va="top",
                fontsize=MAP_ARROW_FONTSIZE,
                fontweight="bold",
                color=arrow_color,
                zorder=4,
                fontproperties=TNR_PROP,
            )


    # ---------------- legend (inside map) ----------------
    up_text = Text(
        0, 0, "⬆",
        color="#F90105",
        fontweight="bold",
        fontproperties=ARROW_PROP,
        fontsize=LEGEND_ARROW_FONTSIZE,
    )
    down_text = Text(
        0, 0, "⬇",
        color="#0099FF",
        fontweight="bold",
        fontproperties=ARROW_PROP,
        fontsize=LEGEND_ARROW_FONTSIZE,
    )

    legend_handles = [
        up_text,
        down_text,
        Patch(facecolor=COLOR_GRAY, edgecolor="black"),
        Patch(facecolor=COLOR_YELLOW, edgecolor="black"),
        Patch(facecolor=COLOR_BLUE, edgecolor="black"),
        Patch(facecolor=COLOR_RED, edgecolor="black"), 
    ]
    legend_labels = [
        f"Median(Δ{metric}) > 0",
        f"Median(Δ{metric}) < 0",
        "N/A",
        f"Not significant (p≥{alpha})",
        f"{nameB} < {nameA} (p<{alpha})",
        f"{nameB} > {nameA} (p<{alpha})",
    ]

    leg = ax.legend(
        handles=legend_handles,
        labels=legend_labels,
        handler_map={Text: HandlerText()},
        loc="lower left",
        bbox_to_anchor=(-0.05, -0.05),
        frameon=False,
        prop=fm.FontProperties(fname=TNR_PROP.get_file(), size=LEGEND_FONTSIZE),
        borderaxespad=0.8,
    )


    if leg is not None:
        for t in leg.get_texts():
            t.set_fontproperties(fm.FontProperties(fname=TNR_PROP.get_file(), size=LEGEND_FONTSIZE))

    out_png = os.path.join(out_dir, f"region_wk_{metric}_{nameB}_vs_{nameA}.png")
    ax.set_axis_off()
    plt.savefig(out_png, dpi=300, bbox_inches="tight")
    plt.close(fig)
    print(f"Saved map: {out_png}")

# ===================== 9) main =====================
def main():
    topo = read_camels_topo(CAMELS_TOPO)
    conus_geom = load_conus_geometry(USA_SHP, US_EXTENT)
    sub_gdf = load_subregions(REGION_SHPS_DIR, SUBREGION_NAMES, conus_geom)

    for nameA, csvA, nameB, csvB in COMPARE_PAIRS:
        merged = load_two_models(csvA, csvB)
        df = merged.merge(topo, on="basin_id_str", how="inner")
        print(f"[INFO] {nameB} vs {nameA}: matched basins with topo: {len(df)}")

        joined = assign_basins_to_regions(df, sub_gdf)

        for metric in METRICS_FOR_TEST:
            stats_df = region_stats_wilcoxon(joined, metric, ALPHA)

            out_csv = os.path.join(OUT_DIR, f"region_stats_wk_{metric}_{nameB}_vs_{nameA}.csv")
            stats_df.to_csv(out_csv, index=False)
            print(f"Saved stats CSV: {out_csv}")

            plot_region_map(sub_gdf, conus_geom, stats_df, metric, nameA, nameB, ALPHA, OUT_DIR)

    print("[DONE]")

if __name__ == "__main__":
    main()
