In [1]:
#default_exp hierarchical

In [2]:
#hide
%load_ext autoreload
%autoreload 2

# Hierarchical Time Series Forecasting

In [3]:
#hide
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

In [4]:
#hide
from fastcore.test import test_eq

In [5]:
#export
from typing import Callable, List, Optional, Tuple

import numpy as np
import pandas as pd

from statsforecast.core import StatsForecast

In [6]:
#export
def _aggregate_key(df: pd.DataFrame, keys: List[List[str]], agg_fn: Callable = np.sum):
    """Aggregates `df` according to `keys` using `agg_fn`."""
    df = df.copy()
    max_len_idx = np.argmax([len(key) for key in keys])
    bottom_comb = keys[max_len_idx]
    orig_cols = df.drop(labels=['ds', 'y'], axis=1).columns.to_list()
    df_keys = []
    for key in keys:
        if key == ['total']:
            df = df.assign(total='total')
        df_key = df.groupby(key + ['ds'])['y'].apply(agg_fn).reset_index()
        df_key['unique_id'] = df_key[key].agg('_'.join, axis=1)
        if key == bottom_comb:
            bottom_keys = df_key['unique_id'].unique()
        df_keys.append(df_key)
    df_keys = pd.concat(df_keys)
    s_df = df_keys[['unique_id'] + orig_cols].drop_duplicates().reset_index(drop=True)
    s_df = s_df.set_index('unique_id')
    y_df = df_keys[['unique_id', 'ds', 'y']].set_index('unique_id')
    #s_mat definition
    s_mat = np.zeros((len(s_df), len(bottom_keys)))
    for idx, label in enumerate(s_df.index, start=0):
        if label == 'total':
            s_mat[idx] = 1
        else:
            s_mat[idx, [label in bt for bt in bottom_keys]] = 1
    return s_df, s_mat, y_df

In [7]:
#hide
df = pd.read_csv('https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/tourism.csv')
df = df.rename({'Trips': 'y', 'Quarter': 'ds'}, axis=1)
s_df, s_mat, y_df = _aggregate_key(df, [['total'],
                                         ['State'], 
                                         ['Purpose'], 
                                         ['State', 'Region'], 
                                         ['State', 'Purpose'], 
                                         ['State', 'Region', 'Purpose']])
test_eq(len(y_df), 34_000)
test_eq(y_df.index.nunique(), 425)
test_eq(s_mat.shape, (425, 304))

In [8]:
#export
def bottom_up(y_hat: np.ndarray, s_mat: np.ndarray):
    # size of y_hat = n_levels, horizon
    _, h = y_hat.shape
    n_levels, n_bottom = s_mat.shape
    g_mat = np.eye(n_bottom, n_levels, k=(n_levels - n_bottom))
    s_g_mat = s_mat @ g_mat
    return np.vstack([s_g_mat @ y_hat[:,i] for i in range(h)]).T

In [26]:
#export
class HierarchicalStatsForecast:
    
    def __init__(self, df: pd.DataFrame, keys: List[List[str]], 
                 models: List, freq: str, 
                 n_jobs: int = 1, ray_address: Optional[str] = None):
        self.s_df, self.s_mat, y_df = _aggregate_key(df, keys=keys)
        self.fcst = StatsForecast(df=y_df, models=models, freq=freq, 
                                  n_jobs=n_jobs, ray_address=ray_address,
                                  sort_df=False)
    
    def forecast(self, h: int, reconcile_fns: List[Callable] = [bottom_up],
                 xreg: Optional[pd.DataFrame] = None, 
                 level: Optional[Tuple] = None):
        fcsts = self.fcst.forecast(h=h, xreg=xreg, level=level)
        model_names = fcsts.drop(columns=['ds'], axis=1).columns.to_list()
        for model_name in model_names:
            fcsts_model = fcsts[model_name].values.reshape(-1, h)
            for reconcile_fn in reconcile_fns:
                reconcile_fn_name = reconcile_fn.__name__
                fcsts_model = reconcile_fn(fcsts_model, self.s_mat)
                fcsts[f'{reconcile_fn_name}_{model_name}'] = fcsts_model.flatten()
        return fcsts
    
    def cross_validation(self, h: int, test_size: int, 
                         input_size: Optional[int] = None,
                         reconcile_fns: List[Callable] = [bottom_up]):
        fcsts = self.fcst.cross_validation(h=h, test_size=test_size, input_size=input_size)
        model_names = fcsts.drop(columns=['ds', 'cutoff', 'y'], axis=1).columns.to_list()
        cutoffs = fcsts['cutoff'].unique()
        for model_name in model_names:
            for cutoff in cutoffs:
                cutoff_idx = fcsts['cutoff'] == cutoff
                fcsts_model = fcsts.loc[cutoff_idx, model_name].values.reshape(-1, h)
                for reconcile_fn in reconcile_fns:
                    reconcile_fn_name = reconcile_fn.__name__
                    fcsts_model = reconcile_fn(fcsts_model, self.s_mat)
                    fcsts.loc[cutoff_idx, f'{reconcile_fn_name}_{model_name}'] = fcsts_model.flatten()
        return fcsts

In [20]:
from statsforecast.models import naive

In [11]:
ds_int = df[['ds']].drop_duplicates().assign(ds_int = lambda df: np.arange(len(df)) + 1)

In [12]:
df = df.merge(ds_int, how='left', on=['ds']).drop('ds', axis=1)

In [13]:
df = df.rename(columns={'ds_int': 'ds'})

In [27]:
hier_fcst = HierarchicalStatsForecast(df, 
                                      keys=[['total'],
                                            ['State'], 
                                            ['Purpose'], 
                                            ['State', 'Region'], 
                                            ['State', 'Purpose'], 
                                            ['State', 'Region', 'Purpose']],
                                      models=[naive],
                                      freq='D', 
                                      n_jobs=-1)

In [30]:
hier_fcsts = hier_fcst.cross_validation(7, test_size=8)

In [32]:
hier_fcsts.query('unique_id == "total"')

Unnamed: 0_level_0,ds,cutoff,y,naive,bottom_up_naive
unique_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
total,73,72,26660.636719,25140.162109,25140.161276
total,74,72,24285.027344,25140.162109,25140.161276
total,75,72,24191.320312,25140.162109,25140.161276
total,76,72,26347.601562,25140.162109,25140.161276
total,77,72,27496.388672,25140.162109,25140.161276
total,78,72,26113.607422,25140.162109,25140.161276
total,79,72,26506.314453,25140.162109,25140.161276
total,74,73,24285.027344,26660.636719,26660.637739
total,75,73,24191.320312,26660.636719,26660.637739
total,76,73,26347.601562,26660.636719,26660.637739
