In [None]:
# Fig.8#
"""
SpearmanCorr(w_channel, skill) bar plot (with FDR correction)
- skill = (CaseB - CaseA) / (1 - CaseA)
- Uses channel weights from Excel: channel_weights_by_basin_mid_only.xlsx
- Uses metrics from CSV: val_metrics_lead0.csv
- Plots only selected 8 channels
- Significance stars:
    p_fdr < 0.05      => **
    0.05 <= p_fdr < .1 => *
- Outputs:
    bar_corr_<METRIC>_<CASEB>_vs_<CASEA>.png
    corr_table_<METRIC>_<CASEB>_vs_<CASEA>.csv
"""

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from scipy.stats import spearmanr

# ===================== Font: Times New Roman (Windows Fonts) =====================
import matplotlib as mpl
from matplotlib import font_manager as fm

FONT_DIR = Path("/mnt/c/Windows/Fonts")  # Times New Roman 

def setup_times_new_roman(font_dir: Path):
    if not font_dir.exists():
        raise FileNotFoundError(f"FONT_DIR not found: {font_dir}")

    candidates = [
        font_dir / "times.ttf",
        font_dir / "timesbd.ttf",
        font_dir / "timesi.ttf",
        font_dir / "timesbi.ttf",
    ]
    existing = [p for p in candidates if p.exists()]

    if not existing:
        fallback = []
        for ext in ("*.ttf", "*.otf", "*.ttc"):
            fallback.extend(list(font_dir.glob(ext)))
        fallback = [p for p in fallback if "times" in p.name.lower()]
        existing = fallback

    if not existing:
        raise RuntimeError(f"No Times New Roman font files found under: {font_dir}")

    for fp in existing:
        try:
            fm.fontManager.addfont(str(fp))
        except Exception as e:
            print(f"[WARN] failed to add font {fp}: {e}")

    regular = candidates[0] if candidates[0].exists() else existing[0]
    tnr_prop = fm.FontProperties(fname=str(regular))
    tnr_name = tnr_prop.get_name()

    mpl.rcParams["font.family"] = tnr_name
    mpl.rcParams["font.sans-serif"] = [tnr_name]
    mpl.rcParams["axes.unicode_minus"] = False

    mpl.rcParams["mathtext.fontset"] = "custom"
    mpl.rcParams["mathtext.rm"] = tnr_name
    mpl.rcParams["mathtext.it"] = f"{tnr_name}:italic"
    mpl.rcParams["mathtext.bf"] = f"{tnr_name}:bold"

    print(f"[INFO] Using font: {tnr_name}")
    print(f"[INFO] Registered font files: {[p.name for p in existing]}")
    return tnr_prop, tnr_name

TNR_PROP, TNR_NAME = setup_times_new_roman(FONT_DIR)

# ===================== 1) CONFIG =====================

# ---- Models (you can add more pairs here) ----
MODEL_CSVS = {
    "Case0": "/mnt/d/desktop/paper_data/01/model_data/case0/val_metrics_lead0.csv",
    "Case2": "/mnt/d/desktop/paper_data/01/model_data/case2/val_metrics_lead0.csv",
    "Case3": "/mnt/d/desktop/paper_data/01/model_data/case3/val_metrics_lead0.csv",
}

# Channel weights excel (usually in the model folder)
WEIGHTS_XLSX = "/mnt/d/desktop/paper_data/01/model_data/case2/channel_weights_by_basin_mid_only.xlsx"
# If Case3 has its own weights file and you want to use that instead, change above path.

# Comparisons you want
COMPARISONS = [
    ("Case2", "Case0"),
    ("Case3", "Case0"),

]

# Metrics to analyze
METRICS = ["KGE"]

# basin id padding
BASIN_ZFILL = 8

# 15 channels in your file (full set)
CHANNEL_COLS_ALL = [
    "w_dem", "w_slope", "w_tree", "w_ai", "w_et0",
    "w_pre_mean", "w_pre_sea", "w_clay", "w_sand", "w_silt",
    "w_soil_depth", "w_soil_depth2", "w_porosity", "w_water", "w_conductivity"
]

# Plot only 8 channels (YOU CONTROL ORDER HERE)
PLOT_CHANNELS = [
    "w_et0",
    "w_pre_sea",
    "w_pre_mean",
    "w_ai",
    "w_soil_depth2",
    "w_clay",
    "w_porosity",
    "w_conductivity",
]

# FDR alpha used for “flagging”
FDR_ALPHA = 0.05

# Star rules based on adjusted p (p_fdr)
STAR_P05 = 0.05
STAR_P10 = 0.10

# Output directory
OUT_DIR = Path("/mnt/d/desktop/paper_data/01/paper_output/channel_corr_skill")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Plot style
FIGSIZE = (7, 6)
DPI = 300
POS_COLOR = "#c0392b"  # positive rho
NEG_COLOR = "#2b6cb0"  # negative rho


plt.rcParams.update({
    "font.family": TNR_NAME,
    "font.size": 24,
    "figure.dpi": DPI,
})

# ===================== 2) HELPERS =====================

def norm_basin_id(s: pd.Series, zfill: int = 8) -> pd.Series:
    return (s.astype(str).str.strip()
            .str.replace(r"\.0$", "", regex=True)
            .str.zfill(zfill))

def find_basin_id_column(df: pd.DataFrame) -> str:
    if "basin_id" in df.columns:
        return "basin_id"
    for c in df.columns:
        cl = c.lower()
        if ("basin" in cl and "id" in cl) or cl in ["gauge_id", "gage_id", "site_id", "id"]:
            return c
    raise ValueError(f"Cannot find basin id column. Columns={list(df.columns)}")

def fdr_bh(pvals: np.ndarray) -> np.ndarray:
    """
    Benjamini–Hochberg FDR correction.
    Input: pvals array (nan allowed)
    Output: adjusted pvals (nan preserved)
    """
    p = np.asarray(pvals, dtype=float)
    out = np.full_like(p, np.nan, dtype=float)

    mask = np.isfinite(p)
    if mask.sum() == 0:
        return out

    pv = p[mask]
    n = pv.size
    order = np.argsort(pv)
    ranked = pv[order]
    adj = ranked * n / (np.arange(1, n + 1))

    adj = np.minimum.accumulate(adj[::-1])[::-1]
    adj = np.clip(adj, 0.0, 1.0)

    out_idx = np.where(mask)[0][order]
    out[out_idx] = adj
    return out

def star_by_p(p: float) -> str:
    if not np.isfinite(p):
        return ""
    if p < STAR_P05:
        return "**"
    if p < STAR_P10:
        return "*"
    return ""

def safe_spearman(x: np.ndarray, y: np.ndarray):
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)
    m = np.isfinite(x) & np.isfinite(y)
    x = x[m]; y = y[m]
    if x.size < 10:
        return np.nan, np.nan, int(x.size)
    if np.nanstd(x) == 0 or np.nanstd(y) == 0:
        return np.nan, np.nan, int(x.size)
    rho, p = spearmanr(x, y)
    return float(rho), float(p), int(x.size)

# ===================== 3) IO =====================

def load_channel_weights(weights_xlsx: str) -> pd.DataFrame:
    weights_xlsx = Path(weights_xlsx)
    if not weights_xlsx.exists():
        raise FileNotFoundError(f"Missing weights file: {weights_xlsx}")

    df = pd.read_excel(weights_xlsx)

    id_col = find_basin_id_column(df)
    df = df.rename(columns={id_col: "basin_id"})
    df["basin_id_str"] = norm_basin_id(df["basin_id"], BASIN_ZFILL)

    missing = [c for c in CHANNEL_COLS_ALL if c not in df.columns]
    if missing:
        raise ValueError(f"Weights file missing columns: {missing}\nAvailable={list(df.columns)}")

    keep = ["basin_id_str"] + CHANNEL_COLS_ALL
    df = df[keep].copy()
    return df

def load_metric_csv(path: str) -> pd.DataFrame:
    path = Path(path)
    if not path.exists():
        raise FileNotFoundError(f"Missing metrics csv: {path}")
    df = pd.read_csv(path)

    id_col = find_basin_id_column(df)
    df = df.rename(columns={id_col: "basin_id"})
    df["basin_id_str"] = norm_basin_id(df["basin_id"], BASIN_ZFILL)
    return df

# ===================== 4) CORE =====================

def build_skill_table(dfA: pd.DataFrame, dfB: pd.DataFrame, metric: str) -> pd.DataFrame:
    if metric not in dfA.columns:
        raise ValueError(f"{metric} not in CaseA columns: {list(dfA.columns)}")
    if metric not in dfB.columns:
        raise ValueError(f"{metric} not in CaseB columns: {list(dfB.columns)}")

    m = dfA[["basin_id_str", metric]].merge(
        dfB[["basin_id_str", metric]],
        on="basin_id_str",
        suffixes=("_A", "_B"),
        how="inner"
    )

    A = m[f"{metric}_A"].astype(float)
    B = m[f"{metric}_B"].astype(float)

    den = (1.0 - A).replace(0, np.nan)
    m["skill"] = (B - A) / den
    return m[["basin_id_str", "skill"]].copy()

def compute_corr_table(weights_df: pd.DataFrame, skill_df: pd.DataFrame, channels_to_plot: list) -> pd.DataFrame:
    merged = weights_df.merge(skill_df, on="basin_id_str", how="inner")
    rows = []
    for ch in channels_to_plot:
        rho, p, n = safe_spearman(merged[ch].values, merged["skill"].values)
        rows.append({"channel": ch, "rho": rho, "p_raw": p, "n": n})
    out = pd.DataFrame(rows)

    out["p_fdr"] = fdr_bh(out["p_raw"].values)
    out["sig_fdr_0p05"] = out["p_fdr"] < FDR_ALPHA

    out["channel"] = pd.Categorical(out["channel"], categories=channels_to_plot, ordered=True)
    out = out.sort_values("channel").reset_index(drop=True)
    return out

def plot_corr_bar(df_corr: pd.DataFrame, title: str, out_png: Path):
    x = np.arange(len(df_corr))
    y = df_corr["rho"].values

    colors = [POS_COLOR if v >= 0 else NEG_COLOR for v in y]

    fig, ax = plt.subplots(figsize=FIGSIZE, dpi=DPI)
    ax.axhline(0, color="black", linewidth=1.2)

    ax.bar(x, y, color=colors, width=0.65, edgecolor="black", linewidth=1.0)

    ax.set_xticks(x)
    labels = [str(c).replace("w_", "", 1) for c in df_corr["channel"].tolist()]
    labels = [lab.replace("_", " ") for lab in labels]
    ax.set_xticklabels(labels, rotation=45, ha="right")
    import matplotlib.transforms as mtransforms

    # ---- shift x tick LABELS only (ticks stay) ----
    DX_PTS = 14   
    trans = mtransforms.ScaledTranslation(DX_PTS/72., 0, fig.dpi_scale_trans)

    for lab in ax.get_xticklabels():
        lab.set_transform(lab.get_transform() + trans)


    ax.set_ylabel(
        r"Spearman correlation" "\n" r"coefficient ($\mathit{\rho}$)",
        rotation=90,
        ha="center",
        va="center"
    )
    ax.yaxis.set_label_coords(-0.2, 0.5)   


    # ax.set_title(title, pad=10)

    ax.set_ylim(-0.3, 0.3)
    ax.set_yticks(np.arange(-0.3, 0.31, 0.1))
    ax.axhline(0.1, color="gray", linestyle="--", linewidth=0.8)
    ax.axhline(-0.1, color="gray", linestyle="--", linewidth=0.8)

    ymax = np.nanmax(np.r_[0, y])
    ymin = np.nanmin(np.r_[0, y])
    yr = (ymax - ymin) + 1e-12
    pad = 0.04 * yr

    for i, row in enumerate(df_corr.itertuples(index=False)):
        s = star_by_p(row.p_fdr)
        if not s:
            continue
        yy = row.rho + pad if row.rho >= 0 else row.rho - pad
        va = "bottom" if row.rho >= 0 else "top"
        ax.text(i, yy, s, ha="center", va=va, fontsize=18, fontweight="bold")

    ax.set_xlim(-0.6, len(df_corr) - 0.4)
    fig.tight_layout()
    fig.savefig(out_png, dpi=DPI, bbox_inches="tight")
    print(f"[SAVED] {out_png}")

# ===================== 5) MAIN =====================

def main():
    weights_df = load_channel_weights(WEIGHTS_XLSX)

    miss = [c for c in PLOT_CHANNELS if c not in weights_df.columns]
    if miss:
        raise ValueError(f"PLOT_CHANNELS not found in weights: {miss}")

    model_dfs = {}
    for name, p in MODEL_CSVS.items():
        model_dfs[name] = load_metric_csv(p)
        print(f"[INFO] Loaded {name}: rows={len(model_dfs[name])}")

    for caseB, caseA in COMPARISONS:
        if caseA not in model_dfs or caseB not in model_dfs:
            raise ValueError(f"Missing case in MODEL_CSVS: {caseB} or {caseA}")

        dfA = model_dfs[caseA]
        dfB = model_dfs[caseB]

        common = set(dfA["basin_id_str"]).intersection(dfB["basin_id_str"]).intersection(weights_df["basin_id_str"])
        print(f"[INFO] Common basins for {caseB} vs {caseA}: {len(common)}")

        dfA2 = dfA[dfA["basin_id_str"].isin(common)].copy()
        dfB2 = dfB[dfB["basin_id_str"].isin(common)].copy()
        w2 = weights_df[weights_df["basin_id_str"].isin(common)].copy()

        for metric in METRICS:
            skill_df = build_skill_table(dfA2, dfB2, metric)
            corr_df = compute_corr_table(w2, skill_df, PLOT_CHANNELS)

            title = f"SpearmanCorr(w_channel, skill) | skill=({caseB}-{caseA})/(1-{caseA}) | Metric={metric} | Stars: p_fdr<0.05(**), <0.1(*)"
            out_png = OUT_DIR / f"bar_corr_{metric}_{caseB}_vs_{caseA}.png"
            plot_corr_bar(corr_df, title, out_png)

    print("[DONE]")

if __name__ == "__main__":
    main()
