# Импорт библиотек

In [89]:
import pandas as pd
import numpy as np

from sklearn.model_selection import TimeSeriesSplit
from sklearn.metrics import mean_absolute_error

from catboost import CatBoostRegressor
from catboost import Pool

# Глобальные переменные

In [90]:
RANDOM_STATE = 123

# Загрузка данных

In [91]:
train_df = pd.read_csv('train_dataset.csv')

In [92]:
test_df = pd.read_csv('test_dataset.csv')

In [93]:
# джойним дату и время в один признак

# train
train_df['date_time'] = pd.to_datetime(train_df['date'] + ' ' + train_df['time'].astype(str) + ':00')

# test
test_df['date_time'] = pd.to_datetime(test_df['date'] + ' ' + test_df['time'].astype(str) + ':00')

In [94]:
# дату установим как индекс df

# train
train_df = train_df.set_index('date_time').sort_index()

# test
test_df = test_df.set_index('date_time').sort_index()

In [95]:
ouliers_date_list = [
 '01-01',
 '01-02',
 '01-04',
 '01-05',
 '01-07',
 '02-23',
 '03-05',
 '03-08',
 '03-11',
 '04-07',
 '04-15',
 '05-01',
 '05-06',
 '05-09',
 '05-10',
 '11-04',
 '12-28',
 '12-29',
 '12-30',
 '12-31'
]

In [96]:
test_index = test_df.index

## Объединяем train и test

In [97]:
train_test_df = pd.concat([train_df, test_df])

## Обрабатываем NA в temp_pred

In [98]:
na_values_index = train_test_df[train_test_df['temp_pred'].isna()].index.strftime("%Y-%m-%d").unique()
na_values_index = pd.to_datetime(na_values_index)

In [99]:
for day in na_values_index:
    try:
        # вычисляем среднюю температуру за предыдыщий день
        fix_pred_temp = train_test_df.loc[(day - pd.Timedelta(days=1)).strftime("%Y-%m-%d")]\
            .groupby('date')\
            .agg({'temp':'mean'})

        # создаем лист со значениями
        fix_pred_temp_list = np.ones(train_test_df.loc[day.strftime("%Y-%m-%d")]['temp_pred'].shape[0])*\
                                     fix_pred_temp.values[0]

        # замена
        train_test_df.loc[day.strftime("%Y-%m-%d"), ['temp_pred']] = fix_pred_temp_list
        
    except:
        train_test_df.loc[day.strftime("%Y-%m-%d")].fillna(0)

## Агрегация значений

In [100]:
train_test_df_agg = train_test_df.groupby('date').agg({'target':sum, 'temp':'mean', 'temp_pred':'mean'})

## shift, ma

In [101]:
ROLLING = 8

# shift target
train_test_df_agg['target_shift_1'] = train_test_df_agg['target'].shift(1)

# shift temp
train_test_df_agg['temp_shift_1'] = train_test_df_agg['temp'].shift(1)

# ma
train_test_df_agg['ma'] = train_test_df_agg['target'].shift(1).rolling(ROLLING).mean()
train_test_df_agg['temp_ma'] = train_test_df_agg['temp'].shift(1).rolling(ROLLING).mean()



In [103]:
# temp удалим на этапе формирования X, y

train_test_df_agg.head(10)

Unnamed: 0_level_0,target,temp,temp_pred,target_shift_1,temp_shift_1,ma,temp_ma
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
2019-01-01,11126.866,4.875,4.375,,,,
2019-01-02,11547.21,1.35,1.5,11126.866,4.875,,
2019-01-03,12235.564,-1.1625,-1.375,11547.21,1.35,,
2019-01-04,12763.044,-1.5,-0.125,12235.564,-1.1625,,
2019-01-05,12735.145,0.9375,1.5,12763.044,-1.5,,
2019-01-06,12744.419,-3.3,-3.75,12735.145,0.9375,,
2019-01-07,12719.935,-1.425,-3.0,12744.419,-3.3,,
2019-01-08,13185.565,-0.925,0.25,12719.935,-1.425,,
2019-01-09,13949.53,-2.125,-0.5,13185.565,-0.925,12382.2185,-0.14375
2019-01-10,14339.844,-6.475,-3.5,13949.53,-2.125,12735.0515,-1.01875


In [104]:
# удаляем пропуски
train_test_df_agg = train_test_df_agg.dropna(axis=0)

## временные признаки

In [105]:
# выделим временные признаки

# преобразуем inxex в dt
train_test_df_agg.index = pd.to_datetime(train_test_df_agg.index)

train_test_df_agg['day_name'] = train_test_df_agg.index.day_name()

## праздники

In [106]:
def ouliers_date_feature(row):
    try:
        if np.isin(pd.Timestamp(row.name).strftime("%m-%d"), ouliers_date_list):
            return pd.Timestamp(row.name).strftime("%m-%d")
        else:
            return str('not_holiday_date')
    except:
        return str('not_holiday_date')

In [107]:
train_test_df_agg['holiday_date'] = train_test_df_agg.apply(ouliers_date_feature, axis=1)

# train, test

In [108]:
train_sample = train_test_df_agg[~train_test_df_agg.index.isin(test_index)]

In [109]:
test_sample = train_test_df_agg[train_test_df_agg.index.isin(test_index)]

In [110]:
lr_tscv = TimeSeriesSplit(n_splits=5)

In [111]:
features = ['target_shift_1',
            'temp_pred',
            'ma',
            'day_name',
            'holiday_date',
            'temp_ma'
                    ]

In [112]:
# X, y train

X_train = train_sample[features]

y_train = train_sample['target']

In [113]:
# X, y test

X_test = test_sample[features]

y_test = test_sample['target']

## CatBoost 

In [114]:
cat_train_pool = Pool(data=X_train,
                 label=y_train,
                 cat_features = ['day_name', 'holiday_date'],
                  has_header=True
                 )

  self._init_pool(data, label, cat_features, text_features, embedding_features, pairs, weight, group_id, group_weight, subgroup_id, pairs_weight, baseline, feature_names, thread_count)


In [115]:
cat_regressor = CatBoostRegressor(random_seed=RANDOM_STATE, eval_metric='MAE', verbose=False)

In [116]:
cat_grid = {
        'iterations' : [2000],
        'verbose': [100],
        'l2_leaf_reg': [1, 3, 7],
        'auto_class_weights': ['None', 'Balanced', 'SqrtBalanced'],
        'depth': np.arange(2, 8) #[2, 4, 8],

       }

In [117]:
lr_tscv = TimeSeriesSplit(n_splits=5)

In [118]:
grid_search_result = cat_regressor.grid_search(cat_grid,
                                     cat_train_pool,
                                     cv=lr_tscv,
#                                      stratified=True,
                                              plot=True)

MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))

0:	learn: 11217.6104288	test: 11420.6806044	best: 11420.6806044 (0)	total: 1.16ms	remaining: 2.32s
100:	learn: 606.0983792	test: 677.4476565	best: 677.4476565 (100)	total: 70.4ms	remaining: 1.32s
200:	learn: 297.7797339	test: 264.8827033	best: 264.8827033 (200)	total: 116ms	remaining: 1.04s
300:	learn: 274.4545556	test: 239.5659381	best: 239.5659381 (300)	total: 163ms	remaining: 919ms
400:	learn: 258.7904533	test: 227.0917658	best: 227.0917658 (400)	total: 212ms	remaining: 846ms
500:	learn: 245.0985770	test: 217.3637141	best: 217.2309020 (495)	total: 262ms	remaining: 784ms
600:	learn: 232.7319180	test: 206.2768321	best: 206.2768321 (600)	total: 307ms	remaining: 715ms
700:	learn: 219.8719276	test: 195.8068372	best: 195.7678940 (699)	total: 355ms	remaining: 658ms
800:	learn: 208.8132022	test: 186.4672955	best: 186.4672955 (800)	total: 401ms	remaining: 600ms
900:	learn: 197.8373093	test: 175.8303647	best: 175.8303647 (900)	total: 456ms	remaining: 556ms
1000:	learn: 191.6305216	test: 171.7

1900:	learn: 134.6519452	test: 140.4290825	best: 140.3266299 (1890)	total: 1.38s	remaining: 71.8ms
1999:	learn: 133.0928793	test: 140.2944152	best: 140.2552953 (1965)	total: 1.45s	remaining: 0us

bestTest = 140.2552953
bestIteration = 1965

3:	loss: 140.2552953	best: 140.2552953 (3)	total: 5.22s	remaining: 1m 5s
0:	learn: 11220.9555291	test: 11427.7039455	best: 11427.7039455 (0)	total: 726us	remaining: 1.45s
100:	learn: 607.3842100	test: 675.7700019	best: 675.7700019 (100)	total: 62.6ms	remaining: 1.18s
200:	learn: 289.1887504	test: 266.2662287	best: 266.2662287 (200)	total: 127ms	remaining: 1.14s
300:	learn: 262.1805523	test: 235.3636847	best: 235.3636847 (300)	total: 192ms	remaining: 1.08s
400:	learn: 241.9221413	test: 217.9962590	best: 217.9819153 (399)	total: 255ms	remaining: 1.01s
500:	learn: 223.9880781	test: 204.9001311	best: 204.9001311 (500)	total: 319ms	remaining: 956ms
600:	learn: 208.4724495	test: 187.9334961	best: 187.8730553 (599)	total: 388ms	remaining: 903ms
700:	learn:

1600:	learn: 134.9494169	test: 161.6534041	best: 161.6409581 (1597)	total: 1.41s	remaining: 351ms
1700:	learn: 131.9311978	test: 161.0645128	best: 161.0138380 (1655)	total: 1.49s	remaining: 262ms
1800:	learn: 129.2377640	test: 160.7425006	best: 160.7405228 (1799)	total: 1.57s	remaining: 174ms
1900:	learn: 127.2952465	test: 160.3550147	best: 160.3550147 (1900)	total: 1.66s	remaining: 86.5ms
1999:	learn: 124.6578045	test: 160.2762632	best: 160.0980468 (1943)	total: 1.74s	remaining: 0us

bestTest = 160.0980468
bestIteration = 1943

7:	loss: 160.0980468	best: 140.2552953 (3)	total: 11.6s	remaining: 1m 6s
0:	learn: 11227.6467335	test: 11434.1100079	best: 11434.1100079 (0)	total: 1.14ms	remaining: 2.29s
100:	learn: 632.8686061	test: 698.4477680	best: 698.4477680 (100)	total: 73.9ms	remaining: 1.39s
200:	learn: 288.6585530	test: 271.8100456	best: 271.8100456 (200)	total: 162ms	remaining: 1.45s
300:	learn: 260.8447036	test: 240.6641219	best: 240.6641219 (300)	total: 240ms	remaining: 1.36s
400:

1300:	learn: 153.0737207	test: 181.8074068	best: 181.8074068 (1300)	total: 1.57s	remaining: 844ms
1400:	learn: 147.5114734	test: 178.7979408	best: 178.7954157 (1398)	total: 1.68s	remaining: 720ms
1500:	learn: 143.0723582	test: 176.3781686	best: 176.3781686 (1500)	total: 1.8s	remaining: 600ms
1600:	learn: 139.3468737	test: 174.4585053	best: 174.4585053 (1600)	total: 1.92s	remaining: 478ms
1700:	learn: 135.7568032	test: 172.7009038	best: 172.6952783 (1699)	total: 2.04s	remaining: 359ms
1800:	learn: 131.9032911	test: 170.6579124	best: 170.6402037 (1799)	total: 2.15s	remaining: 238ms
1900:	learn: 128.8695529	test: 169.2984669	best: 169.2902634 (1899)	total: 2.28s	remaining: 119ms
1999:	learn: 125.8526807	test: 168.1958033	best: 168.1658270 (1992)	total: 2.39s	remaining: 0us

bestTest = 168.165827
bestIteration = 1992

11:	loss: 168.1658270	best: 140.2552953 (3)	total: 21s	remaining: 1m 13s
0:	learn: 11220.2246903	test: 11425.6687030	best: 11425.6687030 (0)	total: 1.69ms	remaining: 3.38s
10

900:	learn: 110.8862424	test: 168.1488830	best: 168.1238914 (899)	total: 2.14s	remaining: 2.61s
1000:	learn: 103.2310498	test: 167.1171890	best: 167.1171890 (1000)	total: 2.38s	remaining: 2.37s
1100:	learn: 96.5210969	test: 165.4302634	best: 165.4302634 (1100)	total: 2.62s	remaining: 2.14s
1200:	learn: 91.2290062	test: 164.4431151	best: 164.4431151 (1200)	total: 2.85s	remaining: 1.9s
1300:	learn: 85.9212453	test: 164.0295029	best: 163.8368389 (1243)	total: 3.11s	remaining: 1.67s
1400:	learn: 80.9040763	test: 162.9885631	best: 162.9885631 (1400)	total: 3.34s	remaining: 1.43s
1500:	learn: 76.5488391	test: 162.8928099	best: 162.5625592 (1460)	total: 3.59s	remaining: 1.19s
1600:	learn: 72.5703148	test: 162.9079016	best: 162.5625592 (1460)	total: 3.82s	remaining: 953ms
1700:	learn: 69.1148825	test: 162.8514781	best: 162.5625592 (1460)	total: 4.06s	remaining: 714ms
1800:	learn: 65.9935230	test: 162.8167685	best: 162.5625592 (1460)	total: 4.31s	remaining: 477ms
1900:	learn: 63.1468926	test: 1

In [88]:
mean_absolute_error(y_test, cat_regressor.predict(X_test))

128.76219301668067