In [None]:
!git clone https://github.com/Taeksu-Kim/Temporal_Fusion_Transformer.git

In [None]:
cd ./Temporal_Fusion_Transformer

In [None]:
!pip install wget pyunpack patool

In [None]:
!pip install torchsummaryX

In [None]:
#common
import pandas as pd
import numpy as np
import os
import gc
import json
import random
from tqdm import tqdm

from torchsummaryX import summary

import torch
import torch.nn as nn
from torch.utils.data import Dataset

# custom
from utils import data_downloader

from data_formatters.volatility import VolatilityFormatter as data_formatter
# from data_formatters.electricity import ElectricityFormatter as data_formatter

from utils.hyperparam_opt import HyperparamOptManager

from data_formatters import base as base_formatters
import utils.utils as utils

from model import Temporal_Fusion_Transformer

In [None]:
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = False  # type: ignore
    torch.cuda.manual_seed_all(seed)

seed_everything(42)

## Data Load & Preprocess

In [None]:
data_csv_path = data_downloader.make_csv('volatility')
# data_csv_path = data_downloader.make_csv('electricity')

raw_data = pd.read_csv(data_csv_path, index_col=0)
data_formatter = data_formatter()

train, valid, test = data_formatter.split_data(raw_data)
train_samples, valid_samples = data_formatter.get_num_samples_for_calibration()

model_folder = './fixed'
if not os.path.exists(model_folder):
  os.makedirs(model_folder)

fixed_params = data_formatter.get_experiment_params()
params = data_formatter.get_default_model_params()
params["model_folder"] = model_folder

In [None]:
raw_data = pd.read_csv(data_csv_path, index_col=0)

In [None]:
train, valid, test = data_formatter.split_data(raw_data)

In [None]:
train

In [None]:
valid

In [None]:
test

In [None]:
# Sets up hyperparam manager
print("*** Loading hyperparm manager ***")
opt_manager = HyperparamOptManager({k: [params[k]] for k in params},
                                    fixed_params, model_folder)

params = opt_manager.get_next_parameters()

In [None]:
# Training -- one iteration only
print("*** Running calibration ***")
print("Params Selected:")
for k in params:
  print("{}: {}".format(k, params[k]))

In [None]:
print("*** Running calibration ***")

In [None]:
# Default input types.
InputTypes = base_formatters.InputTypes

num_encoder_steps = params['num_encoder_steps']

params

In [None]:
""" configuration json을 읽어들이는 class """
class Config(dict): 
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__

    @classmethod
    def load(cls, file):
        with open(file, 'r') as f:
            config = json.loads(f.read())
            return Config(config)

In [None]:
config = Config(params)
config.lstm_num_layers = 1
config.quantiles = [0.1, 0.5, 0.9]
print(config)

In [None]:
train.columns

In [None]:
column_definition = params['column_definition']

In [None]:
train_max_sample = utils.cal_max_sample(train, InputTypes, config)
valid_max_sample = utils.cal_max_sample(valid, InputTypes, config)

In [None]:
train_max_sample 

In [None]:
train_data =  utils.batch_sampled_data(train, train_max_sample-1, InputTypes, config)
valid_data =  utils.batch_sampled_data(valid, valid_max_sample-1, InputTypes, config)

## Making Input Data

In [None]:
class tft_dataset(Dataset):

  def __init__(self, data):
    self.data = data

  def __len__(self):
    return self.data['inputs'].shape[0]

  def __getitem__(self, index):
    return {
        'inputs' : self.data['inputs'][index],
        'outputs' : self.data['outputs'][index],
        'active_entries' : self.data['active_entries'][index],
        # 'time' : self.data['time'][index],
        # 'identifier' : self.data['identifier'][index],
    }

In [None]:
len(tft_dataset(train_data))

In [None]:
train_data.keys()

In [None]:
batch_size = 64

train_dataset = tft_dataset(train_data)
valid_dataset = tft_dataset(valid_data)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=0, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, num_workers=0, shuffle=True)

In [None]:
for i, batch in enumerate(train_dataloader):
  if i == 0:
    break

In [None]:
batch['inputs'].shape

In [None]:
config['column_definition']

In [None]:
# input_columns 
input_col_list = ['log_vol', 'open_to_close', 'days_from_start', 'day_of_week', 'day_of_month', 'week_of_year', 'month', 'Region']

for i in range(len(input_col_list)):
  print(i,':',input_col_list[i])

In [None]:
config

### Model Definition

In [None]:
from model import Temporal_Fusion_Transformer

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
model = Temporal_Fusion_Transformer(config)
model.to(device)

In [None]:
summary(model, torch.rand(64, 257, 8).to(device))
# summary(tft, torch.rand(64, 257, 8).to(device), torch.rand(64, 5, 1).to(device))

### Train

In [None]:
epochs = 5
learning_rate = 1e-5
weight_decay = 1e-2

gradient_accumulation = False
gradient_scaler = True
use_lr_scheduler = False

early_stopping_patience = 2

save_name = 'tft_model'

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

In [None]:
scaler = torch.cuda.amp.GradScaler(enabled=True)

In [None]:
def train_step(batch_item, epoch, batch, training):
    inputs = batch_item['inputs'].to(device)
    labels = batch_item['outputs'].to(device)

    if training is True:
        model.train()
        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            output = model(inputs=inputs,
                          labels=labels,
                          )

            loss = output['loss']
            loss = torch.sum(loss, dim=-1)
            loss = torch.mean(loss, dim=-1)

        loss.backward()
        optimizer.step()
            
        lr = optimizer.param_groups[0]["lr"]

        return loss, round(lr, 10)

    else:
        model.eval()
        with torch.no_grad():
            output = model(inputs=inputs,
                          labels=labels,
                          )

            loss = output['loss']
            loss = torch.sum(loss, dim=-1)
            loss = torch.mean(loss, dim=-1)

        return loss

In [None]:
%%time
# train

loss_plot, val_loss_plot = [], []
lrs = []

check_list = []

best_val_acc = 0
best_val_loss = 100

best_epoch = 0
patience = 0

for epoch in range(epochs):
    gc.collect()
    total_loss, total_val_loss = 0, 0
    
    tqdm_dataset = tqdm(enumerate(train_dataloader))
    training = True
    for batch, batch_item in tqdm_dataset:
        batch_loss, lr = train_step(batch_item, epoch, batch, training)
        total_loss += batch_loss
        
        tqdm_dataset.set_postfix({
            'Epoch': epoch + 1,
            'LR' : lr,
            'Loss': '{:04f}'.format(batch_loss.item()),
            'Total Loss' : '{:04f}'.format(total_loss/(batch+1)),
        })
            
    loss_plot.append(total_loss/(batch+1))
    
    tqdm_dataset = tqdm(enumerate(valid_dataloader))
    training = False
    for batch, batch_item in tqdm_dataset:
        batch_loss = train_step(batch_item, epoch, batch, training)
        total_val_loss += batch_loss
        
        tqdm_dataset.set_postfix({
            'Epoch': epoch + 1,
            'Val Loss': '{:04f}'.format(batch_loss.item()),
            'Total Val Loss' : '{:04f}'.format(total_val_loss/(batch+1)),
        })
    val_loss_plot.append(total_val_loss/(batch+1)) 

    cur_val_loss = total_val_loss/(batch+1)
    
    if cur_val_loss < best_val_loss:
        print(f'best_val_acc is updated from {best_val_loss} to {cur_val_loss} on epoch {epoch+1}')
        best_val_loss = cur_val_loss
        best_epoch = epoch+1
        torch.save(model.state_dict(), './'+save_name+'.ckpt')
        patience = 0
    else:
        patience += 1
    
    if use_lr_scheduler == True:
        scheduler.step(metrics=total_val_loss/(batch+1)) 
    
    lrs.append(lr)
    
    if patience == early_stopping_patience:
        break

In [None]:
# 모델 로드
model.load_state_dict(torch.load('./'+save_name+'.ckpt'))

In [None]:
def batch_data(data, config):
    """Batches data for training.
    Converts raw dataframe from a 2-D tabular format to a batched 3-D array
    to feed into Keras model.
    Args:
      data: DataFrame to batch
    Returns:
      Batched Numpy array with shape=(?, self.time_steps, self.input_size)
    """

    # Functions.
    def _batch_single_entity(input_data):
        time_steps = len(input_data)
        lags = config.total_time_steps
        x = input_data.values
        if time_steps >= lags:
            return np.stack(
                [x[i:time_steps - (lags - 1) + i, :] for i in range(lags)], axis=1)

        else:
            return None

    id_col = utils.get_single_col_by_input_type(InputTypes.ID, config.column_definition)
    time_col = utils.get_single_col_by_input_type(InputTypes.TIME, config.column_definition)
    target_col = utils.get_single_col_by_input_type(InputTypes.TARGET, config.column_definition)
    input_cols = [
        tup[0]
        for tup in config.column_definition
        if tup[2] not in {InputTypes.ID, InputTypes.TIME}
    ]

    data_map = {}
    for _, sliced in data.groupby(id_col):

        col_mappings = {
            'identifier': [id_col],
            'time': [time_col],
            'outputs': [target_col],
            'inputs': input_cols
        }

        for k in col_mappings:
            cols = col_mappings[k]
            arr = _batch_single_entity(sliced[cols].copy())

            if k not in data_map:
                data_map[k] = [arr]
            else:
                data_map[k].append(arr)

    # Combine all data
    for k in data_map:
        data_map[k] = np.concatenate(data_map[k], axis=0)

    # Shorten target so we only get decoder steps
    data_map['outputs'] = data_map['outputs'][:, config.num_encoder_steps:, :]

    active_entries = np.ones_like(data_map['outputs'])
    if 'active_entries' not in data_map:
        data_map['active_entries'] = active_entries
    else:
        data_map['active_entries'].append(active_entries)

    return data_map

In [None]:
def predict(df, return_targets=False):
    """Computes predictions for a given input dataset.
    Args:
      df: Input dataframe
      return_targets: Whether to also return outputs aligned with predictions to
        faciliate evaluation
    Returns:
      Input dataframe or tuple of (input dataframe, algined output dataframe).
    """

    test_data = batch_data(test, config)
    test_dataset = tft_dataset(test_data)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, num_workers=0, shuffle=False)
    tqdm_dataset = tqdm(enumerate(test_dataloader))
    
    # Extract predictions for each quantile into different entries
    process_map = {
          'p{}'.format(int(q * 100)):
          []
          for i, q in enumerate(config.quantiles)
    }    
    
    time = test_data['time']
    identifier = test_data['identifier']
    outputs = test_data['outputs']

    for batch, batch_item in tqdm_dataset:
      inputs = batch_item['inputs'].to(device)

      combined = model(inputs)['outputs'].detach().cpu()
      
      for i, q in enumerate(config.quantiles):
          process_map['p{}'.format(int(q * 100))].extend(combined[Ellipsis, i * config.output_size:(i + 1) * config.output_size])
      
    for i in range(len(process_map.keys())):
      process_map[list(process_map.keys())[i]] = torch.stack(process_map[list(process_map.keys())[i]], dim=0)
  
    # Format output_csv

    def format_outputs(prediction):
        """Returns formatted dataframes for prediction."""

        flat_prediction = pd.DataFrame(
            prediction[:, :, 0],
            columns=[
                't+{}'.format(i+1)
                for i in range(config.total_time_steps - config.num_encoder_steps)
            ])
        cols = list(flat_prediction.columns)
        flat_prediction['forecast_time'] = time[:, config.num_encoder_steps - 1, 0]
        flat_prediction['identifier'] = identifier[:, 0, 0]

        # Arrange in order
        return flat_prediction[['forecast_time', 'identifier'] + cols]
        # return flat_prediction[cols]

    if return_targets:
        # Add targets if relevant
        process_map['targets'] = outputs

    return {k: format_outputs(process_map[k]) for k in process_map}

In [None]:
batch_size = 64

In [None]:
result = predict(test, return_targets=True)

In [None]:
result.keys()

In [None]:
result['p50']

In [None]:
result['targets']

In [None]:
test[(test['date']=='2018-12-28')&(test['Symbol']=='.AEX')]['log_vol']

In [None]:
test[(test['date']=='2018-12-31')&(test['Symbol']=='.AEX')]['log_vol']

In [None]:
test[(test['date']=='2019-01-02')&(test['Symbol']=='.AEX')]['log_vol']

In [None]:
test[(test['date']=='2019-01-03')&(test['Symbol']=='.AEX')]['log_vol']