In [1]:
from pathlib import Path
import re, json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [2]:
MULTI_PREFIX = Path("example/output/schism_base.multi_gru2_SINGLE_EXP_TEST_Trial1")
OUT_DIR = MULTI_PREFIX.parent / (MULTI_PREFIX.name + "_postproc")
FIG_DIR = OUT_DIR / "figs"
OUT_DIR.mkdir(parents=True, exist_ok=True)
FIG_DIR.mkdir(parents=True, exist_ok=True)

DT_COLS = ["datetime", "case"]      
METRICS = ("mae", "rmse", "nse", "pearson_r")
MAX_TS_CASES_TO_PLOT = 7
PRINT_ROWS = 25

In [3]:
pred_glob = sorted(MULTI_PREFIX.parent.glob(MULTI_PREFIX.name + "_xvalid_*.csv"))
tag_re = re.compile(r"_xvalid_(.+)\.csv$")
TAGS = []
PRED_PATHS = {}
REF_PATHS  = {}

for p in pred_glob:
    m = tag_re.search(p.name)
    
    if not m:
        continue
        
    tag = m.group(1)
    TAGS.append(tag)
    PRED_PATHS[tag] = p
    REF_PATHS[tag] = MULTI_PREFIX.parent / f"{MULTI_PREFIX.name}_xvalid_ref_out_{tag}_unscaled.csv"

In [4]:
print(f"Found {len(TAGS)} tags:", TAGS)
missing_ref = [t for t in TAGS if not REF_PATHS[t].exists()]
if missing_ref:
    print("WARNING: missing reference CSVs for tags:", missing_ref)

MASTER_SUMMARY = None
cand = sorted(MULTI_PREFIX.parent.glob("gridsearch_*_master_results.csv"))
if cand:
    MASTER_SUMMARY = cand[-1]
    print("Master summary:", MASTER_SUMMARY)
else:
    print("Master summary not found in", MULTI_PREFIX.parent)

Found 0 tags: []
Master summary not found in example/output


##### DATA LOADING

In [5]:
def load_ms_pair(prefix: Path, tag: str) -> pd.DataFrame:
    """Load ref+pred for one tag; return merged long df with _pred columns."""
    ref_csv  = REF_PATHS[tag]
    pred_csv = PRED_PATHS[tag]
    if not (ref_csv.exists() and pred_csv.exists()):
        raise FileNotFoundError(f"Missing files for tag={tag}\n  {ref_csv}\n  {pred_csv}")

    df_ref  = pd.read_csv(ref_csv,  parse_dates=["datetime"])
    df_pred = pd.read_csv(pred_csv, parse_dates=["datetime"])
    df = pd.merge(df_ref, df_pred, on=["datetime", "case"], suffixes=("", "_pred"))

    try:
        df["case"] = df["case"].astype(int)
    except Exception:
        pass
    return df.sort_values(["case", "datetime"]).reset_index(drop=True)

MERGED = {tag: load_ms_pair(MULTI_PREFIX, tag) for tag in TAGS}
sample_cols = [c for c in MERGED[TAGS[0]].columns if c not in DT_COLS and not c.endswith("_pred")]
STATIONS = sample_cols
print(f"Detected {len(STATIONS)} stations.")
print(STATIONS[:PRINT_ROWS])

IndexError: list index out of range

In [6]:
def _metrics_vector(y_true: np.ndarray, y_pred: np.ndarray) -> dict:
    """Return MAE/RMSE/NSE/Pearson r for vectors (ignores NaNs)."""
    mask = (~np.isnan(y_true)) & (~np.isnan(y_pred))
    if mask.sum() < 2:
        return {"mae": np.nan, "rmse": np.nan, "nse": np.nan, "pearson_r": np.nan}
    yt, yp = y_true[mask], y_pred[mask]
    mae  = float(np.mean(np.abs(yt - yp)))
    rmse = float(np.sqrt(np.mean((yt - yp) ** 2)))
    denom = float(np.sum((yt - np.mean(yt)) ** 2))
    nse  = float(1.0 - np.sum((yt - yp) ** 2) / denom) if denom > 0 else np.nan
    r    = float(np.corrcoef(yt, yp)[0, 1]) if len(yt) > 1 else np.nan
    return {"mae": mae, "rmse": rmse, "nse": nse, "pearson_r": r}

def compute_station_metrics(df: pd.DataFrame, station: str) -> dict:
    """Compute metrics for a single station over ALL cases and times."""
    return _metrics_vector(df[station].to_numpy(), df[f"{station}_pred"].to_numpy())

def compute_metrics_table(merged_by_tag: dict, stations: list) -> pd.DataFrame:
    """Return tidy table: [tag, station, mae, rmse, nse, pearson_r, n_points]."""
    rows = []
    for tag, df in merged_by_tag.items():
        for st in stations:
            if st not in df.columns or f"{st}_pred" not in df.columns:
                continue
            m = compute_station_metrics(df, st)
            rows.append({
                "tag": tag,
                "station": st,
                **{k: round(v, 6) if pd.notnull(v) else np.nan for k, v in m.items()},
                "n_points": int((~df[st].isna() & ~df[f"{st}_pred"].isna()).sum())})
    out = pd.DataFrame(rows)
    return out.sort_values(["tag", "station"]).reset_index(drop=True)

METRICS_DF = compute_metrics_table(MERGED, STATIONS)
display(METRICS_DF.head(PRINT_ROWS))
METRICS_DF.to_csv(OUT_DIR / "metrics_by_station_tag.csv", index=False)
print("Saved:", OUT_DIR / "metrics_by_station_tag.csv")

NameError: name 'STATIONS' is not defined

In [7]:
MEAN_BY_TAG = (METRICS_DF.groupby("tag")[list(METRICS)].mean(numeric_only=True).sort_values("nse", ascending=False).reset_index())
display(MEAN_BY_TAG)
NSE_PIVOT = METRICS_DF.pivot(index="station", columns="tag", values="nse")
MAE_PIVOT = METRICS_DF.pivot(index="station", columns="tag", values="mae")

if "base" in NSE_PIVOT.columns:
    NSE_DELTA = NSE_PIVOT.apply(lambda col: col - NSE_PIVOT["base"])
    NSE_DELTA = NSE_DELTA.drop(columns=["base"], errors="ignore")
    display(NSE_DELTA.sort_values(by=NSE_DELTA.columns[0] if len(NSE_DELTA.columns) else "base",
                                  ascending=False).head(PRINT_ROWS))
    NSE_PIVOT.to_csv(OUT_DIR / "nse_pivot_station_by_tag.csv")
    NSE_DELTA.to_csv(OUT_DIR / "nse_delta_vs_base.csv")
    print("Saved pivot & delta tables in", OUT_DIR)

NameError: name 'METRICS_DF' is not defined

In [8]:

if "MASTER_SUMMARY" in globals() and MASTER_SUMMARY:
    ms = pd.read_csv(MASTER_SUMMARY)
    display(ms.sort_values(ms.columns[ms.columns.str.contains("nse", case=False)].tolist(), ascending=False).head(10))

##### PLOTTING

In [9]:
def plot_timeseries(tag: str, station: str, case: int = 1, save: bool = False):
    df = MERGED[tag]
    sub = df[df["case"] == case]
    if sub.empty:
        print(f"No rows for tag={tag}, case={case}")
        return
    if station not in sub.columns or f"{station}_pred" not in sub.columns:
        print(f"Station {station} not found in tag={tag}")
        return
    plt.figure(figsize=(9, 3.5))
    plt.plot(sub["datetime"], sub[station], label="Ref")
    plt.plot(sub["datetime"], sub[f"{station}_pred"], label="ANN")
    plt.title(f"Time series — tag={tag}, station={station}, case={case}")
    plt.xlabel("datetime"); plt.ylabel(station); plt.legend(); plt.tight_layout()
    if save:
        fname = FIG_DIR / f"ts_{tag}_{station}_case{case}.png"
        plt.savefig(fname, dpi=150)
        print("Saved:", fname)
    plt.show()

def plot_parity(tag: str, station: str, sample: int = None, save: bool = False):
    df = MERGED[tag]
    if station not in df.columns or f"{station}_pred" not in df.columns:
        print(f"Station {station} not found in tag={tag}")
        return
    x = df[station].to_numpy()
    y = df[f"{station}_pred"].to_numpy()
    mask = (~np.isnan(x)) & (~np.isnan(y))
    x, y = x[mask], y[mask]
    if sample and sample < len(x):
        idx = np.random.default_rng(42).choice(len(x), sample, replace=False)
        x, y = x[idx], y[idx]
    lim = [np.nanmin([x.min(), y.min()]), np.nanmax([x.max(), y.max()])]
    plt.figure(figsize=(4.8, 4.8))
    plt.scatter(x, y, s=8, alpha=0.35)
    plt.plot(lim, lim)  # 1:1
    plt.title(f"Parity — tag={tag}, station={station}")
    plt.xlabel("Reference"); plt.ylabel("Prediction"); plt.tight_layout()
    if save:
        fname = FIG_DIR / f"parity_{tag}_{station}.png"
        plt.savefig(fname, dpi=150)
        print("Saved:", fname)
    plt.show()

In [10]:
def plot_metric_bar(metric: str = "nse", top_k: int = 20, scenario: str = "base", save: bool = False):
    """Top-K stations by a given metric for a chosen scenario."""
    assert metric in METRICS, f"metric must be one of {METRICS}"
    sub = METRICS_DF[METRICS_DF["tag"] == scenario].sort_values(metric, ascending=False).head(top_k)
    plt.figure(figsize=(10, 0.35 * len(sub) + 1.5))
    plt.barh(sub["station"][::-1], sub[metric][::-1])
    plt.title(f"Top {top_k} stations by {metric.upper()} — {scenario}")
    plt.xlabel(metric.upper()); plt.tight_layout()
    if save:
        fname = FIG_DIR / f"bar_{metric}_{scenario}_top{top_k}.png"
        plt.savefig(fname, dpi=150); print("Saved:", fname)
    plt.show()

def plot_parallel_lines(metric: str = "nse", stations: list = None, scenarios: list = None, sort_by: str = "base", top_k: int = 25, save: bool = False):

    assert metric in METRICS, f"metric must be one of {METRICS}"
    pivot = METRICS_DF.pivot(index="station", columns="tag", values=metric)

    if scenarios is None:
        cols = list(pivot.columns)
        if "base" in cols:
            cols = ["base"] + [c for c in cols if c != "base"]
        scenarios = cols

    if stations is None:
        rank_col = sort_by if sort_by in pivot.columns else scenarios[0]
        stations = pivot[rank_col].sort_values(ascending=False).head(top_k).index.tolist()
    x = np.arange(len(scenarios))
    plt.figure(figsize=(10, 0.35 * len(stations) + 1.5))
    
    for st in stations:
        y = pivot.loc[st, scenarios].to_numpy(dtype=float)
        plt.plot(x, y, alpha=0.6, linewidth=1.5)
    plt.xticks(x, scenarios)
    plt.title(f"Parallel {metric.upper()} across scenarios (top {len(stations)} by {sort_by})")
    plt.ylabel(metric.upper()); plt.grid(axis="y", alpha=0.2); plt.tight_layout()
    if save:
        fname = FIG_DIR / f"parallel_{metric}_top{len(stations)}_sortby_{sort_by}.png"
        plt.savefig(fname, dpi=150); print("Saved:", fname)
    plt.show()

In [11]:
plot_metric_bar(metric="nse", top_k=20, scenario="base", save=True)

if len(TAGS) > 1:
    plot_metric_bar(metric="nse", top_k=20, scenario=TAGS[1], save=True)

NameError: name 'METRICS_DF' is not defined

In [12]:
plot_parallel_lines(metric="nse", top_k=25, sort_by="base", save=True)

NameError: name 'METRICS_DF' is not defined

In [13]:
plot_timeseries(tag="base",   station=STATIONS[0], case=1, save=True)
if "suisun" in TAGS:
    plot_timeseries(tag="suisun", station=STATIONS[0], case=1, save=True)

NameError: name 'STATIONS' is not defined

In [14]:
plot_parity(tag="base", station=STATIONS[0], sample=5000, save=True)

NameError: name 'STATIONS' is not defined