In [2]:
import pickle
from pathlib import Path
import matplotlib.pyplot as plt

hydrographs_dir = Path("./hydrographs")
hydrographs_dir.mkdir(parents=True, exist_ok=True)

# --- Runs to compare (label -> file path)
runs = {
    "MSWEP": "../MSWEP/runs/mswep_precip_seq_270_30_epochs_seq_270_hidden_256_dropout_04_fb_05_seed111_2701_212959/validation/model_epoch030/validation_results.p",
    "LS": "./runs/ls_MSWEP_new_precip_correction_seq_270_30_epochs_seq_270_hidden_256_dropout_04_fb_05_seed111_0202_172417/validation/model_epoch030/validation_results.p",
    "LOCI": "./runs/loci_MSWEP_new_precip_correction_seq_270_30_epochs_seq_270_hidden_256_dropout_04_fb_05_seed111_0202_171744/validation/model_epoch030/validation_results.p",
    "PT": "./runs/pt_MSWEP_new_precip_correction_seq_270_30_epochs_seq_270_hidden_256_dropout_04_fb_05_seed111_0202_165759/validation/model_epoch030/validation_results.p",
    "DM": "./runs/dm_MSWEP_new_precip_correction_seq_270_30_epochs_seq_270_hidden_256_dropout_04_fb_05_seed111_0202_164748/validation/model_epoch030/validation_results.p",
}

# convert to Path + check existence
runs = {k: Path(v) for k, v in runs.items()}
missing = [k for k, p in runs.items() if not p.exists()]
if missing:
    raise FileNotFoundError(f"Missing validation_results.p for: {missing}")

# output folder
run_out_dir = hydrographs_dir / "multiple_mswep"
run_out_dir.mkdir(parents=True, exist_ok=True)
print(f"Saving hydrographs to: {run_out_dir}")

# --- Load all results once
all_data = {}
for label, fp in runs.items():
    with open(fp, "rb") as f:
        all_data[label] = pickle.load(f)
    print(f"[{label}] Loaded {len(all_data[label])} catchments")

# --- Use common catchments only (so you don't crash if one run lacks a basin)
common_catchments = set.intersection(*(set(d.keys()) for d in all_data.values()))
print(f"Common catchments across all runs: {len(common_catchments)}")

for catchment in sorted(common_catchments):

    plt.figure(figsize=(12, 5))

    # --- Plot OBS only once (from first run)
    first_label = next(iter(all_data))
    ds0 = all_data[first_label][catchment]["1D"]["xr"]
    time = ds0["date"].values
    obs = ds0["QObs_mm_d_obs"].values.flatten()

    plt.plot(time, obs, label="Observed", linewidth=1, alpha=0.8)

    # --- Plot SIM for each run
    for label, data in all_data.items():
        ds = data[catchment]["1D"]["xr"]
        nse = float(data[catchment]["1D"]["NSE"])
        sim = ds["QObs_mm_d_sim"].values.flatten()

        plt.plot(time, sim, alpha=0.6, linewidth=0.7, label=f"{label} (NSE={nse:.3f})")

    plt.xlabel("Date")
    plt.ylabel("Streamflow (mm/day)")
    plt.title(f"Validation - {catchment}")
    plt.legend(ncol=2)
    plt.grid(True, alpha=0.35)
    plt.tight_layout()

    out_path = run_out_dir / f"{catchment}.png"
    plt.savefig(out_path, dpi=200)
    plt.close()


Saving hydrographs to: hydrographs/multiple_mswep
[MSWEP] Loaded 11 catchments
[LS] Loaded 11 catchments
[LOCI] Loaded 11 catchments
[PT] Loaded 11 catchments
[DM] Loaded 11 catchments
Common catchments across all runs: 11
