In [1]:
import sys
from pathlib import Path
sys.path.append(str(Path().absolute().parent))
import time

In [2]:
import torch
import lightning.pytorch as pl
from lightning.pytorch import Trainer
from pytorch_forecasting import TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.metrics import QuantileLoss
from pytorch_forecasting.data import GroupNormalizer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
import pandas as pd
from io import StringIO
import numpy as np
import glob 
from pathlib import Path

from src.benchmark_tft.data_loading import combine_camels_data
from src.data_models.camels_ch import CamelsCH, CamelsCHConfig, get_all_gauge_ids

  from tqdm.autonotebook import tqdm


---

## Getting the data

In [3]:
camels_config = CamelsCHConfig(
    timeseries_dir="/Users/cooper/Desktop/CAMELS-CH/data/timeseries/observation_based/",
    timeseries_pattern="CAMELS_CH_obs_based_*.csv",
    static_attributes_dir="/Users/cooper/Desktop/CAMELS-CH/data/static_attributes",
    use_climate=False,
    use_geology=False,
    use_glacier=False,
    use_human_influence=False,
    use_hydrogeology=False,
    use_hydrology=False,
    use_landcover=False,
    use_soil=False,
    use_topographic=False,
)

gauge_ids = get_all_gauge_ids(camels_config)

print(f"There are {len(gauge_ids)} gauge ids")

camels = CamelsCH(camels_config)
camels.load_stations(gauge_ids)

There are 331 gauge ids
Loaded time series data for 331 stations


In [6]:
data = camels.get_time_series()
data = data[
    [
        "date",
        "discharge_spec(mm/d)",
        "precipitation(mm/d)",
        "temperature_mean(degC)",
        "gauge_id",
    ]
]

data

Unnamed: 0,date,discharge_spec(mm/d),precipitation(mm/d),temperature_mean(degC),gauge_id
0,1981-01-01,0.429,0.11,-3.35,2474
1,1981-01-02,0.423,0.14,-7.63,2474
2,1981-01-03,0.416,3.40,-1.13,2474
3,1981-01-04,0.424,2.46,-5.10,2474
4,1981-01-05,0.420,1.20,-11.88,2474
...,...,...,...,...,...
4835905,2020-12-27,2.240,2.66,-2.03,2486
4835906,2020-12-28,2.048,6.02,-0.45,2486
4835907,2020-12-29,1.725,3.41,-0.04,2486
4835908,2020-12-30,1.587,2.92,-1.99,2486


In [8]:
data = data.dropna(subset=["discharge_spec(mm/d)"])

data.loc[:, "precipitation(mm/d)"] = data["precipitation(mm/d)"].fillna(0)

data.loc[:, "temperature_mean(degC)"] = data["temperature_mean(degC)"].fillna(
    data["temperature_mean(degC)"].mean()
)

In [9]:
data

Unnamed: 0,date,discharge_spec(mm/d),precipitation(mm/d),temperature_mean(degC),gauge_id
0,1981-01-01,0.429,0.11,-3.35,2474
1,1981-01-02,0.423,0.14,-7.63,2474
2,1981-01-03,0.416,3.40,-1.13,2474
3,1981-01-04,0.424,2.46,-5.10,2474
4,1981-01-05,0.420,1.20,-11.88,2474
...,...,...,...,...,...
4835905,2020-12-27,2.240,2.66,-2.03,2486
4835906,2020-12-28,2.048,6.02,-0.45,2486
4835907,2020-12-29,1.725,3.41,-0.04,2486
4835908,2020-12-30,1.587,2.92,-1.99,2486


## Preparing the data

In [None]:
training = TimeSeriesDataSet(
    camels_combined,
    time_idx="time_idx",
    target="discharge_spec(mm/d)",
    group_ids=["gauge_id"],
    max_encoder_length=30,
    max_prediction_length=7,
    time_varying_known_reals=["precipitation(mm/d)", "temperature_mean(degC)"],
    time_varying_unknown_reals=["discharge_spec(mm/d)"],
    target_normalizer=GroupNormalizer(groups=["gauge_id"]),
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
    allow_missing_timesteps=True,
)

# Create validation set
validation = TimeSeriesDataSet.from_dataset(
    training,
    camels_combined,
    min_prediction_idx=training.index.time.max() - 30,



)

# Create dataloaders
batch_size = 128
train_dataloader = training.to_dataloader(
    train=True, batch_size=batch_size, num_workers=0
)
val_dataloader = validation.to_dataloader(
    train=False, batch_size=batch_size, num_workers=0, drop_last=True
)

In [None]:
callbacks = [
    EarlyStopping(monitor="val_loss", patience=3, mode="min"),
    ModelCheckpoint(
        monitor="val_loss",
        dirpath="checkpoints",
        filename="tft-{epoch:02d}-{val_loss:.2f}",
        save_top_k=3,
        mode="min",
    ),
]

trainer = Trainer(
    max_epochs=30,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=[0] if torch.cuda.is_available() else 1,
    gradient_clip_val=0.1,
    limit_train_batches=50,
    enable_checkpointing=True,
    logger=True,
    callbacks=callbacks,
)

In [None]:
tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=16,
    attention_head_size=1,
    dropout=0.1,
    hidden_continuous_size=8,
    output_size=7,  # Number of quantiles
    loss=QuantileLoss(),
    log_interval=10,
    reduce_on_plateau_patience=4,
)

# Print number of parameters
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

# trainer.fit(tft, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

In [None]:
import matplotlib.pyplot as plt

def evaluate_and_plot(tft, val_dataloader):
    # Get predictions with return_y=True
    predictions = tft.predict(val_dataloader, return_y=True, trainer_kwargs=dict(accelerator="cpu"))
    
    # Ensure tensors are on CPU and convert to numpy
    actuals = predictions.y[0].cpu().numpy()
    outputs = predictions.output.cpu().numpy()
    
    # Calculate MAE
    mae = np.mean(np.abs(actuals - outputs))
    print(f"Mean Absolute Error: {mae:.4f}")
    
    # Plot results
    plt.figure(figsize=(12, 6))
    plt.plot(actuals, label="Actual", alpha=0.7)
    plt.plot(outputs, label="Forecast", alpha=0.7)
    plt.title("TFT Predictions vs Actuals")
    plt.xlabel("Time Step")
    plt.ylabel("Discharge")
    plt.legend()
    plt.grid(True)
    plt.show()
    
    return mae


# Load and evaluate model
best_model_path = Path(
    "/Users/cooper/Desktop/CAMELS-CH/notebooks/checkpoints/tft-epoch=08-val_loss=0.29.ckpt"
)
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
mae = evaluate_and_plot(best_tft, val_dataloader)