## Evaluation

In [1]:
import polars as pl
from pathlib import Path
from evalutils import split_by_timestep, metrics_to_latex, compute_metrics_by_horizon
from metrics import compute_metrics
import numpy as np
import matplotlib.pyplot as plt
from typing import Union

metrics = ["NSE",
           "MSE",
           "Alpha-NSE",
           "Beta-NSE",
           "Pearson $r$",
           "KGE",
           "FHV",
           "FLV",
           "FMS",
           "Peak-Timing"]

models = ["lstm", "autoformer", "patchtst"]
mq_loss = compute_metrics_by_horizon(Path("../results/mq"), metrics, horizon_days=3)
mse_loss = compute_metrics_by_horizon(Path("../results/mse/"), metrics, horizon_days=3)

In [5]:
futr_mse_loss = compute_metrics_by_horizon(Path("../results/futr_exog/"), metrics, horizon_days=3)
futr_mse_loss

{'day1': {'lstm': shape: (50, 11)
  ┌──────────┬──────────┬──────────┬───────────┬───┬────────────┬────────────┬───────────┬───────────┐
  │ basin    ┆ NSE      ┆ MSE      ┆ Alpha-NSE ┆ … ┆ FHV        ┆ FLV        ┆ FMS       ┆ Peak-Timi │
  │ ---      ┆ ---      ┆ ---      ┆ ---       ┆   ┆ ---        ┆ ---        ┆ ---       ┆ ng        │
  │ str      ┆ f64      ┆ f64      ┆ f64       ┆   ┆ f64        ┆ f64        ┆ f64       ┆ ---       │
  │          ┆          ┆          ┆           ┆   ┆            ┆            ┆           ┆ f64       │
  ╞══════════╪══════════╪══════════╪═══════════╪═══╪════════════╪════════════╪═══════════╪═══════════╡
  │ DE214770 ┆ 0.305279 ┆ 0.991302 ┆ 0.447588  ┆ … ┆ -39.684165 ┆ -54.444513 ┆ 6.896623  ┆ 1.163934  │
  │ DE111070 ┆ 0.470807 ┆ 0.136314 ┆ 0.747419  ┆ … ┆ -28.384418 ┆ -54.975917 ┆ 0.332872  ┆ 1.166667  │
  │ DE213150 ┆ 0.518744 ┆ 0.651337 ┆ 0.616061  ┆ … ┆ -37.88351  ┆ -58.716314 ┆ -8.620922 ┆ 1.370968  │
  │ DEE10430 ┆ 0.820306 ┆ 0.10512  ┆ 0.

In [10]:
lstm_mse_w_futr = compute_metrics_by_horizon(Path("../results/futr_exog/"), metrics, horizon_days=3)
lstm_mse_wo_futr = {h:{"lstm":models["lstm"]} for h, models in mse_loss.items()}

In [None]:
unified = {}
for day in ["day1","day2","day3"]:
    unified[day] = {
        "lstm_w_futr": lstm_mse_w_futr[day]["lstm"].select(["basin", "NSE"]),
        "lstm_wo_futr": lstm_mse_wo_futr[day]["lstm"].select(["basin", "NSE"]),
    }

In [14]:
comparisons = []
for day in ["day1", "day2", "day3"]:
    comparisons.append(
        (day, "lstm_w_futr", "lstm_wo_futr", f"LSTM con vs sin futuros ({day})")
    )

In [18]:
result = paired_wilcoxon_polars(
    data=unified,
    metric_col="NSE",
    comparisons=comparisons,
    #apply_bias_correction=True  # o False si no quieres corrección tipo Hedges
)

In [19]:
result

horizon,comparison,W,p_raw,effect_size_rank_biserial,p_adjusted,significant
str,str,f64,f64,f64,f64,bool
"""day1""","""LSTM con vs sin futuros (day1)""",198.0,8e-06,-0.581451,8e-06,True
"""day2""","""LSTM con vs sin futuros (day2)""",27.0,2.24e-12,1.39834,2.24e-12,True
"""day3""","""LSTM con vs sin futuros (day3)""",3.0,8.8818e-15,2.228696,8.8818e-15,True


In [None]:
horizons = ["day1", "day2", "day3"]
models = ["lstm"]
loss_vs_loss = build_loss_vs_loss(horizons, models, ["MQLoss", "MSELoss"])
model_pairs = [("lstm", "patchtst"), ("lstm", "tcn"), ("tcn", "patchtst")]
model_vs_model = build_model_vs_model(horizons, ["MQLoss", "MSELoss"], model_pairs)
all_comps = loss_vs_loss + model_vs_model

In [None]:
result_table = paired_wilcoxon_polars(unified, metric_col="NSE", comparisons=all_comps)

In [None]:
mse_loss

{'day1': {'tcn': shape: (50, 11)
  ┌──────────┬──────────┬───────────┬───────────┬───┬────────────┬───────────┬───────────┬───────────┐
  │ basin    ┆ NSE      ┆ MSE       ┆ Alpha-NSE ┆ … ┆ FHV        ┆ FLV       ┆ FMS       ┆ Peak-Timi │
  │ ---      ┆ ---      ┆ ---       ┆ ---       ┆   ┆ ---        ┆ ---       ┆ ---       ┆ ng        │
  │ str      ┆ f64      ┆ f64       ┆ f64       ┆   ┆ f64        ┆ f64       ┆ f64       ┆ ---       │
  │          ┆          ┆           ┆           ┆   ┆            ┆           ┆           ┆ f64       │
  ╞══════════╪══════════╪═══════════╪═══════════╪═══╪════════════╪═══════════╪═══════════╪═══════════╡
  │ DE110350 ┆ 0.501984 ┆ 0.674414  ┆ 0.668844  ┆ … ┆ -32.645105 ┆ -96.71403 ┆ -24.92953 ┆ 1.101695  │
  │          ┆          ┆           ┆           ┆   ┆            ┆ 9         ┆ 2         ┆           │
  │ DE213970 ┆ 0.768474 ┆ 0.564371  ┆ 0.679934  ┆ … ┆ -31.970054 ┆ -293.0360 ┆ -25.87329 ┆ 1.180328  │
  │          ┆          ┆           ┆   

In [5]:
futr_mse_loss

{'day1': {'lstm': shape: (50, 11)
  ┌──────────┬──────────┬───────────┬───────────┬───┬────────────┬───────────┬───────────┬───────────┐
  │ basin    ┆ NSE      ┆ MSE       ┆ Alpha-NSE ┆ … ┆ FHV        ┆ FLV       ┆ FMS       ┆ Peak-Timi │
  │ ---      ┆ ---      ┆ ---       ┆ ---       ┆   ┆ ---        ┆ ---       ┆ ---       ┆ ng        │
  │ str      ┆ f64      ┆ f64       ┆ f64       ┆   ┆ f64        ┆ f64       ┆ f64       ┆ ---       │
  │          ┆          ┆           ┆           ┆   ┆            ┆           ┆           ┆ f64       │
  ╞══════════╪══════════╪═══════════╪═══════════╪═══╪════════════╪═══════════╪═══════════╪═══════════╡
  │ DE212710 ┆ 0.599974 ┆ 0.932745  ┆ 0.706727  ┆ … ┆ -34.343937 ┆ -1839.109 ┆ 2.144275  ┆ 1.05      │
  │          ┆          ┆           ┆           ┆   ┆            ┆ 899       ┆           ┆           │
  │ DE212300 ┆ 0.296382 ┆ 11.507786 ┆ 0.492634  ┆ … ┆ -51.793173 ┆ -711.1848 ┆ -7.838327 ┆ 1.37931   │
  │          ┆          ┆           ┆  

In [None]:
# \begin{table}[ht]
# \centering
# \caption{caption}
# \label{tab:lstm_horizon}
# \begin{tabular}{lccc}
# \toprule
# \textbf{LSTM} & \textbf{Time} & \textbf{Mean} & \textbf{Median} \\
# \midrule
# \multirow{3}{*}{Without future inputs}
#  & Day 1 & 0.43 & 0.76 \\
#  & Day 2 & 0.41 & 0.73 \\
#  & Day 3 & 0.38 & 0.70 \\
# \midrule
# \multirow{3}{*}{With future inputs}
#  & Day 1 & 0.56* & 0.83* \\
#  & Day 2 & 0.53* & 0.80* \\
#  & Day 3 & 0.50* & 0.77* \\
# \bottomrule
# \end{tabular}
# \end{table}

def metrics_to_latex(
        df_loss_mqloss: pl.DataFrame,
        df_loss_mse: pl.DataFrame, 
        metrics_labels: List[str],
        model_labels: List[str],
        aggregation: str,
        file_name: str) -> None:
    """
    Generate a LaTeX table of performance metrics for multiple models across different forecast horizons and write the table to a .tex file.

    This function iterates over specified metrics and timesteps, computes the mean and standard deviation
    for each model and metric, and formats the results into a LaTeX table.

    Parameters
    ----------
        df (pl.DataFrame): A Polars DataFrame containing the df evaluation results.
            It should be structured such that df[horizon][model][metric] returns a series
            of metric values for the given forecast horizon and model.
        metrics_labels (List[str]): A list of metric names to include in the table (e.g., ["NSE", "KGE"]).
        model_labels (List[str]): A list of model identifiers matching the DataFrame keys
            (e.g., ["LSTM", "TCN", "PatchTST"]).
        file_name (str): The name of the output file where the LaTeX table will be saved.

    Returns
    -------
        None: Writes a LaTeX-formatted table to "file_name.tex".
    """
    
    timesteps = list(df_loss_mqloss.keys())
    timesteps_labels = {f"day{h}": f"Day {h}" for h in range(1, len(timesteps) + 1)}

    # Initialize lines for LaTeX table structure
    lines = [
        r"\begin{table}[ht]",
        r"\centering",
        r"\caption{caption}",
        r"\label{tab:table}",
        r"\begin{tabular}{ll*{4}{cc}}",
        r"\toprule",
        r"\multirow{2}{*}{\textbf{Metric}} & \multirow{2}{*}{\textbf{Time}}",
        r"& \multicolumn{2}{c}{\textbf{LSTM}}",
        r"& \multicolumn{2}{c}{\textbf{TCN}}",
        r"& \multicolumn{2}{c}{\textbf{PatchTST}}\\",
        r"\cmidrule(lr){3-4} \cmidrule(lr){5-6} \cmidrule(lr){7-8} \cmidrule(lr){9-10}",
        r"& & \textbf{MQLoss} & \textbf{MSELoss}",
        r"& \textbf{MQLoss} & \textbf{MSELoss}",
        r"& \textbf{MQLoss} & \textbf{MSELoss}\\"
        r"\midrule"
    ]

    # Populate table rows with mean and std for each metric at each timestep
    for metric_label in metrics_labels:
        # Add multirow entry for current metric
        lines.append(rf"\multirow{{3}}{{*}}{{{metric_label}\tnote{{{chr(97 + metrics_labels.index(metric_label))}}}}}")
        for t in timesteps:
            # Compute mean and std for each model
            if aggregation == "mean":
                row_metrics = {
                    "mse": { model: df_loss_mse[t][model.lower()][metric_label].mean() for model in model_labels },
                    "mqloss": { model: df_loss_mqloss[t][model.lower()][metric_label].mean() for model in model_labels },
                    # "std":   { model: df[t][model.lower()][metric_label].std()  for model in model_labels }
                }
            else:
                row_metrics = {
                    "mse": { model: df_loss_mse[t][model.lower()][metric_label].median() for model in model_labels },
                    "mqloss": { model: df_loss_mqloss[t][model.lower()][metric_label].median() for model in model_labels },
                }
            
            # Format cells as 'mean (std)'
            cells = [fr"{row_metrics['mqloss'][model]:.2f} & {row_metrics['mse'][model]:.2f}" for model in model_labels]

            # Append the formatted row to the table
            lines.append(rf" & {timesteps_labels[t]} & {' & '.join(cells)} \\")
        # Add a midrule between metrics, except after the last one
        if metric_label != metrics_labels[-1]:
            lines.append(r"\midrule")

    # Close the table and add footnotes explaining metrics
    lines.extend([
        r"\bottomrule", 
        r"\end{tabular}",
        # r"\begin{tablenotes}",
        # r"\footnotesize",
        # r"(a) Nash–Sutcliffe efficiency: ($-\infty$,\,1]; values closer to one are desirable.",
        # r"(b) $\alpha$-NSE decomposition: (0,\,\,$\infty$); values closer to one are desirable.",
        # r"(c) $\beta$-NSE decomposition: ($-\infty$,\,\,$\infty$); values closer to zero are desirable.",
        # r"(d) KGE (top 2\% peak flow bias): ($-\infty$,\,\,$\infty$); values closer to zero are desirable.",
        # r"(e) Bias of FDC midsegment slope: ($-\infty$,\,\,$\infty$); values closer to zero are desirable.",
        # r"(f) Bias en el caudal bajo al 30\%: ($-\infty$,\,\,$\infty$); values closer to zero are desirable.",
        # r"(g) FMS mid‐segment slope bias: ($-\infty$,\,\,$\infty$); values closer to zero are desirable.",
        # r"\end{tablenotes}",
        r"\end{threeparttable}",
        r"\end{table*}"
    ])

    # Write the assembled lines to the LaTeX file
    with open(f"tables/{file_name}_{aggregation}.tex", "w") as f:
        f.write("\n".join(lines))




In [12]:
import polars as pl
from scipy.stats import wilcoxon
from typing import Dict

def compare_wilcoxon(
    mqloss_df: Dict[str, Dict[str, pl.DataFrame]],
    mse_loss_df: Dict[str, Dict[str, pl.DataFrame]]
) -> None:
    """
    Interactively compare paired NSE series using the Wilcoxon signed-rank test.

    Parameters:
    -----------
    mqloss_df : Dict[str, Dict[str, pl.DataFrame]]
        Dictionary mapping horizon keys ("day1","day2","day3") to dictionaries
        that map model names ("lstm","autoformer","patchtst") to Polars DataFrames.
        Each DataFrame should have columns "basin" and "NSE".
    mse_loss_df : Dict[str, Dict[str, pl.DataFrame]]
        Similar structure, containing the metrics for MSE loss.
    """
    # 1) Select horizon
    while True:
        h = input("Select time horizon (1, 2, or 3): ").strip()
        if h in {"1", "2", "3"}:
            horizon = f"day{h}"
            break
        print("Invalid option. Please choose 1, 2, or 3.")

    # 2) Select comparison type
    while True:
        choice = input("What do you want to compare? (1) Models, (2) Loss functions: ").strip()
        if choice in {"1", "2"}:
            break
        print("Invalid option. Please enter 1 or 2.")

    if choice == "1":
        # Compare two models under the same loss dataset
        models = {"1": "lstm", "2": "autoformer", "3": "patchtst"}
        while True:
            m1 = input("Select first model: (1) LSTM, (2) Autoformer, (3) PatchTST: ").strip()
            if m1 in models:
                break
            print("Invalid option.")
        while True:
            m2 = input("Select second model (different): (1) LSTM, (2) Autoformer, (3) PatchTST: ").strip()
            if m2 in models and m2 != m1:
                break
            print("Invalid or same as first model.")
        model1, model2 = models[m1], models[m2]

        losses = {"1": ("MQLoss", mqloss_df), "2": ("MSELoss", mse_loss_df)}
        while True:
            lopt = input("Which loss function dataset to use? (1) MQLoss, (2) MSELoss: ").strip()
            if lopt in losses:
                loss_name, df_source = losses[lopt]
                break
            print("Invalid option. Please enter 1 or 2.")

        print(f"\nComparing models '{model1}' vs '{model2}' using {loss_name} for horizon '{horizon}':")

        df1 = df_source[horizon][model1]
        df2 = df_source[horizon][model2]
        df_joined = df1.join(df2, on="basin", suffix=f"_{model2}")

        series1 = df_joined["NSE"]
        series2 = df_joined[f"NSE_{model2}"]
        stat, p_value = wilcoxon(series1.to_list(), series2.to_list())
        print(f"Wilcoxon W statistic: {stat}, p-value: {p_value:.6f}")

    else:
        # Compare two loss functions for the same model
        models = {"1": "lstm", "2": "autoformer", "3": "patchtst"}
        while True:
            m = input("Select model: (1) LSTM, (2) Autoformer, (3) PatchTST: ").strip()
            if m in models:
                model = models[m]
                break
            print("Invalid option.")
        print(f"\nComparing loss functions for model '{model}' at horizon '{horizon}':")

        df_mq = mqloss_df[horizon][model]
        df_mse = mse_loss_df[horizon][model]
        df_joined = df_mq.join(df_mse, on="basin", suffix="_mse")

        series_mq = df_joined["NSE"]
        series_mse = df_joined["NSE_mse"]
        stat, p_value = wilcoxon(series_mq.to_list(), series_mse.to_list())
        print(f"Wilcoxon W statistic: {stat}, p-value: {p_value:.6f}")


# compare_wilcoxon(mqloss, mse_loss)

In [20]:
# from scipy.stats import wilcoxon

# a = mqloss["day1"]["lstm"]
# b = mse_loss["day1"]["lstm"]
# c= mqloss["day1"]["autoformer"]

# # join by basin

# # df = a.join(b, on="basin", suffix="_mqloss")
# df = a.join(c, on="basin", suffix="_mqloss")


# _, p_value = wilcoxon(df["NSE"], df["NSE_mqloss"])
# print(f"Wilcoxon test p-value: {p_value}")

In [None]:
import polars as pl
import numpy as np
import pandas as pd
from scipy.stats import wilcoxon
from statsmodels.stats.multitest import multipletests

# def rank_biserial_effect_size(x, y):
#     """
#     Rank-biserial correlation for paired samples x, y.
#     """
#     diff = np.array(x) - np.array(y)
#     nonzero = diff != 0
#     diffs = diff[nonzero]
#     if len(diffs) == 0:
#         return 0.0
#     abs_ranks = pd.Series(np.abs(diffs)).rank(method="average")
#     signs = np.sign(diffs)
#     W_pos = (abs_ranks * (signs > 0)).sum()
#     W_neg = (abs_ranks * (signs < 0)).sum()
#     n = len(diffs)
#     # rank-biserial: (W_pos - W_neg) / (n*(n+1)/2)
#     return (W_pos - W_neg) / (n * (n + 1) / 2)


def cohen_d_paired(x, y, bias_correction=True):
    """
    Cohen's d for paired samples: mean(diff) / std(diff).
    If bias_correction=True applies small-sample correction (Hedges-like).
    """
    diffs = np.array(x) - np.array(y)
    # exclude zero-variance case
    if len(diffs) < 2:
        return 0.0
    mean_diff = np.mean(diffs)
    sd_diff = np.std(diffs, ddof=1)  # sample std
    if sd_diff == 0:
        return 0.0
    dz = mean_diff / sd_diff
    if not bias_correction:
        return dz
    n = len(diffs)
    J = 1 - (3 / (4 * n - 1))
    return dz * J  # corrected (Hedges-like)


def paired_wilcoxon_polars(
    data: dict,  # e.g., {"day1": {"lstm_MQLoss": pl.DataFrame, ...}, ...}
    metric_col: str = "NSE",
    comparisons: list[tuple] = None
) -> pl.DataFrame:
    """
    Runs Wilcoxon paired tests over specified comparisons.

    comparisons: list of (horizon, key1, key2, description)
    """
    records = []
    for horizon, key1, key2, desc in comparisons:
        df1 = data[horizon].get(key1)
        df2 = data[horizon].get(key2)
        if df1 is None or df2 is None:
            continue
        # join on basin
        joined = df1.join(df2, on="basin", how="inner", suffix=f"_{key2}")
        col1 = metric_col if f"{metric_col}" in joined.columns else None
        # disambiguate if both have same name: suffix handling
        name1 = metric_col if f"{metric_col}" in df1.columns else f"{metric_col}"
        name2 = metric_col if f"{metric_col}" in df2.columns else f"{metric_col}"
        # After join, columns are like NSE and NSE_<key2> if names clash
        if f"{metric_col}_{key2}" in joined.columns:
            s1 = joined[metric_col].to_list()
            s2 = joined[f"{metric_col}_{key2}"].to_list()
        else:
            # assume no suffix (unlikely) or both had same name but not distinguished
            s1 = joined[metric_col].to_list()
            s2 = joined[metric_col].to_list()  # will lead to zero diff
        # filter out NaNs
        a = []
        b = []
        for v1, v2 in zip(s1, s2):
            if v1 is None or v2 is None or (isinstance(v1, float) and np.isnan(v1)) or (isinstance(v2, float) and np.isnan(v2)):
                continue
            a.append(v1)
            b.append(v2)
        if len(a) == 0:
            continue
        try:
            stat, p_raw = wilcoxon(a, b, zero_method="wilcox", alternative="two-sided", correction=False)
        except ValueError:
            stat, p_raw = np.nan, 1.0
        # effect = rank_biserial_effect_size(a, b)
        effect = cohen_d_paired(a, b, bias_correction=True)
        records.append({
            "horizon": horizon,
            "comparison": desc,
            "W": stat,
            "p_raw": p_raw,
            "effect_size_rank_biserial": effect
        })

    if not records:
        return pl.DataFrame([])

    df = pd.DataFrame(records)
    # ajuste de p-values dentro de cada horizonte (Holm)
    df["p_adjusted"] = df.groupby("horizon")["p_raw"].transform(lambda ps: multipletests(ps, method="holm")[1])
    df["significant"] = df["p_adjusted"] < 0.05
    return pl.from_pandas(df)

# Helpers para construir comparaciones:
def build_loss_vs_loss(horizons, models, losses):
    comps = []
    for h in horizons:
        for model in models:
            k1 = f"{model}_{losses[0]}"
            k2 = f"{model}_{losses[1]}"
            desc = f"{model} {losses[0]} vs {losses[1]}"
            comps.append((h, k1, k2, desc))
    return comps

def build_model_vs_model(horizons, loss_types, model_pairs):
    comps = []
    for h in horizons:
        for loss in loss_types:
            for m1, m2 in model_pairs:
                k1 = f"{m1}_{loss}"
                k2 = f"{m2}_{loss}"
                desc = f"{m1} vs {m2} ({loss})"
                comps.append((h, k1, k2, desc))
    return comps

In [17]:
unified = {}
for day in ["day1","day2","day3"]:
    unified[day] = {
        "lstm_MQLoss": mq_loss[day]["lstm"].select(["basin", "NSE"]),
        "lstm_MSELoss": mse_loss[day]["lstm"].select(["basin", "NSE"]),

        "tcn_MQLoss": mq_loss[day]["tcn"].select(["basin", "NSE"]),
        "tcn_MSELoss": mse_loss[day]["tcn"].select(["basin", "NSE"]),
        
        "patchtst_MQLoss": mq_loss[day]["patchtst"].select(["basin", "NSE"]),
        "patchtst_MSELoss": mse_loss[day]["patchtst"].select(["basin", "NSE"])
    }

In [18]:
horizons = ["day1", "day2", "day3"]
models = ["lstm", "tcn", "patchtst"]
loss_vs_loss = build_loss_vs_loss(horizons, models, ["MQLoss", "MSELoss"])
model_pairs = [("lstm", "patchtst"), ("lstm", "tcn"), ("tcn", "patchtst")]
model_vs_model = build_model_vs_model(horizons, ["MQLoss", "MSELoss"], model_pairs)
all_comps = loss_vs_loss + model_vs_model

In [13]:
all_comps

[('day1', 'lstm_MQLoss', 'lstm_MSELoss', 'lstm MQLoss vs MSELoss'),
 ('day1', 'tcn_MQLoss', 'tcn_MSELoss', 'tcn MQLoss vs MSELoss'),
 ('day1', 'patchtst_MQLoss', 'patchtst_MSELoss', 'patchtst MQLoss vs MSELoss'),
 ('day2', 'lstm_MQLoss', 'lstm_MSELoss', 'lstm MQLoss vs MSELoss'),
 ('day2', 'tcn_MQLoss', 'tcn_MSELoss', 'tcn MQLoss vs MSELoss'),
 ('day2', 'patchtst_MQLoss', 'patchtst_MSELoss', 'patchtst MQLoss vs MSELoss'),
 ('day3', 'lstm_MQLoss', 'lstm_MSELoss', 'lstm MQLoss vs MSELoss'),
 ('day3', 'tcn_MQLoss', 'tcn_MSELoss', 'tcn MQLoss vs MSELoss'),
 ('day3', 'patchtst_MQLoss', 'patchtst_MSELoss', 'patchtst MQLoss vs MSELoss'),
 ('day1', 'lstm_MQLoss', 'patchtst_MQLoss', 'lstm vs patchtst (MQLoss)'),
 ('day1', 'lstm_MQLoss', 'tcn_MQLoss', 'lstm vs tcn (MQLoss)'),
 ('day1', 'tcn_MQLoss', 'patchtst_MQLoss', 'tcn vs patchtst (MQLoss)'),
 ('day1', 'lstm_MSELoss', 'patchtst_MSELoss', 'lstm vs patchtst (MSELoss)'),
 ('day1', 'lstm_MSELoss', 'tcn_MSELoss', 'lstm vs tcn (MSELoss)'),
 ('day1

In [19]:
result_table = paired_wilcoxon_polars(unified, metric_col="NSE", comparisons=all_comps)

In [22]:
result_table

horizon,comparison,W,p_raw,effect_size_rank_biserial,p_adjusted,significant
str,str,f64,f64,f64,f64,bool
"""day1""","""lstm MQLoss vs MSELoss""",266.0,0.000212,-0.437916,0.000847,true
"""day1""","""tcn MQLoss vs MSELoss""",136.0,1.5079e-7,0.891446,0.000001,true
"""day1""","""patchtst MQLoss vs MSELoss""",369.0,0.008849,-0.370837,0.026547,true
"""day2""","""lstm MQLoss vs MSELoss""",242.0,0.000072,0.523628,0.000286,true
"""day2""","""tcn MQLoss vs MSELoss""",2.0,5.3291e-15,1.971188,4.2633e-14,true
…,…,…,…,…,…,…
"""day3""","""lstm vs tcn (MQLoss)""",537.0,0.337403,-0.248688,0.337403,false
"""day3""","""tcn vs patchtst (MQLoss)""",13.0,1.5632e-13,-1.390118,1.0942e-12,true
"""day3""","""lstm vs patchtst (MSELoss)""",2.0,5.3291e-15,-1.105015,4.2633e-14,true
"""day3""","""lstm vs tcn (MSELoss)""",220.0,0.000024,-0.632967,0.000073,true


In [10]:
unified

{'day1': {'lstm_MQLoss': shape: (50, 2)
  ┌──────────┬──────────┐
  │ basin    ┆ NSE      │
  │ ---      ┆ ---      │
  │ str      ┆ f64      │
  ╞══════════╪══════════╡
  │ DE110030 ┆ 0.889457 │
  │ DE110140 ┆ 0.701975 │
  │ DE212710 ┆ 0.576581 │
  │ DEG10310 ┆ 0.664652 │
  │ DE210840 ┆ 0.59076  │
  │ …        ┆ …        │
  │ DE710530 ┆ 0.879158 │
  │ DEA10760 ┆ 0.768857 │
  │ DEA10160 ┆ 0.762835 │
  │ DE215120 ┆ 0.641265 │
  │ DE213800 ┆ 0.651985 │
  └──────────┴──────────┘,
  'lstm_MSELoss': shape: (50, 2)
  ┌──────────┬──────────┐
  │ basin    ┆ NSE      │
  │ ---      ┆ ---      │
  │ str      ┆ f64      │
  ╞══════════╪══════════╡
  │ DE213140 ┆ 0.527372 │
  │ DEG10310 ┆ 0.761529 │
  │ DE212300 ┆ 0.310522 │
  │ DE210880 ┆ 0.726349 │
  │ DE214160 ┆ 0.71175  │
  │ …        ┆ …        │
  │ DEA12000 ┆ 0.720442 │
  │ DEG10100 ┆ 0.751329 │
  │ DE110140 ┆ 0.68163  │
  │ DE215050 ┆ 0.852214 │
  │ DEA10000 ┆ 0.764179 │
  └──────────┴──────────┘,
  'tcn_MQLoss': shape: (50, 2)
  ┌───────

In [None]:
horizons = ["day1", "day2", "day3"]
models = ["lstm", "tcn", "patchtst"]
loss_vs_loss = build_loss_vs_loss(horizons, models, ["MQLoss", "MSELoss"])
model_pairs = [("lstm", "patchtst"), ("lstm", "tcn"), ("tcn", "patchtst")]
model_vs_model = build_model_vs_model(horizons, ["MQLoss", "MSELoss"], model_pairs)
all_comps = loss_vs_loss + model_vs_model

result_table = paired_wilcoxon_polars(unified, metric_col="NSE", comparisons=all_comps)

In [None]:
# import tqdm

# model_specs = {
#     'ealstm_MSE': {
#         'model': 'ealstm',
#         'loss': 'MSELoss'
#     },
#     'ealstm_NSE': {
#         'model': 'ealstm',
#         'loss': 'NSELoss'
#     },
#     'lstm_MSE': {
#         'model': 'lstm',
#         'loss': 'MSELoss'
#     },
#     'lstm_NSE': {
#         'model': 'lstm',
#         'loss': 'NSELoss'
#     },
#     'lstm_no_static_MSE': {
#         'model': 'lstm_no_static',
#         'loss': 'MSELoss'
#     },
#     'lstm_no_static_NSE': {
#         'model': 'lstm_no_static',
#         'loss': 'NSELoss'
#     }
# }

# model_draw_style = {
#     'ealstm_NSE': {
#         'ensemble_color': '#1b9e77',
#         'single_color': '#b3e2cd',
#         'linestyle': '-',
#         'marker': 's',
#         'label': 'EA-LSTM NSE'
#     },
#     'ealstm_MSE': {
#         'ensemble_color': '#1b9e77',
#         'single_color': '#b3e2cd',
#         'linestyle': '--',
#         'marker': 's',
#         'label': 'EA-LSTM MSE'
#     },
#     'lstm_NSE': {
#         'ensemble_color': '#d95f02',
#         'single_color': '#fdcdac',
#         'linestyle': '-',
#         'marker': 's',
#         'label': 'LSTM NSE'
#     },
#     'lstm_MSE': {
#         'ensemble_color': '#d95f02',
#         'single_color': '#fdcdac',
#         'linestyle': '--',
#         'marker': 's',
#         'label': 'LSTM MSE'
#     },
#     'lstm_no_static_MSE': {
#         'ensemble_color': '#7570b3',
#         'single_color': '#cbd5e8',
#         'linestyle': '--',
#         'marker': '^',
#         'label': 'LSTM (no static inputs) MSE'
#     },
#     'lstm_no_static_NSE': {
#         'ensemble_color': '#7570b3',
#         'single_color': '#cbd5e8',
#         'linestyle': '-',
#         'marker': '^',
#         'label': 'LSTM (no static inputs) NSE'
#     },
#     'SAC_SMA': {
#         'color': '#e7298a',
#         'linestyle': '-.',
#         'marker': None,
#         'label': 'SAC-SMA'
#     },
#     'VIC_basin': {
#         'color': '#66a61e',
#         'linestyle': '-.',
#         'marker': None,
#         'label': 'VIC (basin-wise calibrated)'
#     },
#     'VIC_conus': {
#         'color': '#66a61e',
#         'linestyle': '-.',
#         'marker': None,
#         'label': 'VIC (CONUS-wide calibrated)'
#     },
#     'mHm_basin': {
#         'color': '#e6ab02',
#         'linestyle': '-.',
#         'marker': None,
#         'label': 'mHm (basin-wise calibrated)'
#     },
#     'mHm_conus': {
#         'color': '#e6ab02',
#         'linestyle': '-.',
#         'marker': None,
#         'label': 'mHm (CONUS-wide calibrated)'
#     },
#     'HBV_lb': {
#         'color': '#a6761d',
#         'linestyle': '-.',
#         'marker': 'x',
#         'label': 'HBV lower bound (n=1000 uncalibrated)'
#     },
#     'HBV_ub': {
#         'color': '#a6761d',
#         'linestyle': '-.',
#         'marker': None,
#         'label': 'HBV upper bound (n=100 calibrated)'
#     },
#     'q_sim_fuse_900': {
#         'color': '#666666',
#         'linestyle': '-.',
#         'marker': None,
#         'label': 'FUSE (900)'
#     },
#     'q_sim_fuse_902': {
#         'color': '#666666',
#         'linestyle': '-.',
#         'marker': '.',
#         'label': 'FUSE (902)'
#     },
#     'q_sim_fuse_904': {
#         'color': '#666666',
#         'linestyle': '-.',
#         'marker': 'd',
#         'label': 'FUSE (904)'
#     }
# }

In [None]:
# from pyfonts import load_google_font
# from matplotlib.font_manager import fontManager
# from matplotlib import rcParams
# from plotutils import ecdf

# font = load_google_font("Courier Prime", weight="regular", italic=False)
# fontManager.addfont(str(font.get_file()))
# rcParams.update(
#     {
#         "font.family": font.get_name(),
#         "font.style": font.get_style(),
#         "font.weight": font.get_weight(),
#         "font.size": font.get_size(),
#         "font.stretch": font.get_stretch(),
#         "font.variant": font.get_variant(),
#         "axes.titlesize": 20,
#         "axes.labelsize": 20,
#         "xtick.labelsize": 20,
#         "ytick.labelsize": 20,
#         "legend.fontsize": 18,
#         "figure.titlesize": 20
#     })



# model_draw_style = {
#     'lstm': {
#         'mse_color': '#1b9e77',
#         'mq_color': '#b3e2cd',
#         'mse_linestyle': '-',
#         'mq_linestyle': '--',
#         'mse_label': 'LSTM MSE',
#         'mq_label': 'LSTM MQ'
#     },
#     'tcn': {
#         'mse_color': '#7570b3',
#         'mq_color': '#cbd5e8',
#         'mse_linestyle': '-',
#         'mq_linestyle': '--',
#         'mse_label': 'TCN MSE',
#         'mq_label': 'TCN MQ'
#     },
#     'patchtst': {
#         'mse_color': '#d95f02',
#         'mq_color': '#fdcdac',
#         'mse_linestyle': '-',
#         'mq_linestyle': '--',
#         'mse_label': 'PatchTST MSE',
#         'mq_label': 'PatchTST MQ'
#     }
# }


# mse_metrics_filtered_by_day = mse_loss[f"day{day}"] 
# mq_metrics_filtered_by_day = mqloss[f"day{day}"]

# fig, ax = plt.subplots(figsize=figsize)

# ax.set_aspect('equal', adjustable='datalim') # Lock the square shape

# # Major grid:
# ax.grid(True, which='major', linestyle='-', linewidth=0.75, alpha=0.2)

# # Minor ticks and grid:
# ax.minorticks_on()
# ax.grid(True, which='minor', linestyle='-', linewidth=0.25, alpha=0.10)

# ax.set_axisbelow(True) # Ensure grid is below data

# for model_type, df in mse_metrics_filtered_by_day.items():
#     bin_, cdf_ = ecdf(df['NSE'].to_numpy())
#     ax.plot(bin_,
#             cdf_,
#             label=f"{model_draw_style[model_type]['mse_label']}",
#             color=model_draw_style[model_type]["mse_color"], 
#             marker="s", 
#             markevery=5, 
#             linestyle=model_draw_style[model_type]['mse_linestyle']
#             )

# for model_type, df in mq_metrics_filtered_by_day.items():
#     # ensemble
#     # values = list(models['ensemble'].values())
#     bin_, cdf_ = ecdf(df['NSE'].to_numpy())
#     ax.plot(bin_,
#             cdf_,
#             label=f"{model_draw_style[model_type]['mq_label']}",
#             color=model_draw_style[model_type]["mq_color"], 
#             marker="s", 
#             markevery=5, 
#             linestyle=model_draw_style[model_type]['mq_linestyle']
#             )

# handles, labels = ax.get_legend_handles_labels() # get all legend items
# desired_order = [1, 4, 2, 5, 0, 3]  # change the order of legend elements

# ax.legend(
#     [handles[i] for i in desired_order],
#     [labels[i] for i in desired_order],
#     loc = 'upper center',
#     bbox_to_anchor = (0.5, 1.11),
#     ncol = 3,
#     frameon = False
# ) 

# ax.set_xlabel(metric)
# ax.set_ylabel('CDF')

# fig.savefig(f"{metric}_day{day}.pdf", bbox_inches='tight', dpi=300, format='pdf')

In [None]:
# import datetime
# import pandas as pd
# import locale

# from matplotlib.ticker import NullLocator
# from matplotlib.dates import DateFormatter, MonthLocator


# timeseries_dir = "../../data/CAMELS_DE/timeseries"
# ts_path = Path(timeseries_dir)
# start_date = datetime.datetime.strptime("2017-01-01", "%Y-%m-%d")
# end_date = datetime.datetime.strptime("2017-9-30", "%Y-%m-%d")

# BASIN = "DE210310"
# frames: list[pl.DataFrame] = []

# csv_path = ts_path / f"CAMELS_DE_hydromet_timeseries_{BASIN}.csv"

# df_ts = pl.read_csv(
#     csv_path,
#     schema_overrides={
#             "date": pl.Date,
#             "discharge_spec_obs": pl.Float64
#     },
#     infer_schema=10_000
# )

# if "date" in df_ts.columns:
#     df_ts = df_ts.with_columns(pl.col("date").cast(pl.Date, strict=False))
#     df_ts = df_ts.filter(
#     pl.col("date").is_between(start_date, end_date, closed="both")
# )

# df_ts = df_ts.select(
#     pl.col("date"),
#     pl.col("precipitation_mean").alias("prep")
# )

# d = split_by_timestep(Path(f"../results/mq/{model}.parquet"))
# day1_basin = (
#     d["day1"]
#     .filter(
#     (pl.col("basins") == "DE210310")
#     & (pl.col("ds").is_between(start_date, end_date))
#     )
#     .drop("cutoff", "basins")
#     .rename({"ds": "date"})
#     .join(
#         df_ts,
#         on="date",
#         how="left")
# )


# pdf = day1_basin.to_pandas()
# pdf['date'] = pd.to_datetime(pdf['date'])
# pdf = pdf.set_index('date')
# pdf = pdf.sort_index()

# fig, ax1 = plt.subplots(figsize=(18,10))

# # Major grid:
# ax1.grid(True, which='major', linestyle='-', linewidth=0.75, alpha=0.2)

# # Minor ticks and grid:
# ax1.minorticks_on()
# ax1.grid(True, which='minor', linestyle='-', linewidth=0.25, alpha=0.1)

# ax1.set_axisbelow(True) # Ensure grid is below data

# ax1.plot(pdf.index, pdf['sim'],
#              label='sim', 
#              zorder=2, 
#              color="#ff7f0e", 
#              linewidth=2,
#              linestyle='-')
# ax1.plot(pdf.index, 
#          pdf['obs'],
#          label='obs', 
#          zorder=1, 
#          color="#9467bd", 
#          linewidth=2)

# ax1.fill_between(pdf.index,
#                  pdf['sim-lo-90'],
#                  pdf['sim-hi-90'],
#                  alpha=0.3,
#                  label='C.I 90%',
#                  color="#bdbdbd",
#                  zorder=3)

# ax1.set_ylabel('Streamflow')
# ax1.set_ylim(0, 70)


# locale.setlocale(locale.LC_ALL, 'en_US.UTF-8')  # Set locale for date formatting
# ax1.xaxis.set_major_locator(MonthLocator(interval=2))
# ax1.xaxis.set_major_formatter(DateFormatter('%b %Y'))
# ax1.xaxis.set_minor_locator(NullLocator())

# handles, labels = ax1.get_legend_handles_labels()
# desired_order = [1, 0, 2]

# ax1.legend(
#     [handles[i] for i in desired_order],
#     [labels[i] for i in desired_order],
#     loc = 'upper center',
#     bbox_to_anchor = (0.5, 1.08),
#     ncol = 4,
#     frameon = False
# ) 


# ax2 = ax1.twinx()
# ax2.bar(pdf.index, pdf['prep'],
#         width=1, alpha=0.8,
#         color="#1f78b4")
# ax2.set_ylabel('Precipitation')
# ax2.set_ylim(0,180)
# ax2.invert_yaxis()


# plt.title('Basin DE210310', loc="left")

# fig.savefig("DE210310.pdf", bbox_inches='tight', dpi=300, format='pdf')