In [None]:
#default_exp hierarchical

In [None]:
#hide
%load_ext autoreload
%autoreload 2

# Hierarchical Time Series Forecasting

In [None]:
#hide
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

In [None]:
#hide
from fastcore.test import test_eq

In [None]:
#export
from typing import Callable, List, Optional

import numpy as np
import pandas as pd

from statsforecast.core import StatsForecast

In [102]:
#export
def _aggregate_key(df: pd.DataFrame, keys: List[List[str]], agg_fn: Callable = np.sum):
    """Aggregates `df` according to `keys` using `agg_fn`."""
    df = df.copy()
    df_keys = []
    orig_cols = df.drop(labels=['ds', 'y'], axis=1).columns.to_list()
    for key in keys:
        if key == ['total']:
            df = df.assign(total='total')
        df_key = df.groupby(key + ['ds'])['y'].apply(agg_fn).reset_index()
        df_key['unique_id'] = df_key[key].agg('_'.join, axis=1)
        df_keys.append(df_key)
    df_keys = pd.concat(df_keys)
    s_df = df_keys[['unique_id'] + orig_cols].drop_duplicates()
    y_df = df_keys[['unique_id', 'ds', 'y']].set_index('unique_id')
    return s_df, y_df

In [105]:
#hide
df = pd.read_csv('https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/tourism.csv')
df = df.rename({'Trips': 'y', 'Quarter': 'ds'}, axis=1)
s_df, y_df = _aggregate_key(df, [['total'],
                                 ['State'], 
                                 ['Purpose'], 
                                 ['State', 'Region'], 
                                 ['State', 'Purpose'], 
                                 ['State', 'Region', 'Purpose']])
test_eq(len(y_df), 34_000)
test_eq(y_df.index.nunique(), 425)

In [106]:
#export
class HierarchicalStatsForecast:
    
    def __init__(self, df: pd.DataFrame, keys: List[List[str]], 
                 models: List, freq: str, 
                 n_jobs: int = 1, ray_address: Optional[str] = None):
        s_df, y_df = _aggregate_key(df, keys=keys)[['unique_id', 'ds', 'y']]
        self.fcst = StatsForecast(df=y_df, model=models, freq=freq, 
                                  n_jobs=n_jobs, ray_address=ray_address)