# 🧠 Temporal Fusion Transformer (TFT) - Beispiel

Ein praktisches Beispiel für Zeitreihen-Vorhersage mit PyTorch Forecasting und dem Temporal Fusion Transformer (TFT).

In [None]:
# 📦 Notwendige Pakete installieren
# (auskommentieren, wenn bereits installiert)
# !pip install pytorch-lightning pytorch-forecasting pandas scikit-learn

In [None]:
# 📁 1. Daten laden und vorbereiten
import pandas as pd
from pytorch_forecasting.data.examples import get_stallion_data

# Beispieldaten (ähnlich Rossmann-Umsätze)
data = get_stallion_data()

# Nur eine Agentur zur Vereinfachung
data = data[data["agency"] == "Agency_01"]

# Zeitindex erstellen
data["time_idx"] = data["month"].dt.month + data["month"].dt.year * 12

# Sortieren
data = data.sort_values(["agency", "sku", "time_idx"])
data.head()

In [None]:
# 📁 2. TFT-Datensatz definieren
from pytorch_forecasting.data import TimeSeriesDataSet
from torch.utils.data import DataLoader

max_encoder_length = 12
max_prediction_length = 6
training_cutoff = data["time_idx"].max() - max_prediction_length

training = TimeSeriesDataSet(
    data[lambda x: x.time_idx <= training_cutoff],
    time_idx="time_idx",
    target="volume",
    group_ids=["agency", "sku"],
    max_encoder_length=max_encoder_length,
    max_prediction_length=max_prediction_length,
    time_varying_unknown_reals=["volume"],
    time_varying_known_reals=["time_idx"],
    static_categoricals=["agency", "sku"],
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
)

train_dataloader = training.to_dataloader(train=True, batch_size=64)

In [None]:
# 📁 3. Modell trainieren
from pytorch_forecasting.models import TemporalFusionTransformer
from pytorch_lightning import Trainer

# Modell definieren
tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=16,
    attention_head_size=1,
    dropout=0.1,
    loss="QuantileLoss",
    log_interval=10,
    reduce_on_plateau_patience=4,
)

# Trainer
trainer = Trainer(
    max_epochs=20,
    gradient_clip_val=0.1,
    enable_checkpointing=True,
    logger=False,
    enable_model_summary=True,
)

trainer.fit(tft, train_dataloader)

In [None]:
# 📁 4. Vorhersage erstellen
test = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_randomization=True)
test_dataloader = test.to_dataloader(train=False, batch_size=64)

raw_predictions, x = tft.predict(test_dataloader, mode="raw", return_x=True)

In [None]:
# 📈 5. Plotten der Vorhersage
tft.plot_prediction(x, raw_predictions, idx=0)