
# Demand forecasting with the Temporal Fusion Transformer (Stallion tutorial)

This notebook mirrors the PyTorch Forecasting tutorial that trains a **TemporalFusionTransformer (TFT)**
on the small **Stallion** beverage sales dataset. The goal is to forecast **6 months** of `volume`
per (agency, SKU) using time-varying and static features.

We will:
1. Load and enrich the dataset with a time index and engineered features.
2. Build `TimeSeriesDataSet` objects and dataloaders.
3. Establish a simple **Baseline**.
4. Configure and train a **TFT** with PyTorch Lightning.
5. (Optional) Explore **hyperparameter tuning** with Optuna.
6. Evaluate on validation data and visualize predictions/interpretability outputs.


In [None]:

import warnings
warnings.filterwarnings("ignore")  # keep logs clean

import copy
from pathlib import Path

import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger

import numpy as np
import pandas as pd
import torch

from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import MAE, SMAPE, PoissonLoss, QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters



## Load data

We use the built-in **Stallion** dataset helper and add useful features:

- `time_idx` (monotonic index per time step),
- calendar month (categorical),
- log-volume,
- cross-sectional rolling means by `sku` and by `agency`,
- special days compressed into a single categorical via reversing one-hot columns.


In [None]:

from pytorch_forecasting.data.examples import get_stallion_data

data = get_stallion_data()

# add time index
data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month
data["time_idx"] -= data["time_idx"].min()

# add additional features
data["month"] = data.date.dt.month.astype(str).astype("category")  # categories must be strings
data["log_volume"] = np.log(data.volume + 1e-8)
data["avg_volume_by_sku"] = data.groupby(["time_idx", "sku"], observed=True).volume.transform("mean")
data["avg_volume_by_agency"] = data.groupby(["time_idx", "agency"], observed=True).volume.transform("mean")

# compress special days into a single categorical (reverse one-hot)
special_days = [
    "easter_day",
    "good_friday",
    "new_year",
    "christmas",
    "labor_day",
    "independence_day",
    "revolution_day_memorial",
    "regional_games",
    "fifa_u_17_world_cup",
    "football_gold_cup",
    "beer_capital",
    "music_fest",
]
data[special_days] = data[special_days].apply(lambda x: x.map({0: "-", 1: x.name})).astype("category")

data.sample(5, random_state=521)


In [None]:

# quick describe
data.describe(include="all")



## Create dataset and dataloaders

We build a `TimeSeriesDataSet` describing inputs/targets and metadata such as which variables are
static vs time-varying and known vs unknown. We also specify **group-wise normalization**.
Validation consists of the last 6 months for each series.


In [None]:

max_prediction_length = 6
max_encoder_length = 24
training_cutoff = data["time_idx"].max() - max_prediction_length

training = TimeSeriesDataSet(
    data[lambda x: x.time_idx <= training_cutoff],
    time_idx="time_idx",
    target="volume",
    group_ids=["agency", "sku"],
    min_encoder_length=max_encoder_length // 2,  # keep encoder fairly long
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    static_categoricals=["agency", "sku"],
    static_reals=["avg_population_2017", "avg_yearly_household_income_2017"],
    time_varying_known_categoricals=["special_days", "month"],
    variable_groups={"special_days": special_days},  # treat the special days as one categorical group
    time_varying_known_reals=["time_idx", "price_regular", "discount_in_percent"],
    time_varying_unknown_categoricals=[],
    time_varying_unknown_reals=[
        "volume",
        "log_volume",
        "industry_volume",
        "soda_volume",
        "avg_max_temp",
        "avg_volume_by_agency",
        "avg_volume_by_sku",
    ],
    target_normalizer=GroupNormalizer(groups=["agency", "sku"], transformation="softplus"),
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
)

# validation predicts last max_prediction_length points for each series
validation = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_randomization=True)

# create dataloaders
batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0)



## Baseline

A naive baseline repeats the last observed value across the horizon.


In [None]:

# baseline MAE
baseline_predictions = Baseline().predict(val_dataloader, return_y=True)
MAE()(baseline_predictions.output, baseline_predictions.y)



## Train the Temporal Fusion Transformer

We'll set seeds, construct a small TFT model, and (optionally) use Lightning's LR finder.


In [None]:

pl.seed_everything(42)

trainer_lr = pl.Trainer(
    accelerator="cpu",
    gradient_clip_val=0.1,
)

tft_lr = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,      # final choice (can adjust based on finder)
    hidden_size=8,
    attention_head_size=1,
    dropout=0.1,
    hidden_continuous_size=8,
    loss=QuantileLoss(),
    optimizer="ranger",
)

print(f"Number of parameters in network (LR search setup): {tft_lr.size()/1e3:.1f}k")


In [None]:

# Optional: learning rate finder (comment out if undesired)
from lightning.pytorch.tuner import Tuner

res = Tuner(trainer_lr).lr_find(
    tft_lr,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
    max_lr=10.0,
    min_lr=1e-6,
)
print(f"suggested learning rate: {res.suggestion()}")
_ = res.plot(show=True, suggest=True)



### Full training

We now configure callbacks and train a slightly larger TFT.
Use TensorBoard to monitor training: `tensorboard --logdir lightning_logs`.


In [None]:

early_stop = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, mode="min")
lr_logger = LearningRateMonitor()
logger = TensorBoardLogger("lightning_logs")

trainer = pl.Trainer(
    max_epochs=50,
    accelerator="cpu",
    enable_model_summary=True,
    gradient_clip_val=0.1,
    limit_train_batches=50,  # demo-friendly; increase for full training
    callbacks=[lr_logger, early_stop],
    logger=logger,
)

tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=16,
    attention_head_size=2,
    dropout=0.1,
    hidden_continuous_size=8,
    loss=QuantileLoss(),
    log_interval=10,
    optimizer="ranger",
    reduce_on_plateau_patience=4,
)

print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

trainer.fit(tft, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)



## Hyperparameter tuning (Optuna) — optional

`optimize_hyperparameters` can search across ranges for key TFT hyperparameters.
This can be time-consuming; reduce `n_trials` or `max_epochs` for quick runs.


In [None]:

# WARNING: This can take a long time. Adjust n_trials and max_epochs for your setup.
# import pickle
# study = optimize_hyperparameters(
#     train_dataloader,
#     val_dataloader,
#     model_path="optuna_tft_study",
#     n_trials=20,
#     max_epochs=50,
#     gradient_clip_val_range=(0.01, 1.0),
#     hidden_size_range=(8, 128),
#     hidden_continuous_size_range=(8, 128),
#     attention_head_size_range=(1, 4),
#     learning_rate_range=(0.001, 0.1),
#     dropout_range=(0.1, 0.3),
#     trainer_kwargs=dict(limit_train_batches=30),
#     reduce_on_plateau_patience=4,
#     use_learning_rate_finder=False,
# )
# with open("optuna_tft_study.pkl", "wb") as f:
#     pickle.dump(study, f)
# print(study.best_trial.params)



## Evaluate performance & visualize

We reload the best checkpoint, compute MAE on validation data, and visualize predictions.


In [None]:

# load best model
best_model_path = trainer.checkpoint_callback.best_model_path
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)

# MAE on validation set
predictions = best_tft.predict(val_dataloader, return_y=True, trainer_kwargs=dict(accelerator="cpu"))
MAE()(predictions.output, predictions.y)


In [None]:

# plot a few validation predictions with attention
raw = best_tft.predict(val_dataloader, mode="raw", return_x=True, trainer_kwargs=dict(accelerator="cpu"))
for idx in range(3):  # fewer plots for convenience
    best_tft.plot_prediction(raw.x, raw.output, idx=idx, add_loss_to_title=True)


In [None]:

# worst-performers by SMAPE
predictions = best_tft.predict(val_dataloader, return_y=True, trainer_kwargs=dict(accelerator="cpu"))
mean_losses = SMAPE(reduction="none").loss(predictions.output, predictions.y[0]).mean(1)
indices = mean_losses.argsort(descending=True)
for idx in range(3):
    best_tft.plot_prediction(raw.x, raw.output, idx=indices[idx], add_loss_to_title=SMAPE(quantiles=best_tft.loss.quantiles))


In [None]:

# Actuals vs predictions across variable bins
pred_x = best_tft.predict(val_dataloader, return_x=True, trainer_kwargs=dict(accelerator="cpu"))
pv = best_tft.calculate_prediction_actual_by_variable(pred_x.x, pred_x.output)
best_tft.plot_prediction_actual_by_variable(pv);
