In [None]:
from typing import Union

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm_notebook as tqdm

class WRMSSEEvaluator(object):
    def __init__(self, train_df: pd.DataFrame, valid_df: pd.DataFrame, calendar: pd.DataFrame, prices: pd.DataFrame):
        train_y = train_df.loc[:, train_df.columns.str.startswith('d_')]
        train_target_columns = train_y.columns.tolist()
        weight_columns = train_y.iloc[:, -28:].columns.tolist()

        train_df['all_id'] = 0  # for lv1 aggregation

        id_columns = train_df.loc[:, ~train_df.columns.str.startswith('d_')].columns.tolist()
        valid_target_columns = valid_df.loc[:, valid_df.columns.str.startswith('d_')].columns.tolist()

        if not all([c in valid_df.columns for c in id_columns]):
            valid_df = pd.concat([train_df[id_columns], valid_df], axis=1, sort=False)

        self.train_df = train_df
        self.valid_df = valid_df
        self.calendar = calendar
        self.prices = prices

        self.weight_columns = weight_columns
        self.id_columns = id_columns
        self.valid_target_columns = valid_target_columns

        weight_df = self.get_weight_df()

        self.group_ids = (
            'all_id',
            'state_id',
            'store_id',
            'cat_id',
            'dept_id',
            ['state_id', 'cat_id'],
            ['state_id', 'dept_id'],
            ['store_id', 'cat_id'],
            ['store_id', 'dept_id'],
            'item_id',
            ['item_id', 'state_id'],
            ['item_id', 'store_id']
        )

        for i, group_id in enumerate(tqdm(self.group_ids)):
            train_y = train_df.groupby(group_id)[train_target_columns].sum()
            scale = []
            for _, row in train_y.iterrows():
                series = row.values[np.argmax(row.values != 0):]
                scale.append(((series[1:] - series[:-1]) ** 2).mean())
            setattr(self, f'lv{i + 1}_scale', np.array(scale))
            setattr(self, f'lv{i + 1}_train_df', train_y)
            setattr(self, f'lv{i + 1}_valid_df', valid_df.groupby(group_id)[valid_target_columns].sum())

            lv_weight = weight_df.groupby(group_id)[weight_columns].sum().sum(axis=1)
            setattr(self, f'lv{i + 1}_weight', lv_weight / lv_weight.sum())

    def get_weight_df(self) -> pd.DataFrame:
        day_to_week = self.calendar.set_index('d')['wm_yr_wk'].to_dict()
        weight_df = self.train_df[['item_id', 'store_id'] + self.weight_columns].set_index(['item_id', 'store_id'])
        weight_df = weight_df.stack().reset_index().rename(columns={'level_2': 'd', 0: 'value'})
        weight_df['wm_yr_wk'] = weight_df['d'].map(day_to_week)

        weight_df = weight_df.merge(self.prices, how='left', on=['item_id', 'store_id', 'wm_yr_wk'])
        weight_df['value'] = weight_df['value'] * weight_df['sell_price']
        weight_df = weight_df.set_index(['item_id', 'store_id', 'd']).unstack(level=2)['value']
        weight_df = weight_df.loc[zip(self.train_df.item_id, self.train_df.store_id), :].reset_index(drop=True)
        weight_df = pd.concat([self.train_df[self.id_columns], weight_df], axis=1, sort=False)
        return weight_df

    def rmsse(self, valid_preds: pd.DataFrame, lv: int) -> pd.Series:
        valid_y = getattr(self, f'lv{lv}_valid_df')
        score = ((valid_y - valid_preds) ** 2).mean(axis=1)
        scale = getattr(self, f'lv{lv}_scale')
        return (score / scale).map(np.sqrt)

    def score(self, valid_preds: Union[pd.DataFrame, np.ndarray]) -> float:
        assert self.valid_df[self.valid_target_columns].shape == valid_preds.shape

        if isinstance(valid_preds, np.ndarray):
            valid_preds = pd.DataFrame(valid_preds, columns=self.valid_target_columns)

        valid_preds = pd.concat([self.valid_df[self.id_columns], valid_preds], axis=1, sort=False)

        all_scores = []
        for i, group_id in enumerate(self.group_ids):
            lv_scores = self.rmsse(valid_preds.groupby(group_id)[self.valid_target_columns].sum(), i + 1)
            weight = getattr(self, f'lv{i + 1}_weight')
            lv_scores = pd.concat([weight, lv_scores], axis=1, sort=False).prod(axis=1)
            all_scores.append(lv_scores.sum())
        print(all_scores)
        return np.mean(all_scores)


In [12]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import root_mean_squared_error, mean_squared_error
from statsmodels.tsa.deterministic import DeterministicProcess
from sklearn.linear_model import LinearRegression

DIRECTORY = './data/'
calendar = pd.read_csv(DIRECTORY + 'calendar.csv', parse_dates=['date'])
train_validation = pd.read_csv(DIRECTORY + 'sales_train_validation.csv')
train_evaluation = pd.read_csv(DIRECTORY + 'sales_train_evaluation.csv')
prices = pd.read_csv(DIRECTORY + 'sell_prices.csv')
sample_submission  = pd.read_csv(DIRECTORY + 'sample_submission.csv')

In [13]:
# Testing 28D rolling prediction
train_fold_df = train_validation.iloc[:, :-28]
valid_fold_df = train_validation.iloc[:, -28:]
valid_preds = np.tile(train_fold_df.iloc[:,-28:].mean(axis=1), (28,1)).T

evaluator = WRMSSEEvaluator(train_fold_df, valid_fold_df, calendar, prices)
evaluator.score(valid_preds)

  0%|          | 0/12 [00:00<?, ?it/s]

[1.200108159846032, 1.1513355400185454, 1.1188410615326518, 1.1938087366742873, 1.230389141317233, 1.1457148998869913, 1.1689707810794892, 1.1110844188708668, 1.1130617602352035, 0.9549363083348601, 0.9106288489597526, 0.8709262813021107]


1.0974838281715018

In [6]:
# Testing Multi LR prediction
Y = train_validation.select_dtypes('number').T
dp = DeterministicProcess(index=Y.index, constant=True, order=1)
X = dp.in_sample()

Y_train = Y.iloc[:-28,:]
X_train = X.iloc[:-28,:]

Y_test = Y.iloc[-28:,:]
X_test = X.iloc[-28:,:]

model = LinearRegression(fit_intercept=False)
model.fit(X_train, Y_train)

predictions = model.predict(X_test)

In [7]:
evaluator = WRMSSEEvaluator(train_fold_df, valid_fold_df, calendar, prices)
evaluator.score(predictions.T)

Unnamed: 0,item_id,store_id,d,value,wm_yr_wk
0,HOBBIES_1_001,CA_1,d_1858,0,11605
1,HOBBIES_1_001,CA_1,d_1859,2,11605
2,HOBBIES_1_001,CA_1,d_1860,0,11605
3,HOBBIES_1_001,CA_1,d_1861,1,11605
4,HOBBIES_1_001,CA_1,d_1862,1,11605
...,...,...,...,...,...
853715,FOODS_3_827,WI_3,d_1881,5,11608
853716,FOODS_3_827,WI_3,d_1882,3,11608
853717,FOODS_3_827,WI_3,d_1883,2,11608
853718,FOODS_3_827,WI_3,d_1884,0,11609


  0%|          | 0/12 [00:00<?, ?it/s]

1.1816991455350472

In [8]:
weight_df = evaluator.get_weight_df()

Unnamed: 0,item_id,store_id,d,value,wm_yr_wk
0,HOBBIES_1_001,CA_1,d_1858,0,11605
1,HOBBIES_1_001,CA_1,d_1859,2,11605
2,HOBBIES_1_001,CA_1,d_1860,0,11605
3,HOBBIES_1_001,CA_1,d_1861,1,11605
4,HOBBIES_1_001,CA_1,d_1862,1,11605
...,...,...,...,...,...
853715,FOODS_3_827,WI_3,d_1881,5,11608
853716,FOODS_3_827,WI_3,d_1882,3,11608
853717,FOODS_3_827,WI_3,d_1883,2,11608
853718,FOODS_3_827,WI_3,d_1884,0,11609


In [10]:
weight_df

Unnamed: 0,id,item_id,dept_id,cat_id,store_id,state_id,all_id,d_1858,d_1859,d_1860,...,d_1876,d_1877,d_1878,d_1879,d_1880,d_1881,d_1882,d_1883,d_1884,d_1885
0,HOBBIES_1_001_CA_1_validation,HOBBIES_1_001,HOBBIES_1,HOBBIES,CA_1,CA,0,0.00,16.52,0.00,...,24.78,8.26,24.78,8.26,16.52,16.52,0.00,8.26,8.26,8.26
1,HOBBIES_1_002_CA_1_validation,HOBBIES_1_002,HOBBIES_1,HOBBIES,CA_1,CA,0,0.00,0.00,0.00,...,0.00,0.00,0.00,0.00,0.00,0.00,3.97,3.97,3.97,3.97
2,HOBBIES_1_003_CA_1_validation,HOBBIES_1_003,HOBBIES_1,HOBBIES,CA_1,CA,0,0.00,0.00,0.00,...,2.97,0.00,0.00,0.00,0.00,0.00,0.00,2.97,2.97,0.00
3,HOBBIES_1_004_CA_1_validation,HOBBIES_1_004,HOBBIES_1,HOBBIES,CA_1,CA,0,0.00,0.00,0.00,...,18.56,9.28,4.64,18.56,4.64,13.92,23.20,0.00,27.84,27.84
4,HOBBIES_1_005_CA_1_validation,HOBBIES_1_005,HOBBIES_1,HOBBIES,CA_1,CA,0,2.88,5.76,2.88,...,8.64,5.76,5.76,5.76,8.64,2.88,0.00,0.00,0.00,0.00
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
30485,FOODS_3_823_WI_3_validation,FOODS_3_823,FOODS_3,FOODS,WI_3,WI,0,0.00,0.00,5.76,...,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00
30486,FOODS_3_824_WI_3_validation,FOODS_3_824,FOODS_3,FOODS,WI_3,WI,0,0.00,0.00,0.00,...,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00
30487,FOODS_3_825_WI_3_validation,FOODS_3_825,FOODS_3,FOODS,WI_3,WI,0,3.98,7.96,7.96,...,7.96,0.00,0.00,7.96,0.00,0.00,0.00,7.96,0.00,3.98
30488,FOODS_3_826_WI_3_validation,FOODS_3_826,FOODS_3,FOODS,WI_3,WI,0,2.56,2.56,2.56,...,1.28,1.28,1.28,1.28,0.00,2.56,1.28,1.28,2.56,5.12
