### Try TemporalFusionTransformer
* Read: https://towardsdatascience.com/temporal-fusion-transformer-a-primer-on-deep-forecasting-in-python-4eb37f3f3594

In [None]:
import os

import pandas as pd
import numpy as np

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import QuantileLoss
from pytorch_lightning.loggers import WandbLogger

import warnings
warnings.filterwarnings('ignore')

DIR_DATA = '/media/user/12TB1/HanLi/GitHub/CMU11785-project/local_data'
NUM_WORKERS = 32 # Use 4 for AWS

args = {
    # ------------------------------
    # Basic config
    'random_seed': 11785,
    'n_samples': 10,
    'batch_size': 64,
    'n_workers' : NUM_WORKERS,
    'criterion': {
        'quantile': QuantileLoss(),
        'pearson': None,    # TODO: Miao to implement loss class
        'other': None,      # TODO: check out other loss
    },
    # ------------------------------
    # Hyperparameters
    'lr_s': 2e-1,
    'hidden_size': 256,
    'attention_head_size': 1,        # use multihead for large hidden size
    'dropout': 0.1,
    'hidden_continuous_size': 8,     # set to <= hidden_size
    'output_size': 7,                # 7 quantiles for QuantileLoss by default
    'reduce_on_plateau_patience': 4, # reduce learning rate if no improvement in validation loss after x epochs
    'gradient_clip_val': 0.1,
    # ------------------------------
    # Logging
    'log_interval': 5,               # log every n batches, set to None when try to find best lr
    'wandb_entity': '11785_project',
    'wandb_project': '11785_pf_test',
    'wandb_name': 'test_run_1',
}

### Create dataset and dataloaders

In [None]:
# load data, create validation and training dataset
dir_pf_dataset = os.path.join(DIR_DATA, 'pf_dataset')
n = args['n_samples']

train_dataset = TimeSeriesDataSet.load(os.path.join(dir_pf_dataset, f'pf_train_{n}_samples.pf'))
val_dataset = TimeSeriesDataSet.load(os.path.join(dir_pf_dataset, f'pf_val_{n}_samples.pf'))
val_dataset = TimeSeriesDataSet.load(os.path.join(dir_pf_dataset, f'pf_test_{n}_samples.pf'))

# create dataloaders for model
train_dataloader = train_dataset.to_dataloader(train=True, batch_size=args['batch_size'], num_workers=args['n_workers'])
val_dataloader = val_dataset.to_dataloader(train=False, batch_size=args['batch_size']*5, num_workers=args['n_workers'])

print("Load existing dataset completed.")

### Configure network and trainer

In [None]:
if args['random_seed'] is not None: pl.seed_everything(args['random_seed'])
trainer = pl.Trainer(
    gpus=1,
    gradient_clip_val=args['gradient_clip_val'],
)

tft_model = TemporalFusionTransformer.from_dataset(
    train_dataset,
    learning_rate=args['lr_s'],
    hidden_size=args['hidden_size'],  # most important hyperparameter apart from learning rate
    attention_head_size=args['attention_head_size'], # number of attention heads. Set to up to 4 for large datasets
    dropout=args['dropout'],  # between 0.1 and 0.3 are good values
    hidden_continuous_size=args['hidden_continuous_size'],
    output_size=args['output_size'],
    loss=args['criterion']['quantile'],
    log_interval=args['log_interval'],  # uncomment for learning rate finder and otherwise, e.g. to 10 for logging every 10 batches
    reduce_on_plateau_patience=args['reduce_on_plateau_patience'], # reduce learning rate if no improvement in validation loss after x epochs
)

print(f"Number of parameters in network: {tft_model.size()/1e3:.1f}k")

# find optimal learning rate
res = trainer.tuner.lr_find(
    tft_model,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
    max_lr=10.0,
    min_lr=1e-6,
)

print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()