In [1]:
import pandas as pd
import numpy as np
import hts

import matplotlib.pyplot as plt
import seaborn as sns

import warnings
warnings.simplefilter("ignore")

# settings
plt.style.use('seaborn')
plt.rcParams["figure.figsize"] = (16, 8)

In [2]:
path = 'https://raw.githubusercontent.com/erykml/medium_articles/master/data/tourism.csv'
df = pd.read_csv(path) \
       .drop(columns=["Unnamed: 0"])

# lowercase the column names
df.columns = [col_name.lower() for col_name in df.columns]

# sum the trips over purpose
df = df.groupby(["quarter", "region", "state"])["trips"] \
       .sum() \
       .reset_index(drop=False)

# cleanup region name 
df["region"] = df["region"].apply(lambda x: x.replace(" ", "_").replace("'", ""))

# map the full state names to abbreviations
mapping_dict = {
    "New South Wales": "NSW",
    "Northern Territory": "NT",
    "Queensland": "QLD",
    "South Australia": "SA",
    "Victoria": "VIC",
    "Western Australia": "WA",
    "ACT": "ACT",
}

df["state"] = df["state"].map(mapping_dict)

# create the bottom level id
df["state_region"] = df.apply(lambda x: f"{x['state']}_{x['region']}", axis=1)

In [3]:
df.head()

Unnamed: 0,quarter,region,state,trips,state_region
0,1998-01-01,Adelaide,SA,658.553895,SA_Adelaide
1,1998-01-01,Adelaide_Hills,SA,9.79863,SA_Adelaide_Hills
2,1998-01-01,Alice_Springs,NT,20.207638,NT_Alice_Springs
3,1998-01-01,Australias_Coral_Coast,WA,132.516409,WA_Australias_Coral_Coast
4,1998-01-01,Australias_Golden_Outback,WA,161.726948,WA_Australias_Golden_Outback


In [4]:
df.groupby("state")["region"].apply(set).to_frame()

Unnamed: 0_level_0,region
state,Unnamed: 1_level_1
ACT,{Canberra}
NSW,"{Snowy_Mountains, Capital_Country, Outback_NSW..."
NT,"{Tasmania, East_Coast, Kakadu_Arnhem, Darwin, ..."
QLD,"{Fraser_Coast, Mackay, Bundaberg, Darling_Down..."
SA,"{Riverland, Yorke_Peninsula, Murraylands, Clar..."
VIC,"{Western_Grampians, Central_Murray, Spa_Countr..."
WA,"{Australias_Coral_Coast, Experience_Perth, Aus..."


In [5]:
# create the bottom level df
df_bottom_level = df.pivot(index="quarter", columns="state_region", values="trips")

# create the middle level df
df_middle_level = df.groupby(["quarter", "state"]) \
                    .sum() \
                    .reset_index(drop=False) \
                    .pivot(index="quarter", columns="state", values="trips")

# create the total level df
df_total = df.groupby("quarter")["trips"] \
             .sum() \
             .to_frame() \
             .rename(columns={"trips": "total"})

# join the DataFrames
hierarchy_df = df_bottom_level.join(df_middle_level) \
                              .join(df_total)
hierarchy_df.index = pd.to_datetime(hierarchy_df.index)
hierarchy_df = hierarchy_df.resample("QS") \
                           .sum()

print(f"Number of time series at the bottom level: {df_bottom_level.shape[1]}")
print(f"Number of time series at the middle level: {df_middle_level.shape[1]}")

Number of time series at the bottom level: 77
Number of time series at the middle level: 7


Using the snippet above, we create three DataFrames, one for each level of the hierarchy:

bottom level — this is simply a pivot that transforms the initial DataFrame in the long format to the wide format,

middle level — before pivoting the DataFrame, we first sum over the states,

total level — the highest level, which is the sum of all states.

In [10]:
hierarchy_df.columns

Index(['ACT_Canberra', 'NSW_Blue_Mountains', 'NSW_Capital_Country',
       'NSW_Central_Coast', 'NSW_Central_NSW', 'NSW_Hunter',
       'NSW_New_England_North_West', 'NSW_North_Coast_NSW', 'NSW_Outback_NSW',
       'NSW_Riverina', 'NSW_Snowy_Mountains', 'NSW_South_Coast', 'NSW_Sydney',
       'NSW_The_Murray', 'NT_Alice_Springs', 'NT_Barkly', 'NT_Darwin',
       'NT_East_Coast', 'NT_Hobart_and_the_South', 'NT_Kakadu_Arnhem',
       'NT_Katherine_Daly', 'NT_Lasseter',
       'NT_Launceston,_Tamar_and_the_North', 'NT_MacDonnell', 'NT_North_West',
       'NT_Tasmania', 'NT_Wilderness_West', 'QLD_Brisbane', 'QLD_Bundaberg',
       'QLD_Central_Queensland', 'QLD_Darling_Downs', 'QLD_Fraser_Coast',
       'QLD_Gold_Coast', 'QLD_Mackay', 'QLD_Northern', 'QLD_Outback',
       'QLD_Sunshine_Coast', 'QLD_Tropical_North_Queensland',
       'QLD_Whitsundays', 'SA_Adelaide', 'SA_Adelaide_Hills', 'SA_Barossa',
       'SA_Clare_Valley', 'SA_Eyre_Peninsula', 'SA_Fleurieu_Peninsula',
       'SA_Flinder

In [12]:
states = df["state"].unique()
regions = df["state_region"].unique()

total = {'total': list(states)}
state = {k: [v for v in regions if v.startswith(k)] for k in states}
hierarchy = {**total, **state}
hierarchy

{'total': ['SA', 'NT', 'WA', 'VIC', 'NSW', 'QLD', 'ACT'],
 'SA': ['SA_Adelaide',
  'SA_Adelaide_Hills',
  'SA_Barossa',
  'SA_Clare_Valley',
  'SA_Eyre_Peninsula',
  'SA_Fleurieu_Peninsula',
  'SA_Flinders_Ranges_and_Outback',
  'SA_Kangaroo_Island',
  'SA_Limestone_Coast',
  'SA_Murraylands',
  'SA_Riverland',
  'SA_Yorke_Peninsula'],
 'NT': ['NT_Alice_Springs',
  'NT_Barkly',
  'NT_Darwin',
  'NT_East_Coast',
  'NT_Hobart_and_the_South',
  'NT_Kakadu_Arnhem',
  'NT_Katherine_Daly',
  'NT_Lasseter',
  'NT_Launceston,_Tamar_and_the_North',
  'NT_MacDonnell',
  'NT_North_West',
  'NT_Tasmania',
  'NT_Wilderness_West'],
 'WA': ['WA_Australias_Coral_Coast',
  'WA_Australias_Golden_Outback',
  'WA_Australias_North_West',
  'WA_Australias_South_West',
  'WA_Experience_Perth'],
 'VIC': ['VIC_Ballarat',
  'VIC_Bendigo_Loddon',
  'VIC_Central_Highlands',
  'VIC_Central_Murray',
  'VIC_Geelong',
  'VIC_Gippsland',
  'VIC_Goulburn',
  'VIC_High_Country',
  'VIC_Lakes',
  'VIC_Macedon',
  'VIC_Ma

Hierarchical time series forecasting

Finally, we can focus on the modeling part. In this article, I just want to highlight the functionalities of scikit-hts. That is why I present simplified examples, in which I use the entire data set for training and then forecast 4 steps (a year) into the future. Naturally, in a real-life scenario we would employ an appropriate time-series cross-validation scheme and try to tune the hyperparameters of the model to obtain the best fit.

The main class of the library is the HTSRegressor. The usage of the library will be familiar to anyone who has used scikit-learn before (initialize the estimator -> fit to data -> predict). The two arguments we should pay attention to are model and revision_method.
model determines the underlying type of model that will be used to forecast the individual time series. Currently, the library supports:

* auto_arima — from the pmdarima library,

* SARIMAX — from statsmodels ,

* Holt-Winters exponential smoothing — also from statsmodels ,

* Facebook’s Prophet.

The revision_method argument is responsible for the approach to hierarchical time series forecasting. We can choose from:

* BU — the bottom-up approach,

* AHP — average historical proportions (top-down approach),

* PHA — proportions of historical averages (top-down approach),

* FP — the forecasted proportions (top-down approach),

* OLS — the optimal combination using OLS,

* WLSS - optimal combination using structurally weighted OLS,

* WLSV - optimal combination using variance-weighted OLS.

In [17]:
model_bu_arima = hts.HTSRegressor(model='auto_arima', revision_method='BU', n_jobs=0)
model_bu_arima.fit(hierarchy_df, hierarchy)
model_bu_arima.predict(steps_ahead=4)

Fitting models: 100%|███████████████████████████████████████████████████████████████████| 85/85 [02:09<00:00,  1.53s/it]
Fitting models: 100%|██████████████████████████████████████████████████████████████████| 85/85 [00:00<00:00, 341.16it/s]


Unnamed: 0,total,SA,NT,WA,VIC,NSW,QLD,ACT,SA_Adelaide,SA_Adelaide_Hills,...,QLD_Darling_Downs,QLD_Fraser_Coast,QLD_Gold_Coast,QLD_Mackay,QLD_Northern,QLD_Outback,QLD_Sunshine_Coast,QLD_Tropical_North_Queensland,QLD_Whitsundays,ACT_Canberra
1998-01-01,12322.831056,498.118480,2085.916908,3749.558062,2503.256570,1153.078496,1113.909261,1218.993279,564.693726,0.262403,...,0.000000,145.32781,872.889238,0.000000,224.702920,0.000000,707.758764,0.000000,0.000000,498.118480
1998-04-01,22008.377123,509.663190,4017.696349,7804.119591,4796.690950,1683.262316,1649.090863,1547.853863,564.693726,10.059435,...,258.529107,145.32781,867.102345,128.782326,224.989442,37.169537,707.758764,220.532410,60.191399,509.663190
1998-07-01,20538.226362,480.197213,4245.238921,7168.790818,4306.735885,1627.111662,1309.558821,1400.593041,564.693726,18.346699,...,393.374128,145.32781,852.435353,130.645467,209.000282,94.585208,707.758764,246.941565,91.336461,480.197213
1998-10-01,21435.134323,484.564049,4442.530611,7112.956763,4700.126598,1639.040639,1591.031208,1464.884453,564.693726,21.348299,...,337.449665,145.32781,899.617783,140.935474,241.507390,139.104621,707.758764,365.186549,69.560731,484.564049
1999-01-01,22760.978601,487.569983,4111.476461,7825.584788,5187.130124,1743.784385,1827.214557,1578.218303,564.693726,23.115754,...,331.121877,145.32781,846.808200,143.235278,194.385490,45.667888,707.758764,205.066172,91.136698,487.569983
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2016-10-01,24561.451375,514.342591,5280.242780,7539.197479,5515.235407,2399.268723,1821.092848,1492.071547,564.693726,38.343653,...,423.264967,145.32781,928.204356,185.048267,251.822811,108.366693,707.758764,510.409631,120.285651,514.342591
2017-01-01,26074.233094,535.032878,5212.359691,8101.461135,6181.172788,2566.216579,1916.782394,1561.207629,564.693726,39.606186,...,482.862707,145.32781,961.853916,190.555296,188.531398,93.185724,707.758764,389.201282,95.114013,535.032878
2017-04-01,24153.323893,506.177071,5345.480246,7425.359656,5296.295516,2486.697620,1569.574516,1523.739269,564.693726,39.868589,...,525.961566,145.32781,849.196393,183.459663,260.261861,138.774546,707.758764,445.374944,124.286517,506.177071
2017-07-01,23692.308214,499.877709,5524.301621,7192.773125,5079.097720,2389.661102,1578.483223,1428.113715,564.693726,40.130993,...,521.440487,145.32781,851.498943,197.109019,305.113991,151.879606,707.758764,551.244451,120.740969,499.877709


In [18]:
model_td_arima = hts.HTSRegressor(model='auto_arima', revision_method='AHP', n_jobs=0)
model_td_arima.fit(hierarchy_df, hierarchy)
model_td_arima.predict(steps_ahead=4)

Fitting models: 100%|███████████████████████████████████████████████████████████████████| 85/85 [02:10<00:00,  1.54s/it]
Fitting models: 100%|██████████████████████████████████████████████████████████████████| 85/85 [00:00<00:00, 337.30it/s]


Unnamed: 0,total,SA,NT,WA,VIC,NSW,QLD,ACT,SA_Adelaide,SA_Adelaide_Hills,...,QLD_Darling_Downs,QLD_Fraser_Coast,QLD_Gold_Coast,QLD_Mackay,QLD_Northern,QLD_Outback,QLD_Sunshine_Coast,QLD_Tropical_North_Queensland,QLD_Whitsundays,ACT_Canberra
1998-01-01,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
1998-04-01,37124.357581,852.179142,8157.414608,11804.149097,8198.631122,3001.557914,2621.207386,2489.218311,966.094267,47.509643,...,699.378262,250.509345,1495.719782,274.815995,394.118636,186.218724,1209.663206,654.496081,198.450138,852.179142
1998-07-01,23400.042667,537.141370,5141.741496,7440.333262,5167.720886,1891.927237,1652.186560,1568.991856,608.943792,29.946045,...,440.828670,157.899819,942.774744,173.220667,248.418922,117.376471,762.468969,412.538754,125.086116,537.141370
1998-10-01,24977.007511,573.340153,5488.251359,7941.748757,5515.981540,2019.427122,1763.529952,1674.728629,649.981450,31.964155,...,470.536792,168.540931,1006.309783,184.894275,265.160255,125.286652,813.852925,440.340375,133.515862,573.340153
1999-01-01,26828.948841,615.850943,5895.182394,8530.596430,5924.968653,2169.159252,1894.288371,1798.902798,698.174874,34.334164,...,505.425140,181.037541,1080.923473,198.603417,284.820786,134.576137,874.196737,472.989784,143.415508,615.850943
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2016-10-01,26074.777420,598.539151,5729.466692,8290.798290,5758.415649,2108.183404,1841.039242,1748.334992,678.548927,33.369018,...,491.217457,175.948510,1050.538324,193.020603,276.814371,130.793153,849.622752,459.693871,139.384046,598.539151
2017-01-01,27867.718071,639.695597,6123.433382,8860.885967,6154.372912,2253.145244,1967.631851,1868.553118,725.206965,35.663522,...,524.994303,188.046993,1122.774909,206.292988,295.848541,139.786686,908.044082,491.303108,148.968302,639.695597
2017-04-01,26815.298090,615.537594,5892.182885,8526.256005,5921.953989,2168.055569,1893.324544,1797.987504,697.819638,34.316695,...,505.167976,180.945428,1080.373491,198.502366,284.675867,134.507664,873.751940,472.749124,143.342537,615.537594
2017-07-01,26796.419625,615.104244,5888.034679,8520.253363,5917.784824,2166.529218,1891.991609,1796.721688,697.328360,34.292535,...,504.812328,180.818039,1079.612889,198.362617,284.475450,134.412968,873.136803,472.416300,143.241621,615.104244


In [19]:
model_ols_arima = hts.HTSRegressor(model='auto_arima', revision_method='OLS', n_jobs=0)
model_ols_arima.fit(hierarchy_df, hierarchy)
model_ols_arima.predict(steps_ahead=4)

Fitting models: 100%|███████████████████████████████████████████████████████████████████| 85/85 [02:11<00:00,  1.55s/it]
Fitting models: 100%|██████████████████████████████████████████████████████████████████| 85/85 [00:00<00:00, 344.83it/s]


Unnamed: 0,total,SA,NT,WA,VIC,NSW,QLD,ACT,SA_Adelaide,SA_Adelaide_Hills,...,QLD_Darling_Downs,QLD_Fraser_Coast,QLD_Gold_Coast,QLD_Mackay,QLD_Northern,QLD_Outback,QLD_Sunshine_Coast,QLD_Tropical_North_Queensland,QLD_Whitsundays,ACT_Canberra
1998-01-01,1372.751117,292.697693,330.762937,-1006.871890,-1195.479681,4860.359067,-1195.132519,-713.584490,403.645578,-160.785744,...,-146.262831,-0.935021,726.626407,-146.262831,78.440089,-146.262831,561.495933,-146.262831,-146.262831,292.697693
1998-04-01,35372.581684,1911.890726,3442.927160,3700.336147,10880.051293,7967.786417,5263.042025,2206.547916,619.584897,64.950606,...,210.631674,97.430377,819.204912,80.884894,177.092010,-10.727895,659.861332,172.634978,12.293967,1911.890726
1998-07-01,22763.681198,1251.681549,2288.833894,2561.375133,4751.117567,6670.882619,4101.382957,1138.407479,542.844929,-3.502098,...,230.340375,-17.705943,689.401600,-32.388285,45.966530,-68.448545,544.725012,83.907813,-71.697291,1251.681549
1998-10-01,24352.344084,1245.712144,2415.377532,2579.835607,5718.788635,6764.292188,4491.752271,1136.585706,537.335497,-6.009930,...,168.520242,-23.601614,730.688359,-27.993949,72.577966,-29.824802,538.829341,196.257125,-99.368692,1245.712144
1999-01-01,26148.710777,1329.065438,2587.252669,2829.429088,6581.593757,7072.791763,4549.199995,1199.378067,533.123706,-8.454266,...,204.103228,18.309160,719.789550,16.216629,67.366841,-81.350762,580.740115,78.047522,-35.881951,1329.065438
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2016-10-01,25866.136274,1101.604114,2352.816805,3040.820444,6205.691503,6912.106387,5470.952685,782.144338,505.533125,-20.816947,...,179.312802,-98.624355,684.252192,-58.903898,7.870646,-135.585471,463.806600,266.457466,-123.666514,1101.604114
2017-01-01,27553.996849,1189.454534,2460.854339,3394.679608,6924.566511,7244.972653,5435.910985,903.558218,509.889608,-15.197931,...,253.570594,-83.964303,732.561803,-38.736817,-40.760715,-136.106389,478.466652,159.909169,-134.178100,1189.454534
2017-04-01,26396.958286,1192.272449,2256.208743,3284.904342,6009.830838,6969.095960,5714.035050,970.610905,518.599695,-6.225441,...,268.522275,-112.111482,591.757101,-73.979629,2.822569,-118.664746,450.319473,187.935652,-133.152775,1192.272449
2017-07-01,26395.601968,1180.361695,2347.344614,3157.733744,5891.345687,6842.127859,6035.422360,941.266008,524.123083,-0.439649,...,256.694070,-119.418608,586.752526,-67.637398,40.367574,-112.866811,443.012347,286.498034,-144.005448,1180.361695
