In [1]:
import sys
from pathlib import Path
sys.path.append(str(Path().absolute().parent))
import time

In [None]:
import torch
import lightning.pytorch as pl
from lightning.pytorch import Trainer
from pytorch_forecasting import TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.metrics import QuantileLoss
from pytorch_forecasting.data import GroupNormalizer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
import pandas as pd
from io import StringIO
import numpy as np
import glob 
from pathlib import Path

from src.data_models.camels_ch import CamelsCH, CamelsCHConfig, get_all_gauge_ids
from src.data_models.dataset import HydroDataset

---

## Getting the data

In [None]:
camels_config = CamelsCHConfig(
    timeseries_dir="/Users/cooper/Desktop/CAMELS-CH/data/timeseries/observation_based/",
    timeseries_pattern="CAMELS_CH_obs_based_*.csv",
    static_attributes_dir="/Users/cooper/Desktop/CAMELS-CH/data/static_attributes",
    use_climate=False,
    use_geology=False,
    use_glacier=False,
    use_human_influence=False,
    use_hydrogeology=False,
    use_hydrology=False,
    use_landcover=False,
    use_soil=False,
    use_topographic=True,
)

camels = CamelsCH(camels_config)
camels.load_stations(["2018", "6005"]) 

In [None]:
static = camels.get_static_attributes()
static

In [None]:
# Create dataset
dataset = HydroDataset(
    time_series_df=camels.get_time_series(),
    static_df=camels.get_static_attributes(),
    input_length=365,
    output_length=5,
    features=["discharge_spec(mm/d)", "precipitation(mm/d)", "temperature_mean(degC)"],
    target="discharge_spec(mm/d)",
    static_features=["elev_mean", "slope_mean"],
)

# Test dataset
print(f"Dataset size: {len(dataset)}")
sample = dataset[0]
print("\nSample shapes:")
print(f"X: {sample['X'].shape}")
print(f"y: {sample['y'].shape}")
print(f"static: {sample['static'].shape}")
print(f"gauge_id: {sample['gauge_id']}")

In [None]:
data = camels.get_time_series()
data = data[
    [
        "date",
        "discharge_spec(mm/d)",
        "precipitation(mm/d)",
        "temperature_mean(degC)",
        "gauge_id",
    ]
]

data

In [5]:
data = data.dropna(subset=["discharge_spec(mm/d)"])

data.loc[:, "precipitation(mm/d)"] = data["precipitation(mm/d)"].fillna(0)

data.loc[:, "temperature_mean(degC)"] = data["temperature_mean(degC)"].fillna(
    data["temperature_mean(degC)"].mean()
)

In [None]:
data.loc[:, "time_idx"] = data["date"].rank(method="dense").astype(int) - 1

## Preparing the data

In [7]:
max_encoder_length = 365
max_prediction_length = 1

training_cutoff = data["time_idx"].max() - max_prediction_length * 365 

training = TimeSeriesDataSet(
   data[lambda x: x["time_idx"] <= training_cutoff],
   time_idx="time_idx",
   target="discharge_spec(mm/d)", 
   group_ids=["gauge_id"],
   max_encoder_length=max_encoder_length,
   min_encoder_length=max_encoder_length // 2,
   max_prediction_length=max_prediction_length,
   min_prediction_length=1,
   time_varying_known_reals=["precipitation(mm/d)", "temperature_mean(degC)"],
   time_varying_unknown_reals=["discharge_spec(mm/d)"],
   target_normalizer=GroupNormalizer(groups=["gauge_id"]),
   add_relative_time_idx=True,
   add_target_scales=True,
   add_encoder_length=True,
   allow_missing_timesteps=True
)

# Create validation set using last max_prediction_length timesteps
validation = TimeSeriesDataSet.from_dataset(
   training,
   data[lambda x: x["time_idx"] > training_cutoff],
   predict=True,
   stop_randomization=True
)

batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0)

In [None]:
callbacks = [
    EarlyStopping(monitor="val_loss", patience=3, mode="min"),
    ModelCheckpoint(
        monitor="val_loss",
        dirpath="checkpoints",
        filename="tft-{epoch:02d}-{val_loss:.2f}",
        save_top_k=3,
        mode="min",
    ),
]

trainer = Trainer(
    max_epochs=30,
    accelerator="cpu",
    devices=[0] if torch.cuda.is_available() else 1,
    gradient_clip_val=0.1,
    limit_train_batches=50,
    enable_checkpointing=True,
    logger=True,
    callbacks=callbacks,
)

In [None]:
tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=16,
    attention_head_size=1,
    dropout=0.1,
    hidden_continuous_size=8,
    output_size=7,
    loss=QuantileLoss(),
    log_interval=10,
    reduce_on_plateau_patience=4,
    optimizer="adam",
)

trainer.fit(tft, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

In [None]:
best_model_path = "/Users/cooper/Desktop/CAMELS-CH/notebooks/checkpoints/tft-epoch=00-val_loss=0.16.ckpt"

hindcast_cutoff = data["time_idx"].max() - max_prediction_length * 365

# Load the best model and set to eval mode
best_model = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
best_model.eval()

# Evaluate on a validation/test dataset (built via from_dataset)
predictions = best_model.predict(val_dataloader, return_y=True)
# Compute a metric (e.g., SMAPE)
smape = SMAPE()(predictions.output, predictions.y)
print("SMAPE:", smape.item())

# Hindcast example: create a hindcast dataset (using predict_mode=True ensures only the last forecast point is used)
hindcast_dataset = TimeSeriesDataSet.from_dataset(
    training,
    data[lambda x: x["time_idx"] > hindcast_cutoff],
    predict=True,
    stop_randomization=True
)
hindcast_dataloader = hindcast_dataset.to_dataloader(train=False, batch_size=128)
hindcast_predictions = best_model.predict(hindcast_dataloader)
# You can then compare hindcast_predictions with the known historical targets
best_model.plot_prediction(hindcast_predictions.x, hindcast_predictions.output, idx=0)
