## Import

In [10]:
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn

import pandas as pd

import os

import numpy as np

from skorch import NeuralNetRegressor
from skorch.callbacks import EarlyStopping, Checkpoint, LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from skorch.helper import predefined_split
from skorch.dataset import Dataset

from models import FFNeuralNetwork, LSTMNeuralNetwork, LSTMDataset
from utilities import create_scaled_data_by_col, rmsle

%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Loading Data and Scaling

In [11]:
data_dir = 'data/'
df = pd.read_csv(os.path.join(data_dir, 'train_data.csv'))
df['date'] = pd.to_datetime(df['date'])
df = df.sort_values(by=['date', 'store_nbr'])
display(df.head())

Unnamed: 0,date,store_nbr,family,sales,onpromotion,city,state,store_type,cluster,oil,...,prev_1_sales,prev_7_sales,prev_14_sales,prev_1_transactions,prev_7_transactions,prev_14_transactions,is_payday,is_earth_quake,is_weekday,is_weekend
0,2013-02-01,1,0,3.0,0,0,0,0,13,97.46,...,0.0,0.0,0.0,0.0,0.0,0.0,False,False,1,0
1,2013-02-01,1,1,0.0,0,0,0,0,13,97.46,...,0.0,0.0,0.0,0.0,0.0,0.0,False,False,1,0
2,2013-02-01,1,2,0.0,0,0,0,0,13,97.46,...,0.0,0.0,0.0,0.0,0.0,0.0,False,False,1,0
3,2013-02-01,1,3,941.0,0,0,0,0,13,97.46,...,0.0,0.0,0.0,0.0,0.0,0.0,False,False,1,0
4,2013-02-01,1,4,0.0,0,0,0,0,13,97.46,...,0.0,0.0,0.0,0.0,0.0,0.0,False,False,1,0


In [12]:
min_max_cols = ['store_nbr', 'city', 'state', 'store_type', 'cluster', 'h_type_nat', 'h_description_nat', 'h_transferred_nat',
                'h_type_loc', 'h_description_loc', 'h_transferred_loc', 'month', 'day', 'day_of_week', 'is_payday',
                'is_earth_quake', 'is_weekday', 'is_weekend'
                ]
normalize_cols = ['onpromotion', 'oil', 'dow_avg_sales', 'dow_rolling_1_sales', 'dow_rolling_3_sales', 'prev_1_sales',
                'prev_7_sales', 'prev_14_sales', 'dow_avg_transactions', 'dow_rolling_1_transactions',
                'dow_rolling_3_transactions', 'prev_1_transactions', 'prev_7_transactions', 'prev_14_transactions'
                ]
x_cols = min_max_cols + normalize_cols
y_cols = ['sales']
split_col = 'family'

print(min_max_cols)
print(df.columns)

final_run = False

if final_run:
    train_df = df
else:
    rows_before = (df['date'] < '2017-08-01')
    rows_after = ~rows_before

    print('rows_before', rows_before.sum())
    print('rows_after', rows_after.sum())
    print('rows_total', len(df))

    train_df = df[rows_before]
    val_df = df[rows_after]

train_df_by_cluster = {}
scaler_x_by_cluster = {}
scaler_y_by_cluster = {}

for cluster in df[split_col].unique():
    cluster_df, cluster_min_max_scaler, cluster_normalize_scaler, cluster_y_scaler = create_scaled_data_by_col(train_df, min_max_cols, normalize_cols, y_cols, split_col, cluster)
    train_df_by_cluster[cluster] = cluster_df
    scaler_x_by_cluster[cluster] = (cluster_min_max_scaler, cluster_normalize_scaler)
    scaler_y_by_cluster[cluster] = cluster_y_scaler

if not final_run:
    val_df_by_cluster = {}

    for cluster in df[split_col].unique():
        val_cluster_min_max_scaler, val_cluster_normalize_scaler = scaler_x_by_cluster[cluster]
        val_cluster_y_scaler = scaler_y_by_cluster[cluster]

        val_cluster_df = val_df[val_df[split_col] == cluster]
        val_cluster_df = val_cluster_df.drop(columns=split_col)

        val_cluster_x_min_max = val_cluster_df[min_max_cols].values.astype(np.float32)
        val_cluster_x_normalize = val_cluster_df[normalize_cols].values.astype(np.float32)
        val_cluster_y = val_cluster_df[y_cols].values.reshape(-1, len(y_cols)).astype(np.float32)

        val_cluster_x_min_max = val_cluster_min_max_scaler.transform(val_cluster_x_min_max)
        val_cluster_x_normalize = val_cluster_normalize_scaler.transform(val_cluster_x_normalize)
        val_cluster_y = val_cluster_y_scaler.transform(val_cluster_y)

        val_cluster_df[min_max_cols] = val_cluster_x_min_max
        val_cluster_df[normalize_cols] = val_cluster_x_normalize
        val_cluster_df[y_cols] = val_cluster_y

        val_df_by_cluster[cluster] = val_cluster_df

['store_nbr', 'city', 'state', 'store_type', 'cluster', 'h_type_nat', 'h_description_nat', 'h_transferred_nat', 'h_type_loc', 'h_description_loc', 'h_transferred_loc', 'month', 'day', 'day_of_week', 'is_payday', 'is_earth_quake', 'is_weekday', 'is_weekend']
Index(['date', 'store_nbr', 'family', 'sales', 'onpromotion', 'city', 'state',
       'store_type', 'cluster', 'oil', 'h_type_nat', 'h_description_nat',
       'h_transferred_nat', 'h_type_loc', 'h_description_loc',
       'h_transferred_loc', 'transactions', 'year', 'month', 'day',
       'day_of_week', 'dow_avg_sales', 'dow_rolling_1_sales',
       'dow_rolling_3_sales', 'dow_rolling_7_sales', 'dow_avg_transactions',
       'dow_rolling_1_transactions', 'dow_rolling_3_transactions',
       'dow_rolling_7_transactions', 'prev_1_sales', 'prev_7_sales',
       'prev_14_sales', 'prev_1_transactions', 'prev_7_transactions',
       'prev_14_transactions', 'is_payday', 'is_earth_quake', 'is_weekday',
       'is_weekend'],
      dtype='ob

## NN Training

In [13]:
net_by_cluster = {}
train_params = {
                "criterion": nn.L1Loss,
                "optimizer": torch.optim.AdamW,
                "optimizer__weight_decay": 1e-8,
                #'train_split' : None,
                #"train_split": predefined_split(Dataset(val_x, val_y)),
                "lr": 0.001,
                "batch_size": 32,
                "max_epochs": 1000,
                "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
                "iterator_train__shuffle": False,
                "iterator_train__num_workers": 2,
                "iterator_train__pin_memory": True,
                "iterator_valid__shuffle": False,
                "iterator_valid__num_workers": 2,
                "iterator_valid__pin_memory": True,
                "verbose": 2,
        }

net_params = {
    'input_dim': len(x_cols),
    'out_dim': 1,
    'hidden_dim': 200,
    'num_hidden_layers': 6,
    }

In [5]:
for cluster in df[split_col].unique():
    train_df = train_df_by_cluster[cluster]
    train_x = train_df[x_cols].values.astype(np.float32)
    train_y = train_df[y_cols].values.reshape(-1, len(y_cols)).astype(np.float32)

    if not final_run:
        val_df = val_df_by_cluster[cluster]
        train_params['train_split'] = predefined_split(Dataset(val_df[x_cols].values.astype(np.float32), val_df[y_cols].values.reshape(-1, len(y_cols)).astype(np.float32)))
    else:
        train_params['train_split'] = None

    callbacks = [EarlyStopping(patience=10, threshold=0.0001, threshold_mode='abs', monitor='valid_loss', lower_is_better=True),
            Checkpoint(monitor='valid_loss_best', f_params=f'sales_forecaster_{cluster}.pt', dirname='models/'),
            LRScheduler(policy=ReduceLROnPlateau, monitor='train_loss', factor=0.5, patience=5, threshold=0.001, threshold_mode='abs', mode='min', verbose=True)
            ]

    train_params['callbacks'] = callbacks

    net = NeuralNetRegressor(FFNeuralNetwork(**net_params), **train_params)

    print(cluster)

     
    net.fit(train_x, train_y)
    net_by_cluster[cluster] = net

0




  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.4307[0m        [32m0.4698[0m     +  4.1980
      2        [36m0.4219[0m        [32m0.4683[0m     +  4.1347
      3        [36m0.4196[0m        0.4694        4.1510
      4        [36m0.4174[0m        0.4707        4.2521
      5        [36m0.4155[0m        0.4757        4.1395
      6        [36m0.4136[0m        0.4734        4.2411
      7        [36m0.4125[0m        0.4704        4.2442
      8        [36m0.4107[0m        0.4728        4.3170
      9        [36m0.4091[0m        0.4697        4.2836
     10        [36m0.4077[0m        0.4718        4.2592
     11        [36m0.4062[0m        0.4702        4.2804
Stopping since valid_loss has not improved in the last 10 epochs.
1




  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.1678[0m        [32m0.2509[0m     +  4.2360
      2        [36m0.1549[0m        0.2510        4.2950
      3        [36m0.1538[0m        0.2528        4.3367
      4        [36m0.1528[0m        0.2550        4.2269
      5        0.1528        0.2545        4.2620
      6        0.1528        0.2554        4.3451
      7        [36m0.1519[0m        0.2557        4.3433
      8        [36m0.1506[0m        0.2565        4.2788
      9        0.1512        0.2573        4.2632
     10        [36m0.1500[0m        0.2586        4.3233
Stopping since valid_loss has not improved in the last 10 epochs.
2




  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.3475[0m        [32m0.5007[0m     +  4.2809
      2        [36m0.3367[0m        0.5082        4.3656
      3        [36m0.3338[0m        0.5035        4.1999
      4        [36m0.3314[0m        0.5056        4.3248
      5        [36m0.3290[0m        0.5101        4.2832
      6        [36m0.3272[0m        0.5015        4.3515
      7        [36m0.3259[0m        [32m0.4945[0m     +  4.2395
      8        [36m0.3252[0m        0.4973        4.3235
      9        [36m0.3236[0m        0.4952        4.2196
     10        [36m0.3226[0m        0.4972        4.2369
     11        [36m0.3218[0m        [32m0.4943[0m     +  4.2873
     12        [36m0.3207[0m        [32m0.4943[0m     +  4.3350
     13        [36m0.3188[0m        0.4948        4.2264
     14        [36m0.3187[0m        0.4972        4.3855
     15        [36m0.3176[0m        



  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.1771[0m        [32m0.2524[0m     +  4.2712
      2        [36m0.1623[0m        [32m0.2481[0m     +  4.3634
      3        [36m0.1575[0m        [32m0.2451[0m     +  4.2680
      4        [36m0.1542[0m        0.2489        4.2553
      5        [36m0.1517[0m        0.2470        4.2615
      6        [36m0.1485[0m        [32m0.2429[0m     +  4.2489
      7        [36m0.1461[0m        [32m0.2419[0m     +  4.2631
      8        [36m0.1437[0m        0.2449        4.2755
      9        [36m0.1414[0m        0.2424        4.1918
     10        [36m0.1406[0m        0.2453        4.2885
     11        [36m0.1385[0m        0.2425        4.2298
     12        [36m0.1370[0m        [32m0.2382[0m     +  4.2409
     13        [36m0.1361[0m        [32m0.2375[0m     +  4.2168
     14        [36m0.1358[0m        0.2418        4.2699
     15    



  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.1150[0m        [32m0.0197[0m     +  4.2958
      2        [36m0.1024[0m        [32m0.0195[0m     +  4.3178
      3        [36m0.1004[0m        [32m0.0191[0m     +  4.2740
      4        [36m0.0996[0m        [32m0.0186[0m     +  4.4214
      5        [36m0.0988[0m        0.0187        4.3444
      6        [36m0.0988[0m        [32m0.0183[0m     +  4.3935
      7        [36m0.0981[0m        0.0185        4.4099
      8        0.0989        0.0188        4.3617
      9        [36m0.0978[0m        0.0191        5.3830
     10        [36m0.0977[0m        0.0189        4.9651
     11        [36m0.0970[0m        0.0194        4.2544
     12        [36m0.0968[0m        0.0190        4.0513
     13        [36m0.0964[0m        0.0185        3.9940
     14        [36m0.0959[0m        0.0190        4.0262
     15        [36m0.0956[0m        



  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.1809[0m        [32m0.1865[0m     +  4.1560
      2        [36m0.1648[0m        [32m0.1824[0m     +  4.1197
      3        [36m0.1588[0m        [32m0.1814[0m     +  4.2058
      4        [36m0.1550[0m        [32m0.1746[0m     +  4.1649
      5        [36m0.1508[0m        0.1747        4.1699
      6        [36m0.1494[0m        [32m0.1732[0m     +  4.0736
      7        [36m0.1468[0m        [32m0.1714[0m     +  4.1021
      8        [36m0.1451[0m        0.1748        4.1055
      9        [36m0.1443[0m        0.1736        4.0946
     10        [36m0.1440[0m        0.1718        4.1653
     11        0.1440        0.1721        3.9984
     12        [36m0.1419[0m        0.1742        3.9931
     13        [36m0.1412[0m        0.1736        3.9743
     14        [36m0.1408[0m        0.1726        4.0921
     15        [36m0.1398[0



  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.2488[0m        [32m0.3126[0m     +  4.1218
      2        [36m0.2347[0m        [32m0.3104[0m     +  3.9878
      3        [36m0.2288[0m        0.3124        4.0494
      4        [36m0.2241[0m        0.3134        4.1828
      5        [36m0.2203[0m        0.3153        4.2142
      6        [36m0.2170[0m        0.3152        4.0823
      7        [36m0.2151[0m        [32m0.3097[0m     +  4.1009
      8        [36m0.2128[0m        0.3138        4.1995
      9        [36m0.2113[0m        0.3116        4.4984
     10        [36m0.2098[0m        0.3116        4.1515
     11        [36m0.2098[0m        0.3101        4.1006
     12        [36m0.2091[0m        0.3110        4.1306
     13        [36m0.2071[0m        [32m0.3049[0m     +  4.0095
     14        [36m0.2067[0m        0.3156        4.0896
     15        [36m0.2066[0m        



  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.2457[0m        [32m0.3541[0m     +  4.1649
      2        [36m0.2139[0m        0.3600        4.2424
      3        [36m0.2020[0m        0.3846        4.1415
      4        [36m0.1959[0m        0.3695        4.1057
      5        [36m0.1896[0m        0.3592        4.0490
      6        [36m0.1852[0m        [32m0.3338[0m     +  4.1499
      7        [36m0.1834[0m        0.3832        4.1573
      8        [36m0.1809[0m        0.3529        4.0071
      9        [36m0.1795[0m        0.3710        4.1537
     10        [36m0.1780[0m        0.3820        4.1746
     11        [36m0.1773[0m        0.3660        4.0503
     12        [36m0.1747[0m        0.3878        4.0941
     13        [36m0.1736[0m        0.3870        3.9750
     14        [36m0.1729[0m        0.3752        3.9381
     15        0.1735        0.3817        4.0695
Stoppi



  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.1517[0m        [32m0.1600[0m     +  4.0357
      2        [36m0.1367[0m        0.1606        3.9940
      3        [36m0.1318[0m        [32m0.1539[0m     +  4.0729
      4        [36m0.1260[0m        [32m0.1531[0m     +  4.2444
      5        [36m0.1222[0m        [32m0.1494[0m     +  4.2224
      6        [36m0.1204[0m        0.1531        4.2387
      7        [36m0.1186[0m        0.1577        4.3109
      8        [36m0.1176[0m        0.1525        4.2510
      9        [36m0.1164[0m        0.1510        4.2506
     10        [36m0.1156[0m        0.1538        4.2569
     11        [36m0.1143[0m        0.1495        4.3177
     12        [36m0.1138[0m        0.1513        4.4597
     13        [36m0.1135[0m        0.1529        4.1657
     14        [36m0.1124[0m        [32m0.1475[0m     +  4.1104
     15        [36m0.1119[0



  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.2133[0m        [32m0.2150[0m     +  4.4854
      2        [36m0.1908[0m        [32m0.1971[0m     +  4.4571
      3        [36m0.1803[0m        [32m0.1884[0m     +  4.4838
      4        [36m0.1755[0m        [32m0.1883[0m     +  4.5340
      5        [36m0.1729[0m        0.1899        4.5118
      6        [36m0.1707[0m        0.1931        4.5052
      7        [36m0.1688[0m        [32m0.1880[0m     +  4.5001
      8        [36m0.1672[0m        0.1880        4.4579
      9        [36m0.1658[0m        [32m0.1869[0m     +  4.5098
     10        [36m0.1649[0m        0.1881        4.4872
     11        [36m0.1648[0m        0.1902        4.4345
     12        [36m0.1631[0m        0.1968        4.2824
     13        [36m0.1623[0m        0.2004        4.2914
     14        [36m0.1621[0m        0.1945        4.2920
     15        [36m



  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.2224[0m        [32m0.2180[0m     +  4.3677
      2        [36m0.2054[0m        0.2204        4.3364
      3        [36m0.2007[0m        0.2242        4.2659
      4        [36m0.1965[0m        0.2256        4.2551
      5        [36m0.1925[0m        0.2263        4.3915
      6        [36m0.1904[0m        0.2311        4.4313
      7        [36m0.1884[0m        0.2310        4.3330
      8        [36m0.1866[0m        0.2311        4.3699
      9        [36m0.1863[0m        0.2315        4.2571
     10        [36m0.1850[0m        0.2291        4.2429
Stopping since valid_loss has not improved in the last 10 epochs.
11




  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.1475[0m        [32m0.0761[0m     +  4.2673
      2        [36m0.1372[0m        0.0798        4.3435
      3        [36m0.1356[0m        0.0796        4.2313
      4        [36m0.1337[0m        0.0806        4.2987
      5        [36m0.1300[0m        0.0819        4.2428
      6        [36m0.1261[0m        0.0816        4.1350
      7        [36m0.1227[0m        0.0828        4.2384
      8        [36m0.1206[0m        0.0801        4.3322
      9        [36m0.1181[0m        0.0808        4.2556
     10        [36m0.1173[0m        0.0825        4.2481
Stopping since valid_loss has not improved in the last 10 epochs.
12




  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.2125[0m        [32m0.2150[0m     +  4.3469
      2        [36m0.1851[0m        [32m0.2035[0m     +  4.3921
      3        [36m0.1752[0m        0.2065        4.3203
      4        [36m0.1692[0m        0.2058        4.3253
      5        [36m0.1643[0m        0.2088        4.2765
      6        [36m0.1609[0m        0.2077        4.3058
      7        [36m0.1592[0m        0.2169        4.3029
      8        [36m0.1582[0m        0.2195        4.3269
      9        [36m0.1566[0m        0.2199        4.3053
     10        [36m0.1556[0m        0.2172        4.3267
     11        [36m0.1541[0m        0.2129        4.2611
Stopping since valid_loss has not improved in the last 10 epochs.
13




  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.2428[0m        [32m0.3289[0m     +  4.3449
      2        [36m0.2337[0m        0.3553        4.3794
      3        [36m0.2316[0m        0.3541        4.3525
      4        [36m0.2291[0m        0.3492        4.2758
      5        [36m0.2276[0m        0.3389        4.3888
      6        [36m0.2264[0m        0.3404        4.2317
      7        [36m0.2246[0m        [32m0.3189[0m     +  4.3873
      8        [36m0.2243[0m        0.3255        4.3163
      9        [36m0.2227[0m        0.3303        4.2211
     10        [36m0.2222[0m        0.3494        4.3295
     11        [36m0.2211[0m        0.3454        4.2878
     12        [36m0.2205[0m        0.3408        4.3322
     13        [36m0.2199[0m        0.3433        4.3237
     14        [36m0.2192[0m        0.3413        4.3787
     15        [36m0.2185[0m        0.3501        4.32



  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.5589[0m        [32m0.6397[0m     +  4.3009
      2        [36m0.5487[0m        [32m0.6342[0m     +  4.3797
      3        [36m0.5440[0m        0.6366        4.2927
      4        [36m0.5417[0m        [32m0.6322[0m     +  4.3381
      5        [36m0.5405[0m        0.6373        4.2968
      6        [36m0.5395[0m        0.6397        4.2564
      7        [36m0.5389[0m        [32m0.6294[0m     +  4.3660
      8        [36m0.5379[0m        0.6338        4.3174
      9        [36m0.5370[0m        0.6294        4.3115
     10        [36m0.5363[0m        0.6388        4.3441
     11        [36m0.5346[0m        0.6352        4.2712
     12        [36m0.5341[0m        0.6303        4.2877
     13        0.5342        0.6325        4.3440
     14        [36m0.5333[0m        0.6486        4.3571
     15        [36m0.5324[0m        0.6358   



  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.2766[0m        [32m0.3082[0m     +  4.3104
      2        [36m0.2670[0m        0.3134        4.1967
      3        [36m0.2641[0m        0.3185        4.3164
      4        [36m0.2615[0m        0.3230        4.2510
      5        [36m0.2593[0m        0.3202        4.3273
      6        [36m0.2586[0m        0.3289        4.2274
      7        [36m0.2565[0m        0.3222        4.3068
      8        [36m0.2564[0m        0.3242        4.2432
      9        [36m0.2541[0m        0.3301        4.3314
     10        [36m0.2529[0m        0.3348        4.4801
Stopping since valid_loss has not improved in the last 10 epochs.
16




  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.2230[0m        [32m0.3729[0m     +  4.3226
      2        [36m0.2111[0m        [32m0.3588[0m     +  4.4144
      3        [36m0.2082[0m        [32m0.3579[0m     +  4.5334
      4        [36m0.2066[0m        0.3580        4.2812
      5        [36m0.2044[0m        0.3586        4.2593
      6        [36m0.2026[0m        [32m0.3557[0m     +  4.3513
      7        [36m0.2017[0m        [32m0.3557[0m     +  4.3678
      8        [36m0.2009[0m        [32m0.3543[0m     +  4.3631
      9        [36m0.1997[0m        [32m0.3534[0m     +  4.3548
     10        [36m0.1984[0m        [32m0.3509[0m     +  4.3536
     11        [36m0.1976[0m        [32m0.3502[0m     +  4.2934
     12        [36m0.1963[0m        0.3508        4.2980
     13        [36m0.1960[0m        [32m0.3459[0m     +  4.3553
     14        [36m0.1951[0m        0.356



  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.4211[0m        [32m0.2058[0m     +  4.3964
      2        [36m0.4041[0m        [32m0.1977[0m     +  4.3906
      3        [36m0.4000[0m        [32m0.1946[0m     +  4.3036
      4        [36m0.3988[0m        0.1959        4.2737
      5        [36m0.3977[0m        [32m0.1929[0m     +  4.2890
      6        [36m0.3963[0m        [32m0.1928[0m     +  4.4223
      7        0.3966        [32m0.1923[0m     +  4.2657
      8        [36m0.3954[0m        0.2038        4.4266
      9        [36m0.3950[0m        0.1996        4.3828
     10        [36m0.3947[0m        0.1963        4.3712
     11        [36m0.3945[0m        0.2015        4.4476
     12        [36m0.3937[0m        0.1998        4.3478
     13        0.3940        0.1983        4.3463
     14        [36m0.3935[0m        0.1927        4.4615
     15        [36m0.3930[0m        



  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.2070[0m        [32m0.3233[0m     +  4.4217
      2        [36m0.1737[0m        [32m0.3037[0m     +  4.4294
      3        [36m0.1578[0m        [32m0.2929[0m     +  4.4370
      4        [36m0.1517[0m        [32m0.2913[0m     +  4.2399
      5        [36m0.1466[0m        [32m0.2843[0m     +  4.3671
      6        [36m0.1413[0m        [32m0.2806[0m     +  4.4152
      7        [36m0.1390[0m        0.2929        4.4231
      8        [36m0.1362[0m        0.2892        4.3511
      9        [36m0.1329[0m        0.2845        4.3521
     10        [36m0.1301[0m        [32m0.2802[0m     +  4.4142
     11        [36m0.1288[0m        [32m0.2760[0m     +  4.4027
     12        [36m0.1262[0m        [32m0.2735[0m     +  4.3413
     13        [36m0.1245[0m        0.2775        4.3896
     14        [36m0.1236[0m        0.2802        



  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.2259[0m        [32m0.2561[0m     +  4.4177
      2        [36m0.2081[0m        0.2597        4.4265
      3        [36m0.1978[0m        0.2933        4.4619
      4        [36m0.1919[0m        0.2729        4.4145
      5        [36m0.1888[0m        [32m0.2537[0m     +  4.4392
      6        [36m0.1867[0m        [32m0.2518[0m     +  4.5099
      7        [36m0.1855[0m        0.2527        4.3712
      8        [36m0.1854[0m        [32m0.2516[0m     +  4.4079
      9        [36m0.1835[0m        0.2540        4.4831
     10        [36m0.1811[0m        [32m0.2513[0m     +  4.4409
     11        0.1812        0.2711        4.3968
     12        [36m0.1800[0m        0.2629        4.4256
     13        [36m0.1783[0m        0.2790        4.4681
     14        [36m0.1779[0m        0.2579        4.4342
     15        [36m0.1772[0m        



  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.2278[0m        [32m0.3833[0m     +  4.4612
      2        [36m0.2173[0m        [32m0.3784[0m     +  4.4770
      3        [36m0.2158[0m        [32m0.3730[0m     +  4.4004
      4        [36m0.2137[0m        [32m0.3716[0m     +  4.3424
      5        [36m0.2132[0m        0.3726        4.3652
      6        [36m0.2119[0m        0.3726        4.4807
      7        [36m0.2107[0m        [32m0.3658[0m     +  4.4819
      8        [36m0.2097[0m        0.3715        4.5869
      9        0.2099        0.3716        4.4783
     10        [36m0.2089[0m        0.3719        4.4442
     11        [36m0.2075[0m        0.3728        4.5026
     12        0.2075        0.3695        4.4374
     13        [36m0.2061[0m        0.3741        4.4315
     14        [36m0.2060[0m        0.3773        4.4749
     15        [36m0.2056[0m        0.3768   



  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.3727[0m        [32m0.4067[0m     +  4.5219
      2        [36m0.3660[0m        [32m0.4034[0m     +  4.5110
      3        [36m0.3635[0m        [32m0.4009[0m     +  4.4793
      4        [36m0.3617[0m        0.4012        4.4277
      5        [36m0.3595[0m        [32m0.3997[0m     +  4.3147
      6        [36m0.3578[0m        [32m0.3972[0m     +  4.4335
      7        [36m0.3562[0m        0.4014        4.4574
      8        [36m0.3552[0m        0.4088        4.4449
      9        [36m0.3545[0m        0.4009        4.6099
     10        [36m0.3531[0m        0.4004        4.4387
     11        [36m0.3520[0m        0.4041        4.4118
     12        [36m0.3510[0m        0.4009        4.5149
     13        [36m0.3497[0m        0.3989        4.4893
     14        [36m0.3486[0m        0.4006        4.4603
     15        [36m0.3482[0



  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.2613[0m        [32m0.2292[0m     +  4.5459
      2        [36m0.2425[0m        [32m0.2246[0m     +  4.4323
      3        [36m0.2365[0m        0.2325        4.4557
      4        [36m0.2326[0m        0.2324        4.4522
      5        [36m0.2294[0m        0.2413        4.4664
      6        [36m0.2283[0m        0.2375        4.4554
      7        [36m0.2256[0m        0.2384        4.4626
      8        [36m0.2241[0m        0.2353        4.4703
      9        [36m0.2208[0m        0.2507        4.3908
     10        [36m0.2198[0m        0.2546        4.3638
     11        [36m0.2175[0m        0.2444        4.3098
Stopping since valid_loss has not improved in the last 10 epochs.
23




  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.2349[0m        [32m0.4418[0m     +  4.2676
      2        [36m0.2233[0m        [32m0.4392[0m     +  4.4144
      3        [36m0.2198[0m        [32m0.4317[0m     +  4.3821
      4        [36m0.2181[0m        0.4329        4.3929
      5        [36m0.2159[0m        [32m0.4303[0m     +  4.3161
      6        [36m0.2148[0m        0.4315        4.3699
      7        [36m0.2136[0m        0.4333        4.3807
      8        [36m0.2126[0m        0.4331        4.3375
      9        [36m0.2114[0m        0.4313        4.3552
     10        [36m0.2101[0m        0.4310        4.3016
     11        [36m0.2097[0m        0.4317        4.3672
     12        [36m0.2084[0m        0.4316        4.2952
     13        [36m0.2076[0m        0.4321        4.3258
     14        [36m0.2076[0m        0.4330        4.3182
Stopping since valid_loss has not impr



  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.1380[0m        [32m0.1251[0m     +  4.3620
      2        [36m0.1258[0m        [32m0.1202[0m     +  4.4195
      3        [36m0.1200[0m        [32m0.1168[0m     +  4.2968
      4        [36m0.1171[0m        [32m0.1141[0m     +  4.3587
      5        [36m0.1138[0m        [32m0.1119[0m     +  4.3818
      6        [36m0.1128[0m        0.1130        4.2670
      7        [36m0.1113[0m        [32m0.1106[0m     +  4.3368
      8        [36m0.1106[0m        [32m0.1095[0m     +  4.3261
      9        [36m0.1098[0m        0.1107        4.3882
     10        [36m0.1091[0m        [32m0.1073[0m     +  4.3830
     11        [36m0.1089[0m        0.1081        4.2967
     12        [36m0.1074[0m        0.1079        4.3442
     13        [36m0.1072[0m        0.1079        4.3058
     14        [36m0.1066[0m        0.1087        4.3228
  



  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.2420[0m        [32m0.2816[0m     +  4.3417
      2        [36m0.2115[0m        0.3055        4.4288
      3        [36m0.2009[0m        [32m0.2814[0m     +  4.3467
      4        [36m0.1955[0m        [32m0.2796[0m     +  4.3632
      5        [36m0.1913[0m        0.2913        4.4017
      6        [36m0.1894[0m        0.2949        4.3926
      7        [36m0.1867[0m        0.3160        4.3377
      8        [36m0.1839[0m        0.3038        4.3444
      9        0.1844        0.3049        4.3449
     10        [36m0.1818[0m        0.3084        4.3716
     11        [36m0.1813[0m        0.3004        4.4095
     12        [36m0.1796[0m        0.3077        4.3043
     13        [36m0.1780[0m        0.3029        4.2557
Stopping since valid_loss has not improved in the last 10 epochs.
26




  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.2144[0m        [32m0.3715[0m     +  4.4338
      2        [36m0.2054[0m        [32m0.3699[0m     +  4.3490
      3        [36m0.2002[0m        [32m0.3674[0m     +  4.3847
      4        [36m0.1977[0m        [32m0.3671[0m     +  4.4370
      5        [36m0.1924[0m        [32m0.3633[0m     +  4.3467
      6        [36m0.1894[0m        0.3662        4.3579
      7        [36m0.1874[0m        0.3652        4.4154
      8        [36m0.1854[0m        0.3645        4.3109
      9        [36m0.1839[0m        0.3644        4.3408
     10        [36m0.1825[0m        0.3666        4.5077
     11        0.1826        0.3647        4.3489
     12        [36m0.1817[0m        0.3663        4.3634
     13        [36m0.1807[0m        0.3647        4.4379
     14        [36m0.1805[0m        0.3655        4.3694
Stopping since valid_loss has not impr



  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.2565[0m        [32m0.3506[0m     +  4.4953
      2        [36m0.2375[0m        [32m0.3406[0m     +  4.4787
      3        [36m0.2276[0m        [32m0.3401[0m     +  4.3988
      4        [36m0.2227[0m        [32m0.3378[0m     +  4.4208
      5        [36m0.2182[0m        [32m0.3377[0m     +  4.3528
      6        [36m0.2153[0m        0.3382        4.3638
      7        [36m0.2139[0m        [32m0.3366[0m     +  4.4338
      8        [36m0.2121[0m        [32m0.3348[0m     +  4.3539
      9        [36m0.2099[0m        0.3351        4.1575
     10        [36m0.2085[0m        [32m0.3336[0m     +  4.1377
     11        [36m0.2077[0m        0.3345        4.1976
     12        [36m0.2063[0m        [32m0.3328[0m     +  4.2023
     13        [36m0.2059[0m        0.3383        4.1124
     14        [36m0.2048[0m        0.3332        



  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.1590[0m        [32m0.1362[0m     +  4.1583
      2        [36m0.1437[0m        [32m0.1336[0m     +  4.1304
      3        [36m0.1385[0m        [32m0.1313[0m     +  4.1948
      4        [36m0.1323[0m        [32m0.1271[0m     +  4.1664
      5        [36m0.1279[0m        [32m0.1247[0m     +  4.2345
      6        [36m0.1260[0m        0.1252        4.2619
      7        [36m0.1242[0m        0.1265        4.1564
      8        [36m0.1236[0m        0.1251        4.1020
      9        [36m0.1222[0m        0.1285        4.1314
     10        [36m0.1213[0m        0.1287        4.2469
     11        [36m0.1206[0m        [32m0.1232[0m     +  4.3453
     12        [36m0.1195[0m        0.1249        4.3620
     13        [36m0.1191[0m        0.1256        4.3654
     14        [36m0.1185[0m        0.1248        4.1692
     15        [36m



  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.1753[0m        [32m0.1514[0m     +  4.0953
      2        [36m0.1616[0m        0.1547        4.2524
      3        [36m0.1583[0m        [32m0.1509[0m     +  4.2040
      4        [36m0.1560[0m        [32m0.1503[0m     +  4.1289
      5        [36m0.1542[0m        [32m0.1477[0m     +  4.0640
      6        [36m0.1526[0m        0.1487        4.1585
      7        [36m0.1515[0m        [32m0.1470[0m     +  4.0953
      8        [36m0.1506[0m        0.1502        4.1308
      9        [36m0.1493[0m        0.1494        4.1544
     10        [36m0.1491[0m        0.1481        4.1146
     11        [36m0.1479[0m        0.1499        4.1187
     12        0.1479        0.1505        4.2340
     13        [36m0.1474[0m        0.1502        4.0484
     14        [36m0.1468[0m        0.1480        4.1090
     15        [36m0.1460[0m        



  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.1325[0m        [32m0.1160[0m     +  4.2150
      2        [36m0.1080[0m        [32m0.1113[0m     +  4.1338
      3        [36m0.0922[0m        [32m0.1096[0m     +  4.0590
      4        [36m0.0882[0m        0.1110        4.1530
      5        [36m0.0862[0m        0.1101        4.0933
      6        [36m0.0833[0m        0.1109        4.0371
      7        [36m0.0819[0m        0.1106        4.0174
      8        [36m0.0818[0m        0.1118        4.0990
      9        [36m0.0798[0m        0.1114        4.1012
     10        [36m0.0791[0m        [32m0.1078[0m     +  4.1422
     11        [36m0.0786[0m        0.1104        4.1622
     12        [36m0.0777[0m        0.1102        4.1061
     13        [36m0.0768[0m        0.1103        4.1152
     14        [36m0.0747[0m        0.1094        4.2103
     15        0.0761        0.1138   



  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.1016[0m        [32m1.3946[0m     +  4.1521
      2        [36m0.0903[0m        [32m1.2724[0m     +  4.1336
      3        [36m0.0876[0m        1.2728        4.1086
      4        [36m0.0861[0m        1.3153        4.1654
      5        [36m0.0841[0m        1.3924        4.1199
      6        0.0847        1.5598        4.0842
      7        [36m0.0821[0m        1.2871        4.1251
      8        0.0824        1.4577        4.0982
      9        0.0830        1.6859        4.0939
     10        [36m0.0811[0m        1.4115        4.0695
     11        0.0816        1.5884        4.1599
Stopping since valid_loss has not improved in the last 10 epochs.
32




  epoch    train_loss    valid_loss    cp     dur
-------  ------------  ------------  ----  ------
      1        [36m0.1733[0m        [32m0.1435[0m     +  4.0772
      2        [36m0.1621[0m        [32m0.1426[0m     +  4.1007
      3        [36m0.1592[0m        [32m0.1403[0m     +  4.0873
      4        [36m0.1577[0m        [32m0.1402[0m     +  4.1215
      5        [36m0.1565[0m        [32m0.1399[0m     +  4.1222
      6        [36m0.1549[0m        0.1401        4.1017
      7        [36m0.1540[0m        0.1401        4.1105
      8        [36m0.1529[0m        0.1399        4.1057
      9        [36m0.1522[0m        [32m0.1397[0m     +  4.1780
     10        [36m0.1515[0m        [32m0.1394[0m     +  4.1811
     11        [36m0.1513[0m        0.1410        4.1203
     12        [36m0.1504[0m        0.1400        4.2257
     13        [36m0.1498[0m        0.1407        4.1253
     14        [36m0.1493[0m        0.1424        4.1368
     15    

# Load Nets from Checkpoints

In [14]:
for cluster in df[split_col].unique():
    net = NeuralNetRegressor(FFNeuralNetwork(**net_params), **train_params)
    net.initialize()
    net.load_params(f_params=f'models/sales_forecaster_{cluster}.pt')
    net_by_cluster[cluster] = net

## LSTM Training

In [23]:
endogenous_cols = [
        'sales', 'onpromotion', 'oil', 
       'dow_avg_sales', 'dow_rolling_3_sales', 'dow_rolling_7_sales',
       'dow_avg_transactions', 'dow_rolling_3_transactions',
       'dow_rolling_7_transactions', 'rolling_7_sales', 'rolling_14_sales',
       'rolling_7_transactions', 'rolling_14_transactions']

exogenous_cols = [
    'h_type_nat', 'h_description_nat', 'h_transferred_nat', 'h_type_loc',
    'h_description_loc', 'h_transferred_loc', 'month', 'day', 'day_of_week', 'store_nbr'
    ]

out_cols = ['sales']


lstm_net_by_cluster = {}
lstm_net_params = {
    'input_dim': 512,
    'endogenous_dim': len(endogenous_cols)*54,
    'endogenous_len': 5,
    'exogenous_dim': len(exogenous_cols),
    'hidden_dim': 1024,
    'out_dim': 54,
    'out_seq_len': 15,
    'num_layers': 4
}

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 1000

for cluster in df[split_col].unique():
    train_df = train_df_by_cluster[cluster]
    train_dataset = LSTMDataset(train_df, 5, 15, 'date', endogenous_cols, exogenous_cols, out_cols, 'store_nbr')
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=30, shuffle=False)

    scaler_y = scaler_y_by_cluster[cluster]

    val_df = val_df_by_cluster[cluster]
    val_dataset = LSTMDataset(val_df, 5, 15, 'date', endogenous_cols, exogenous_cols, out_cols, 'store_nbr')

    model = LSTMNeuralNetwork(**lstm_net_params)
    optim = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-8)
    criterion = nn.L1Loss()
    model.train()
    model.to(device)
    model.zero_grad()

    val = val_dataset[0]
    val_endog = torch.tensor(val['endog']).unsqueeze(0).to(device).to(torch.float32)
    val_exog = torch.tensor(val['exog']).unsqueeze(0).to(device).to(torch.float32)
    val_y = scaler_y.inverse_transform(val['label'].reshape(-1,1).astype(np.float32))

    for epoch in range(num_epochs):
        for sample in train_loader:
            endog = sample['endog'].to(device).to(torch.float32)
            exog = sample['exog'].to(device).to(torch.float32)
            y = sample['label'].to(device).to(torch.float32)
            optim.zero_grad()
            output = model((endog, exog))
            loss = criterion(output, y)
            loss.backward()
            optim.step()
        with torch.no_grad():
            model.eval()
            val_pred = model((val_endog, val_exog))
            val_pred = val_pred.view(-1,1).cpu().detach().numpy()
            val_pred = scaler_y.inverse_transform(val_pred)

            val_loss1 = rmsle(val_y, val_pred.clip(0))
            val_loss2 = np.abs(val_y - val_pred).sum()
            print(f'Epoch {epoch+1}/{num_epochs}, RMSLE: {val_loss1.item():6.5f} L1: {val_loss2:6.5f}')
            model.train()


    lstm_net_by_cluster[cluster] = model

Epoch 1/1000, RMSLE: 0.59281 L1: 2621.91406
Epoch 2/1000, RMSLE: 0.59805 L1: 2591.96606
Epoch 3/1000, RMSLE: 0.61355 L1: 2627.24707
Epoch 4/1000, RMSLE: 0.61980 L1: 2648.24902
Epoch 5/1000, RMSLE: 0.63566 L1: 2669.86035
Epoch 6/1000, RMSLE: 0.64438 L1: 2688.27637
Epoch 7/1000, RMSLE: 0.66854 L1: 2746.78809
Epoch 8/1000, RMSLE: 0.68094 L1: 2764.41455
Epoch 9/1000, RMSLE: 0.69042 L1: 2781.10596
Epoch 10/1000, RMSLE: 0.69891 L1: 2787.39185
Epoch 11/1000, RMSLE: 0.69744 L1: 2791.86279
Epoch 12/1000, RMSLE: 0.69394 L1: 2789.51074
Epoch 13/1000, RMSLE: 0.70647 L1: 2797.06030
Epoch 14/1000, RMSLE: 0.69428 L1: 2787.90845
Epoch 15/1000, RMSLE: 0.70769 L1: 2804.49365
Epoch 16/1000, RMSLE: 0.69218 L1: 2788.97607
Epoch 17/1000, RMSLE: 0.71373 L1: 2809.92358
Epoch 18/1000, RMSLE: 0.68829 L1: 2788.63428
Epoch 19/1000, RMSLE: 0.71235 L1: 2810.70557
Epoch 20/1000, RMSLE: 0.68650 L1: 2790.86255
Epoch 21/1000, RMSLE: 0.71120 L1: 2808.48950


## Random Forest

In [15]:
from sklearn.ensemble import RandomForestRegressor
cluster_rfs = {}

for cluster in df[split_col].unique():
    train_df = train_df_by_cluster[cluster]

    train_x = train_df[x_cols].values.astype(np.float32)
    train_y = train_df[y_cols].values.reshape(-1, len(y_cols)).astype(np.float32)

    rf = RandomForestRegressor(n_estimators=100, max_depth=15, random_state=42, n_jobs=4)
    rf.fit(train_x, train_y.squeeze())

    cluster_rfs[cluster] = rf

## XGBoost

In [18]:
import xgboost as xgb
cluster_xgb = {}
for cluster in df[split_col].unique():
    train_x = train_x_by_cluster[cluster]
    train_y = train_y_by_cluster[cluster]

    xgb_model = xgb.XGBRegressor(n_estimators=1000, max_depth=12, learning_rate=0.001, random_state=42, n_jobs=2)
    xgb_model.fit(train_x, train_y.squeeze())

    cluster_xgb[cluster] = xgb_model

## Predict on Training Data

In [27]:
net_train_preds = []
rf_train_preds = []

for cluster in df[split_col].unique():
    train_df = train_df_by_cluster[cluster]

    train_x = train_df[x_cols].values.astype(np.float32)
    train_y = train_df[y_cols].values.reshape(-1, len(y_cols)).astype(np.float32)
    y_scaler = scaler_y_by_cluster[cluster]

    net = net_by_cluster[cluster]
    rf = cluster_rfs[cluster]

    net_preds = net.predict(train_x)
    rf_preds = rf.predict(train_x)

    train_df['sales_nn'] = net_preds
    train_df['sales_rf'] = rf_preds

## Validation Loss Evaluation

In [20]:
def rmsle(y_true, y_pred):
    return np.sqrt(np.mean(np.square(np.log1p(y_true) - np.log1p(y_pred))))

rf_preds = []
net_preds = []
xgb_preds = []
val_y_true = []

for cluster in df[split_col].unique():
    val_cluster_df = val_df_by_cluster[cluster]
    val_x = val_cluster_df[x_cols].values.astype(np.float32)
    val_y = val_cluster_df[y_cols].values.reshape(-1, len(y_cols)).astype(np.float32)

    rf = cluster_rfs[cluster]
    net = net_by_cluster[cluster]

    rf_preds.append(scaler_y_by_cluster[cluster].inverse_transform(rf.predict(val_x).reshape(-1, 1)))
    net_preds.append(scaler_y_by_cluster[cluster].inverse_transform(net.predict(val_x).reshape(-1, 1)).clip(0))    
    #xgb_preds.append(scaler_y_by_cluster[cluster].inverse_transform(cluster_xgb[cluster].predict(val_x).reshape(-1, 1)))
    val_y_true.append(scaler_y_by_cluster[cluster].inverse_transform(val_y))

    #rf_preds.append(rf.predict(val_x).reshape(-1, 1))
    #net_preds.append(net.predict(val_x).reshape(-1, 1).clip(0) )
    #val_y_true.append(val_y)   


rf_preds = np.concatenate(rf_preds)
net_preds = np.concatenate(net_preds)
#xgb_preds = np.concatenate(xgb_preds)
val_y_true = np.concatenate(val_y_true)

print(f'RF RMSLE: {rmsle(val_y_true, rf_preds)}')
#print(f'XGB RMSLE: {rmsle(val_y_true, xgb_preds)}')
print(f'NN RMSLE: {rmsle(val_y_true, net_preds)}')

RF RMSLE: 0.4077625548953881
NN RMSLE: 0.39476779103279114


## Loading Test Data

In [21]:
test_df = pd.read_csv(os.path.join(data_dir, 'test_data.csv'), index_col=0)
display(test_df.head())

test_x_by_cluster = {}
test_id_by_cluster = {}

for cluster in df[split_col].unique():
    test_cluster_min_max_scaler, test_cluster_normalize_scaler = scaler_x_by_cluster[cluster]
    test_cluster_y_scaler = scaler_y_by_cluster[cluster]

    test_cluster_x_df = test_df[test_df[split_col] == cluster]
    test_cluster_x_df = test_cluster_x_df.drop(columns=split_col)

    test_cluster_x_min_max = test_cluster_x_df[min_max_cols].values.astype(np.float32)
    test_cluster_x_normalize = test_cluster_x_df[normalize_cols].values.astype(np.float32)

    test_cluster_x_min_max = test_cluster_min_max_scaler.transform(test_cluster_x_min_max)
    test_cluster_x_normalize = test_cluster_normalize_scaler.transform(test_cluster_x_normalize)
    #test_cluster_x_normalize = test_cluster_x_normalize

    test_x_by_cluster[cluster] = np.concatenate([test_cluster_x_min_max, test_cluster_x_normalize], axis=1)
    test_id_by_cluster[cluster] = test_cluster_x_df.index


test_preds_dfs = []

for cluster in df[split_col].unique():
    test_x = test_x_by_cluster[cluster]
    id = test_id_by_cluster[cluster]
    rf = cluster_rfs[cluster]
    net = net_by_cluster[cluster]

    pred_rf = scaler_y_by_cluster[cluster].inverse_transform(rf.predict(test_x).reshape(-1, 1))
    #pred_xgb = scaler_y_by_cluster[cluster].inverse_transform(cluster_xgb[cluster].predict(test_x).reshape(-1, 1))
    pred_nn = scaler_y_by_cluster[cluster].inverse_transform(net_by_cluster[cluster].predict(test_x).reshape(-1, 1)).clip(0)

    #pred_rf = rf.predict(test_x).reshape(-1, 1)
    #pred_nn = net.predict(test_x).reshape(-1, 1).clip(0)

    
    #cluster_df = pd.DataFrame(np.concatenate([pred_nn], axis=1), index=id, columns=['sales_nn'])
    cluster_df = pd.DataFrame(np.concatenate([pred_rf, pred_nn], axis=1), index=id, columns=['sales_rf', 'sales_nn'])

    test_preds_dfs.append(cluster_df)

test_preds_df = pd.concat(test_preds_dfs)

test_df = test_df.merge(test_preds_df, on='id', how='left')

sub_df_nn = test_df[['sales_nn']]
sub_df_rf = test_df[['sales_rf']]
#sub_df_xgb = test_df[['sales_xgb']]

sub_df_rf = sub_df_rf.rename(columns={'sales_rf': 'sales'})
#sub_df_xgb = sub_df_xgb.rename(columns={'sales_xgb': 'sales'})
sub_df_nn = sub_df_nn.rename(columns={'sales_nn': 'sales'})


display(sub_df_nn.head())
display(sub_df_rf.head())
#display(sub_df_xgb.head())

sub_df_nn.to_csv('data/submission_nn.csv')
#sub_df_xgb.to_csv('data/submission_xgb.csv')
sub_df_rf.to_csv('data/submission_rf.csv')

Unnamed: 0_level_0,store_nbr,family,onpromotion,city,state,store_type,cluster,oil,h_type_nat,h_description_nat,...,dow_rolling_1_transactions,dow_rolling_3_transactions,dow_rolling_7_transactions,prev_1_transactions,prev_7_transactions,prev_14_transactions,is_payday,is_earth_quake,is_weekday,is_weekend
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
3000888,1,0,0,0,0,0,13,46.8,0,0,...,1892.0,1864.0,1888.857143,1766.0,1892.0,1903.0,False,False,1,0
3000889,1,1,0,0,0,0,13,46.8,0,0,...,1892.0,1864.0,1888.857143,1766.0,1892.0,1903.0,False,False,1,0
3000890,1,2,2,0,0,0,13,46.8,0,0,...,1892.0,1864.0,1888.857143,1766.0,1892.0,1903.0,False,False,1,0
3000891,1,3,20,0,0,0,13,46.8,0,0,...,1892.0,1864.0,1888.857143,1766.0,1892.0,1903.0,False,False,1,0
3000892,1,4,0,0,0,0,13,46.8,0,0,...,1892.0,1864.0,1888.857143,1766.0,1892.0,1903.0,False,False,1,0


Unnamed: 0_level_0,sales
id,Unnamed: 1_level_1
3000888,4.415288
3000889,0.00012
3000890,4.730282
3000891,2429.969482
3000892,0.0


Unnamed: 0_level_0,sales
id,Unnamed: 1_level_1
3000888,4.0387
3000889,0.004688
3000890,4.437217
3000891,2406.571371
3000892,0.16605
