In [None]:

## Fig.5
# -*- coding: utf-8 -*-
"""
Draw multi-model box plots by sub-region (one sub-figure for each indicator, sharing x-axis labels)
- Sub-region: multiple csv files within one folder (suggest 17 files), each csv contains the "basin_id" column (or can be automatically recognized)
- Model: multiple "val_metrics_lead0.csv" files (each containing the "basin_id" + indicator columns)
"""

import os
import glob
import numpy as np
import pandas as pd

from pathlib import Path
import matplotlib as mpl
from matplotlib import font_manager as fm
import matplotlib.transforms as mtransforms



FONT_DIR = Path("/mnt/c/Windows/Fonts")

def setup_times_new_roman(font_dir: Path):
    """
    Register Times New Roman from Windows font directory and force matplotlib to use it everywhere.
    Tries common Windows TNR filenames:
      times.ttf, timesbd.ttf, timesi.ttf, timesbi.ttf
    """
    if not font_dir.exists():
        raise FileNotFoundError(f"FONT_DIR not found: {font_dir}")

    candidates = [
        font_dir / "times.ttf",    # Regular
        font_dir / "timesbd.ttf",  # Bold
        font_dir / "timesi.ttf",   # Italic
        font_dir / "timesbi.ttf",  # Bold Italic
    ]
    existing = [p for p in candidates if p.exists()]

    if not existing:
        # fallback: search any font file containing 'times'
        fallback = []
        for ext in ("*.ttf", "*.otf", "*.ttc"):
            fallback.extend(list(font_dir.glob(ext)))
        fallback = [p for p in fallback if "times" in p.name.lower()]
        existing = fallback

    if not existing:
        raise RuntimeError(f"No Times New Roman font files found under: {font_dir}")

    # register
    for fp in existing:
        try:
            fm.fontManager.addfont(str(fp))
        except Exception as e:
            print(f"[WARN] failed to add font {fp}: {e}")

    regular = candidates[0] if candidates[0].exists() else existing[0]
    tnr_prop = fm.FontProperties(fname=str(regular))
    tnr_name = tnr_prop.get_name()  # usually "Times New Roman"

    # global rcParams
    mpl.rcParams["font.family"] = tnr_name
    mpl.rcParams["font.sans-serif"] = [tnr_name]  # avoid fallback
    mpl.rcParams["axes.unicode_minus"] = False

    # math text (optional but stabilizes)
    mpl.rcParams["mathtext.fontset"] = "custom"
    mpl.rcParams["mathtext.rm"] = tnr_name
    mpl.rcParams["mathtext.it"] = f"{tnr_name}:italic"
    mpl.rcParams["mathtext.bf"] = f"{tnr_name}:bold"

    print(f"[INFO] Using font: {tnr_name}")
    print(f"[INFO] Registered font files: {[p.name for p in existing]}")
    return tnr_prop, tnr_name

TNR_PROP, TNR_NAME = setup_times_new_roman(FONT_DIR)


import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator


# ===================== 1. Config =====================

REGION_DIR = "/mnt/d/desktop/paper_data/01/bukovsky_subregions_csv_17"

MODEL_PATHS = [
    "/mnt/d/desktop/paper_data/01/model_data/case0/val_metrics_lead0.csv",
    "/mnt/d/desktop/paper_data/01/model_data/case1/val_metrics_lead0.csv",
    "/mnt/d/desktop/paper_data/01/model_data/case2/val_metrics_lead0.csv",
    "/mnt/d/desktop/paper_data/01/model_data/case3/val_metrics_lead0.csv",
]

MODEL_NAMES = ["Attr-LSTM", "RasterMean-LSTM", "MID-CNN-LSTM", "HIGH-CNN-LSTM"]

METRICS_TO_PLOT = ["NSE", "KGE", "FHV_0.10","PBIAS"]

BASIN_ZFILL = 8

OUT_PNG = "/mnt/d/desktop/paper_data/01/paper_output/regional_grouped_boxplots.png"
OUT_PDF = "/mnt/d/desktop/paper_data/01/paper_output/regional_grouped_boxplots.pdf"

FIG_W = 14
ROW_H = 2.8
BOX_WIDTH = 0.12
SHOW_FLIERS = False


TICK_FONTSIZE = 36       
XTICK_FONTSIZE = 36     
YTICK_FONTSIZE = 36      
AXLABEL_FONTSIZE = 36     
TITLE_FONTSIZE = 14       
LEGEND_FONTSIZE = 36     

GROUP_DELIM = "_"   
SEP_LINE_KW = dict(color="black", linestyle="--", linewidth=1.2, alpha=0.8, zorder=10)


REGION_GROUPS = {
    
}


YAXIS_CTRL = {
    "NSE":   {"ylim": (-0.4, 1.0), "ystep": 0.2},
    "KGE":   {"ylim": (-0.4, 1.0), "ystep": 0.2},
    "FHV_0.10": {"ylim": (-100, 75), "ystep": 25},
    "PBIAS": {"ylim": (-125, 75), "ystep": 25},
}

SEP_LINE_KW = dict(color="black", linestyle="--", linewidth=1.0, alpha=0.6)

METRIC_DISPLAY = {
    "FHV_0.10": r"FHV$_{0.10}$",
}

def metric_label(m: str) -> str:
    return METRIC_DISPLAY.get(m, m)

plt.rcParams.update({
    "font.size": 20,        
    "figure.dpi": 300,
    "savefig.dpi": 300,
})


# ===================== 2. Utility function =====================

def _standardize_basin_id(df: pd.DataFrame) -> pd.DataFrame:
    """Convert the "basin_id" to a string and pad it with zeros to the left to generate "basin_id_str"."""
    if "basin_id" not in df.columns:
        candidates = [c for c in df.columns if c.lower() in ["basin_id", "gauge_id", "gage_id", "site_id", "id"]]
        if not candidates:
            raise ValueError(f"The "basin_id" column cannot be found. The existing columns are:{list(df.columns)}")
        df = df.rename(columns={candidates[0]: "basin_id"})

    df["basin_id_str"] = df["basin_id"].astype(str).str.zfill(BASIN_ZFILL)
    return df


def load_regions(region_dir: str):
    paths = sorted(glob.glob(os.path.join(region_dir, "*.csv")))
    if len(paths) == 0:
        raise FileNotFoundError(f"No CSV files were found in the REGION_DIR directory:{region_dir}")

    region_names = []
    region_basins = {}

    for p in paths:
        name = os.path.splitext(os.path.basename(p))[0]
        df = pd.read_csv(p)
        df = _standardize_basin_id(df)

        s = set(df["basin_id_str"].dropna().astype(str).tolist())
        if len(s) == 0:
            print(f"[WARNING] The region {name} does not have a valid basin_id, thus skipping:{p}")
            continue

        region_names.append(name)
        region_basins[name] = s

    if len(region_names) == 0:
        raise ValueError("All the regional CSV files are invalid basin_idã€‚")

    return region_names, region_basins


def load_models(model_paths, model_names):
    models = {}
    for path, name in zip(model_paths, model_names):
        df = pd.read_csv(path)
        df = _standardize_basin_id(df)
        models[name] = df
        print(f"Model {name}: {len(df)} basins loaded.")
    return models


def intersect_common_basins(models: dict, region_basins: dict):
    model_sets = []
    for _, df in models.items():
        model_sets.append(set(df["basin_id_str"].tolist()))
    common = set.intersection(*model_sets)

    union_regions = set().union(*region_basins.values())
    common = common.intersection(union_regions)

    return common


def build_metric_arrays(region_names, region_basins, models, common_basins, metric):
    data_by_model = {}

    for mname, df in models.items():
        if metric not in df.columns:
            raise ValueError(f"The CSV file of model {mname} does not contain the column {metric}. Available columns: {list(df.columns)}")

        sub = df[df["basin_id_str"].isin(common_basins)][["basin_id_str", metric]].copy()
        mp = dict(zip(sub["basin_id_str"], sub[metric]))

        region_arrays = []
        for r in region_names:
            basins = region_basins[r].intersection(common_basins)
            vals = [mp.get(b, np.nan) for b in basins]
            vals = np.array(vals, dtype=float)
            vals = vals[~np.isnan(vals)]
            region_arrays.append(vals)

        data_by_model[mname] = region_arrays

    return data_by_model


def boxplot_five_number(vals: np.ndarray):

    x = np.asarray(vals, dtype=float)
    x = x[~np.isnan(x)]
    if x.size == 0:
        return (np.nan, np.nan, np.nan, np.nan, np.nan)

    q1 = np.percentile(x, 25)
    med = np.percentile(x, 50)
    q3 = np.percentile(x, 75)
    iqr = q3 - q1

    low_fence = q1 - 1.5 * iqr
    high_fence = q3 + 1.5 * iqr

    lw = np.min(x[x >= low_fence]) if np.any(x >= low_fence) else np.min(x)
    uw = np.max(x[x <= high_fence]) if np.any(x <= high_fence) else np.max(x)

    return (lw, q1, med, q3, uw)


def build_median_table(region_names, metric_to_data, model_names):

    out = {"Region": region_names}
    for m in model_names:
        meds = []
        for arr in metric_to_data[m]:
            x = np.asarray(arr, dtype=float)
            x = x[~np.isnan(x)]
            meds.append(np.nan if x.size == 0 else float(np.median(x)))
        out[m] = meds
    return pd.DataFrame(out)


def apply_yaxis_control(ax, metric, yaxis_ctrl):

    if metric not in yaxis_ctrl:
        return

    cfg = yaxis_ctrl[metric]
    if "ylim" in cfg and cfg["ylim"] is not None:
        ax.set_ylim(cfg["ylim"])

    if "ystep" in cfg and cfg["ystep"] is not None:
        ax.yaxis.set_major_locator(MultipleLocator(cfg["ystep"]))


def plot_median_by_region(region_names, med_df, metric, ax, model_names,
                          line_width=2.0, marker="o"):
    x = np.arange(len(region_names))

    for m in model_names:
        y = med_df[m].to_numpy(dtype=float)
        ax.plot(x, y, marker=marker, linewidth=line_width, label=m)

    ax.set_xlim(-0.5, len(region_names) - 0.5)
    ax.set_xticks(x)
    ax.set_xticklabels(region_names, rotation=45, ha="right", fontsize=XTICK_FONTSIZE)
    ax.tick_params(axis="y", labelsize=YTICK_FONTSIZE)
    ax.set_ylabel(f"Median {metric}", fontsize=AXLABEL_FONTSIZE)

    if metric in YAXIS_CTRL:
        apply_yaxis_control(ax, metric, YAXIS_CTRL)

    ax.grid(True, axis="y", alpha=0.2)

    for lab in ax.get_xticklabels() + ax.get_yticklabels():
        lab.set_fontproperties(TNR_PROP)
    ax.yaxis.label.set_fontproperties(TNR_PROP)
    ax.xaxis.label.set_fontproperties(TNR_PROP)


def summarize_boxplot_tables(region_names, metric_to_data, metric, model_names, decimals=3):
    rows = []
    for r_i, rname in enumerate(region_names):
        for mname in model_names:
            vals = metric_to_data[mname][r_i]
            lw, q1, med, q3, uw = boxplot_five_number(vals)
            rows.append({
                "Region": rname,
                "Model": mname,
                "LowerWhisker": lw,
                "Q1": q1,
                "Median": med,
                "Q3": q3,
                "UpperWhisker": uw,
                "N": int(np.sum(~np.isnan(vals))) if isinstance(vals, np.ndarray) else len(vals),
            })

    df_out = pd.DataFrame(rows)
    out_csv = f"/mnt/d/desktop/paper_data/01/paper_output/boxplot_summary_{metric}.csv"
    df_out.to_csv(out_csv, index=False, encoding="utf-8-sig")
    print(f"Saved table: {os.path.abspath(out_csv)}")

    show_cols = ["Region", "Model", "LowerWhisker", "Q1", "Median", "Q3", "UpperWhisker", "N"]
    df_out = df_out[show_cols]

    with pd.option_context(
        "display.max_rows", None,
        "display.max_columns", None,
        "display.width", 220,
        "display.float_format", lambda x: f"{x:.{decimals}f}"
    ):
        print("\n" + "=" * 120)
        print(f"[Boxplot summary] Metric = {metric}")
        print("=" * 120 + "\n")


def compute_group_boundaries_auto(region_names, delim="_"):
    prefixes = []
    for r in region_names:
        prefixes.append(r.split(delim)[0] if delim in r else r)

    boundaries = []
    for i in range(1, len(prefixes)):
        if prefixes[i] != prefixes[i - 1]:
            boundaries.append(i - 0.5)

    return boundaries


def plot_grouped_boxplots(region_names, metric_to_data, metric, ax, model_names,
                          box_width=0.12, show_fliers=True,
                          boundaries_x=None, sep_line_kw=None,
                          yaxis_ctrl=None):
    n_regions = len(region_names)
    n_models = len(model_names)

    base_pos = np.arange(n_regions)
    step = box_width * 1.8
    offsets = (np.arange(n_models) - (n_models - 1) / 2.0) * step

    handles = []
    labels = []

    for i, mname in enumerate(model_names):
        data_list = metric_to_data[mname]
        positions = base_pos + offsets[i]

        bp = ax.boxplot(
            data_list,
            positions=positions,
            widths=box_width,
            patch_artist=True,
            showfliers=show_fliers,
            manage_ticks=False
        )

        color = plt.rcParams["axes.prop_cycle"].by_key()["color"][i % 10]
        for box in bp["boxes"]:
            box.set_facecolor(color)
            box.set_alpha(0.85)
        for k in ["whiskers", "caps", "medians"]:
            for item in bp[k]:
                item.set_color("black")
                item.set_linewidth(1.0)

        handles.append(bp["boxes"][0])
        labels.append(mname)

    ax.set_xlim(-0.8, n_regions - 0.2)
    ax.set_xticks(base_pos)
    ax.set_xticklabels(region_names, rotation=45, ha="right", fontsize=XTICK_FONTSIZE)

    dx_pt = 14  
    offset = mtransforms.ScaledTranslation(dx_pt/72., 0, ax.figure.dpi_scale_trans)


    for lab in ax.get_xticklabels():
        lab.set_transform(lab.get_transform() + offset)

    ax.tick_params(axis="x", labelsize=XTICK_FONTSIZE)
    ax.tick_params(axis="y", labelsize=YTICK_FONTSIZE)

    ax.set_ylabel(metric_label(metric), fontsize=AXLABEL_FONTSIZE)


    if boundaries_x:
        for x in boundaries_x:
            ax.axvline(x=x, **(sep_line_kw or {}))

    if yaxis_ctrl:
        apply_yaxis_control(ax, metric, yaxis_ctrl)


    for lab in ax.get_xticklabels() + ax.get_yticklabels():
        lab.set_fontproperties(TNR_PROP)
    ax.yaxis.label.set_fontproperties(TNR_PROP)
    ax.xaxis.label.set_fontproperties(TNR_PROP)

    return handles, labels


# ===================== 3. Main  =====================

def main():
    region_names, region_basins = load_regions(REGION_DIR)
    print(f"Regions loaded: {len(region_names)} -> {region_names}")

    models = load_models(MODEL_PATHS, MODEL_NAMES)

    common_basins = intersect_common_basins(models, region_basins)
    print(f"Common basins across all models & regions: {len(common_basins)}")
    if len(common_basins) == 0:
        raise ValueError("No common basins found. Please check if basin_id in region CSVs and model CSVs are consistent or if the matching rules are correct.")

    boundaries_x = compute_group_boundaries_auto(region_names, delim=GROUP_DELIM)
    print("region_names =", region_names)
    print("boundaries_x =", boundaries_x)

    n_metrics = len(METRICS_TO_PLOT)
    fig, axes = plt.subplots(
        n_metrics, 1, sharex=True,
        figsize=(FIG_W, ROW_H * n_metrics),
        constrained_layout=True
    )
    if n_metrics == 1:
        axes = [axes]

    legend_handles = None
    legend_labels = None

    for i, metric in enumerate(METRICS_TO_PLOT):
        metric_to_data = build_metric_arrays(region_names, region_basins, models, common_basins, metric)

        summarize_boxplot_tables(region_names, metric_to_data, metric, MODEL_NAMES, decimals=3)

        handles, labels = plot_grouped_boxplots(
            region_names=region_names,
            metric_to_data=metric_to_data,
            metric=metric,
            ax=axes[i],
            model_names=MODEL_NAMES,
            box_width=BOX_WIDTH,
            show_fliers=SHOW_FLIERS,
            boundaries_x=boundaries_x,
            sep_line_kw=SEP_LINE_KW,
            yaxis_ctrl=YAXIS_CTRL
        )

        if i == 0:
            legend_handles, legend_labels = handles, labels

        if i < n_metrics - 1:
            axes[i].tick_params(axis="x", labelbottom=False)


    print(f"Saved: {os.path.abspath('regional_median_lines.png')}")

    axes[-1].set_xlabel("Region", fontsize=AXLABEL_FONTSIZE)
    axes[-1].xaxis.label.set_fontproperties(TNR_PROP)


    leg = fig.legend(
        legend_handles, legend_labels,
        loc="upper center",
        bbox_to_anchor=(0.55, 1.06),   
        ncol=len(MODEL_NAMES),
        frameon=False,
        fontsize=LEGEND_FONTSIZE,
        handlelength=1.4,          
        handletextpad=0.35,         
        columnspacing=0.9,          
        labelspacing=0.3,           
        borderaxespad=0.1
    )

    if leg is not None:
        dy_pt = 4  
        text_offset = mtransforms.ScaledTranslation(0, dy_pt/72., fig.dpi_scale_trans)

        for t in leg.get_texts():
            t.set_fontproperties(TNR_PROP)
            t.set_transform(t.get_transform() + text_offset)


    fig.subplots_adjust(top=0.90)

    fig.savefig(OUT_PNG, dpi=300, bbox_inches="tight")
    # fig.savefig(OUT_PDF)
    print(f"Saved: {os.path.abspath(OUT_PNG)}")
    print(f"Saved: {os.path.abspath(OUT_PDF)}")
    print("boundaries_x =", boundaries_x)

if __name__ == "__main__":
    main()
