In [None]:
import os, gc, torch

import warnings
warnings.filterwarnings("ignore")
import pandas as pd

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

from pytorch_forecasting import TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer, MultiNormalizer
from pytorch_forecasting.metrics import RMSE, MultiLoss

device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)

print(f'Using {device} backend.')

In [None]:
from configurations.config import *
@dataclass
class arguments:
    experiment = 'traffic'
    show_progress = True

config = ExperimentConfig(experiment=arguments.experiment)
formatter = config.data_formatter

In [None]:
df = formatter.read_file()
print(f'Total data shape {df.shape}')

from utils.metric import show_result
from utils.data import create_TimeSeriesDataSet
from utils.model import seed_torch
seed_torch(seed=config.seed)
train, validation, test = formatter.split(df)

parameters = config.model_parameters(ModelType.TFT)
batch_size = parameters['batch_size']
_, train_dataloader = create_TimeSeriesDataSet(
    train, formatter, batch_size, train=True
)
_, val_dataloader = create_TimeSeriesDataSet(validation, formatter, batch_size)
test_timeseries, test_dataloader = create_TimeSeriesDataSet(test, formatter, batch_size)

In [None]:
import tensorflow as tf
# click this and locate the lightning_logs folder path and select that folder. 
# this will load tensorbaord visualization
import tensorboard as tb
tf.io.gfile = tb.compat.tensorflow_stub.io.gfile

early_stop_callback = EarlyStopping(
    monitor="val_loss", min_delta=0, 
    patience=parameters['early_stopping_patience']
    , verbose=True, mode="min"
)
best_checkpoint = pl.callbacks.ModelCheckpoint(
    dirpath=config.experiment_folder, monitor="val_loss", 
    filename="best-{epoch}"
)
latest_checkpoint = pl.callbacks.ModelCheckpoint(
    dirpath=config.experiment_folder, 
    every_n_epochs=1, filename="latest-{epoch}"
)

logger = TensorBoardLogger(config.experiment_folder)  # logging results to a tensorboard

# https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-class-api
trainer = pl.Trainer(
    max_epochs = parameters['epochs'],
    accelerator = 'auto',
    enable_model_summary=True,
    callbacks = [early_stop_callback, best_checkpoint, latest_checkpoint],
    logger = logger,
    enable_progress_bar = arguments.show_progress,
    check_val_every_n_epoch = 2,
    max_time=pd.to_timedelta(1, unit='minutes')
)

tft = TemporalFusionTransformer.from_dataset(
    test_timeseries,
    learning_rate= parameters['learning_rate'],
    hidden_size= parameters['hidden_layer_size'],
    attention_head_size=parameters['attention_head_size'],
    dropout=parameters['dropout_rate'],
    loss=MultiLoss([RMSE(reduction='mean') for _ in formatter.targets]), # RMSE(reduction='sqrt-mean')
    optimizer='adam',
    log_interval=1,
    # reduce_on_plateau_patience=2
)

print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

In [None]:
from datetime import datetime
gc.collect()

start = datetime.now()
print(f'\n----Training started at {start}----\n')

trainer.fit(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
)
end = datetime.now()
print(f'\n----Training ended at {end}, elapsed time {end-start}')
print(f'Best model by validation loss saved at {trainer.checkpoint_callback.best_model_path}')

In [None]:
from classes.PredictionProcessor import PredictionProcessor

processor = PredictionProcessor(
    formatter.time_index, formatter.group_id, 
    formatter.parameters['horizon'], formatter.targets, 
    formatter.parameters['window']
)

# %%
from classes.Plotter import *

plotter = PlotResults(
   config.experiment_folder, formatter.time_index, 
   formatter.targets, show=arguments.show_progress
)

best_model_path = trainer.checkpoint_callback.best_model_path
print(f'Loading best model from {best_model_path}')

# tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)

In [None]:

# print('\n---Training prediction--\n')
# train_predictions, train_index = tft.predict(
#     train_dataloader, return_index=True, 
#     show_progress_bar=arguments.show_progress
# )
# train_result_merged = processor.align_result_with_dataset(
#    train, train_predictions, train_index
# )

# show_result(train_result_merged, formatter.targets)
# plotter.summed_plot(train_result_merged, type='Train_error', plot_error=True)
# gc.collect()

In [None]:
print(f'\n---Validation results--\n')

validation_predictions, validation_index = tft.predict(
    val_dataloader, return_index=True, 
    show_progress_bar=arguments.show_progress
)

In [None]:
validation_result_merged = processor.align_result_with_dataset(
   validation, validation_predictions, validation_index
)
show_result(validation_result_merged, formatter.targets)

plotter.summed_plot(validation_result_merged, type='Validation')
gc.collect()

In [None]:
print(f'\n---Test results--\n')

test_predictions, test_index = tft.predict(
    test_dataloader, return_index=True, 
    show_progress_bar=arguments.show_progress
)

test_result_merged = processor.align_result_with_dataset(
    test, test_predictions, test_index
)
show_result(test_result_merged, formatter.targets)
plotter.summed_plot(test_result_merged, 'Test')
gc.collect()

In [None]:
# train_result_merged['split'] = 'train'
validation_result_merged['split'] = 'validation'
test_result_merged['split'] = 'test'
df = pd.concat([validation_result_merged, test_result_merged])
df.to_csv(os.path.join(plotter.figPath, 'predictions.csv'), index=False)

print(f'Ended at {datetime.now()}. Elapsed time {datetime.now() - start}')