# Attempt to use Forecasting with the Temporal Fusion Transformer

### Required resources: 128 GB of RAM is required to train four parquet files.

- [Demand forecasting with the Temporal Fusion Transformer](https://pytorch-forecasting.readthedocs.io/en/stable/tutorials/stallion.html)
- [Tommaso Guerrini - temporal-fusion-transformer-in-pytorch v12](https://www.kaggle.com/code/tomwarrens/temporal-fusion-transformer-in-pytorch?scriptVersionId=106300693)

In [None]:
!pip install pytorch-forecasting
!pip install pytorch-forecasting[mqf2]
!pip install pytorch_optimizer

Collecting pytorch-forecasting
  Downloading pytorch_forecasting-1.2.0-py3-none-any.whl.metadata (13 kB)
Collecting lightning<3.0.0,>=2.0.0 (from pytorch-forecasting)
  Downloading lightning-2.5.0.post0-py3-none-any.whl.metadata (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.4/40.4 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
Downloading pytorch_forecasting-1.2.0-py3-none-any.whl (181 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m181.9/181.9 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning-2.5.0.post0-py3-none-any.whl (815 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m815.2/815.2 kB[0m [31m33.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lightning, pytorch-forecasting
Successfully installed lightning-2.5.0.post0 pytorch-forecasting-1.2.0
Collecting cpflows (from pytorch-forecasting[mqf2])
  Downloading cpflows-0.1.2.tar.gz (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import os
import sys
import gc
import copy
import random
from pathlib import Path
from tqdm import tqdm
import numpy as np
import polars as pl
import pandas as pd

import torch
import lightning.pytorch as pln
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, RichProgressBar, TQDMProgressBar
from lightning.pytorch.loggers import TensorBoardLogger

from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import MAE, SMAPE, PoissonLoss, QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters

sys.path.append("/kaggle/input/jane-street-real-time-market-data-forecasting")
import kaggle_evaluation.jane_street_inference_server

import warnings
warnings.filterwarnings("ignore")  # avoid printing out absolute paths

pd.options.display.max_columns = None
#pd.options.display.max_rows = None

!cat /etc/os-release | grep -oP "PRETTY_NAME=\"\K([^\"]*)" && uname -r
print(f"CONTAINER_NAME={os.environ.get('CONTAINER_NAME',None)}, BUILD_DATE={os.environ.get('BUILD_DATE',None)}, CUDA={os.environ.get('CUDA_VERSION', None)}")
!free -h
!nv_version="$(nvidia-smi --query-gpu=driver_version --format=csv,noheader)" && echo "My NVIDIA driver version is '${nv_version}'."
!ls -l /usr/local | grep cuda

def set_seed(seed=42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    pln.seed_everything(seed)
    # Set a fixed value for the hash seed

def reduce_mem_usage(df, float16_as32=True):
    start_mem = df.memory_usage().sum() / 1024**2
    print('Memory usage of dataframe is {:.2f} MB'.format(start_mem))
    for col in df.columns:
        col_type = df[col].dtype
        if col_type != object and str(col_type)!='category':
            c_min,c_max = df[col].min(),df[col].max()
            if str(col_type)[:3] == 'int':
                if c_min > np.iinfo(np.int8).min and c_max < np.iinfo(np.int8).max:
                    df[col] = df[col].astype(np.int8)
                elif c_min > np.iinfo(np.int16).min and c_max < np.iinfo(np.int16).max:
                    df[col] = df[col].astype(np.int16)
                elif c_min > np.iinfo(np.int32).min and c_max < np.iinfo(np.int32).max:
                    df[col] = df[col].astype(np.int32)
                elif c_min > np.iinfo(np.int64).min and c_max < np.iinfo(np.int64).max:
                    df[col] = df[col].astype(np.int64)  
            else:
                if c_min > np.finfo(np.float16).min and c_max < np.finfo(np.float16).max:
                    if float16_as32:
                        df[col] = df[col].astype(np.float32)
                    else:
                        df[col] = df[col].astype(np.float16)  
                elif c_min > np.finfo(np.float32).min and c_max < np.finfo(np.float32).max:
                    df[col] = df[col].astype(np.float32)
                else:
                    df[col] = df[col].astype(np.float64)
    end_mem = df.memory_usage().sum() / 1024**2
    print('Memory usage after optimization is: {:.2f} MB'.format(end_mem))
    print('Decreased by {:.1f}%'.format(100 * (start_mem - end_mem) / start_mem))
    return df

set_seed(2025)

Ubuntu 22.04.3 LTS
6.6.56+
CONTAINER_NAME=None, BUILD_DATE=20241217-203356, CUDA=12.2.2
               total        used        free      shared  buff/cache   available
Mem:            31Gi       1.1Gi        23Gi       1.0Mi       7.1Gi        29Gi
Swap:             0B          0B          0B
My NVIDIA driver version is '560.35.03'.
lrwxrwxrwx 1 root root   22 Nov 10  2023 cuda -> /etc/alternatives/cuda
lrwxrwxrwx 1 root root   25 Nov 10  2023 cuda-12 -> /etc/alternatives/cuda-12
drwxr-xr-x 1 root root 4096 Nov 10  2023 cuda-12.2


INFO: Seed set to 2025


## Load data

In [3]:
df = None
for i in range(6, 10):
    train = pl.read_parquet(f"/kaggle/input/jane-street-real-time-market-data-forecasting/train.parquet/partition_id={i}/part-0.parquet")
    print(f"Block: {i}, DateID: {train['date_id'].min():04d} - {train['date_id'].max():04d}, TimeID: {train['time_id'].min():03d} - {train['time_id'].max():03d}")
    train = train.with_columns(
        pl.col('date_id').cast(pl.Int64),
        pl.col('time_id').cast(pl.Int64),
    )
    if df is None:
        df = train
    else:
        df = df.vstack(train)

del train
_ = gc.collect()

test_path = '/kaggle/input/jane-street-real-time-market-data-forecasting/test.parquet'
test_df = pl.read_parquet(f"{test_path}/date_id=0")
print(f"Test Data:{test_df.shape}")
display(test_df.head(3))

df_time_idx = df.group_by(['date_id', 'time_id'], maintain_order=True).all(
    ).select(pl.col(['date_id', 'time_id']), pl.int_range(pl.len(), dtype=pl.UInt32).alias("time_idx"))
#display(df_time_idx.select(pl.col(["date_id", "time_id", 'time_idx'])))

df = df.join(df_time_idx, on=["date_id", "time_id"],  how="left")
display(df.select(pl.col(["date_id", "time_id", 'time_idx'])))

Block: 9, DateID: 1530 - 1698, TimeID: 000 - 967
Test Data:(39, 85)


row_id,date_id,time_id,symbol_id,weight,is_scored,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,…,feature_42,feature_43,feature_44,feature_45,feature_46,feature_47,feature_48,feature_49,feature_50,feature_51,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78
i64,i16,i16,i8,f32,bool,f32,f32,f32,f32,f32,f32,f32,f32,f32,f64,f64,f64,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,0,0,0,3.169998,True,0.0,0.0,0.0,0.0,0.0,-0.0,-0.0,-0.0,0.0,0.0,0.0,0.0,-0.0,0.0,0.0,,-0.0,,-0.0,-0.0,0.0,-0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.0,-0.0,…,,-0.0,,-0.0,0.0,-0.0,0.0,0.0,,0.0,,,-0.0,,-0.0,0.0,,0.0,0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.0,-0.0,-0.0,0.0,0.0,,,0.0,0.0,-0.0,-0.0
1,0,0,1,2.165993,True,0.0,-0.0,0.0,0.0,0.0,-0.0,-0.0,-0.0,0.0,0.0,0.0,0.0,-0.0,0.0,-0.0,,-0.0,,-0.0,-0.0,0.0,-0.0,0.0,0.0,0.0,0.0,-0.0,0.0,0.0,-0.0,-0.0,…,,-0.0,,-0.0,0.0,0.0,0.0,0.0,,0.0,,,-0.0,,-0.0,0.0,,0.0,0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.0,-0.0,-0.0,0.0,-0.0,,,0.0,0.0,0.0,0.0
2,0,0,2,3.06555,True,0.0,-0.0,0.0,0.0,0.0,-0.0,-0.0,-0.0,0.0,0.0,0.0,0.0,-0.0,0.0,0.0,,-0.0,,-0.0,-0.0,0.0,-0.0,0.0,0.0,0.0,-0.0,0.0,0.0,0.0,-0.0,-0.0,…,,-0.0,,-0.0,0.0,-0.0,-0.0,-0.0,,0.0,,,-0.0,,-0.0,0.0,,-0.0,-0.0,-0.0,0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0,,,0.0,0.0,-0.0,-0.0


date_id,time_id,time_idx
i64,i64,u32
1530,0,0
1530,0,0
1530,0,0
1530,0,0
1530,0,0
…,…,…
1698,967,163591
1698,967,163591
1698,967,163591
1698,967,163591


In [None]:
if 1 == -1:
    import seaborn as sns
    import matplotlib
    import matplotlib.pyplot as plt

    #corr_matrix = df.drop(['date_id','time_id']).corr()
    corr_matrix = df[[f"responder_{i}" for i in range(9)]].corr()
    mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
    colors = sns.color_palette('coolwarm', 16)
    levels = np.linspace(-1, 1, 16)
    cmap_plot, norm = matplotlib.colors.from_levels_and_colors(levels, colors, extend="max")

    fig, ax = plt.subplots(1, 1, figsize = (10, 10))

    mask_feature = np.triu(np.ones_like(corr_matrix, dtype=bool))
    sns.heatmap(corr_matrix, 
        mask = mask_feature | (np.abs(corr_matrix) < 0.01),
        annot=True, ax = ax, cbar=False,
        cmap = cmap_plot, 
        norm = norm, annot_kws={"size": 13, "color": 'black'}
    )

    ax.hlines(range(corr_matrix.shape[1]), *ax.get_xlim(), color = 'black')
    ax.vlines(range(corr_matrix.shape[1]), *ax.get_ylim(), color = 'black')

    ax.set_title('Correlation Matrix between each time series: absolute values under 0.01 are masked', 
                fontsize = 20, color = 'black', fontweight = 'bold');

In [5]:
# Select and drop columns with 100% null values
df = df.drop([col for col in df.columns if df.select(pl.col(col).null_count()).item() == df.height])

# Select (if not provided) and drop columns with only one unique value
bad_cols = [col for col in df.columns if df.select(pl.col(col).n_unique()).item() == 1]            
df = df.drop(bad_cols)

#for col in df.columns:
#    # Set datatype for a numeric column as per the datatype of the first non-null item
#    val = df.select(plr.col(col).drop_nulls().first()).item()
#    #df = df.with_columns(plr.col(col).cast(plr.Int16) if isinstance(val, int) else plr.col(col).cast(plr.Float32))
#    if isinstance(val, int):
#        df = df.with_columns(plr.col(col).cast(plr.Float32))

# Calculate None columns
#display(df.select(pl.all().is_null().any()))

# Calculate None rows
#display(df.with_columns(null_count = pl.sum_horizontal(pl.all().is_null())))

#df = df.fill_null(strategy="mean")
df = df.to_pandas()

# Number of Nans
#pd_df = df.to_pandas()
#print(pd_df.shape)
#print(pd_df.columns.to_list())
#pd_df.isna().sum(axis = 0).rename('nans_per_column_train').rename_axis('column').reset_index().set_index('column')
df = df.fillna(method='ffill').fillna(0)
#df = df.fillna(0)

feature_cols = []
for col in df.columns:
    if 'feature_' in col:
        feature_cols.append(col)

final_feature = ['time_idx', 'date_id', 'symbol_id', 'weight'] + feature_cols
df = df[final_feature + ['responder_6']]
#df = reduce_mem_usage(df, float16_as32=False)
df['time_idx'] = df['time_idx'].astype(np.int32)
df['date_id'] = df['date_id'].astype(np.int16)
df['symbol_id'] = df['symbol_id'].astype(np.int8)
print(f"df.shape:{df.shape}")
display(df.head(3))


df.shape:(6274576, 84)


Unnamed: 0,time_idx,date_id,symbol_id,weight,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,feature_32,feature_33,feature_34,feature_35,feature_36,feature_37,feature_38,feature_39,feature_40,feature_41,feature_42,feature_43,feature_44,feature_45,feature_46,feature_47,feature_48,feature_49,feature_50,feature_51,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78,responder_6
0,0,1530,0,3.084694,1.153571,1.563784,0.697396,0.756759,2.580965,0.171311,1.126353,0.536153,0.05715,11,7,76,-0.656288,2.110188,0.145784,0.0,-0.203291,0.0,-1.238222,-2.294707,-0.06356,-0.148218,1.721362,0.64558,1.477857,0.528492,1.153077,0.466157,0.145568,-0.546845,-0.694435,-0.163897,0.0,0.0,0.502917,0.910145,-0.507707,0.218792,0.412922,0.0,0.081268,0.0,0.0,-2.023247,0.0,-1.967165,0.262769,-0.426009,-3.682122,-1.549827,0.0,0.680807,0.0,0.0,-2.786826,0.0,-1.2279,0.044606,0.0,-2.540213,-2.19028,0.385893,-0.460265,-0.415684,-0.45772,-1.333965,-2.23413,-0.352034,3.125156,0.493488,-0.9591,1.284456,-0.275493,0.0,0.0,4.188457,3.666236,0.848177,0.999516,3.071231
1,0,1530,1,2.232906,0.553354,1.730064,0.990195,0.61149,2.023031,0.319015,1.183371,0.562853,0.057789,11,7,76,-1.063518,1.037634,-0.255358,0.0,-0.318528,0.0,-1.46613,-2.160217,0.009386,0.042186,0.319811,0.14307,1.866907,1.238242,-1.986826,-0.476918,0.408439,-0.689795,-0.619278,0.081413,0.0,0.0,1.130648,0.726115,2.071485,0.179241,0.045131,0.0,0.002134,0.0,0.0,-0.828163,0.0,-1.304763,0.870251,-0.09534,-0.888243,-0.159577,0.0,-0.00268,0.0,0.0,-1.736226,0.0,-2.354893,1.309985,0.0,-2.429267,-1.26697,0.385893,-0.24877,-0.286104,-0.455154,-1.797363,-2.535985,-0.734866,1.533782,0.033801,-0.960126,0.306505,-0.522036,0.0,0.0,1.138142,1.579439,0.179564,0.160609,1.979042
2,0,1530,2,2.404948,1.532503,2.095852,0.919688,0.583715,2.330047,0.337096,1.262236,0.49605,0.073556,81,2,59,-1.001967,1.10577,-0.304426,0.0,-0.531873,0.0,-1.301579,-1.615271,0.454406,-0.188808,0.01512,-0.159487,1.379064,0.604568,0.736194,0.522007,-0.183058,-0.632819,-0.839542,-0.20955,0.0,0.0,0.211059,0.788082,-0.57527,0.157013,0.178823,0.0,0.486033,0.0,0.0,-1.121402,0.0,-1.019831,0.741859,-1.735237,-0.707955,-0.510588,0.0,0.793936,0.0,0.0,-1.191118,0.0,-2.190607,1.381697,0.0,-1.829545,-0.867858,0.385893,-0.295958,-0.386221,-0.345102,-1.598371,-2.111468,-0.780465,0.848857,-0.152994,-1.219395,0.359229,-0.636138,0.0,0.0,0.445388,0.300118,-0.043114,-0.065761,-0.50626


In [None]:
#max_prediction_length = int(df['time_id'].max() // 5)
max_prediction_length = 100
max_encoder_length = max_prediction_length
training_cutoff = int(df['time_idx'].max() - max_prediction_length)
print(training_cutoff, max_prediction_length)

training = TimeSeriesDataSet(
    df[lambda x: x['time_idx'] <= training_cutoff],
    time_idx = 'time_idx',
    target = 'responder_6',
    group_ids = ['symbol_id'],  # 'date_id'
    #weight = ['weight'],
    min_encoder_length=max_encoder_length // 4,  # keep encoder length long (as it is in the validation set)
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    static_categoricals=[],  # use a string type / categorified string
    static_reals=['symbol_id'],
    time_varying_known_categoricals=[],
    time_varying_known_reals=['time_idx', 'date_id'],
    variable_groups={},  # group of categorical variables can be treated as one variable
    time_varying_unknown_categoricals=[],
    time_varying_unknown_reals=feature_cols,
    target_normalizer=GroupNormalizer(groups=['symbol_id'], transformation='softplus'),  # use softplus and normalize by group
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
    allow_missing_timesteps=True,  # <--
)

# create validation set (predict=True) which means to predict the last max_prediction_length points in time
# for each series
validation = TimeSeriesDataSet.from_dataset(training, df, predict=True, stop_randomization=True)

# create dataloaders for model
batch_size = 128  # set this between 32 to 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0)

#let's see how a naive model does

actuals = torch.cat([y for x, (y, weight) in iter(val_dataloader)])
baseline_predictions = Baseline().predict(val_dataloader)
print(f"{(actuals - baseline_predictions.cpu()).abs().mean().item():.4f}")

sm_loss = SMAPE().loss(actuals, baseline_predictions.cpu()).mean(axis = 1).median().item()
print(f"Median loss for naive prediction on validation: {sm_loss:.4f}")
print(training.get_parameters())

163491 100


## Create baseline model

In [None]:
# calculate baseline mean absolute error, i.e. predict next value as the last available value from the history
baseline_predictions = Baseline().predict(val_dataloader, return_y=True)
print(f"{MAE()(baseline_predictions.output, baseline_predictions.y):.4f}")

## Train the Temporal Fusion Transformer

In [None]:
if 1 == -1:
    # configure network and trainer
    pln.seed_everything(42)
    trainer = pln.Trainer(
        accelerator="cpu",
        # clipping gradients is a hyperparameter and important to prevent divergance
        # of the gradient for recurrent neural networks
        gradient_clip_val=0.1,
        #callbacks=[RichProgressBar()],  # <--
    )

    tft = TemporalFusionTransformer.from_dataset(
        training,
        # not meaningful for finding the learning rate but otherwise very important
        learning_rate=0.03,
        hidden_size=8,  # most important hyperparameter apart from learning rate
        # number of attention heads. Set to up to 4 for large datasets
        attention_head_size=1,
        dropout=0.1,  # between 0.1 and 0.3 are good values
        hidden_continuous_size=8,  # set to <= hidden_size
        loss=QuantileLoss(),
        optimizer="ranger",
        # reduce learning rate if no improvement in validation loss after x epochs
        # reduce_on_plateau_patience=1000,
    )
    print(f"Number of parameters in network: {tft.size() / 1e3:.1f}k")

In [None]:
if 1 == -1:
    # find optimal learning rate
    from lightning.pytorch.tuner import Tuner

    res = Tuner(trainer).lr_find(
        tft,
        train_dataloaders=train_dataloader,
        val_dataloaders=val_dataloader,
        max_lr=10.0,
        min_lr=1e-6,
        num_training=100,
    )

    print(f"suggested learning rate: {res.suggestion()}")
    fig = res.plot(show=True, suggest=True)
    fig.show()

In [None]:
from pytorch_forecasting.metrics import MultiHorizonMetric, MultiLoss, SMAPE
from pytorch_forecasting.metrics.base_metrics import AggregationMetric
import torch.nn as nn

class R2Loss(nn.Module):
    def __init__(self):
        super(R2Loss, self).__init__()

    def forward(self, y_pred, y_true):
        mse_loss = torch.sum((y_pred - y_true) ** 2)
        var_y = torch.sum(y_true ** 2)
        loss = mse_loss / (var_y + 1e-38)
        return loss

def r2_val(y_true, y_pred, sample_weight):
    residuals = sample_weight * (y_true - y_pred) ** 2
    weighted_residual_sum = np.sum(residuals)
    # Calculate weighted sum of squared true values (denominator)
    weighted_true_sum = np.sum(sample_weight * (y_true) ** 2)
    # Calculate weighted R2
    r2 = 1 - weighted_residual_sum / weighted_true_sum
    return r2

class R2LossMhm(MultiHorizonMetric):
    def __init__(self):
        super(R2LossMhm, self).__init__()

    def loss(self, y_pred, target):
        mse_loss = (y_pred.squeeze() - target) ** 2
        var_y = target ** 2
        loss = mse_loss / (var_y + 1e-38)
        return loss

class MAE(MultiHorizonMetric):
    def loss(self, y_pred, target):
        loss = (self.to_prediction(y_pred) - target).abs()
        return loss

class R2LossAgrM(AggregationMetric):
    def __init__(self):
        super(R2LossAgrM, self).__init__(metric=R2LossMhm())

    def loss(self, y_pred, y_true):
        mse_loss = torch.sum((y_pred - y_true) ** 2)
        var_y = torch.sum(y_true ** 2)
        loss = mse_loss / (var_y + 1e-38)
        return loss

class CustomProgressBar(TQDMProgressBar):
    def init_validation_tqdm(self):
        bar = super().init_validation_tqdm()
        bar.set_description("running validation...")
        return bar

# configure network and trainer
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")
lr_logger = LearningRateMonitor()  # log the learning rate
logger = TensorBoardLogger("lightning_logs")  # logging results to a tensorboard

trainer = pln.Trainer(
    max_epochs=50,
    accelerator="cpu",
    enable_model_summary=True,
    gradient_clip_val=0.1,
    limit_train_batches=50,  # coment in for training, running valiation every 30 batches
    # fast_dev_run=True,  # comment in to check that networkor dataset has no serious bugs
    callbacks=[lr_logger, early_stop_callback, RichProgressBar()],  # , CustomProgressBar()
    logger=logger,
)
print(f"logged_metrics: {trainer.logged_metrics}\ntrainer.state: {trainer.state}")

tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=16,
    attention_head_size=2,
    dropout=0.1,
    hidden_continuous_size=8,
    loss=QuantileLoss(),
    #loss=MAE(),
    #loss=MultiLoss(metrics=[MAE(), SMAPE()], weights=[2.0, 1.0]),
    #loss=R2LossMhm(),
    log_interval=10,  # uncomment for learning rate finder and otherwise, e.g. to 10 for logging every 10 batches
    logging_metrics=[R2LossAgrM()],
    optimizer="Ranger",
    reduce_on_plateau_patience=4,
)
print(f"Number of parameters in network: {tft.size() / 1e3:.1f}k")
print(tft.hparams)

# fit network
trainer.fit(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
)

## Evaluate performance

In [None]:
# load the best model according to the validation loss
# (given that we use early stopping, this is not necessarily the last epoch)
best_model_path = trainer.checkpoint_callback.best_model_path
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)

In [None]:
# calcualte mean absolute error on validation set
predictions = best_tft.predict(val_dataloader, return_y=True, trainer_kwargs=dict(accelerator="cpu"))
MAE()(predictions.output, predictions.y)

In [None]:
if 1 == -1:
    def r2_val(y_true, y_pred, sample_weight):
        residuals = sample_weight * (y_true - y_pred) ** 2
        weighted_residual_sum = np.sum(residuals)
        # Calculate weighted sum of squared true values (denominator)
        weighted_true_sum = np.sum(sample_weight * (y_true) ** 2)
        # Calculate weighted R2
        r2 = 1 - weighted_residual_sum / weighted_true_sum
        return r2

    predictions = best_tft.predict(val_dataloader, return_y=True, trainer_kwargs=dict(accelerator="cpu"))
    val_r2 = r2_val(predictions.y, predictions.output, weights_eval)  

In [None]:
# raw predictions are a dictionary from which all kind of information including quantiles can be extracted
raw_predictions = best_tft.predict(val_dataloader, mode="raw", return_x=True)

for idx in range(10):  # plot 10 examples
    best_tft.plot_prediction(raw_predictions.x, raw_predictions.output, idx=idx, add_loss_to_title=True)

### Worst performers

In [None]:
# calcualte metric by which to display
predictions = best_tft.predict(val_dataloader, return_y=True)
mean_losses = SMAPE(reduction="none").loss(predictions.output, predictions.y[0]).mean(1)
indices = mean_losses.argsort(descending=True)  # sort losses
for idx in range(10):  # plot 10 examples
    best_tft.plot_prediction(
        raw_predictions.x,
        raw_predictions.output,
        idx=indices[idx],
        add_loss_to_title=SMAPE(quantiles=best_tft.loss.quantiles),
    )

### Actuals vs predictions by variables

In [None]:
predictions = best_tft.predict(val_dataloader, return_x=True)
predictions_vs_actuals = best_tft.calculate_prediction_actual_by_variable(predictions.x, predictions.output)
best_tft.plot_prediction_actual_by_variable(predictions_vs_actuals)

## Predict on selected data

In [None]:
#best_tft.predict(
#    training.filter(lambda x: (x.agency == "Agency_01") & (x.sku == "SKU_01") & (x.time_idx_first_prediction == 15)),
#    mode="quantiles",
#)

best_tft.predict(
    training.filter(lambda x: (x.time_idx_first_prediction == 950)),
    mode="quantiles",
)