In [9]:
import sys
from pathlib import Path
sys.path.append(str(Path().absolute().parent))
import time

In [10]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import torch
from torch.nn import MSELoss
from torch.optim import Adam
import matplotlib.pyplot as plt
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

import pandas as pd
import glob
from pathlib import Path

from src.data_models.camels_ch import CamelsCH, CamelsCHConfig, get_all_gauge_ids
from src.data_models.dataset import HydroDataset
from src.data_models.preprocessing import (
    scale_time_series,
    scale_static_attributes,
    inverse_scale_static_attributes,
    inverse_scale_time_series,
)

from utils.metrics import nash_sutcliffe_efficiency
from src.data_models.datamodule import HydroDataModule

---

# 1. Load and prepare CAMELS-CH data

In [15]:
camels_config = CamelsCHConfig(
    timeseries_dir="/Users/cooper/Desktop/CAMELS-CH/data/timeseries/observation_based/",
    timeseries_pattern="CAMELS_CH_obs_based_*.csv",
    static_attributes_dir="/Users/cooper/Desktop/CAMELS-CH/data/static_attributes",
    use_climate=False,
    use_geology=False,
    use_glacier=False,
    use_human_influence=False,
    use_hydrogeology=False,
    use_hydrology=False,
    use_landcover=False,
    use_soil=False,
    use_topographic=False,
)

In [16]:
all_gauge_ids = get_all_gauge_ids(camels_config)
ids_for_training = [
    "4005",
    "2312",
    "4011",
    "2110",
    "2104",
    "2070",
    "2299",
    "2500",
    "2139",
    "2105",
    "2307",
    "3019",
    "4010",
    "2461",
    "4004",
    "2475",
    "2488",
    "2463",
    "2477",
    "4006",
    "3033",
    "2067",
    "2265",
    "2270",
    "2099",
    "2112",
    "2106",
    "3032",
    "4007",
    "4013",
    "2304",
    "2300",
    "4017",
    "4003",
    "2102",
    "2063",
    "2117",
    "4002",
    "2473",
    "4016",
    "3023",
    "2498",
    "3009",
    "2471",
    "4014",
    "2303",
    "2263",
    "2276",
    "2262",
    "2289",
    "4015",
    "4001",
    "2458",
    "2371",
    "2417",
    "2403",
    "2167",
    "5032",
    "2205",
    "2239",
    "2210",
    "2199",
    "2364",
    "2370",
    "2416",
    "2366",
    "2414",
    "2372",
    "2602",
    "2170",
    "2011",
    "2159",
    "2617",
    "2603",
    "2415",
    "2161",
    "2607",
    "2613",
    "2029",
    "2203",
    "2202",
    "5009",
    "2612",
    "2174",
    "2160",
    "2606",
    "2410",
    "2412",
    "2374",
    "2176",
    "2610",
    "2604",
    "2016",
    "2200",
    "2215",
    "2605",
    "2349",
    "2387",
    "2436",
    "2378",
    "2185",
    "2634",
    "2152",
    "2608",
    "2219",
    "2033",
    "2609",
    "2635",
    "2351",
    "2437",
    "2386",
    "2347",
    "2409",
    "2151",
    "2179",
    "2019",
    "2232",
    "2018",
    "2024",
    "2030",
    "2150",
    "2187",
    "2434",
    "2352",
    "2346",
    "2420",
    "2418",
    "2342",
    "2356",
    "2430",
    "2034",
    "2020",
    "5014",
    "2009",
    "2155",
    "2141",
    "2343",
    "2419",
    "2369",
    "2433",
    "2355",
    "2157",
    "2143",
    "2181",
    "2426",
    "2432",
    "2368",
    "4024",
    "2327",
    "2469",
    "4018",
    "2125",
    "2119",
    "2086",
    "2079",
    "2290",
    "2247",
    "2252",
    "2078",
    "2044",
    "2087",
    "3004",
    "2468",
    "4025",
    "2497",
    "2481",
    "3006",
    "2126",
    "2132",
    "2091",
    "2085",
    "2251",
    "2053",
    "2084",
    "2319",
    "3007",
    "2494",
    "2480",
    "2490",
    "2309",
    "2321",
    "2282",
    "2269",
    "2268",
    "2283",
    "2056",
    "2122",
    "2308",
    "2491",
    "2485",
    "2493",
    "2487",
    "3014",
    "2478",
    "4009",
    "2450",
    "4021",
    "2256",
    "2243",
    "2135",
    "2109",
    "4020",
    "4008",
    "3015",
    "2486",
]


camels = CamelsCH(camels_config)
camels.load_stations(ids_for_training)

Loaded time series data for 205 stations


In [13]:
# static = camels.get_static_attributes()
# sc = static.columns

# # for i in range(len(sc)):
# #     print(f"{i}: {sc[i]}")
# static_attributes = [
#     "gauge_id",
#     "area", 
#     "elev_mean",  
#     "slope_mean",  
#     "aridity",  
#     "p_seasonality",  
#     "frac_snow",  
#     "porosity",  
#     "conductivity",  
#     "p_mean",  
#     "geo_porosity",  
# ]
# static = static[static_attributes]
# static

In [None]:
data = camels.get_time_series()
data = data[
    [
        "gauge_id",
        "date",
        "discharge_spec(mm/d)",
    ]
]

gauge_id,2009,2011,2016,2018,2019,2020,2024,2029,2030,2033,...,4016,4017,4018,4020,4021,4024,4025,5009,5014,5032
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1981-01-01,0.821,0.673,1.69,1.303,1.25,0.888,0.763,1.614,1.593,1.15,...,1.612,,,1.891,,1.588,,0.89,1.046,1.486
1981-01-02,0.818,0.791,1.579,1.307,1.452,0.941,0.938,1.529,1.594,0.693,...,1.565,,,1.756,,1.507,,0.748,1.049,0.994
1981-01-03,0.973,0.788,1.877,1.354,1.199,1.036,0.958,1.575,1.591,0.77,...,2.762,,,2.723,,1.949,,0.982,1.114,1.493
1981-01-04,1.208,1.001,3.394,3.062,1.724,1.006,1.109,1.978,2.043,0.642,...,10.332,,,9.254,,5.24,,13.678,5.312,5.703
1981-01-05,1.498,1.049,2.926,2.549,2.127,1.641,1.606,2.409,2.492,0.904,...,5.442,,,4.81,,2.568,,9.147,8.719,6.656


# 2. Configure preprocessing

In [None]:
preprocessing_config = {
    "features": {
        "scale_method": "per_basin",
        "log_transform": []
    },
    "target": {
        "scale_method": "per_basin",
        "log_transform": False
    },
    "static_features": {
        "scale_method": "global"
    }
}

# 3. Create DataModule

In [None]:
output_length = 1

data_module = HydroDataModule(
    time_series_df=data,
    static_df=None,
    group_identifier="gauge_id",
    preprocessing_config=preprocessing_config,
    batch_size=32,
    input_length=30,
    output_length=output_length,
    num_workers=4,
    features=["discharge_spec(mm/d)"],
    # static_features=static_attributes[1:],
    target="discharge_spec(mm/d)",
    train_years=15,
    val_years=3,
    min_test_years=6,
)

# data_module.static_df

## 4. Create model and train

In [None]:
from src.models.lstm import LitLSTM
from src.models.ealstm import LitEALSTM
from torch.optim import Adam
from torch.nn import MSELoss

# 5. Evalue and plot results

In [None]:
model = LitLSTM(
    input_size=1,
    hidden_size=2,
    num_layers=1,
    output_size=output_length,
    target="discharge_spec(mm/d)",
)

# model = LitEALSTM(
#     input_size_dyn=3,
#     input_size_stat=len(static_attributes) - 1,
#     hidden_size=64,
#     output_size=5,
#     target="discharge_spec(mm/d)",
# )

# Configure trainer
trainer = pl.Trainer(
    max_epochs=1,
    accelerator="cpu",
    devices=1,
    callbacks=[
        ModelCheckpoint(
            monitor="val_loss",
            dirpath="checkpoints",
            filename="best-checkpoint",
            save_top_k=1,
            mode="min",
        ),
        EarlyStopping(monitor="val_loss", patience=3, mode="min"),
    ],
)

# Train the model
trainer.fit(model, data_module)

In [None]:
quality_report = data_module.quality_report

excluded_basins = list(quality_report["excluded_basins"].keys())
excluded_basins

ids_for_training = [id for id in ids_for_training if id not in excluded_basins]

In [None]:
ids_for_training

In [None]:
trainer.test(model, data_module)
test_results = model.test_results

In [None]:
# Get the results
results_df = model.test_results["forecast_df"]
horizon_metrics = model.test_results["horizon_metrics"]

horizons = []
nse_values = []
for horizon, metrics in horizon_metrics.items():
    horizons.append(horizon)
    nse_values.append(metrics["NSE"])

# Create bar plot
plt.figure(figsize=(10, 6))
colors = sns.color_palette("Blues", 1)
plt.bar(horizons, nse_values, color=colors)

# Customize plot
plt.xlabel("Forecast Horizon (Days)")
plt.ylabel("Nash-Sutcliffe Efficiency")
plt.title("Forecast Skill by Prediction Horizon")
plt.grid(True, linestyle="--", alpha=0.3)
sns.despine()

# Set x-axis ticks to show all horizons
plt.xticks(horizons)

# Add value labels on top of each bar
for i, v in enumerate(nse_values):
    plt.text(i + 1, v, f"{v:.3f}", ha="center", va="bottom")

plt.tight_layout()
plt.show()

In [None]:
results_df.head(10)

In [None]:
# Process results_df to get NSE by basin and horizon
basin_metrics = {}
for basin in results_df["basin_id"].unique():
    basin_data = results_df[results_df["basin_id"] == basin]
    nse_values = []
    for horizon in range(1, max(basin_data["horizon"]) + 1):
        horizon_data = basin_data[basin_data["horizon"] == horizon]
        nse = nash_sutcliffe_efficiency(
            horizon_data["observed"].values, horizon_data["prediction"].values
        )
        nse_values.append(nse)
    basin_metrics[basin] = nse_values

# Sort basins by NSE at horizon 1
sorted_basins = sorted(
    basin_metrics.keys(), key=lambda x: basin_metrics[x][0], reverse=True
)
basin_metrics = {basin: basin_metrics[basin] for basin in sorted_basins}

# Plot settings
plt.figure(figsize=(12, 6))
bar_width = 0.8 / len(basin_metrics)

# Create color palette of blue shades
colors = sns.color_palette("Blues", len(basin_metrics) + 2)[2:]

# Create bars for each basin
for i, (basin, nse_values) in enumerate(basin_metrics.items()):
    x = np.arange(len(nse_values)) + i * bar_width
    plt.bar(x, nse_values, bar_width, label=f"Basin {basin}", color=colors[i])

# Customize plot
plt.xlabel("Forecast Horizon (Days)", fontsize=12)
plt.ylabel("Nash-Sutcliffe Efficiency", fontsize=12)
plt.title("Forecast Skill by Basin and Horizon", fontsize=14, pad=20)
plt.grid(True, linestyle="--", alpha=0.3)
plt.legend(title="Basin ID", title_fontsize=10, fontsize=10)
sns.despine()

# Set x-axis ticks in middle of grouped bars
plt.xticks(
    np.arange(len(next(iter(basin_metrics.values()))))
    + bar_width * (len(basin_metrics) - 1) / 2,
    np.arange(1, len(next(iter(basin_metrics.values()))) + 1),
)

# Remove top and right spines
sns.despine()

plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns


def plot_predictions(results_df, n_timesteps=None):
    # Filter for horizon 1
    horizon_1_data = results_df[results_df["horizon"] == 1]

    if n_timesteps:
        # Get last n_timesteps for each basin
        horizon_1_data = (
            horizon_1_data.groupby("basin_id").tail(n_timesteps).reset_index(drop=True)
        )

    n_basins = len(horizon_1_data["basin_id"].unique())
    n_cols = 2
    n_rows = (n_basins + 1) // 2

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))
    axes = axes.flatten()

    for idx, basin in enumerate(horizon_1_data["basin_id"].unique()):
        basin_data = horizon_1_data[horizon_1_data["basin_id"] == basin]
        ax = axes[idx]

        nse = nash_sutcliffe_efficiency(
            basin_data["observed"].values, basin_data["prediction"].values
        )

        x = np.arange(len(basin_data))
        ax.plot(x, basin_data["observed"], label="Observed", color="#1d4ed8")
        ax.plot(
            x,
            basin_data["prediction"],
            label="Predicted",
            color="#dc2626",
            alpha=0.8,
        )

        ax.set_title(f"Basin {basin} (NSE: {nse:.3f})", fontsize=12)
        ax.set_xlabel("Time Step", fontsize=10)
        ax.set_ylabel("Discharge", fontsize=10)
        ax.grid(True, linestyle="--", alpha=0.3)
        ax.legend(fontsize=9)
        sns.despine(ax=ax)

    for idx in range(n_basins, len(axes)):
        fig.delaxes(axes[idx])

    plt.tight_layout()
    plt.show()


# Example usage:
plot_predictions(results_df, n_timesteps=365) 