In [None]:
import warnings
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
import torch

from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import SMAPE, PoissonLoss, QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters

import seaborn as sns
import matplotlib.pyplot as plt

import tensorflow as tf 
import tensorboard as tb 
tf.io.gfile = tb.compat.tensorflow_stub.io.gfile

### Data Preprocessing

In [None]:
sales = pd.read_csv("sales_train_validation.csv")
calendar = pd.read_csv("calendar.csv")
prices = pd.read_csv("sell_prices.csv")

In [None]:
sales_long = sales.melt(
    id_vars=["id", "item_id", "dept_id", "cat_id", "store_id", "state_id"],
    var_name="d",
    value_name="demand"
)

sales_long = sales_long.merge(
    calendar[["d", "date", "wm_yr_wk", "event_name_1", "event_type_1", "snap_CA", "snap_TX", "snap_WI"]],
    on="d",
    how="left"
)
sales_long = sales_long.merge(prices, on=["store_id", "item_id", "wm_yr_wk"], how="left")


selected_items = sales_long["item_id"].unique()[:10]  # first 10 items
sales_long = sales_long[sales_long["item_id"].isin(selected_items)].copy()

In [None]:
sales_long["date"] = pd.to_datetime(sales_long["date"])
sales_long["time_idx"] = (sales_long["date"] - sales_long["date"].min()).dt.days
sales_long["month"] = sales_long["date"].dt.month
sales_long["day_of_week"] = sales_long["date"].dt.dayofweek
sales_long["day"] = sales_long["date"].dt.day

### Time series dataset

In [None]:
max_encoder_length = 28 * 4   # 4 weeks of history
max_prediction_length = 28    # forecast horizon
training_cutoff = sales_long["time_idx"].max() - max_prediction_length

training = TimeSeriesDataSet(
    sales_long[sales_long.time_idx <= training_cutoff],
    time_idx="time_idx",
    target="demand",
    group_ids=["item_id", "store_id"],
    min_encoder_length=max_encoder_length // 2,
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    static_categoricals=["item_id", "dept_id", "cat_id", "store_id", "state_id"],
    time_varying_known_categoricals=["event_name_1", "event_type_1"],
    time_varying_known_reals=["sell_price", "day", "day_of_week", "month"],
    time_varying_unknown_reals=["demand"],
    target_normalizer=GroupNormalizer(groups=["item_id", "store_id"], transformation="softplus"),
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
)

In [None]:
validation = TimeSeriesDataSet.from_dataset(training, sales_long, predict=True, stop_randomization=True)

In [None]:
batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=4)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=4)


### TFT model

In [None]:
early_stop_callback = EarlyStopping(monitor="val_loss", patience=5, mode="min")
lr_logger = LearningRateMonitor(logging_interval="epoch")
logger = TensorBoardLogger("lightning_logs", name="m5_tft")

trainer = pl.Trainer(
    max_epochs=30,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1,
    gradient_clip_val=0.1,
    callbacks=[early_stop_callback, lr_logger],
    logger=logger,
    log_every_n_steps=50,
)


In [None]:
tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.001,
    hidden_size=64,
    attention_head_size=4,
    dropout=0.1,
    hidden_continuous_size=64,
    loss=QuantileLoss(),
    output_size=7,  # 7 quantiles by default
    log_interval=10,
    reduce_on_plateau_patience=4,
)

In [None]:
trainer.fit(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
)

In [None]:
best_model_path = trainer.checkpoint_callback.best_model_path
print("Best model saved to:", best_model_path)

best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
raw_predictions, x = best_tft.predict(val_dataloader, mode="raw", return_x=True)
best_tft.plot_prediction(x, raw_predictions, idx=0)