### Try TemporalFusionTransformer
* Read: https://towardsdatascience.com/temporal-fusion-transformer-a-primer-on-deep-forecasting-in-python-4eb37f3f3594

In [None]:
import warnings
warnings.filterwarnings('ignore')

import os
import sys

import pandas as pd
import numpy as np
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import QuantileLoss
from pytorch_lightning.loggers import WandbLogger


DIR_PROJECT = '/media/user/12TB1/HanLi/GitHub/CMU11785-project/'
sys.path.append(os.path.join(DIR_PROJECT, 'src'))
sys.path.append(os.path.join(DIR_PROJECT, 'utils'))
DIR_TRAINED = os.path.join(DIR_PROJECT, 'local_trained')

from criterions import Pearson

DIR_DATA = os.path.join(DIR_PROJECT, 'local_data')
NUM_WORKERS = 8 # Use 4 for AWS


# Baseline hyperparameters: 
# TFT_quantile_loss_tune_10
# https://wandb.ai/11785_project/11785_project_tuning
args = {
    # ------------------------------
    # Basic config
    'random_seed': 11785,
    'n_samples': 1000,
    'batch_size': 64,
    'n_workers' : NUM_WORKERS,
    'criterion': {
        'quantile': QuantileLoss(),
        'pearson': Pearson.Pearson(),   # Miao's implementation
        'other': None,                  # TODO: check out other loss (e.g., MSE)
    },
    # ------------------------------
    # Hyperparameters
    'lr_s': 7e-3,
    'hidden_size': 512,
    'attention_head_size': 2,        # use multihead for large hidden size
    'dropout': 0.1,
    'hidden_continuous_size': 4,     # set to <= hidden_size
    'output_size': 7,                # 7 quantiles for QuantileLoss by default
    'reduce_on_plateau_patience': 4, # reduce learning rate if no improvement in validation loss after x epochs
    'gradient_clip_val': 0.1,
    # ------------------------------
    # Logging
    'log_interval': 5,               # log every n batches, set to None when try to find best lr
    'wandb_entity': '11785_project',
    'wandb_project': '11785_project_tuning',
    'wandb_name': 'TFT_baseline_0426',
}

### Create dataset and dataloaders

In [None]:
# load data, create validation and training dataset
dir_pf_dataset = os.path.join(DIR_DATA, 'pf_dataset_tft')
n = args['n_samples']

train_dataset = TimeSeriesDataSet.load(os.path.join(dir_pf_dataset, f'pf_train_{n}_samples.pf'))
val_dataset = TimeSeriesDataSet.load(os.path.join(dir_pf_dataset, f'pf_val_{n}_samples.pf'))
test_dataset = TimeSeriesDataSet.load(os.path.join(dir_pf_dataset, f'pf_test_{n}_samples.pf'))

# create dataloaders for model
train_dataloader = train_dataset.to_dataloader(train=True, batch_size=args['batch_size'], num_workers=args['n_workers'])
val_dataloader = val_dataset.to_dataloader(train=False, batch_size=args['batch_size'], num_workers=args['n_workers'])
test_dataloader = test_dataset.to_dataloader(train=False, batch_size=args['batch_size'], num_workers=args['n_workers'])

print("Load existing datasets completed:")
print(f"Train dataset:  pf_train_{n}_samples.pf")
print(f"Val dataset:    pf_val_{n}_samples.pf")
print(f"Test dataset:   pf_test_{n}_samples.pf")

### Configure network and trainer

In [None]:
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")
lr_logger = LearningRateMonitor()  # log the learning rate

checkpoint_callback = ModelCheckpoint(
    monitor='val_SMAPE', 
    dirpath='/media/user/12TB1/HanLi/GitHub/CMU11785-project/logs/model_checkpoints/', 
    save_top_k=2, 
    filename='500-default-{epoch:02d}-{val_SMAPE:.2f}'
)

logger = WandbLogger(
    entity="11785_project",
    project="11785_project_tuning",
    name='TFT_baseline',
    log_model=True
)


trainer = pl.Trainer(
    max_epochs=20,
    gpus=1,
    weights_summary="top",
    gradient_clip_val=0.1,
    limit_train_batches=30,  # coment in for training, running valiation every 30 batches
    # fast_dev_run=True,  # comment in to check that networkor dataset has no serious bugs
    callbacks=[lr_logger, early_stop_callback, checkpoint_callback],
    logger=logger,
)


tft_model = TemporalFusionTransformer.from_dataset(
    train_dataset,
    learning_rate=args['lr_s'],
    hidden_size=args['hidden_size'],  # most important hyperparameter apart from learning rate
    attention_head_size=args['attention_head_size'], # number of attention heads. Set to up to 4 for large datasets
    dropout=args['dropout'],  # between 0.1 and 0.3 are good values
    hidden_continuous_size=args['hidden_continuous_size'],
    output_size=args['output_size'],
    loss=args['criterion']['quantile'],
    # loss=args['criterion']['pearson'],
    log_interval=args['log_interval'],  # uncomment for learning rate finder and otherwise, e.g. to 10 for logging every 10 batches
    reduce_on_plateau_patience=args['reduce_on_plateau_patience'], # reduce learning rate if no improvement in validation loss after x epochs
)

# fit network
trainer.fit(
    tft_model,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
)
torch.save(tft_model.state_dict(), os.path.join(DIR_TRAINED, 'tft_baseline.pth'))

In [29]:
for i_b, data in enumerate(test_dataloader):
    inp = data[0]
    target = data[1]
    print('='*100)
    print("==>> inp['encoder_cat'].shape: ", inp['encoder_cat'].shape)
    print("==>> inp['encoder_cont'].shape: ", inp['encoder_cont'].shape)
    print("==>> inp['encoder_target'].shape: ", inp['encoder_target'].shape)
    print("==>> inp['encoder_lengths'].shape: ", inp['encoder_lengths'].shape)
    print("==>> inp['decoder_cat'].shape: ", inp['decoder_cat'].shape)
    print("==>> inp['decoder_cont'].shape: ", inp['decoder_cont'].shape)
    print("==>> inp['decoder_target'].shape: ", inp['decoder_target'].shape)
    print("==>> inp['decoder_lengths'].shape: ", inp['decoder_lengths'].shape)
    print("==>> inp['decoder_time_idx'].shape: ", inp['decoder_time_idx'].shape)
    print("==>> inp['groups'].shape: ", inp['groups'].shape)
    print("==>> inp['target_scale'].shape: ", inp['target_scale'].shape)

    print("==>> target[0].shape: ", target[0].shape)

    # Make prediction, calculate metric
    # Y_hat = model(process(inp))
    # loss = criterion(Y_hat, target[0])
    if i_b > 1: break

==>> inp['encoder_cat'].shape:  torch.Size([64, 14, 1])
==>> inp['encoder_cont'].shape:  torch.Size([64, 14, 305])
==>> inp['encoder_target'].shape:  torch.Size([64, 14])
==>> inp['encoder_lengths'].shape:  torch.Size([64])
==>> inp['decoder_cat'].shape:  torch.Size([64, 3, 1])
==>> inp['decoder_cont'].shape:  torch.Size([64, 3, 305])
==>> inp['decoder_target'].shape:  torch.Size([64, 3])
==>> inp['decoder_lengths'].shape:  torch.Size([64])
==>> inp['decoder_time_idx'].shape:  torch.Size([64, 3])
==>> inp['groups'].shape:  torch.Size([64, 1])
==>> inp['target_scale'].shape:  torch.Size([64, 2])
==>> target[0].shape:  torch.Size([64, 3])
==>> inp['encoder_cat'].shape:  torch.Size([64, 14, 1])
==>> inp['encoder_cont'].shape:  torch.Size([64, 14, 305])
==>> inp['encoder_target'].shape:  torch.Size([64, 14])
==>> inp['encoder_lengths'].shape:  torch.Size([64])
==>> inp['decoder_cat'].shape:  torch.Size([64, 3, 1])
==>> inp['decoder_cont'].shape:  torch.Size([64, 3, 305])
==>> inp['decoder_

tensor([[-9.3967e-01, -9.9649e-01, -4.4687e-01],
        [ 1.1992e+00,  9.9912e-01,  4.5174e-01],
        [-1.0053e-01,  2.0277e-01,  3.5411e-01],
        [ 6.1046e-01,  1.8834e-02, -3.1667e+00],
        [-7.3379e-01,  5.4562e-02, -1.8218e-01],
        [ 2.2193e-01,  1.2711e+00, -2.5285e-01],
        [ 1.6738e-01,  2.9543e-01, -9.3722e-02],
        [-1.7903e-01,  7.2142e-01,  1.9773e-01],
        [ 2.8210e-01, -3.0660e-01,  1.1602e-01],
        [-1.3792e-02, -2.9306e-01,  2.1163e-01],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [-4.4023e-01,  7.0060e-01, -1.1780e-01],
        [ 7.5802e-01,  1.7436e+00,  9.9800e-01],
        [ 1.7983e-01, -9.2113e-01, -3.8856e-01],
        [-3.2562e-01, -3.9280e-01, -2.9481e-01],
        [-7.7240e-01,  8.9730e-01,  8.0999e-01],
        [-2.2687e-01,  2.5099e-01, -3.3362e-03],
        [-9.9442e-02,  5.3375e-01, -2.8614e-01],
        [ 8.1182e-01,  1.5101e-01, -5.9204e-01],
        [ 2.6194e+00,  5.8686e-01, -3.1339e-01],
        [ 1.7240e-02

### Load trained model (optional)

In [8]:
tft_model = TemporalFusionTransformer.from_dataset(
    test_dataset,
    learning_rate=args['lr_s'],
    hidden_size=args['hidden_size'],  # most important hyperparameter apart from learning rate
    attention_head_size=args['attention_head_size'], # number of attention heads. Set to up to 4 for large datasets
    dropout=args['dropout'],  # between 0.1 and 0.3 are good values
    hidden_continuous_size=args['hidden_continuous_size'],
    output_size=args['output_size'],
    loss=args['criterion']['quantile'],
    # loss=args['criterion']['pearson'],
    log_interval=args['log_interval'],  # uncomment for learning rate finder and otherwise, e.g. to 10 for logging every 10 batches
    reduce_on_plateau_patience=args['reduce_on_plateau_patience'], # reduce learning rate if no improvement in validation loss after x epochs
)
tft_model.load_state_dict(torch.load(os.path.join(DIR_TRAINED, 'tft_baseline_426.pth')))
tft_model.eval()


TemporalFusionTransformer(
  (loss): QuantileLoss()
  (logging_metrics): ModuleList(
    (0): SMAPE()
    (1): MAE()
    (2): RMSE()
    (3): MAPE()
  )
  (input_embeddings): MultiEmbedding(
    (embeddings): ModuleDict(
      (investment_id): Embedding(1000, 77)
    )
  )
  (prescalers): ModuleDict(
    (encoder_length): Linear(in_features=1, out_features=4, bias=True)
    (target_center): Linear(in_features=1, out_features=4, bias=True)
    (target_scale): Linear(in_features=1, out_features=4, bias=True)
    (f_0): Linear(in_features=1, out_features=4, bias=True)
    (f_1): Linear(in_features=1, out_features=4, bias=True)
    (f_2): Linear(in_features=1, out_features=4, bias=True)
    (f_3): Linear(in_features=1, out_features=4, bias=True)
    (f_4): Linear(in_features=1, out_features=4, bias=True)
    (f_5): Linear(in_features=1, out_features=4, bias=True)
    (f_6): Linear(in_features=1, out_features=4, bias=True)
    (f_7): Linear(in_features=1, out_features=4, bias=True)
    (f_8

### Test

In [None]:
# X(1 2 3 4 5) -> Yhat(6 7 8)
# X(2 3 4 5 6) -> Yhat(7 8 9)
# Metric = f(Yhat(6 7 8), Y(6 7 8))

test_results = trainer.test(tft_model, dataloaders=test_dataloader)