In [1]:
import sys
from pathlib import Path
sys.path.append(str(Path().absolute().parent))
import time

In [2]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import torch
from torch.nn import MSELoss
from torch.optim import Adam
import matplotlib.pyplot as plt
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

import pandas as pd
import glob
from pathlib import Path

from src.data_models.camels_ch import CamelsCH, CamelsCHConfig, get_all_gauge_ids
from src.data_models.dataset import HydroDataset
from src.data_models.preprocessing import (
    scale_time_series,
    scale_static_attributes,
    inverse_scale_static_attributes,
    inverse_scale_time_series,
)

from utils.metrics import nash_sutcliffe_efficiency
from src.data_models.datamodule import HydroDataModule

---

# 1. Load and prepare CAMELS-CH data

In [3]:
camels_config = CamelsCHConfig(
    timeseries_dir="/Users/cooper/Desktop/CAMELS-CH/data/timeseries/observation_based/",
    timeseries_pattern="CAMELS_CH_obs_based_*.csv",
    static_attributes_dir="/Users/cooper/Desktop/CAMELS-CH/data/static_attributes",
    use_climate=True,
    use_geology=True,
    use_glacier=False,
    use_human_influence=False,
    use_hydrogeology=False,
    use_hydrology=False,
    use_landcover=True,
    use_soil=True,
    use_topographic=True,
)

In [4]:
all_gauge_ids = get_all_gauge_ids(camels_config)
ids_for_training = all_gauge_ids


camels = CamelsCH(camels_config)
camels.load_stations(ids_for_training)

Loaded time series data for 331 stations
Loading climate attributes
Loading geology attributes
Loading landcover attributes
Loading soil attributes
Loading topographic attributes
Loaded static attributes for 331 stations


In [5]:
static = camels.get_static_attributes()
sc = static.columns

# for i in range(len(sc)):
#     print(f"{i}: {sc[i]}")
static_attributes = [
    "gauge_id",
    "area", 
    "elev_mean",  
    "slope_mean",  
    "aridity",  
    "p_seasonality",  
    "frac_snow",  
    "porosity",  
    "conductivity",  
    "p_mean",  
    "geo_porosity",  
]
static = static[static_attributes]
static

Unnamed: 0,gauge_id,area,elev_mean,slope_mean,aridity,p_seasonality,frac_snow,porosity,conductivity,p_mean,geo_porosity
0,2004,712.7,644.60,5.53,0.597,0.159,0.039,44.855,81.482,3.059,0.101
1,2007,209.3,1228.27,8.13,0.369,-0.118,0.170,49.508,32.419,4.983,0.070
2,2009,5239.4,2124.19,25.72,0.440,0.078,0.436,45.708,36.593,3.558,0.045
3,2011,3372.4,2286.71,25.82,0.442,0.106,0.474,45.247,34.908,3.401,0.038
4,2014,1583.5,1331.63,22.10,0.316,0.279,0.223,48.837,25.776,4.869,0.091
...,...,...,...,...,...,...,...,...,...,...,...
326,6007,1531.4,1672.11,28.61,0.451,0.228,0.379,47.955,31.384,4.058,0.012
327,6008,229.7,878.60,21.26,0.536,,,46.943,58.987,4.742,0.025
328,6009,121.6,1248.36,31.66,0.427,,,48.761,31.715,5.446,0.010
329,6010,60.2,912.20,24.84,0.434,,,47.189,38.748,5.953,0.009


In [6]:
data = camels.get_time_series()
data = data[
    [
        "gauge_id",
        "date",
        "swe(mm)",
    ]
]



# 2. Configure preprocessing

In [7]:
preprocessing_config = {
    "features": {
        "scale_method": "per_basin",
        "log_transform": []
    },
    "target": {
        "scale_method": "per_basin",
        "log_transform": False
    },
    "static_features": {
        "scale_method": "global"
    }
}

# 3. Create DataModule

In [12]:
data_module = HydroDataModule(
    time_series_df=data,
    static_df=None,
    group_identifier="gauge_id",
    preprocessing_config=preprocessing_config,
    batch_size=32,
    input_length=30,
    output_length=5,
    num_workers=4,
    features=["swe(mm)"],
    # static_features=static_attributes[1:],
    target="swe(mm)",
    train_years=10,
    val_years=1,
    min_test_years=1,
)

# data_module.static_df

## 4. Create model and train

In [13]:
from src.models.lstm import LitLSTM
from src.models.ealstm import LitEALSTM
from torch.optim import Adam
from torch.nn import MSELoss

# 5. Evalue and plot results

In [14]:
# model = LitLSTM(
#     input_size=3,
#     hidden_size=64,
#     num_layers=1,
#     output_size=5,
#     target="discharge_spec(mm/d)",
# )

model = LitEALSTM(
    input_size_dyn=3,
    input_size_stat=len(static_attributes) - 1,
    hidden_size=64,
    output_size=5,
    target="discharge_spec(mm/d)",
)

# Configure trainer
trainer = pl.Trainer(
    max_epochs=5,
    accelerator="gpu",
    devices=1,
    callbacks=[
        ModelCheckpoint(
            monitor="val_loss",
            dirpath="checkpoints",
            filename="best-checkpoint",
            save_top_k=1,
            mode="min",
        ),
        EarlyStopping(monitor="val_loss", patience=3, mode="min"),
    ],
)

# Train the model
trainer.fit(model, data_module)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/cooper/Desktop/CAMELS-CH/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default



Quality Check Summary:
Original basins: 331
Retained basins: 297
Excluded basins: 34



Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [None]:
quality_report = data_module.quality_report

excluded_basins = list(quality_report["excluded_basins"].keys())
excluded_basins

ids_for_training = [id for id in ids_for_training if id not in excluded_basins]

In [None]:
ids_for_training

In [None]:
trainer.test(model, data_module)
test_results = model.test_results

In [None]:
# Get the results
results_df = model.test_results["forecast_df"]
horizon_metrics = model.test_results["horizon_metrics"]

horizons = []
nse_values = []
for horizon, metrics in horizon_metrics.items():
    horizons.append(horizon)
    nse_values.append(metrics["NSE"])

# Create bar plot
plt.figure(figsize=(10, 6))
colors = sns.color_palette("Blues", 1)
plt.bar(horizons, nse_values, color=colors)

# Customize plot
plt.xlabel("Forecast Horizon (Days)")
plt.ylabel("Nash-Sutcliffe Efficiency")
plt.title("Forecast Skill by Prediction Horizon")
plt.grid(True, linestyle="--", alpha=0.3)
sns.despine()

# Set x-axis ticks to show all horizons
plt.xticks(horizons)

# Add value labels on top of each bar
for i, v in enumerate(nse_values):
    plt.text(i + 1, v, f"{v:.3f}", ha="center", va="bottom")

plt.tight_layout()
plt.show()

In [None]:
results_df.head(10)

In [None]:
# Process results_df to get NSE by basin and horizon
basin_metrics = {}
for basin in results_df["basin_id"].unique():
    basin_data = results_df[results_df["basin_id"] == basin]
    nse_values = []
    for horizon in range(1, max(basin_data["horizon"]) + 1):
        horizon_data = basin_data[basin_data["horizon"] == horizon]
        nse = nash_sutcliffe_efficiency(
            horizon_data["observed"].values, horizon_data["prediction"].values
        )
        nse_values.append(nse)
    basin_metrics[basin] = nse_values

# Sort basins by NSE at horizon 1
sorted_basins = sorted(
    basin_metrics.keys(), key=lambda x: basin_metrics[x][0], reverse=True
)
basin_metrics = {basin: basin_metrics[basin] for basin in sorted_basins}

# Plot settings
plt.figure(figsize=(12, 6))
bar_width = 0.8 / len(basin_metrics)

# Create color palette of blue shades
colors = sns.color_palette("Blues", len(basin_metrics) + 2)[2:]

# Create bars for each basin
for i, (basin, nse_values) in enumerate(basin_metrics.items()):
    x = np.arange(len(nse_values)) + i * bar_width
    plt.bar(x, nse_values, bar_width, label=f"Basin {basin}", color=colors[i])

# Customize plot
plt.xlabel("Forecast Horizon (Days)", fontsize=12)
plt.ylabel("Nash-Sutcliffe Efficiency", fontsize=12)
plt.title("Forecast Skill by Basin and Horizon", fontsize=14, pad=20)
plt.grid(True, linestyle="--", alpha=0.3)
plt.legend(title="Basin ID", title_fontsize=10, fontsize=10)
sns.despine()

# Set x-axis ticks in middle of grouped bars
plt.xticks(
    np.arange(len(next(iter(basin_metrics.values()))))
    + bar_width * (len(basin_metrics) - 1) / 2,
    np.arange(1, len(next(iter(basin_metrics.values()))) + 1),
)

# Remove top and right spines
sns.despine()

plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns


def plot_predictions(results_df, n_timesteps=None):
    # Filter for horizon 1
    horizon_1_data = results_df[results_df["horizon"] == 1]

    if n_timesteps:
        # Get last n_timesteps for each basin
        horizon_1_data = (
            horizon_1_data.groupby("basin_id").tail(n_timesteps).reset_index(drop=True)
        )

    n_basins = len(horizon_1_data["basin_id"].unique())
    n_cols = 2
    n_rows = (n_basins + 1) // 2

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))
    axes = axes.flatten()

    for idx, basin in enumerate(horizon_1_data["basin_id"].unique()):
        basin_data = horizon_1_data[horizon_1_data["basin_id"] == basin]
        ax = axes[idx]

        nse = nash_sutcliffe_efficiency(
            basin_data["observed"].values, basin_data["prediction"].values
        )

        x = np.arange(len(basin_data))
        ax.plot(x, basin_data["observed"], label="Observed", color="#1d4ed8")
        ax.plot(
            x,
            basin_data["prediction"],
            label="Predicted",
            color="#dc2626",
            alpha=0.8,
        )

        ax.set_title(f"Basin {basin} (NSE: {nse:.3f})", fontsize=12)
        ax.set_xlabel("Time Step", fontsize=10)
        ax.set_ylabel("Discharge", fontsize=10)
        ax.grid(True, linestyle="--", alpha=0.3)
        ax.legend(fontsize=9)
        sns.despine(ax=ax)

    for idx in range(n_basins, len(axes)):
        fig.delaxes(axes[idx])

    plt.tight_layout()
    plt.show()


# Example usage:
plot_predictions(results_df, n_timesteps=365) 