# HierE2E Baseline

This notebook runs and evaluates HierE2E's baseline method predictions.

- It reads a preprocessed hierarchical dataset.
- It fits HierE2E's optimal reported configuration.
- It evaluates HierE2E's sCRPS and MSSE.

## References
- [GluonTS, DeepVARHierarchicalEstimator](https://ts.gluon.ai/stable/api/gluonts/gluonts.mx.model.deepvar_hierarchical.html?highlight=deepvarhierarchicalestimator#gluonts.mx.model.deepvar_hierarchical.DeepVARHierarchicalEstimator)
- [Syama Sundar Rangapuram, Lucien D Werner, Konstantinos,Benidis, Pedro Mercado, Jan Gasthaus, Tim Januschowski. (2021). End-to-End Learning of Coherent Probabilistic Forecasts for Hierarchical Time Series. Proceedings of the 38th International Conference on Machine Learning (ICML).](https://proceedings.mlr.press/v139/rangapuram21a.html)


<br>
You can run these experiments using GPU with Google Colab.

<a href="https://colab.research.google.com/github/Nixtla/hierarchicalforecast/blob/main/experiments/hierarchical_baselines/nbs/run_hiere2e.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!pip install mxnet-cu112

In [2]:
import mxnet as mx

assert mx.context.num_gpus() > 0

In [3]:
%%capture
!pip install "gluonts[mxnet,pro]" # Install gluonts + mxnet-GPU
!pip install git+https://github.com/Nixtla/hierarchicalforecast.git
!pip install git+https://github.com/Nixtla/datasetsforecast.git@feat/favorita_dataset

In [4]:
import pydantic
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from gluonts.mx.trainer import Trainer
from gluonts.dataset.hierarchical import HierarchicalTimeSeries
from gluonts.mx.model.deepvar_hierarchical import DeepVARHierarchicalEstimator

from hierarchicalforecast.evaluation import scaled_crps, msse
from datasetsforecast.hierarchical import HierarchicalData, HierarchicalInfo

## Auxiliary Functions

In [5]:
class HierarchicalDataset(object):
    # Class with loading, processing and
    # prediction evaluation methods for hierarchical data

    available_datasets = ['Labour','Traffic',
                          'TourismSmall','TourismLarge','Wiki2',
                          'OldTraffic', 'OldTourismLarge']

    @staticmethod
    def _get_hierarchical_scrps(hier_idxs, Y, Yq_hat, q_to_pred):
        # We use the indexes obtained from the aggregation tags
        # to compute scaled CRPS across the hierarchy levels
        scrps_list = []
        for idxs in hier_idxs:
            y      = Y[idxs, :]
            yq_hat = Yq_hat[idxs, :, :]
            scrps  = scaled_crps(y, yq_hat, q_to_pred)
            scrps_list.append(scrps)
        return scrps_list

    @staticmethod
    def _get_hierarchical_msse(hier_idxs, Y, Y_hat, Y_train):
        # We use the indexes obtained from the aggregation tags
        # to compute scaled CRPS across the hierarchy levels
        msse_list = []
        for idxs in hier_idxs:
            y       = Y[idxs, :]
            y_hat   = Y_hat[idxs, :]
            y_train = Y_train[idxs, :]
            crps    = msse(y, y_hat, y_train)
            msse_list.append(crps)
        return msse_list

    @staticmethod
    def _sort_hier_df(Y_df, S_df):
        # NeuralForecast core, sorts unique_id lexicographically
        # deviating from S_df, this class matches S_df and Y_hat_df order.
        Y_df.unique_id = Y_df.unique_id.astype('category')
        Y_df.unique_id = Y_df.unique_id.cat.set_categories(S_df.index)
        Y_df = Y_df.sort_values(by=['unique_id', 'ds'])
        return Y_df

    @staticmethod
    def _nonzero_indexes_by_row(M):
        return [np.nonzero(M[row,:])[0] for row in range(len(M))]

    @staticmethod
    def load_process_data(dataset, directory='./data'):
        # Load data
        data_info = HierarchicalInfo[dataset]
        Y_df, S_df, tags = HierarchicalData.load(directory=directory,
                                                 group=dataset)

        # Parse and augment data
        Y_df['ds'] = pd.to_datetime(Y_df['ds'])
        Y_df = HierarchicalDataset._sort_hier_df(Y_df=Y_df, S_df=S_df)

        # Obtain indexes for plots and evaluation
        hier_levels = ['Overall'] + list(tags.keys())
        hier_idxs = [np.arange(len(S_df))] +\
            [S_df.index.get_indexer(tags[level]) for level in list(tags.keys())]
        hier_linked_idxs = HierarchicalDataset._nonzero_indexes_by_row(S_df.values.T)

        # Final output
        data = dict(Y_df=Y_df, S_df=S_df, tags=tags,
                    # Hierarchical idxs
                    hier_idxs=hier_idxs,
                    hier_levels=hier_levels,
                    hier_linked_idxs=hier_linked_idxs,
                    # Dataset Properties
                    horizon=data_info.papers_horizon,
                    freq=data_info.freq,
                    seasonality=data_info.seasonality)
        return data

In [6]:
# Optimal parameters reported from ICML 2021 code
configs = {"Labour": {"epochs": 50, "num_batches_per_epoch": 50, "scaling": True, "pick_incomplete": False, "batch_size": 32, "num_parallel_samples": 200, "hybridize": False, "learning_rate": 0.001, "context_length": 24, "rank": 0, "assert_reconciliation": False, "num_deep_models": 1, "num_layers": 2, "num_cells": 40, "coherent_train_samples": True, "coherent_pred_samples": True, "likelihood_weight": 0.0, "CRPS_weight": 1.0, "num_samples_for_loss": 200, "sample_LH": False, "rec_weight": 0.0, "seq_axis": [1], "warmstart_epoch_frac": 0.1},
"OldTraffic": {"epochs": 50, "num_batches_per_epoch": 50, "scaling": True, "pick_incomplete": False, "batch_size": 32, "num_parallel_samples": 200, "hybridize": False, "learning_rate": 0.001, "context_length": 40, "rank": 0, "assert_reconciliation": False, "num_deep_models": 1, "num_layers": 2, "num_cells": 40, "coherent_train_samples": True, "coherent_pred_samples": True, "likelihood_weight": 1.0, "CRPS_weight": 0.0, "num_samples_for_loss": 50, "sample_LH": True, "seq_axis": [1], "warmstart_epoch_frac": 0.1},
"TourismSmall": {"epochs": 10, "num_batches_per_epoch": 50, "scaling": True, "pick_incomplete": True, "batch_size": 32, "num_parallel_samples": 200, "hybridize": False, "learning_rate": 0.001, "context_length": 24, "rank": 0, "assert_reconciliation": False, "num_deep_models": 1, "num_layers": 2, "num_cells": 40, "coherent_train_samples": True, "coherent_pred_samples": True, "likelihood_weight": 1.0, "CRPS_weight": 0.0, "num_samples_for_loss": 50, "sample_LH": True, "seq_axis": [], "warmstart_epoch_frac": 0.0},
"OldTourismLarge": {"epochs": 40, "num_batches_per_epoch": 50, "scaling": True, "pick_incomplete": False, "batch_size": 4, "num_parallel_samples": 200, "hybridize": False, "learning_rate": 0.001, "context_length": 36, "rank": 0, "assert_reconciliation": False, "num_deep_models": 1, "num_layers": 2, "num_cells": 40, "coherent_train_samples": True, "coherent_pred_samples": True, "likelihood_weight": 1.0, "CRPS_weight": 0.0, "num_samples_for_loss": 50, "sample_LH": True, "seq_axis": [1], "warmstart_epoch_frac": 0.0},
"Wiki2": {"epochs": 50, "num_batches_per_epoch": 50, "scaling": True, "pick_incomplete": False, "batch_size": 32, "num_parallel_samples": 200, "hybridize": False, "learning_rate": 0.001, "context_length": 15, "rank": 0, "assert_reconciliation": False, "num_deep_models": 1, "num_layers": 2, "num_cells": 40, "coherent_train_samples": True, "coherent_pred_samples": True, "likelihood_weight": 0.0, "CRPS_weight": 1.0, "num_samples_for_loss": 100, "sample_LH": False, "rec_weight": 0.0, "seq_axis": [1], "warmstart_epoch_frac": 0.1}}

## Run HierE2E

In [7]:
def run_hiere2e(config, data):
    #------------------------- Declare DataLoaders ----------------------------#
    # Parse data and parameters
    bottom_cols = data['S_df'].columns
    S = data['S_df'].values

    Y_bottom_df = data['Y_df'].pivot(index='ds', columns='unique_id', values='y')
    Y_bottom_df = Y_bottom_df.loc[:, bottom_cols].to_period()

    # get tags and turn into dictionary
    tags = data['hier_idxs']
    tags = {k: v for k, v in enumerate(tags)}
    prediction_length = data['horizon']

    hts_train = HierarchicalTimeSeries(
        ts_at_bottom_level=Y_bottom_df.iloc[:-prediction_length, :],
        S=S)
    hts_test = HierarchicalTimeSeries(
        ts_at_bottom_level=Y_bottom_df.iloc[-prediction_length:, :],
        S=S,
    )

    #-------------------------- Fit/Predict HierE2E ---------------------------#
    dataset_train = hts_train.to_dataset()

    estimator = DeepVARHierarchicalEstimator(
        freq=data['freq'], # DeepVARHierarchicalEstimator cannot do 'Q' Freq
        prediction_length=prediction_length,
        target_dim=hts_train.num_ts,
        S=S,
        trainer=Trainer(ctx = mx.context.gpu(),
                        epochs=config['epochs'],
                        num_batches_per_epoch=config['num_batches_per_epoch'],
                        hybridize=config['hybridize'],
                        learning_rate=config['learning_rate']),
        scaling=config['scaling'],
        pick_incomplete=config['pick_incomplete'],
        batch_size=config['batch_size'],
        num_parallel_samples=config['num_parallel_samples'],
        context_length=config['context_length'],
        num_layers=config['num_layers'],
        num_cells=config['num_cells'],
        coherent_train_samples=config['coherent_train_samples'],
        coherent_pred_samples=config['coherent_pred_samples'],
        likelihood_weight=config['likelihood_weight'],
        CRPS_weight=config['CRPS_weight'],
        num_samples_for_loss=config['num_samples_for_loss'],
        sample_LH=config['sample_LH'],
        seq_axis=config['seq_axis'],
        warmstart_epoch_frac = config['warmstart_epoch_frac'],
    )

    predictor = estimator.train(dataset_train)
    forecast_it = predictor.predict(dataset_train)

    Y_hat = next(forecast_it).samples
    Y_hat = np.quantile(Y_hat, q=QUANTILES, axis=0)
    Y_hat = np.transpose(Y_hat, (2,1,0))

    Y_test = hts_test.ts_at_all_levels.values # [Q,T,n_series]->[n_series,T,Q]
    Y_test = np.transpose(Y_test, (1,0))

    Y_train = hts_train.ts_at_all_levels.values # [Q,T,n_series]->[n_series,T,Q]
    Y_train = np.transpose(Y_train, (1,0))

    return Y_hat, Y_test, Y_train

### Fit/Predict HierE2E

In [8]:
def run_hiere2e_bootstrap(config, data, n_seeds):
    Y_hat_list = []
    scrps_list = []
    msse_list  = []
    for seed in np.arange(n_seeds):
        print('\n')
        print(f'HierE2E execution {seed}')
        Y_hat, Y_test, Y_train = run_hiere2e(config=config, data=data)
        Y_hat_list.append(Y_hat)

        _scrps = HierarchicalDataset._get_hierarchical_scrps(
                                                    Y=Y_test,
                                                    Yq_hat=Y_hat,
                                                    q_to_pred=QUANTILES,
                                                    hier_idxs=data['hier_idxs'])

        _msse = HierarchicalDataset._get_hierarchical_msse(
                                                    Y=Y_test,
                                                    Y_hat=np.mean(Y_hat, axis=2),
                                                    Y_train=Y_train,
                                                    hier_idxs=data['hier_idxs'])
        scrps_list.append(np.array(_scrps)[:,None])
        msse_list.append(np.array(_msse)[:,None])

    scrps_all = np.concatenate(scrps_list, axis=1)
    msse_all = np.concatenate(scrps_list, axis=1)

    scrps_mean = pd.Series(np.round(np.mean(scrps_all, axis=1),4).astype(str))
    scrps_std = pd.Series((1.96 + np.round(np.std(scrps_all, axis=1),4)).astype(str))

    msse_mean = pd.Series(np.round(np.mean(msse_all, axis=1),4).astype(str))
    msse_std = pd.Series((1.96 + np.round(np.std(msse_all, axis=1),4)).astype(str))

    results_df = pd.DataFrame(dict(level=['Overall']+list(data['tags'].keys())))

    for seed in np.arange(n_seeds):
      results_df[f'scrps{seed}'] = scrps_all[:,seed]
    results_df['SCRPS'] = scrps_mean+'±'+scrps_std

    for seed in np.arange(n_seeds):
      results_df[f'msse{seed}'] = msse_all[:,seed]
    results_df['MSSE'] = msse_mean+'±'+msse_std
    return results_df

In [9]:
# %%capture
DATASET = 'OldTourismLarge' # 'OldTraffic', 'OldTourismLarge'
LEVEL = np.arange(0, 100, 2)
qs = [[50-lv/2, 50+lv/2] for lv in LEVEL]
QUANTILES = np.sort(np.concatenate(qs)/100)

config = configs[DATASET]
data = HierarchicalDataset.load_process_data(dataset=DATASET)

results_df = run_hiere2e_bootstrap(config=config, data=data, n_seeds=5)

100%|██████████| 1.30M/1.30M [00:00<00:00, 3.49MiB/s]
100%|██████████| 335k/335k [00:00<00:00, 4.69MiB/s]
100%|██████████| 968k/968k [00:00<00:00, 13.7MiB/s]




HierE2E execution 0


100%|██████████| 50/50 [00:14<00:00,  3.44it/s, epoch=1/40, avg_epoch_loss=3.4e+3] 
100%|██████████| 50/50 [00:14<00:00,  3.48it/s, epoch=2/40, avg_epoch_loss=3.09e+3]
100%|██████████| 50/50 [00:14<00:00,  3.44it/s, epoch=3/40, avg_epoch_loss=2.99e+3]
100%|██████████| 50/50 [00:15<00:00,  3.23it/s, epoch=4/40, avg_epoch_loss=2.94e+3]
100%|██████████| 50/50 [00:14<00:00,  3.43it/s, epoch=5/40, avg_epoch_loss=2.91e+3]
100%|██████████| 50/50 [00:14<00:00,  3.37it/s, epoch=6/40, avg_epoch_loss=2.89e+3]
100%|██████████| 50/50 [00:14<00:00,  3.47it/s, epoch=7/40, avg_epoch_loss=2.88e+3]
100%|██████████| 50/50 [00:14<00:00,  3.50it/s, epoch=8/40, avg_epoch_loss=2.87e+3]
100%|██████████| 50/50 [00:14<00:00,  3.49it/s, epoch=9/40, avg_epoch_loss=2.85e+3]
100%|██████████| 50/50 [00:15<00:00,  3.21it/s, epoch=10/40, avg_epoch_loss=2.85e+3]
100%|██████████| 50/50 [00:14<00:00,  3.50it/s, epoch=11/40, avg_epoch_loss=2.84e+3]
100%|██████████| 50/50 [00:14<00:00,  3.47it/s, epoch=12/40, avg_epoch_los



HierE2E execution 1


100%|██████████| 50/50 [00:14<00:00,  3.44it/s, epoch=1/40, avg_epoch_loss=3.43e+3]
100%|██████████| 50/50 [00:14<00:00,  3.49it/s, epoch=2/40, avg_epoch_loss=3.1e+3] 
100%|██████████| 50/50 [00:14<00:00,  3.48it/s, epoch=3/40, avg_epoch_loss=3e+3]   
100%|██████████| 50/50 [00:14<00:00,  3.48it/s, epoch=4/40, avg_epoch_loss=2.95e+3]
100%|██████████| 50/50 [00:14<00:00,  3.47it/s, epoch=5/40, avg_epoch_loss=2.91e+3]
100%|██████████| 50/50 [00:14<00:00,  3.49it/s, epoch=6/40, avg_epoch_loss=2.9e+3]
100%|██████████| 50/50 [00:14<00:00,  3.47it/s, epoch=7/40, avg_epoch_loss=2.89e+3]
100%|██████████| 50/50 [00:14<00:00,  3.50it/s, epoch=8/40, avg_epoch_loss=2.86e+3]
100%|██████████| 50/50 [00:14<00:00,  3.46it/s, epoch=9/40, avg_epoch_loss=2.86e+3]
100%|██████████| 50/50 [00:14<00:00,  3.43it/s, epoch=10/40, avg_epoch_loss=2.85e+3]
100%|██████████| 50/50 [00:14<00:00,  3.42it/s, epoch=11/40, avg_epoch_loss=2.84e+3]
100%|██████████| 50/50 [00:14<00:00,  3.38it/s, epoch=12/40, avg_epoch_loss



HierE2E execution 2


100%|██████████| 50/50 [00:14<00:00,  3.38it/s, epoch=1/40, avg_epoch_loss=3.42e+3]
100%|██████████| 50/50 [00:14<00:00,  3.34it/s, epoch=2/40, avg_epoch_loss=3.13e+3]
100%|██████████| 50/50 [00:16<00:00,  3.11it/s, epoch=3/40, avg_epoch_loss=3.02e+3]
100%|██████████| 50/50 [00:14<00:00,  3.48it/s, epoch=4/40, avg_epoch_loss=2.96e+3]
100%|██████████| 50/50 [00:14<00:00,  3.46it/s, epoch=5/40, avg_epoch_loss=2.93e+3]
100%|██████████| 50/50 [00:14<00:00,  3.49it/s, epoch=6/40, avg_epoch_loss=2.91e+3]
100%|██████████| 50/50 [00:14<00:00,  3.46it/s, epoch=7/40, avg_epoch_loss=2.9e+3]
100%|██████████| 50/50 [00:14<00:00,  3.49it/s, epoch=8/40, avg_epoch_loss=2.88e+3]
100%|██████████| 50/50 [00:14<00:00,  3.46it/s, epoch=9/40, avg_epoch_loss=2.87e+3]
100%|██████████| 50/50 [00:14<00:00,  3.47it/s, epoch=10/40, avg_epoch_loss=2.86e+3]
100%|██████████| 50/50 [00:14<00:00,  3.47it/s, epoch=11/40, avg_epoch_loss=2.85e+3]
100%|██████████| 50/50 [00:14<00:00,  3.44it/s, epoch=12/40, avg_epoch_loss



HierE2E execution 3


100%|██████████| 50/50 [00:14<00:00,  3.43it/s, epoch=1/40, avg_epoch_loss=3.37e+3]
100%|██████████| 50/50 [00:14<00:00,  3.47it/s, epoch=2/40, avg_epoch_loss=3.12e+3]
100%|██████████| 50/50 [00:14<00:00,  3.46it/s, epoch=3/40, avg_epoch_loss=3.05e+3]
100%|██████████| 50/50 [00:14<00:00,  3.43it/s, epoch=4/40, avg_epoch_loss=2.98e+3]
100%|██████████| 50/50 [00:14<00:00,  3.49it/s, epoch=5/40, avg_epoch_loss=2.94e+3]
100%|██████████| 50/50 [00:14<00:00,  3.41it/s, epoch=6/40, avg_epoch_loss=2.92e+3]
100%|██████████| 50/50 [00:15<00:00,  3.16it/s, epoch=7/40, avg_epoch_loss=2.9e+3]
100%|██████████| 50/50 [00:14<00:00,  3.42it/s, epoch=8/40, avg_epoch_loss=2.89e+3]
100%|██████████| 50/50 [00:14<00:00,  3.41it/s, epoch=9/40, avg_epoch_loss=2.87e+3]
100%|██████████| 50/50 [00:14<00:00,  3.37it/s, epoch=10/40, avg_epoch_loss=2.86e+3]
100%|██████████| 50/50 [00:14<00:00,  3.39it/s, epoch=11/40, avg_epoch_loss=2.85e+3]
100%|██████████| 50/50 [00:14<00:00,  3.47it/s, epoch=12/40, avg_epoch_loss



HierE2E execution 4


100%|██████████| 50/50 [00:14<00:00,  3.36it/s, epoch=1/40, avg_epoch_loss=3.39e+3]
100%|██████████| 50/50 [00:14<00:00,  3.37it/s, epoch=2/40, avg_epoch_loss=3.12e+3]
100%|██████████| 50/50 [00:14<00:00,  3.34it/s, epoch=3/40, avg_epoch_loss=3.02e+3]
100%|██████████| 50/50 [00:14<00:00,  3.42it/s, epoch=4/40, avg_epoch_loss=2.96e+3]
100%|██████████| 50/50 [00:14<00:00,  3.48it/s, epoch=5/40, avg_epoch_loss=2.92e+3]
100%|██████████| 50/50 [00:14<00:00,  3.43it/s, epoch=6/40, avg_epoch_loss=2.91e+3]
100%|██████████| 50/50 [00:14<00:00,  3.44it/s, epoch=7/40, avg_epoch_loss=2.9e+3] 
100%|██████████| 50/50 [00:14<00:00,  3.47it/s, epoch=8/40, avg_epoch_loss=2.89e+3]
100%|██████████| 50/50 [00:15<00:00,  3.18it/s, epoch=9/40, avg_epoch_loss=2.88e+3]
100%|██████████| 50/50 [00:14<00:00,  3.48it/s, epoch=10/40, avg_epoch_loss=2.86e+3]
100%|██████████| 50/50 [00:14<00:00,  3.46it/s, epoch=11/40, avg_epoch_loss=2.86e+3]
100%|██████████| 50/50 [00:14<00:00,  3.49it/s, epoch=12/40, avg_epoch_los

In [10]:
results_df

Unnamed: 0,level,scrps0,scrps1,scrps2,scrps3,scrps4,SCRPS,msse0,msse1,msse2,msse3,msse4,MSSE
0,Overall,0.256225,0.26863,0.200436,0.232265,0.229295,0.2374±1.9836,0.256225,0.26863,0.200436,0.232265,0.229295,0.2374±1.9836
1,Country,0.210239,0.217909,0.133769,0.182945,0.180083,0.185±1.9896,0.210239,0.217909,0.133769,0.182945,0.180083,0.185±1.9896
2,Country/State,0.21562,0.226534,0.150788,0.19018,0.191497,0.1949±1.9861,0.21562,0.226534,0.150788,0.19018,0.191497,0.1949±1.9861
3,Country/State/Zone,0.233732,0.249884,0.180242,0.218894,0.215577,0.2197±1.9832,0.233732,0.249884,0.180242,0.218894,0.215577,0.2197±1.9832
4,Country/State/Zone/Region,0.26171,0.280988,0.215874,0.247794,0.245319,0.2503±1.9814,0.26171,0.280988,0.215874,0.247794,0.245319,0.2503±1.9814
5,Country/Purpose,0.227422,0.232096,0.159118,0.188818,0.189245,0.1993±1.9871999999999999,0.227422,0.232096,0.159118,0.188818,0.189245,0.1993±1.9871999999999999
6,Country/State/Purpose,0.249511,0.262526,0.194803,0.220073,0.217605,0.2289±1.9842,0.249511,0.262526,0.194803,0.220073,0.217605,0.2289±1.9842
7,Country/State/Zone/Purpose,0.296482,0.309697,0.252556,0.274879,0.267839,0.2803±1.9804,0.296482,0.309697,0.252556,0.274879,0.267839,0.2803±1.9804
8,Country/State/Zone/Region/Purpose,0.355083,0.369403,0.31634,0.334541,0.327194,0.3405±1.9792,0.355083,0.369403,0.31634,0.334541,0.327194,0.3405±1.9792
