In [1]:
import sys
from pathlib import Path
sys.path.append(str(Path().absolute().parent))
import time

In [None]:
import torch
from pytorch_lightning import Trainer
from pytorch_forecasting import TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.metrics import QuantileLoss
from pytorch_forecasting.data import GroupNormalizer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import pandas as pd
from io import StringIO
import numpy as np

import os 
import sys
from pathlib import Path
import glob

from src.benchmark_tft.data_loading import combine_camels_data

---

## Getting the data

In [3]:
data_base_dir = "/Users/cooper/Desktop/CAMELS-CH/data/timeseries/observation_based/"
data_naming_convention = "CAMELS_CH_obs_based_*.csv"


columns_to_keep = [
    "date",
    "discharge_spec(mm/d)",
    "precipitation(mm/d)",
    "temperature_mean(degC)",
]




In [None]:
camels_combined = combine_camels_data(
    folder_path=data_base_dir,
    data_naming_convention=data_naming_convention,
    columns_to_keep=columns_to_keep,
)

# Drop rows with NaNs in the discharge column
camels_combined = camels_combined.dropna(subset=["discharge_spec(mm/d)"])

# Set precipitation and temperature to 0 where NaN
camels_combined["precipitation(mm/d)"] = camels_combined["precipitation(mm/d)"].fillna(
    0
)

# Impute temperature with mean
camels_combined["temperature_mean(degC)"] = camels_combined[
    "temperature_mean(degC)"
].fillna(camels_combined["temperature_mean(degC)"].mean())

In [None]:
camels_combined

## Preparing the data

In [None]:
training = TimeSeriesDataSet(
    camels_combined,
    time_idx="time_idx",
    target="discharge_spec(mm/d)",
    group_ids=["gauge_id"],
    max_encoder_length=30,
    max_prediction_length=7,
    time_varying_known_reals=["precipitation(mm/d)", "temperature_mean(degC)"],
    time_varying_unknown_reals=["discharge_spec(mm/d)"],
    target_normalizer=GroupNormalizer(groups=["gauge_id"]),
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
    allow_missing_timesteps=True,
)

# Create validation set
validation = TimeSeriesDataSet.from_dataset(
    training,
    camels_combined,
    min_prediction_idx=training.index.time.max() - 30,
)

# Create dataloaders
batch_size = 128
train_dataloader = training.to_dataloader(
    train=True, batch_size=batch_size, num_workers=0
)
val_dataloader = validation.to_dataloader(
    train=False, batch_size=batch_size, num_workers=0
)

In [None]:
callbacks = [
    EarlyStopping(monitor="val_loss", patience=3, mode="min"),
    ModelCheckpoint(
        monitor="val_loss",
        dirpath="checkpoints",
        filename="tft-{epoch:02d}-{val_loss:.2f}",
        save_top_k=3,
        mode="min",
    ),
]

trainer = Trainer(
    max_epochs=30,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=[0] if torch.cuda.is_available() else None,
    gradient_clip_val=0.1,
    limit_train_batches=50,
    enable_checkpointing=True,
    logger=True,
    callbacks=callbacks,
)

In [None]:
tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=16,
    attention_head_size=1,
    dropout=0.1,
    hidden_continuous_size=8,
    output_size=7,  # Number of quantiles
    loss=QuantileLoss(),
    log_interval=10,
    reduce_on_plateau_patience=4,
)

trainer.fit(tft, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)