In [None]:
#| default_exp models.fcgaga

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

# FC-GAGA

The FC-GAGA architecture is a multivariate time series forecasting model built on the combination of a fully-connected univariate forecasting NBEATS-like model and a hard graph mechanism. The FC-GAGA method proved state-of-the-art performance on two traffic forecasting datasets.

**References**<br>
-[FC-GAGA Original Tensorflow implementation.](https://github.com/boreshkinai/fc-gaga/blob/master/model.py)<br>
-[Boris N. Oreshkin, Arezou Amini, Lucy Coyle, Mark J. Coates (2021). "FC-GAGA: Fully Connected Gated Graph Architecture for Spatio-Temporal Traffic Forecasting". The Association for the Advancement of Artificial Intelligence Conference 2021 (AAAI 2021).](https://arxiv.org/pdf/2007.15531)<br>

In [None]:
#| hide
import logging
import warnings
from fastcore.test import test_eq
from nbdev.showdoc import show_doc
from neuralforecast.common._model_checks import check_model

In [None]:
#| export
from typing import Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from neuralforecast.losses.pytorch import MAE
from neuralforecast.common._base_model import BaseModel

In [None]:
#| hide
from fastcore.test import test_eq
from nbdev.showdoc import show_doc
from neuralforecast.utils import generate_series

import matplotlib.pyplot as plt

In [None]:
#| exporti
def _divide_no_nan(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    """
    Auxiliary funtion to handle divide by 0
    """
    div = a / b
    div[div != div] = 0.0
    div[div == float('inf')] = 0.0
    return div

In [None]:
#| exporti
class FcBlock(nn.Module):
    def __init__(self,
                 block_layers: int,
                 hidden_units: int,
                 input_size: int,
                 output_size: int,
                 device=None):
        super(FcBlock, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.block_layers = block_layers
        
        # Define fully connected layers
        self.fc_layers = nn.ModuleList()
        for i in range(block_layers):
            self.fc_layers.append(
                nn.Linear(in_features=(hidden_units if i > 0 else input_size), 
                          out_features=hidden_units,
                          device=device)
            )
        
        # Define forecast and backcast layers
        self.forecast = nn.Linear(hidden_units, output_size, device=device)
        self.backcast = nn.Linear(hidden_units, input_size, device=device)

    def forward(self, inputs):
        # Forward pass through fully connected layers with ReLU activation
        h = F.gelu(self.fc_layers[0](inputs))
        for i in range(1, self.block_layers):
            h = F.gelu(self.fc_layers[i](h))

        # Backcast and forecast
        backcast = F.gelu(inputs - self.backcast(h))
        forecast = self.forecast(h)
        return backcast, forecast


class FcGagaLayer(nn.Module):
    def __init__(self,
                 input_size: int,
                 h,
                 outputsize_multiplier,
                 hist_input_size,
                 stat_input_size,
                 n_series: int,
                 n_blocks: int,
                 block_layers: int,
                 hidden_units: int,
                 node_id_dim: int,
                 epsilon: float=10):
        super().__init__()

        self.n_series = n_series
        self.input_size = input_size
        self.h = h
        self.outputsize_multiplier = outputsize_multiplier
        self.hist_input_size = hist_input_size
        self.stat_input_size = stat_input_size
        self.fcgaga_input_size = (self.n_series+1) * self.input_size + node_id_dim # [B,N,(N+1)*T+S]
        self.epsilon = epsilon

        # TODO: Avoid one_hot_encoding and try embeddings instead
        # self.node_id_em = nn.Embedding(num_embeddings=num_nodes, embedding_dim=node_id_dim)
        self.node_id_em =  nn.Linear(in_features=stat_input_size, out_features=node_id_dim)

        self.time_gate1 = nn.Linear(in_features=(node_id_dim + self.input_size*hist_input_size),
                                    out_features=hidden_units)
        self.time_gate2 = nn.Linear(in_features=hidden_units,
                                    out_features=h)
        self.time_gate3 = nn.Linear(in_features=hidden_units,
                                    out_features=input_size)

        self.blocks = torch.nn.ModuleList()
        self.n_blocks = n_blocks
        for i in range(n_blocks):
            self.blocks.append(FcBlock(block_layers=block_layers, 
                                       hidden_units=hidden_units,
                                       input_size=self.fcgaga_input_size,
                                       output_size=h * outputsize_multiplier))

    def forward(self,
                insample_y: torch.Tensor,
                hist_exog: torch.Tensor,
                stat_exog: torch.Tensor,
                ) -> Tuple[torch.Tensor, torch.Tensor]:

        node_id = self.node_id_em(stat_exog) # [B,N,S]->[B,N,S]

        # ------------------------------------ Time Gate  -----------------------------------#
        time_gate = self.time_gate1(torch.concat([node_id, hist_exog], axis=2)) # [B,N,S],[B,N,X*T]->[B,N,S+X*T]->[B,N,hidden]
        time_gate_forward = self.time_gate2(time_gate)   # [B,N,hidden]->[B,N,H]
        time_gate_backward = self.time_gate3(time_gate)  # [B,N,hidden]->[B,N,T]

        insample_y = insample_y / (1.0 + time_gate_backward) # [B,N,T]

        # ----------------------------------- Graph Gate  -----------------------------------#
        node_embeddings = node_id[0,:,:] # [B,N,S]->[N,S]
        node_embeddings_dp = torch.einsum("ns,ms->nm", node_embeddings, node_embeddings) # [N,S]x[N,S]->[N,N]
        node_embeddings_dp = torch.exp(self.epsilon * node_embeddings_dp) # [N,N]
        
        level, _ = torch.max(insample_y, dim=-1, keepdim=True) # [B,N,T]->[B,N,1]
        
        all_node_history = torch.einsum("bnt,nm->bnmt", insample_y, node_embeddings_dp)
        all_node_history = torch.reshape(all_node_history, (-1, self.n_series, self.n_series * self.input_size)) # [B,N,N,T]->[B,N,N*T]
        all_node_history = _divide_no_nan(all_node_history - level, level)
        all_node_history = F.gelu(all_node_history)

        history = _divide_no_nan(insample_y, level)
        history = torch.concat([history, all_node_history], axis=2) # [B,N,(N+1)*T]
        history = torch.concat([history, node_id], axis=2) # [B,N,(N+1)*T+S]

        backcast, forecast_out = self.blocks[0](history) # [B,N,(N+1)*T+S]->[B,N,(N+1)*T+S],[B,N,H*D]
        for i in range(1, self.n_blocks):
            backcast, forecast_block = self.blocks[i](backcast)
            forecast_out = forecast_out + forecast_block

        # [B,N,H*D]->[B,N,H,D]
        forecast_out = torch.reshape(forecast_out, (-1, self.n_series, self.h, self.outputsize_multiplier))
        forecast = forecast_out * level[:,:,:,None] # [B,N,H,D] * [B,N,1,1]

        forecast = forecast * (1.0 + time_gate_forward[:,:,:,None]) # [B,N,H,D] * [B,N,H,1]
        return backcast, forecast

In [None]:
# |hide
# Test FCGaga forward pass
h = 7
input_size = 10
outputsize_multiplier = 2

hist_input_size = 5
stat_input_size = 3
n_series = 2

n_blocks = 2
block_layers = 2
hidden_units = 16
node_id_dim = 8
epsilon = 10.0

# Create the model
fcgaga_layer = FcGagaLayer(input_size=input_size,
                           h=h,
                           outputsize_multiplier=outputsize_multiplier,
                           hist_input_size=hist_input_size,
                           stat_input_size=stat_input_size,
                           n_series=n_series,
                           n_blocks=n_blocks,
                           block_layers=block_layers,
                           hidden_units=hidden_units,
                           node_id_dim=node_id_dim,
                           epsilon=epsilon)

batch_size = 4
insample_y = torch.randn(batch_size, n_series, input_size)
hist_exog = torch.randn(batch_size, n_series, hist_input_size * input_size)
stat_exog = torch.randn(batch_size, n_series, stat_input_size)

# Run forward pass
backcast, forecast = fcgaga_layer(insample_y=insample_y,
                                  hist_exog=hist_exog,
                                  stat_exog=stat_exog)

# Check that the forward shapes are correct
assert backcast.shape == torch.Size((batch_size, n_series, (n_series+1) * input_size + node_id_dim))
assert forecast.shape == torch.Size((batch_size, n_series, h, outputsize_multiplier))

In [None]:
#| export
class FCGAGA(BaseModel):
    """ FCGAGA

    The FC-GAGA architecture is a multivariate time series forecasting model built on the
    combination of a fully-connected univariate forecasting NBEATS-like model and a hard
    graph mechanism. The FC-GAGA method proved state-of-the-art performance on two traffic
    forecasting datasets.

    **Parameters:**<br>
    `h`: int, forecast horizon.<br>
    `input_size`: int, considered autorregresive inputs (lags), y=[1,2,3,4] input_size=2 -> lags=[1,2].<br>
    `n_s_hidden`: int=5, hidden size of static encoder.<br>
    `hidden_units`: int, Number of units of each hidden layer.<br>
    `n_blocks`: in, Number of blocks within each stack.<br>
    `n_stacks`: int, Number of stacks of the network.<br>
    `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>
    `valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>
    `max_steps`: int=1000, maximum number of training steps.<br>
    `learning_rate`: float=1e-3, Learning rate between (0, 1).<br>
    `num_lr_decays`: int=3, Number of learning rate decays, evenly distributed across max_steps.<br>
    `early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.<br>
    `val_check_steps`: int=100, Number of training steps between every validation loss check.<br>
    `batch_size`: int=32, number of different series in each batch.<br>
    `valid_batch_size`: int=None, number of different series in each validation and test batch, if None uses batch_size.<br>
    `windows_batch_size`: int=1024, number of windows to sample in each training batch, default uses all.<br>
    `inference_windows_batch_size`: int=-1, number of windows to sample in each inference batch, -1 uses all.<br>
    `start_padding_enabled`: bool=False, if True, the model will pad the time series with zeros at the beginning, by input size.<br>
    `step_size`: int=1, step size between each window of temporal data.<br>
    `scaler_type`: str='identity', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).<br>
    `random_seed`: int, random_seed for pytorch initializer and numpy generators.<br>
    `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.<br>
    `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.<br>
    `alias`: str, optional,  Custom name of the model.<br>
    `**trainer_kwargs`: int,  keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>

    **References:**<br>
    -[Boris N. Oreshkin, Arezou Amini, Lucy Coyle, Mark J. Coates (2021).
    "FC-GAGA: Fully Connected Gated Graph Architecture for Spatio-Temporal Traffic Forecasting". The Association for the Advancement of Artificial Intelligence Conference 2021 (AAAI 2021).](https://arxiv.org/pdf/2007.15531)
    """
    # Class attributes
    EXOGENOUS_FUTR = True
    EXOGENOUS_HIST = True
    EXOGENOUS_STAT = True    
    MULTIVARIATE = True     # If the model produces multivariate forecasts (True) or univariate (False)
    RECURRENT = False       # If the model produces forecasts recursively (True) or direct (False)
    
    def __init__(self,
                 h,
                 input_size,
                 futr_exog_list=None,
                 hist_exog_list=None,
                 stat_exog_list=None,
                 hidden_units: int=128,
                 n_s_hidden: int=5,
                 n_blocks: int=2,
                 block_layers: int=3,
                 n_stacks: int=3,
                 epsilon: float=10,
                 loss = MAE(),
                 valid_loss = None,
                 max_steps: int = 1000,
                 learning_rate: float = 1e-3,
                 num_lr_decays: int = 3,
                 early_stop_patience_steps: int =-1,
                 val_check_steps: int = 100,
                 batch_size: int = 32,
                 valid_batch_size: Optional[int] = None,
                 windows_batch_size: int = 1024,
                 inference_windows_batch_size: int = -1,
                 start_padding_enabled = False,
                 step_size: int = 1,
                 scaler_type: str ='identity',
                 random_seed: int = 1,
                 num_workers_loader: int = 0,
                 drop_last_loader: bool = False,
                 **trainer_kwargs):
        # Inherit BaseModel class
        super(FCGAGA, self).__init__(h=h,
                                     input_size=input_size,
                                     futr_exog_list=futr_exog_list,
                                     hist_exog_list=hist_exog_list,
                                     stat_exog_list=stat_exog_list,
                                     loss=loss,
                                     valid_loss=valid_loss,
                                     max_steps=max_steps,
                                     learning_rate=learning_rate,
                                     num_lr_decays=num_lr_decays,
                                     early_stop_patience_steps=early_stop_patience_steps,
                                     val_check_steps=val_check_steps,
                                     batch_size=batch_size,
                                     windows_batch_size=windows_batch_size,
                                     valid_batch_size=valid_batch_size,
                                     inference_windows_batch_size=inference_windows_batch_size,
                                     start_padding_enabled=start_padding_enabled,
                                     step_size=step_size,
                                     scaler_type=scaler_type,
                                     num_workers_loader=num_workers_loader,
                                     drop_last_loader=drop_last_loader,
                                     random_seed=random_seed,
                                     **trainer_kwargs)
        
        # Architecture
        self.hist_input_size = len(self.hist_exog_list)
        self.stat_input_size = len(self.stat_exog_list)

        self.stacks = torch.nn.ModuleList()
        self.n_stacks = n_stacks
        for i in range(n_stacks):
            fcgaga_layer = FcGagaLayer(h=h, input_size=input_size,
                                       outputsize_multiplier=self.loss.outputsize_multiplier,
                                       hist_input_size=self.hist_input_size,
                                       stat_input_size=self.stat_input_size,
                                       n_series=self.n_series,
                                       n_blocks=n_blocks,
                                       block_layers=block_layers,
                                       hidden_units=hidden_units,
                                       node_id_dim=n_s_hidden,
                                       epsilon=epsilon)
            self.stacks.append(fcgaga_layer)

    def forward(self, windows_batch):
        # Parse windows_batch
        insample_y = windows_batch['insample_y']
        hist_exog  = windows_batch['hist_exog']
        stat_exog  = windows_batch['stat_exog']

        # Reshape data for FC-GAGA
        insample_y = torch.reshape(insample_y, (-1, self.n_series, self.input_size)) # [B*N,T] -> [B,N,T]
        hist_exog = torch.reshape(hist_exog, (-1, self.n_series, self.hist_input_size * self.input_size))  # [B*N,X,T] -> [B,N,X,T]
        stat_exog = torch.reshape(stat_exog, (-1, self.n_series, self.stat_input_size)) # [B*N,S] -> [B,N,S]

        # FC-GAGA's forward
        _, forecast = self.stacks[0](insample_y=insample_y, hist_exog=hist_exog, stat_exog=stat_exog)
        for i, stack in enumerate(self.stacks[1:]):
            _, stack_forecast = stack(insample_y=insample_y,
                                      hist_exog=hist_exog,
                                      stat_exog=stat_exog)
            forecast = forecast + stack_forecast
        forecast = forecast / (self.n_stacks + 1)

        # Adapting output's domain, output tuple distribution parameters
        forecast = torch.reshape(forecast, (-1, self.h, self.loss.outputsize_multiplier)) # [B,N,H,D] -> [B*N,H,D]
        output = self.loss.domain_map(forecast)
        return output

In [None]:
# show_doc(FCGAGA)

In [None]:
# show_doc(FCGAGA.fit, name='FCGAGA.fit')

In [None]:
# show_doc(FCGAGA.predict, name='FCGAGA.predict')