In [84]:
import torch
import numpy as np
import pandas as pd
import os
import time
from glob import glob
from itertools import cycle
import copy
from pathlib import Path
import warnings
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import SMAPE, PoissonLoss, QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters

In [95]:
dirs = '../M5_Dataset'

data = pd.read_pickle(dirs+'/dataset.pkl')

In [97]:
data = data[['Node','d', 'store_id', 'dept_id', 'state_id', 'cat_id',
       'weekday', 'wday', 'month', 'year', 'event_name_1', 'event_type_1',
       'event_name_2', 'event_type_2', 'snap_CA', 'snap_TX', 'snap_WI', 'sold']]

In [98]:
data[:70]

Unnamed: 0,Node,d,store_id,dept_id,state_id,cat_id,weekday,wday,month,year,event_name_1,event_type_1,event_name_2,event_type_2,snap_CA,snap_TX,snap_WI,sold
0,0,1,0,0,0,0,2,1,1,2011,-1,-1,-1,-1,0,0,0,297.0
1,1,1,0,1,0,0,2,1,1,2011,-1,-1,-1,-1,0,0,0,674.0
2,2,1,0,2,0,0,2,1,1,2011,-1,-1,-1,-1,0,0,0,2268.0
3,3,1,0,3,0,1,2,1,1,2011,-1,-1,-1,-1,0,0,0,528.0
4,4,1,0,4,0,1,2,1,1,2011,-1,-1,-1,-1,0,0,0,28.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
65,65,1,9,2,2,0,2,1,1,2011,-1,-1,-1,-1,0,0,0,2293.0
66,66,1,9,3,2,1,2,1,1,2011,-1,-1,-1,-1,0,0,0,256.0
67,67,1,9,4,2,1,2,1,1,2011,-1,-1,-1,-1,0,0,0,22.0
68,68,1,9,5,2,2,2,1,1,2011,-1,-1,-1,-1,0,0,0,584.0


In [99]:
data[70:140]

Unnamed: 0,Node,d,store_id,dept_id,state_id,cat_id,weekday,wday,month,year,event_name_1,event_type_1,event_name_2,event_type_2,snap_CA,snap_TX,snap_WI,sold
70,0,2,0,0,0,0,3,2,1,2011,-1,-1,-1,-1,0,0,0,284.0
71,1,2,0,1,0,0,3,2,1,2011,-1,-1,-1,-1,0,0,0,655.0
72,2,2,0,2,0,0,3,2,1,2011,-1,-1,-1,-1,0,0,0,2198.0
73,3,2,0,3,0,1,3,2,1,2011,-1,-1,-1,-1,0,0,0,489.0
74,4,2,0,4,0,1,3,2,1,2011,-1,-1,-1,-1,0,0,0,9.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
135,65,2,9,2,2,0,3,2,1,2011,-1,-1,-1,-1,0,0,0,2383.0
136,66,2,9,3,2,1,3,2,1,2011,-1,-1,-1,-1,0,0,0,342.0
137,67,2,9,4,2,1,3,2,1,2011,-1,-1,-1,-1,0,0,0,14.0
138,68,2,9,5,2,2,3,2,1,2011,-1,-1,-1,-1,0,0,0,541.0


In [100]:
np.array(data[data['Node']==0])

array([[0.000e+00, 1.000e+00, 0.000e+00, ..., 0.000e+00, 0.000e+00,
        2.970e+02],
       [0.000e+00, 2.000e+00, 0.000e+00, ..., 0.000e+00, 0.000e+00,
        2.840e+02],
       [0.000e+00, 3.000e+00, 0.000e+00, ..., 0.000e+00, 0.000e+00,
        2.140e+02],
       ...,
       [0.000e+00, 1.939e+03, 0.000e+00, ..., 0.000e+00, 0.000e+00,
        3.100e+02],
       [0.000e+00, 1.940e+03, 0.000e+00, ..., 0.000e+00, 0.000e+00,
        4.270e+02],
       [0.000e+00, 1.941e+03, 0.000e+00, ..., 0.000e+00, 0.000e+00,
        3.340e+02]])

In [101]:
data[data['Node']==0]

Unnamed: 0,Node,d,store_id,dept_id,state_id,cat_id,weekday,wday,month,year,event_name_1,event_type_1,event_name_2,event_type_2,snap_CA,snap_TX,snap_WI,sold
0,0,1,0,0,0,0,2,1,1,2011,-1,-1,-1,-1,0,0,0,297.0
70,0,2,0,0,0,0,3,2,1,2011,-1,-1,-1,-1,0,0,0,284.0
140,0,3,0,0,0,0,1,3,1,2011,-1,-1,-1,-1,0,0,0,214.0
210,0,4,0,0,0,0,5,4,2,2011,-1,-1,-1,-1,1,1,0,175.0
280,0,5,0,0,0,0,6,5,2,2011,-1,-1,-1,-1,1,0,1,182.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
135520,0,1937,0,0,0,0,6,5,5,2016,-1,-1,-1,-1,0,0,0,397.0
135590,0,1938,0,0,0,0,4,6,5,2016,-1,-1,-1,-1,0,0,0,330.0
135660,0,1939,0,0,0,0,0,7,5,2016,-1,-1,-1,-1,0,0,0,310.0
135730,0,1940,0,0,0,0,2,1,5,2016,-1,-1,-1,-1,0,0,0,427.0


In [102]:
for i in range(70):
    array = np.array(data[data['Node']==i])
    if i == 0:
        array_data = array
    else:
        array_data = np.concatenate([array_data,array])

In [103]:
array_data.shape

(135870, 18)

In [104]:
dataset = array_data.reshape(70,1941,-1)

In [105]:
pd.DataFrame(dataset[1])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17
0,1.0,1.0,0.0,1.0,0.0,0.0,2.0,1.0,1.0,2011.0,-1.0,-1.0,-1.0,-1.0,0.0,0.0,0.0,674.0
1,1.0,2.0,0.0,1.0,0.0,0.0,3.0,2.0,1.0,2011.0,-1.0,-1.0,-1.0,-1.0,0.0,0.0,0.0,655.0
2,1.0,3.0,0.0,1.0,0.0,0.0,1.0,3.0,1.0,2011.0,-1.0,-1.0,-1.0,-1.0,0.0,0.0,0.0,396.0
3,1.0,4.0,0.0,1.0,0.0,0.0,5.0,4.0,2.0,2011.0,-1.0,-1.0,-1.0,-1.0,1.0,1.0,0.0,476.0
4,1.0,5.0,0.0,1.0,0.0,0.0,6.0,5.0,2.0,2011.0,-1.0,-1.0,-1.0,-1.0,1.0,0.0,1.0,354.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1936,1.0,1937.0,0.0,1.0,0.0,0.0,6.0,5.0,5.0,2016.0,-1.0,-1.0,-1.0,-1.0,0.0,0.0,0.0,427.0
1937,1.0,1938.0,0.0,1.0,0.0,0.0,4.0,6.0,5.0,2016.0,-1.0,-1.0,-1.0,-1.0,0.0,0.0,0.0,411.0
1938,1.0,1939.0,0.0,1.0,0.0,0.0,0.0,7.0,5.0,2016.0,-1.0,-1.0,-1.0,-1.0,0.0,0.0,0.0,434.0
1939,1.0,1940.0,0.0,1.0,0.0,0.0,2.0,1.0,5.0,2016.0,-1.0,-1.0,-1.0,-1.0,0.0,0.0,0.0,627.0
