In [None]:
#Fig.6#
import os  
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from matplotlib.colors import Normalize


import cartopy.crs as ccrs
import geopandas as gpd
from shapely.geometry import box

# ===================== Font: Times New Roman (Windows Fonts) =====================
import matplotlib as mpl
from matplotlib import font_manager as fm

FONT_DIR = Path("/mnt/c/Windows/Fonts")  

def setup_times_new_roman(font_dir: Path):
    """
    Register Times New Roman from Windows font directory and force matplotlib to use it everywhere.
    Tries common Windows TNR filenames:
      times.ttf, timesbd.ttf, timesi.ttf, timesbi.ttf
    """
    if not font_dir.exists():
        raise FileNotFoundError(f"FONT_DIR not found: {font_dir}")

    candidates = [
        font_dir / "times.ttf",    # Regular
        font_dir / "timesbd.ttf",  # Bold
        font_dir / "timesi.ttf",   # Italic
        font_dir / "timesbi.ttf",  # Bold Italic
    ]
    existing = [p for p in candidates if p.exists()]

    if not existing:
        fallback = []
        for ext in ("*.ttf", "*.otf", "*.ttc"):
            fallback.extend(list(font_dir.glob(ext)))
        fallback = [p for p in fallback if "times" in p.name.lower()]
        existing = fallback

    if not existing:
        raise RuntimeError(f"No Times New Roman font files found under: {font_dir}")

    for fp in existing:
        try:
            fm.fontManager.addfont(str(fp))
        except Exception as e:
            print(f"[WARN] failed to add font {fp}: {e}")

    regular = candidates[0] if candidates[0].exists() else existing[0]
    tnr_prop = fm.FontProperties(fname=str(regular))
    tnr_name = tnr_prop.get_name()  # usually "Times New Roman"

    mpl.rcParams["font.family"] = tnr_name
    mpl.rcParams["font.sans-serif"] = [tnr_name]
    mpl.rcParams["axes.unicode_minus"] = False

    mpl.rcParams["mathtext.fontset"] = "custom"
    mpl.rcParams["mathtext.rm"] = tnr_name
    mpl.rcParams["mathtext.it"] = f"{tnr_name}:italic"
    mpl.rcParams["mathtext.bf"] = f"{tnr_name}:bold"

    print(f"[INFO] Using font: {tnr_name}")
    print(f"[INFO] Registered font files: {[p.name for p in existing]}")
    return tnr_prop, tnr_name

TNR_PROP, TNR_NAME = setup_times_new_roman(FONT_DIR)


# ===================== 1) Config =====================


MODEL0_CSV = "/mnt/d/desktop/paper_data/01/model_data/case0/val_metrics_lead0.csv"
MODEL2_CSV = "/mnt/d/desktop/paper_data/01/model_data/case2/val_metrics_lead0.csv"
MODEL3_CSV = "/mnt/d/desktop/paper_data/01/model_data/case3/val_metrics_lead0.csv"

MODEL0_NAME = "Attr-LSTM"
MODEL2_NAME = "MID-CNN-LSTM"
MODEL3_NAME = "HIGH-CNN-LSTM"

# (A, B) => skill = (B - A) / (1 - A)
COMPARE_PAIRS = [
    (MODEL0_NAME, MODEL0_CSV, MODEL2_NAME, MODEL2_CSV),  # Case2 vs Case0
    (MODEL0_NAME, MODEL0_CSV, MODEL3_NAME, MODEL3_CSV),  # Case3 vs Case0
    (MODEL2_NAME, MODEL2_CSV, MODEL3_NAME, MODEL3_CSV),  # Case3 vs Case2
]

METRICS_TO_PLOT = ["NSE", "KGE"]

# topo
CAMELS_TOPO = "/mnt/d/desktop/paper_data/01/static_attri/camels_topo.txt"

REGION_SHPS_DIR = "/mnt/d/desktop/paper_data/01/shp_outputs"

SUBREGION_NAMES = [
    "NorthAtlantic", "MidAtlantic", "Appalachia", "Southeast", "DeepSouth",
    "Prairie", "GreatLakes", "NPlains", "NRockies", "SRockies",
    "CPlains", "SPlains", "Southwest", "GreatBasin", "PacificSW", "PacificNW",
    "Mezquital"
]

CBAR_RANGES = {
    "NSE": (-0.4, 0.4),
    "KGE": (-0.4, 0.4),
}
CBAR_CMAP = "RdBu_r"

USA_SHP = "/mnt/d/desktop/paper_data/01/gadm41_USA_shp/gadm41_USA_0.shp"

US_EXTENT = [-125, -66, 24, 50]

BASIN_ZFILL = 8

OUT_DIR = "/mnt/d/desktop/paper_data/01/paper_output"
os.makedirs(OUT_DIR, exist_ok=True)

POINT_SIZE = 35

plt.rcParams.update({
    "font.size": 16,
    "figure.dpi": 300,
    "savefig.dpi": 300,
})


# ===================== 2) topo =====================
def read_camels_topo(path: str) -> pd.DataFrame:
    topo = pd.read_csv(path, sep=r"\s+|,|;", engine="python")
    topo["basin_id_str"] = topo["gauge_id"].astype(str).str.zfill(BASIN_ZFILL)
    return topo[["basin_id_str", "gauge_lat", "gauge_lon"]].rename(
        columns={"gauge_lat": "lat", "gauge_lon": "lon"}
    )


# ===================== 3) skill =====================
def build_skill_df(a_csv: str, b_csv: str, metrics: list) -> pd.DataFrame:
    da = pd.read_csv(a_csv)
    db = pd.read_csv(b_csv)

    da["basin_id_str"] = da["basin_id"].astype(str).str.zfill(BASIN_ZFILL)
    db["basin_id_str"] = db["basin_id"].astype(str).str.zfill(BASIN_ZFILL)

    merged = da.merge(db, on="basin_id_str", suffixes=("_A", "_B"), how="inner")
    print(f"[INFO] common basins between A and B: {len(merged)}")

    for m in metrics:
        if f"{m}_A" not in merged.columns or f"{m}_B" not in merged.columns:
            raise ValueError(f"Missing metric {m} in merged dataframe columns.")
        num = merged[f"{m}_B"] - merged[f"{m}_A"]
        den = 1 - merged[f"{m}_A"]
        merged[f"skill_{m}"] = num / den.replace(0, np.nan)

    return merged


# ===================== 4) US geometry =====================
def load_conus_geometry(usa_shp: str, extent):
    usa = gpd.read_file(usa_shp).to_crs("EPSG:4326")
    usa_geom = usa.geometry.unary_union
    bbox = box(extent[0], extent[2], extent[1], extent[3])
    return usa_geom.intersection(bbox)


# ===================== 5) subregion =====================
def load_subregions(shp_dir: str, names: list, conus_geom):
    out = []
    for name in names:
        shp = Path(shp_dir) / f"{name}.shp"
        if not shp.exists():
            print(f"[WARN] missing subregion shp: {shp}")
            continue
        gdf = gpd.read_file(shp).to_crs("EPSG:4326")
        geom = gdf.geometry.unary_union.intersection(conus_geom)
        if geom.is_empty:
            print(f"[INFO] {name}: intersection with CONUS is empty -> skip")
            continue
        out.append({"name": name, "geom": geom})
    print(f"[INFO] loaded subregions: {len(out)}")
    return out


# ===================== 6) plot =====================
def plot_maps(df_skill: pd.DataFrame,
              topo: pd.DataFrame,
              metrics: list,
              conus_geom,
              sub_geoms: list,
              nameA: str,
              nameB: str,
              out_dir: str):
    dfm = df_skill.merge(topo, on="basin_id_str", how="inner")
    print(f"[INFO] matched with topo: {len(dfm)}")

    sign_txt = f"{nameB} w.r.t {nameA}"

    for metric in metrics:
        col = f"skill_{metric}"
        sub = dfm[["lon", "lat", col]].dropna()
        if sub.empty:
            print(f"[WARN] {metric}: no valid values, skip.")
            continue

        fig = plt.figure(figsize=(7, 5.4), dpi=300)
        ax = fig.add_axes([0.01, 0.03, 0.98, 0.90], projection=ccrs.PlateCarree())

        minx, miny, maxx, maxy = conus_geom.bounds
        pad = 0.3
        ax.set_extent([minx - pad, maxx + pad, miny - pad, maxy + pad], crs=ccrs.PlateCarree())

        ax.add_geometries(
            [conus_geom],
            crs=ccrs.PlateCarree(),
            facecolor="white",
            edgecolor="black",
            linewidth=1.0,
            zorder=1,
        )


        for sg in sub_geoms:
            ax.add_geometries(
                [sg["geom"]],
                crs=ccrs.PlateCarree(),
                facecolor="none",
                edgecolor="black",
                linewidth=1.3,
                zorder=2,
            )

        if metric not in CBAR_RANGES:
            raise KeyError(f"CBAR_RANGES missing key: {metric}")
        vmin, vmax = CBAR_RANGES[metric]
        norm = Normalize(vmin=vmin, vmax=vmax)

        ax.scatter(
            sub["lon"], sub["lat"],
            c=sub[col],
            cmap=CBAR_CMAP,
            norm=norm,
            s=POINT_SIZE,
            edgecolor="k",
            linewidth=0.2,
            zorder=3,
        )

        title_obj = ax.set_title(f"{metric} Skill Score: {sign_txt}")
        title_obj.set_fontproperties(TNR_PROP)

        for lab in ax.get_xticklabels() + ax.get_yticklabels():
            lab.set_fontproperties(TNR_PROP)

        plt.tight_layout()

        out = os.path.join(out_dir, f"skill_{metric}_{nameB}_minus_{nameA}.png")
        plt.savefig(out, dpi=300, bbox_inches="tight")
        print("Saved:", out)
        plt.show()


def save_colorbar(metric: str, out_dir: str):

    from matplotlib.cm import ScalarMappable

    if metric not in CBAR_RANGES:
        raise KeyError(f"CBAR_RANGES missing key: {metric}")

    vmin, vmax = CBAR_RANGES[metric]
    norm = Normalize(vmin=vmin, vmax=vmax)
    sm = ScalarMappable(norm=norm, cmap=CBAR_CMAP)
    sm.set_array([])

    fig, ax = plt.subplots(figsize=(0.4, 10))
    cbar = plt.colorbar(sm, cax=ax, extend="both")
    cbar.ax.tick_params(labelsize=24)


    for lab in cbar.ax.get_yticklabels():
        lab.set_fontproperties(TNR_PROP)

    out = os.path.join(out_dir, f"colorbar_{metric}.png")
    plt.savefig(out, dpi=300, bbox_inches="tight")
    plt.close(fig)
    print(f"Saved colorbar: {out}")


# ===================== 7) main =====================
def main():
    topo = read_camels_topo(CAMELS_TOPO)
    conus_geom = load_conus_geometry(USA_SHP, US_EXTENT)
    sub_geoms = load_subregions(REGION_SHPS_DIR, SUBREGION_NAMES, conus_geom)


    for nameA, csvA, nameB, csvB in COMPARE_PAIRS:
        df_skill = build_skill_df(csvA, csvB, METRICS_TO_PLOT)
        plot_maps(df_skill, topo, METRICS_TO_PLOT, conus_geom, sub_geoms, nameA, nameB, OUT_DIR)


    for metric in METRICS_TO_PLOT:
        save_colorbar(metric, OUT_DIR)

    print("[DONE]")


if __name__ == "__main__":
    main()
