In [None]:
#| default_exp models.gmm_tft

# GMM TFT

> Combination of TFT with GMM loss

In [None]:
#| export
import torch
import torch.nn as nn

from neuralforecast.models.tft import TFT
from neuralforecast.losses.pytorch import MAE

In [None]:
#| hide
from fastcore.test import test_eq
from nbdev.showdoc import show_doc

In [None]:
#| hide
import logging
import warnings
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
warnings.filterwarnings("ignore")

In [None]:
#| export
class GMM_TFT(TFT):
    """
    """
    def __init__(self,
                 h,
                 input_size,
                 K=100,
                 tgt_size=1,
                 hidden_size=128,
                 s_cont_cols=None,
                 s_cat_cols=None,
                 o_cont_cols=None,
                 o_cat_cols=None,
                 k_cont_cols=None,
                 k_cat_cols=None,
                 s_cat_inp_lens=None,
                 s_cont_inp_size=0,
                 k_cat_inp_lens=None,
                 k_cont_inp_size=1,
                 o_cat_inp_lens=None,
                 o_cont_inp_size=0,
                 n_head=4,
                 attn_dropout=0.0,
                 dropout=0.1,
                 loss=MAE(),
                 learning_rate=1e-3,
                 batch_size=32,
                 windows_batch_size=1024,
                 step_size=1,
                 scaler_type='robust',
                 num_workers_loader=0,
                 drop_last_loader=False,
                 random_seed=1,
                 **trainer_kwargs):
        # Inherit TFT class and extend it with GMM
        super(GMM_TFT, self).__init__(h=h,
                                      input_size=input_size,
                                      tgt_size=tgt_size,
                                      hidden_size=hidden_size,
                                      s_cont_cols=s_cont_cols,
                                      s_cat_cols=s_cat_cols,
                                      o_cont_cols=o_cont_cols,
                                      o_cat_cols=o_cat_cols,
                                      k_cont_cols=k_cont_cols,
                                      k_cat_cols=k_cat_cols,
                                      s_cat_inp_lens=s_cat_inp_lens,
                                      s_cont_inp_size=s_cont_inp_size,
                                      k_cat_inp_lens=k_cat_inp_lens,
                                      k_cont_inp_size=k_cont_inp_size,
                                      o_cat_inp_lens=o_cat_inp_lens,
                                      o_cont_inp_size=o_cont_inp_size,
                                      n_head=n_head,
                                      attn_dropout=attn_dropout,
                                      dropout=dropout,
                                      loss=loss,
                                      learning_rate=learning_rate,
                                      batch_size=batch_size,
                                      windows_batch_size=windows_batch_size,
                                      step_size=step_size,
                                      scaler_type=scaler_type,
                                      num_workers_loader=num_workers_loader,
                                      drop_last_loader=drop_last_loader,
                                      random_seed=random_seed,
                                      **trainer_kwargs)
        
        # Define Mixture specialized parameters
        self.K = K

        # Adapter with Loss dependent dimensions
        self.output_adapter = nn.Linear(in_features=hidden_size, out_features=K)
    
    def training_step(self, batch, batch_idx):
        # Deviates from orignal `BaseWindows.training_step` to 
        # allow the model to receive future exogenous available
        # at the time of the prediction.
        
        # Create windows [Ws, L+H, C]
        windows = self._create_windows(batch, step='train')

        # Normalize windows
        if self.scaler is not None:
            windows = self._normalization(windows=windows)

        # outsample
        y_idx = batch['temporal_cols'].get_loc('y')
        mask_idx = batch['temporal_cols'].get_loc('available_mask')
        outsample_y = windows['temporal'][:, -self.h:, y_idx]
        outsample_mask = windows['temporal'][:, -self.h:, mask_idx]

        # [Ws, H, K]
        means_hat = self(x=windows)
        stds = (1/self.K) * torch.ones_like(means_hat).to(means_hat.device)
        weights = (1/self.K) * torch.ones_like(means_hat).to(means_hat.device)

        loss = self.loss(y=outsample_y[:,:,None], means=means_hat,
                         stds=stds, weights=weights,
                         mask=outsample_mask[:,:,None])
        self.log('train_loss', loss, prog_bar=True, on_epoch=True)
        return loss

    def predict_step(self, batch, batch_idx):
        # Deviates from orignal `BaseWindows.training_step` to 
        # allow the model to receive future exogenous available
        # at the time of the prediction.        
        
        # Create windows [Ws, L+H, C]
        windows = self._create_windows(batch, step='predict')

        # Normalize windows
        if self.scaler is not None:
            windows = self._normalization(windows=windows)

        # [Ws, H, K]
        means_hat = self(x=windows)
        stds = (1/self.K) * torch.ones_like(means_hat).to(means_hat.device)
        weights = (1/self.K) * torch.ones_like(means_hat).to(means_hat.device)

        _, quants = self.loss.sample(weights=weights,
                                     means=means_hat, stds=stds,
                                     num_samples=2000)

        # Inv Normalize
        if self.scaler is not None:
            quants = self._inv_normalization(y_hat=quants,
                                             temporal_cols=batch['temporal_cols'])

        return quants

In [None]:
show_doc(GMM_TFT)

In [None]:
show_doc(GMM_TFT.fit, name='TFT.fit')

In [None]:
show_doc(GMM_TFT.predict, name='TFT.predict')

In [None]:
#| hide
from fastcore.test import test_eq
from nbdev.showdoc import show_doc
from neuralforecast.utils import generate_series

## Usage Example

In [None]:
#| eval: false
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import matplotlib.pyplot as plt

from neuralforecast.utils import AirPassengers, AirPassengersPanel
from neuralforecast.tsdataset import TimeSeriesDataset, TimeSeriesLoader

from neuralforecast.losses.pytorch import MQLoss, GMM

Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]] # 132 train
Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]] # 12 test

dataset, *_ = TimeSeriesDataset.from_df(df = AirPassengersPanel)
model = GMM_TFT(h=12, input_size=48,
                hidden_size=100,
                k_cont_cols=['trend'], #['trend', 'y_[lag12]'],
                k_cont_inp_size=2,
                max_epochs=1,
                scaler_type='robust',
                loss=GMM(level=[80, 90]),
                windows_batch_size=None,
                enable_progress_bar=True)

model.fit(dataset=dataset, test_size=12)

# Parse quantile predictions
y_hat = model.predict(dataset=dataset)
Y_hat_df = pd.DataFrame.from_records(data=y_hat,
                columns=['TFT'+q for q in model.loss.output_names],
                index=Y_test_df.index)

# Plot quantile predictions
plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)
plot_df = plot_df[plot_df.unique_id=='Airline2']
plot_df = pd.concat([Y_train_df, plot_df])
plot_df = plot_df[plot_df.unique_id=='Airline2'].drop('unique_id', axis=1)

plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')
# plt.plot(plot_df['ds'], plot_df['TFT'], c='blue', label='median')
plt.plot(plot_df['ds'], plot_df['TFT-median'], c='blue', label='median')
plt.fill_between(x=plot_df['ds'], 
                 y1=plot_df['TFT-lo-90'], y2=plot_df['TFT-hi-90'],
                 alpha=0.4, label='level 90')
plt.grid()
plt.legend()
plt.plot()