In [83]:
import pandas as pd
from sklearn.metrics import mean_squared_error
from tqdm import tqdm
from sklearn.model_selection import TimeSeriesSplit
import warnings
warnings.filterwarnings('ignore')

In [90]:
%load_ext autoreload
%autoreload 2
from models.models import HoltWintersWrapper, ProphetWrapper


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [85]:
cpi = pd.read_csv('cpi.csv')

In [86]:
cpi_columns = ['Food and non-alcoholic beverages', 
               'All items',
               'Alcoholic beverages, tobacco and narcotics',
               'Clothing and footwear',
               'Housing, water, electricity, gas and other fuels', 
               'Transport',
               'Recreation and culture', 
               'Health',
               'Furnishings, household equipment and routine household maintenance',
               'Communication',
               'Education',
               'Restaurants and hotels', 
               'Miscellaneous goods and services']

In [87]:
models = [HoltWintersWrapper(seasonal_periods=6), 
          HoltWintersWrapper(seasonal_periods=3),
          HoltWintersWrapper(seasonal_periods=12), 
          ProphetWrapper()]

In [137]:
tscv = TimeSeriesSplit(n_splits=20,test_size=1)

In [138]:
results = {}
for model in models:
    model_results = []
    print(model.getModelName())
    for category in tqdm(list(cpi.columns[2:-1])):
        
        intermediate_results = {'test': [], 'pred': []}
        for i, (train_index, test_index) in enumerate(tscv.split(cpi[['date',category]])):
            model.fit(None, cpi[['date',category]].iloc[train_index])
            intermediate_results['test'].append(*cpi[category].iloc[test_index].values)
            intermediate_results['pred'].append(*model.predict(test_index))
        model_results.append(mean_squared_error(intermediate_results['test'], intermediate_results['pred']))
    results[model.getModelName()] = model_results

HoltWinters_mul_mul_6


100%|██████████| 13/13 [00:33<00:00,  2.55s/it]


HoltWinters_mul_mul_3


100%|██████████| 13/13 [00:29<00:00,  2.27s/it]


HoltWinters_mul_mul_12


100%|██████████| 13/13 [00:33<00:00,  2.60s/it]


Prophet


100%|██████████| 13/13 [02:11<00:00, 10.11s/it]


In [139]:
results_df = pd.DataFrame(results, index=cpi.columns[2:-1])

In [140]:
results_df

Unnamed: 0,HoltWinters_mul_mul_6,HoltWinters_mul_mul_3,HoltWinters_mul_mul_12,Prophet
Food and non-alcoholic beverages,0.477625,0.46643,0.540259,4.466591
Alcoholic beverages and tobacco,0.404364,0.486202,0.31842,1.258299
Clothing and footwear,0.043005,0.043955,0.049025,0.173556
Housing and utilities,0.158078,0.171314,0.007335,0.257408
Household contents and services,0.102029,0.100969,0.096986,0.735485
Health,0.562189,0.815918,0.077702,1.339092
Transport,5.803694,5.298073,4.387032,21.981848
Communication,0.078897,0.080709,0.074141,0.161152
Recreation and culture,0.197327,0.177366,0.170182,0.294284
Education,1.132585,1.993387,0.132522,2.357395


In [154]:
best_model_table = dict(results_df.T.apply(lambda x: x.idxmin()))
best_model_table

{'Food and non-alcoholic beverages': 'HoltWinters_mul_mul_3',
 'Alcoholic beverages and tobacco': 'HoltWinters_mul_mul_12',
 'Clothing and footwear': 'HoltWinters_mul_mul_6',
 'Housing and utilities': 'HoltWinters_mul_mul_12',
 'Household contents and services': 'HoltWinters_mul_mul_12',
 'Health': 'HoltWinters_mul_mul_12',
 'Transport': 'HoltWinters_mul_mul_12',
 'Communication': 'HoltWinters_mul_mul_12',
 'Recreation and culture': 'HoltWinters_mul_mul_12',
 'Education': 'HoltWinters_mul_mul_12',
 'Restaurants and hotels': 'HoltWinters_mul_mul_3',
 'Miscellaneous goods and services': 'HoltWinters_mul_mul_6',
 'headline CPI': 'HoltWinters_mul_mul_12'}

In [149]:
model_table  = {'HoltWinters_mul_mul_3':HoltWintersWrapper(seasonal_periods=3),
                'HoltWinters_mul_mul_6':HoltWintersWrapper(seasonal_periods=6),
                'HoltWinters_mul_mul_12':HoltWintersWrapper(seasonal_periods=12),
                'Prophet':ProphetWrapper()}

In [156]:
results_table = {}
for entry in best_model_table:
    model_table[best_model_table[entry]].fit(None, cpi[['date',str(entry)]])
    results_table[entry] = model_table[best_model_table[entry]].predict([1])[0]

In [157]:
results_table

{'Food and non-alcoholic beverages': 118.57721758682389,
 'Alcoholic beverages and tobacco': 110.8198843275468,
 'Clothing and footwear': 104.34829025635169,
 'Housing and utilities': 105.16482492675769,
 'Household contents and services': 108.47489557232664,
 'Health': 110.32447074441325,
 'Transport': 115.26210486568644,
 'Communication': 99.73316015090403,
 'Recreation and culture': 104.92571449138615,
 'Education': 110.35596122662159,
 'Restaurants and hotels': 110.61856547641386,
 'Miscellaneous goods and services': 109.46187980545683,
 'headline CPI': 110.27622340971843}

In [118]:
cpi

Unnamed: 0.1,Unnamed: 0,index,Food and non-alcoholic beverages,Alcoholic beverages and tobacco,Clothing and footwear,Housing and utilities,Household contents and services,Health,Transport,Communication,Recreation and culture,Education,Restaurants and hotels,Miscellaneous goods and services,headline CPI,date
0,0,cpi_M201701,81.6,78.5,92.7,82.1,90.5,79.3,76.7,103.7,95.9,74.2,88.2,78.4,82.4,2017-01
1,1,cpi_M201702,82.2,78.3,93.0,82.1,90.2,82.3,77.3,103.6,96.1,74.2,89.3,81.6,83.1,2017-02
2,2,cpi_M201703,82.4,79.5,93.3,82.8,90.7,83.0,77.2,102.9,96.2,79.4,89.4,81.7,83.6,2017-03
3,3,cpi_M201704,82.3,80.1,93.3,82.8,90.5,83.4,77.0,102.9,96.4,79.4,89.2,82.0,83.6,2017-04
4,4,cpi_M201705,82.7,80.3,93.5,82.8,90.6,83.6,77.8,102.7,96.3,79.4,88.5,82.0,83.8,2017-05
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
72,72,cpi_M202301,114.4,106.5,102.9,104.1,106.6,104.9,109.9,99.4,103.4,104.4,106.8,105.4,107.1,2023-01
73,73,cpi_M202302,115.5,106.9,103.5,104.2,106.6,108.5,110.6,99.8,103.3,104.4,108.8,107.7,107.9,2023-02
74,74,cpi_M202303,116.7,109.2,103.4,104.5,107.8,109.1,112.9,99.7,104.3,110.4,109.6,107.8,109.0,2023-03
75,75,cpi_M202304,117.4,110.2,103.7,104.6,107.7,109.5,113.1,99.8,104.9,110.4,108.6,109.3,109.4,2023-04


In [119]:
list(cpi.columns[2:-1])

['Food and non-alcoholic beverages',
 'Alcoholic beverages and tobacco',
 'Clothing and footwear',
 'Housing and utilities',
 'Household contents and services',
 'Health',
 'Transport',
 'Communication',
 'Recreation and culture',
 'Education',
 'Restaurants and hotels',
 'Miscellaneous goods and services',
 'headline CPI']