In [1]:
import sys
from pathlib import Path
sys.path.append(str(Path().absolute().parent))
import time

In [2]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import (
    ModelCheckpoint,
    EarlyStopping,
    LearningRateMonitor,
)
import torch
from torch.nn import MSELoss
from torch.optim import Adam
import matplotlib.pyplot as plt
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

import pandas as pd
import glob
from pathlib import Path

from src.data_models.camels_ch import CamelsCH, CamelsCHConfig, get_all_gauge_ids
from src.data_models.dataset import HydroDataset
from src.data_models.preprocessing import (
    scale_time_series,
    scale_static_attributes,
    inverse_scale_static_attributes,
    inverse_scale_time_series,
)
from src.data_models.caravanify import Caravanify, CaravanifyConfig

from utils.metrics import nash_sutcliffe_efficiency
from src.data_models.datamodule import HydroDataModule

from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer

from src.preprocessing.grouped import GroupedTransformer
from src.preprocessing.log_scale import LogTransformer
from src.preprocessing.standard_scale import StandardScaleTransformer

---

## Testing Caravanify

In [3]:
config = CaravanifyConfig(
    attributes_dir="/Users/cooper/Desktop/CAMELS-CH/data/CARAVANIFY/CH/post_processed/attributes",
    timeseries_dir="/Users/cooper/Desktop/CAMELS-CH/data/CARAVANIFY/CH/post_processed/timeseries/csv",
    gauge_id_prefix="CH",
    use_hydroatlas_attributes=True,
    use_caravan_attributes=True,
    use_other_attributes=True,
)


caravan = Caravanify(config)
# ids_for_training = [
#     "CA_15016",
#     "CA_17462",
# ]
ids_for_training = caravan.get_all_gauge_ids()

print(f"Total number of stations: {len(ids_for_training)}")

caravan.load_stations(ids_for_training)


# Get data
ts_data = caravan.get_time_series()  
static_data = caravan.get_static_attributes()  

Total number of stations: 135


In [4]:
ts_data.columns

Index(['gauge_id', 'date', 'snow_depth_water_equivalent_mean',
       'surface_net_solar_radiation_mean',
       'surface_net_thermal_radiation_mean', 'surface_pressure_mean',
       'temperature_2m_mean', 'dewpoint_temperature_2m_mean',
       'u_component_of_wind_10m_mean', 'v_component_of_wind_10m_mean',
       'volumetric_soil_water_layer_1_mean',
       'volumetric_soil_water_layer_2_mean',
       'volumetric_soil_water_layer_3_mean',
       'volumetric_soil_water_layer_4_mean', 'snow_depth_water_equivalent_min',
       'surface_net_solar_radiation_min', 'surface_net_thermal_radiation_min',
       'surface_pressure_min', 'temperature_2m_min',
       'dewpoint_temperature_2m_min', 'u_component_of_wind_10m_min',
       'v_component_of_wind_10m_min', 'volumetric_soil_water_layer_1_min',
       'volumetric_soil_water_layer_2_min',
       'volumetric_soil_water_layer_3_min',
       'volumetric_soil_water_layer_4_min', 'snow_depth_water_equivalent_max',
       'surface_net_solar_radiati

In [5]:
ts_data["date"] = pd.to_datetime(ts_data["date"])

# Now we can get the day of year using dt accessor
ts_data["julian_day"] = ts_data["date"].dt.dayofyear

# Get column names excluding specific columns
ts_columns = [
    "potential_evaporation_sum_ERA5_LAND",
    "potential_evaporation_sum_FAO_PENMAN_MONTEITH",
    "streamflow",
    # "julian_day",
    "temperature_2m_mean",
    "total_precipitation_sum",
]

In [6]:
ts_data

Unnamed: 0,gauge_id,date,snow_depth_water_equivalent_mean,surface_net_solar_radiation_mean,surface_net_thermal_radiation_mean,surface_pressure_mean,temperature_2m_mean,dewpoint_temperature_2m_mean,u_component_of_wind_10m_mean,v_component_of_wind_10m_mean,...,v_component_of_wind_10m_max,volumetric_soil_water_layer_1_max,volumetric_soil_water_layer_2_max,volumetric_soil_water_layer_3_max,volumetric_soil_water_layer_4_max,total_precipitation_sum,potential_evaporation_sum_ERA5_LAND,potential_evaporation_sum_FAO_PENMAN_MONTEITH,streamflow,julian_day
0,CH_2009,1981-01-02,422.27,16.46,-5.42,85.37,-8.89,-11.15,-0.23,0.77,...,1.03,0.37,0.38,0.38,0.43,4.79,0.34,0.17,0.82,2
1,CH_2009,1981-01-03,427.85,13.50,-1.53,84.77,-5.78,-7.29,0.12,1.02,...,1.12,0.37,0.38,0.38,0.43,13.44,0.17,0.19,0.97,3
2,CH_2009,1981-01-04,470.54,5.70,-7.70,83.81,-4.00,-4.80,0.73,0.23,...,0.98,0.37,0.38,0.38,0.43,50.94,-0.00,0.00,1.21,4
3,CH_2009,1981-01-05,496.06,19.65,-30.43,84.55,-8.98,-10.48,0.89,-0.29,...,0.47,0.37,0.38,0.38,0.43,10.42,0.39,0.00,1.50,5
4,CH_2009,1981-01-06,508.01,7.90,-15.45,84.38,-10.22,-11.95,0.58,0.18,...,0.93,0.37,0.38,0.38,0.43,21.49,0.12,0.00,1.54,6
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1972210,CH_5032,2020-12-27,8.48,24.67,-34.24,89.58,-0.54,-5.98,1.57,3.19,...,5.60,0.39,0.40,0.42,0.41,2.68,3.35,0.56,,362
1972211,CH_5032,2020-12-28,12.66,31.03,-42.11,87.82,-0.40,-3.35,1.34,2.81,...,4.16,0.40,0.39,0.42,0.41,8.38,1.79,0.24,,363
1972212,CH_5032,2020-12-29,19.42,36.83,-43.51,88.78,-0.48,-2.68,2.58,2.45,...,3.83,0.39,0.39,0.42,0.42,8.07,1.55,0.23,,364
1972213,CH_5032,2020-12-30,22.49,27.96,-33.81,89.88,-3.14,-4.45,1.46,0.93,...,1.61,0.39,0.38,0.42,0.42,1.49,0.36,0.03,,365


In [7]:
whole_data = ts_columns + ["gauge_id", "date"]
ts_data = ts_data[whole_data]

# # Group by gauge_id 
# grouped = ts_data.groupby("gauge_id")

# # Plot time series for each gauge of the ts_columns of the last 5 years
# for gauge_id, group in grouped:
#     group = group.set_index("date")
#     group = group.loc["2015-01-01":"2020-12-31"]
#     group[ts_columns].plot(subplots=True, figsize=(20, 20), title=gauge_id)
#     plt.show()



In [8]:
statics_to_keep = [
    "gauge_id",
    "p_mean",
    "area",
    "ele_mt_sav",
    "high_prec_dur",
    "frac_snow",
    "high_prec_freq",
    "slp_dg_sav",
    "cly_pc_sav",
    "aridity_ERA5_LAND",
    "aridity_FAO_PM",
]

static_columns = static_data.columns
static_columns = [col for col in list(static_columns) if col in statics_to_keep]

static_data = static_data[static_columns]
static_data

Unnamed: 0,gauge_id,area,cly_pc_sav,ele_mt_sav,slp_dg_sav,aridity_ERA5_LAND,aridity_FAO_PM,frac_snow,high_prec_dur,high_prec_freq,p_mean
0,CH_2009,701.274189,14.506965,1565.340147,248.671068,0.999746,0.258537,0.386329,1.134021,0.030118,4.954752
1,CH_2011,1566.180595,14.634048,2081.147355,247.628593,0.559558,0.235603,0.479000,1.223587,0.034089,4.197982
2,CH_2016,811.489711,19.279472,536.226852,65.255587,0.788556,0.401406,0.000000,1.114286,0.029365,3.908174
3,CH_2018,229.895078,19.096711,489.142185,47.522328,0.761725,0.420621,0.000000,1.146226,0.033267,3.691103
4,CH_2019,554.489324,12.929081,2130.274450,272.113992,0.462579,0.157501,0.420558,1.206490,0.027996,5.468537
...,...,...,...,...,...,...,...,...,...,...,...
130,CH_4018,263.819995,17.028774,593.942048,25.426573,1.093487,0.500403,0.128006,1.147651,0.035115,3.105997
131,CH_4020,134.347492,18.861457,702.837469,40.909075,1.079755,0.355538,0.222678,1.202469,0.033336,4.191385
132,CH_5009,141.339042,20.967316,1104.135496,66.945681,0.686178,0.346108,0.154859,1.121519,0.030324,4.981448
133,CH_5014,96.878626,20.955220,1032.105499,80.504107,0.697724,0.350408,0.147758,1.127717,0.028407,4.880079


In [9]:
features = [
    col for col in ts_data.columns if col not in ["gauge_id", "date", "streamflow"]
]
ts_columns = features + ["streamflow"]  # Ensure target is not in features


# 1. Load and prepare CAMELS-CH data

In [10]:
# camels_config = CamelsCHConfig(
#     timeseries_dir="/Users/cooper/Desktop/CAMELS-CH/data/timeseries/observation_based/",
#     timeseries_pattern="CAMELS_CH_obs_based_*.csv",
#     static_attributes_dir="/Users/cooper/Desktop/CAMELS-CH/data/static_attributes",
#     use_climate=False,
#     use_geology=False,
#     use_glacier=False,
#     use_human_influence=False,
#     use_hydrogeology=False,
#     use_hydrology=False,
#     use_landcover=False,
#     use_soil=False,
#     use_topographic=False,
# )

In [11]:
# all_gauge_ids = get_all_gauge_ids(camels_config)

# ids_for_training = all_gauge_ids[:5]

# camels = CamelsCH(camels_config)
# camels.load_stations(ids_for_training)

In [12]:
# data = camels.get_time_series()
# data = data[
#     [
#         "gauge_id",
#         "date",
#         "discharge_spec(mm/d)",
#     ]
# ]

# data

In [13]:
# static = camels.get_static_attributes()
# sc = static.columns

# # for i in range(len(sc)):
# #     print(f"{i}: {sc[i]}")
# static_attributes = [
#     "gauge_id",
#     "area", 
#     "elev_mean",  
#     "slope_mean",  
#     "aridity",  
#     "p_seasonality",  
#     "frac_snow",  
#     "porosity",  
#     "conductivity",  
#     "p_mean",  
#     "geo_porosity",  
# ]
# static = static[static_attributes]
# static

# 2. Configure preprocessing

In [14]:
dynamic_feature_cols = features
static_feature_cols = [c for c in static_columns if c != "gauge_id"]
target_cols = ["streamflow"]

# Feature pipeline: log + scale
feature_pipeline = Pipeline([
    # ("log", LogTransformer(columns=dynamic_feature_cols)),
    ("scaler", StandardScaleTransformer(columns=dynamic_feature_cols))
])

# Target pipeline: grouped by basin with log + scale
target_pipeline = GroupedTransformer(
    Pipeline([
        # ("log", LogTransformer(columns=target_cols)),
        ("scaler", StandardScaleTransformer(columns=target_cols))
    ]),
    columns=target_cols,
    group_identifier="gauge_id",
    n_jobs=-1,
)

# Static feature pipeline: just scaling
static_pipeline = Pipeline([
    ("scaler", StandardScaleTransformer(columns=static_feature_cols))
])

# Define preprocessing configurations
preprocessing_configs = {
    "features": {
        "pipeline": feature_pipeline,
        "columns": dynamic_feature_cols
    },
    "target": {
        "pipeline": target_pipeline,
        "columns": target_cols
    },
    "static_features": {
        "pipeline": static_pipeline,
        "columns": static_feature_cols
    },
}

In [15]:
static_columns

['gauge_id',
 'area',
 'cly_pc_sav',
 'ele_mt_sav',
 'slp_dg_sav',
 'aridity_ERA5_LAND',
 'aridity_FAO_PM',
 'frac_snow',
 'high_prec_dur',
 'high_prec_freq',
 'p_mean']

# 3. Create DataModule

In [16]:
output_length = 10
input_length = 40

static_columns = [c for c in static_columns if c not in ["gauge_id"]]

print("TS columns:", ts_columns)
print("Static columns:", static_columns)


data_module = HydroDataModule(
    time_series_df=ts_data,
    static_df=static_data,
    # static_df=None,
    group_identifier="gauge_id",
    preprocessing_config=preprocessing_configs,
    batch_size=128,
    input_length=input_length,
    output_length=output_length,
    num_workers=4,
    features=ts_columns,
    static_features=static_columns,
    # static_features=None,
    target="streamflow",
    min_train_years=2,
    val_years=1,
    test_years=3,
    max_missing_pct=10,
    domain_id="CA",
)

data_module.prepare_data()
data_module.setup()
# train_loader = data_module.test_dataloader()

# for i, batch in enumerate(train_loader):
#     if i >= 6:  # Look at first 6 batches
#         break

#     print(f"Slice indeces: {batch['slice_idx']}")

TS columns: ['potential_evaporation_sum_ERA5_LAND', 'potential_evaporation_sum_FAO_PENMAN_MONTEITH', 'temperature_2m_mean', 'total_precipitation_sum', 'streamflow']
Static columns: ['area', 'cly_pc_sav', 'ele_mt_sav', 'slp_dg_sav', 'aridity_ERA5_LAND', 'aridity_FAO_PM', 'frac_snow', 'high_prec_dur', 'high_prec_freq', 'p_mean']
Original basins: 135
Retained basins: 135
Domain CA: Created 1700755 valid sequences from 135 catchments
Domain CA: Created 42660 valid sequences from 135 catchments
Domain CA: Created 141290 valid sequences from 135 catchments


In [None]:
static_columns

In [None]:
data_module.target

## 4. Create model and train

In [None]:
from src.models.lstm import LitLSTM
from src.models.ealstm import LitEALSTM
from src.models.TSMixer import LitTSMixer, TSMixerConfig
from src.models.evaluators import TSForecastEvaluator
from torch.optim import Adam
from torch.nn import MSELoss

# 5. Evalue and plot results

In [None]:
# model = LitLSTM(
#     input_size=len(ts_columns),
#     hidden_size=16,
#     num_layers=1,
#     output_size=output_length,
#     target=data_module.target,
# )

# model = LitEALSTM(
#     input_size_dyn=len(ts_columns),
#     input_size_stat=len(static_columns) - 1,
#     hidden_size=64,
#     output_size=output_length,
# )

config = TSMixerConfig(
    input_len=input_length,
    output_len=output_length,
    input_size=len(ts_columns),
    static_size=len(static_columns),
    hidden_size=80,
    learning_rate=7e-4,
    dropout=0.1,
    num_layers=2,
)

model = LitTSMixer(config)

# Configure trainer
trainer = pl.Trainer(
    max_epochs=15,
    accelerator="cpu",
    devices=1,
    callbacks=[
        ModelCheckpoint(
            monitor="val_loss",
            dirpath="checkpoints",
            filename="best-checkpoint",
            save_top_k=1,
            mode="min",
        ),
        EarlyStopping(monitor="val_loss", patience=3, mode="min"),
        LearningRateMonitor(logging_interval="epoch"),
    ],
)

# Train the model
trainer.fit(model, data_module)

In [None]:
quality_report = data_module.quality_report

excluded_basins = list(quality_report["excluded_basins"].keys())
excluded_basins

ids_for_training = [id for id in ids_for_training if id not in excluded_basins]

In [None]:
quality_report

In [None]:
trainer.test(model, data_module)
raw_results = model.test_results

# Create evaluator and get metrics
evaluator = TSForecastEvaluator(
    data_module, horizons=list(range(1, model.config.output_len + 1))
)
results_df, overall_metrics, basin_metrics = evaluator.evaluate(raw_results)

# Get overall summary
overall_summary = evaluator.summarize_metrics(overall_metrics)

# Get per-basin summary
basin_summary = evaluator.summarize_metrics(basin_metrics, per_basin=True)

In [None]:
overall_summary

In [None]:
basin_summary

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns


def plot_metric_summary(
    summary_df: pd.DataFrame, metric: str, per_basin: bool = False, figsize=(10, 6)
):
    plt.figure(figsize=figsize)

    if per_basin:
        df_plot = summary_df[metric].unstack(level=0)

        # Sort basins based on first horizon values
        first_horizon_values = df_plot.iloc[0]
        sorted_basins = first_horizon_values.sort_values(ascending=False).index
        df_plot = df_plot[sorted_basins]

        sns.barplot(
            data=df_plot.melt(ignore_index=False).reset_index(),
            x="horizon",
            y="value",
            hue="basin_id",
            palette="Blues",
        )
        plt.title(f"{metric} by Basin and Horizon")
        plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", title="Basin ID")

    else:
        ax = sns.barplot(x=summary_df.index, y=summary_df[metric], color="steelblue")
        plt.title(f"Overall {metric} by Horizon")

        for i, v in enumerate(summary_df[metric]):
            ax.text(i, v, f"{v:.2f}", ha="center", va="bottom")

    plt.xlabel("Forecast Horizon")
    plt.ylabel(metric)
    plt.tight_layout()
    sns.despine()
    plt.show()


# Usage example:
plot_metric_summary(overall_summary, "NSE")  # Plot overall NSE
plot_metric_summary(
    basin_summary, "NSE", per_basin=True, figsize=(12, 6)
)  