In [1]:
import pandas as pd, torch

from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, QuantileLoss
from pytorch_forecasting.data import GroupNormalizer
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from sklearn.metrics import mean_absolute_error, mean_squared_error

/Users/maxi/Documents/GitHub/OpenMeter_Analysis/Venv_OpenMeter/lib/python3.11/site-packages/lightning/fabric/__init__.py:40: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.


In [2]:
DATA_PATH = "/Users/maxi/Desktop/train_data.csv"
TEST_PATH = "/Users/maxi/Desktop/test_data.csv"

train = pd.read_csv(DATA_PATH, nrows=50_000, parse_dates=["Zeitstempel"])
test  = pd.read_csv(TEST_PATH,  nrows=10_000, parse_dates=["Zeitstempel"])

  train = pd.read_csv(DATA_PATH, nrows=50_000, parse_dates=["Zeitstempel"])


In [3]:
df = pd.concat([train.assign(split="train"), test.assign(split="test")])
df["meter_id"] = df["location_id"]

origin = df["Zeitstempel"].min()
df["time_idx"] = ((df["Zeitstempel"] - origin).dt.total_seconds() // 3600).astype(int)

for col in ("month","weekday","hour"):
    df[col] = getattr(df["Zeitstempel"].dt, col).astype(str)

df["is_holiday"] = df["Ferientyp"].notna().astype(str)
df["Ferientyp"]  = df["Ferientyp"].fillna("None").astype(str)
df["post_code"]  = df["post_code"].astype(str)
df["city"]       = df["city"].astype(str)

train = df[df.split=="train"].drop(columns="split")
test  = df[df.split=="test"].drop(columns="split")

In [4]:
ENC_LEN, PRED_LEN = 72, 24
training = TimeSeriesDataSet(
    train,
    time_idx="time_idx",
    target="Messwert",
    group_ids=["meter_id"],
    static_categoricals=["city","post_code"],
    time_varying_known_categoricals=["month","weekday","hour","is_holiday","Ferientyp"],
    time_varying_known_reals=["time_idx"],
    time_varying_unknown_reals=["Messwert"],
    max_encoder_length=ENC_LEN,
    max_prediction_length=PRED_LEN,
    target_normalizer=GroupNormalizer(groups=["meter_id"]),
    allow_missing_timesteps=True,
)

batch = 64
train_dl = training.to_dataloader(train=True, batch_size=batch, num_workers=2, persistent_workers=True)
val_dl   = training.to_dataloader(train=False,batch_size=batch, num_workers=2, persistent_workers=True)

In [None]:
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor

seed_everything(42)

tft = TemporalFusionTransformer.from_dataset(
    training,
    hidden_size=16,
    attention_head_size=2,
    hidden_continuous_size=8,
    dropout=0.2,
    learning_rate=0.001,
    loss=QuantileLoss(),
)

early_stop = EarlyStopping(
    monitor="val_loss",   # <- überwachte Metrik
    patience=3,
    mode="min",
    verbose=True,
)

trainer = Trainer(
    max_epochs        = 30,
    accelerator       = "auto",  
    devices           = 1,
    precision         = 32,
    gradient_clip_val = 0.1,
    callbacks         = [early_stop, LearningRateMonitor("epoch")],
)

trainer.fit(tft, train_dataloaders=train_dl, val_dataloaders=val_dl)

Seed set to 42
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/maxi/Documents/GitHub/OpenMeter_Analysis/Venv_OpenMeter/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default

   | Name                               | Type                            | Params
----------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0     
1 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/maxi/Documents/GitHub/OpenMeter_Analysis/Venv_OpenMeter/lib/python3.11/site-packages/lightning/fabric/__init__.py:40: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
/Users/maxi/Documents/GitHub/OpenMeter_Analysis/Venv_OpenMeter/lib/python3.11/site-packages/lightning/fabric/__init__.py:40: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
/Users/maxi/Documents/GitHub/OpenMeter_Analysis/Venv_OpenMeter/lib/python3.11/site-packages/lightning/fabric/__init__.py:40: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11

Training: |          | 0/? [00:00<?, ?it/s]

In [None]:
test_ds  = TimeSeriesDataSet.from_dataset(training, test, stop_randomization=True)
test_dl  = test_ds.to_dataloader(train=False, batch_size=batch)
pred, ix = tft.predict(test_dl, mode="prediction", return_index=True)
y_true   = test.iloc[ix]["Messwert"].to_numpy()

print("MAE :", mean_absolute_error(y_true, pred))
print("RMSE:", mean_squared_error(y_true, pred, squared=False))

In [None]:
trainer.save_checkpoint("/content/tft_stromverbrauch.ckpt")
print("Checkpoint gespeichert.")