From 765baa9dc23fb6ac92f1f8e361652184634f5f13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Tue, 17 Oct 2023 13:54:37 -0600 Subject: [PATCH 1/3] keep only observed categories in aggregate --- hierarchicalforecast/utils.py | 4 +++- nbs/utils.ipynb | 16 +++++++++++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/hierarchicalforecast/utils.py b/hierarchicalforecast/utils.py index 5d5e635..c05b0eb 100644 --- a/hierarchicalforecast/utils.py +++ b/hierarchicalforecast/utils.py @@ -197,7 +197,9 @@ def aggregate( aggs = [] tags = {} for levels in spec: - agg = df.groupby(levels + ['ds'])['y'].sum().reset_index('ds') + agg = df.groupby(levels + ['ds'], observed=True)['y'].sum().reset_index('ds') + if not agg.index.is_monotonic_increasing: + agg = agg.sort_index() group = agg.index.get_level_values(0) for level in levels[1:]: group = group + '/' + agg.index.get_level_values(level).str.replace('/', '_') diff --git a/nbs/utils.ipynb b/nbs/utils.ipynb index 1448c50..fdf60ef 100644 --- a/nbs/utils.ipynb +++ b/nbs/utils.ipynb @@ -310,7 +310,9 @@ " aggs = []\n", " tags = {}\n", " for levels in spec:\n", - " agg = df.groupby(levels + ['ds'])['y'].sum().reset_index('ds')\n", + " agg = df.groupby(levels + ['ds'], observed=True)['y'].sum().reset_index('ds')\n", + " if not agg.index.is_monotonic_increasing:\n", + " agg = agg.sort_index()\n", " group = agg.index.get_level_values(0)\n", " for level in levels[1:]:\n", " group = group + '/' + agg.index.get_level_values(level).str.replace('/', '_')\n", @@ -354,7 +356,7 @@ { "cell_type": "code", "execution_count": null, - "id": "075b8d76-b206-4ca6-8722-dd60e4c3b535", + "id": "82e70572-9c01-466d-a3e9-7667b92def2c", "metadata": {}, "outputs": [], "source": [ @@ -394,7 +396,15 @@ " 'country/cat1/cat2': ['COUNTRY/a/1', 'COUNTRY/a/2', 'COUNTRY/a/3','COUNTRY/b/2'],\n", "}\n", "for k, actual in tags.items():\n", - " test_eq(actual, expected_tags[k])" + " test_eq(actual, expected_tags[k])\n", + "\n", + "# test categoricals don't produce all combinations\n", + "df2 = df.copy()\n", + "for col in ('cat1', 'cat2'):\n", + " df2[col] = df2[col].astype('category')\n", + "\n", + "Y_df2, *_ = aggregate(df2, spec)\n", + "assert Y_df.shape[0] == Y_df2.shape[0]" ] }, { From be59903d367e15201b9a38959f04293342dcc269 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Tue, 17 Oct 2023 13:57:13 -0600 Subject: [PATCH 2/3] sort full idx --- hierarchicalforecast/utils.py | 3 ++- nbs/utils.ipynb | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/hierarchicalforecast/utils.py b/hierarchicalforecast/utils.py index c05b0eb..a24b028 100644 --- a/hierarchicalforecast/utils.py +++ b/hierarchicalforecast/utils.py @@ -197,9 +197,10 @@ def aggregate( aggs = [] tags = {} for levels in spec: - agg = df.groupby(levels + ['ds'], observed=True)['y'].sum().reset_index('ds') + agg = df.groupby(levels + ['ds'], observed=True)['y'].sum() if not agg.index.is_monotonic_increasing: agg = agg.sort_index() + agg = agg.reset_index('ds') group = agg.index.get_level_values(0) for level in levels[1:]: group = group + '/' + agg.index.get_level_values(level).str.replace('/', '_') diff --git a/nbs/utils.ipynb b/nbs/utils.ipynb index fdf60ef..9ebbb4c 100644 --- a/nbs/utils.ipynb +++ b/nbs/utils.ipynb @@ -310,9 +310,10 @@ " aggs = []\n", " tags = {}\n", " for levels in spec:\n", - " agg = df.groupby(levels + ['ds'], observed=True)['y'].sum().reset_index('ds')\n", + " agg = df.groupby(levels + ['ds'], observed=True)['y'].sum()\n", " if not agg.index.is_monotonic_increasing:\n", " agg = agg.sort_index()\n", + " agg = agg.reset_index('ds')\n", " group = agg.index.get_level_values(0)\n", " for level in levels[1:]:\n", " group = group + '/' + agg.index.get_level_values(level).str.replace('/', '_')\n", From f939f4baf4e842c1cb16b2ca57eb90af1c563fd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Tue, 17 Oct 2023 15:39:46 -0600 Subject: [PATCH 3/3] more tests --- hierarchicalforecast/core.py | 8 ++++---- hierarchicalforecast/utils.py | 2 ++ nbs/core.ipynb | 25 ++++++++++++++++++++++++- nbs/utils.ipynb | 4 +++- 4 files changed, 33 insertions(+), 6 deletions(-) diff --git a/hierarchicalforecast/core.py b/hierarchicalforecast/core.py index 1366795..d0e26d3 100644 --- a/hierarchicalforecast/core.py +++ b/hierarchicalforecast/core.py @@ -3,7 +3,7 @@ # %% auto 0 __all__ = ['HierarchicalReconciliation'] -# %% ../nbs/core.ipynb 3 +# %% ../nbs/core.ipynb 4 import re import gc import time @@ -17,7 +17,7 @@ import numpy as np import pandas as pd -# %% ../nbs/core.ipynb 5 +# %% ../nbs/core.ipynb 6 def _build_fn_name(fn) -> str: fn_name = type(fn).__name__ func_params = fn.__dict__ @@ -37,7 +37,7 @@ def _build_fn_name(fn) -> str: fn_name += '_' + '_'.join(func_params) return fn_name -# %% ../nbs/core.ipynb 9 +# %% ../nbs/core.ipynb 10 def _reverse_engineer_sigmah(Y_hat_df, y_hat, model_name): """ This function assumes that the model creates prediction intervals @@ -73,7 +73,7 @@ def _reverse_engineer_sigmah(Y_hat_df, y_hat, model_name): return sigmah -# %% ../nbs/core.ipynb 10 +# %% ../nbs/core.ipynb 11 class HierarchicalReconciliation: """Hierarchical Reconciliation Class. diff --git a/hierarchicalforecast/utils.py b/hierarchicalforecast/utils.py index a24b028..61b507e 100644 --- a/hierarchicalforecast/utils.py +++ b/hierarchicalforecast/utils.py @@ -202,6 +202,8 @@ def aggregate( agg = agg.sort_index() agg = agg.reset_index('ds') group = agg.index.get_level_values(0) + if not pd.api.types.is_string_dtype(group.dtype): + group = group.astype(str) for level in levels[1:]: group = group + '/' + agg.index.get_level_values(level).str.replace('/', '_') agg.index = group diff --git a/nbs/core.ipynb b/nbs/core.ipynb index 8f6df67..f6e26f9 100644 --- a/nbs/core.ipynb +++ b/nbs/core.ipynb @@ -1,5 +1,16 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, { "cell_type": "code", "execution_count": null, @@ -579,7 +590,19 @@ "\n", "# getting df\n", "hier_grouped_df, S_grouped_df, tags_grouped = aggregate(df, hierS_grouped_df)\n", - "hier_strict_df, S_strict, tags_strict = aggregate(df, hiers_strictly)" + "hier_strict_df, S_strict, tags_strict = aggregate(df, hiers_strictly)\n", + "\n", + "# check categorical input produces same output\n", + "df2 = df.copy()\n", + "for col in ['Country', 'State', 'Purpose', 'Region']:\n", + " df2[col] = df2[col].astype('category')\n", + "\n", + "for spec in [hierS_grouped_df, hiers_strictly]:\n", + " Y_orig, S_orig, tags_orig = aggregate(df, spec)\n", + " Y_cat, S_cat, tags_cat = aggregate(df2, spec)\n", + " pd.testing.assert_frame_equal(Y_cat, Y_orig)\n", + " pd.testing.assert_frame_equal(S_cat, S_orig)\n", + " assert all(np.array_equal(tags_orig[k], tags_cat[k]) for k in tags_orig.keys())" ] }, { diff --git a/nbs/utils.ipynb b/nbs/utils.ipynb index 9ebbb4c..bfc7a76 100644 --- a/nbs/utils.ipynb +++ b/nbs/utils.ipynb @@ -315,6 +315,8 @@ " agg = agg.sort_index()\n", " agg = agg.reset_index('ds')\n", " group = agg.index.get_level_values(0)\n", + " if not pd.api.types.is_string_dtype(group.dtype):\n", + " group = group.astype(str)\n", " for level in levels[1:]:\n", " group = group + '/' + agg.index.get_level_values(level).str.replace('/', '_')\n", " agg.index = group\n", @@ -401,7 +403,7 @@ "\n", "# test categoricals don't produce all combinations\n", "df2 = df.copy()\n", - "for col in ('cat1', 'cat2'):\n", + "for col in ('country', 'cat1', 'cat2'):\n", " df2[col] = df2[col].astype('category')\n", "\n", "Y_df2, *_ = aggregate(df2, spec)\n",