In [1]:
import random
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import itertools
import warnings
from tqdm.auto import tqdm
from sklearn.preprocessing import LabelEncoder
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
import statsmodels.api as sm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
warnings.simplefilter(action='ignore')

In [2]:
CFG = {
    'TRAIN_WINDOW_SIZE':90, # 90일치로 학습
    'PREDICT_SIZE':21, # 21일치 예측
    'EPOCHS':5,
    'LEARNING_RATE':1e-4,
    'BATCH_SIZE':4096,
    'SEED':41
}

In [3]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG['SEED']) # Seed 고정

In [4]:
train_data = pd.read_csv('./new_train.csv').drop(columns=['Unnamed: 0','ID', '제품'])
train_data

Unnamed: 0,대분류,중분류,소분류,브랜드,Avg_price,2022-01-01,2022-01-02,2022-01-03,2022-01-04,2022-01-05,...,2023-03-26,2023-03-27,2023-03-28,2023-03-29,2023-03-30,2023-03-31,2023-04-01,2023-04-02,2023-04-03,2023-04-04
0,B002-C001-0002,B002-C002-0007,B002-C003-0038,B002-00001,7325.000000,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,B002-C001-0003,B002-C002-0008,B002-C003-0044,B002-00002,26333.750000,0,0,0,0,0,...,0,0,0,1,3,2,0,0,2,0
2,B002-C001-0003,B002-C002-0008,B002-C003-0044,B002-00002,10853.492063,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,B002-C001-0003,B002-C002-0008,B002-C003-0044,B002-00002,4791.666667,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,B002-C001-0001,B002-C002-0001,B002-C003-0003,B002-00003,4921.780492,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15885,B002-C001-0003,B002-C002-0008,B002-C003-0042,B002-03799,1888.169643,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
15886,B002-C001-0003,B002-C002-0008,B002-C003-0044,B002-03799,22157.082261,0,0,0,0,0,...,0,0,0,3,0,2,4,1,1,3
15887,B002-C001-0003,B002-C002-0008,B002-C003-0044,B002-03799,11712.896203,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
15888,B002-C001-0003,B002-C002-0008,B002-C003-0044,B002-03799,13600.000000,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,2


In [5]:
'''
# Data Scaling
scale_max_dict = {}
scale_min_dict = {}

for idx in tqdm(range(len(train_data))):
    maxi = np.max(train_data.iloc[idx,5:])
    mini = np.min(train_data.iloc[idx,5:])

    if maxi == mini :
        train_data.iloc[idx,5:] = 0
    else:
        train_data.iloc[idx,5:] = (train_data.iloc[idx,5:] - mini) / (maxi - mini)

    scale_max_dict[idx] = maxi
    scale_min_dict[idx] = mini
'''

'\n# Data Scaling\nscale_max_dict = {}\nscale_min_dict = {}\n\nfor idx in tqdm(range(len(train_data))):\n    maxi = np.max(train_data.iloc[idx,5:])\n    mini = np.min(train_data.iloc[idx,5:])\n\n    if maxi == mini :\n        train_data.iloc[idx,5:] = 0\n    else:\n        train_data.iloc[idx,5:] = (train_data.iloc[idx,5:] - mini) / (maxi - mini)\n\n    scale_max_dict[idx] = maxi\n    scale_min_dict[idx] = mini\n'

In [6]:
# Label Encoding
label_encoder = LabelEncoder()
categorical_columns = ['대분류', '중분류', '소분류', '브랜드']

for col in categorical_columns:
    label_encoder.fit(train_data[col])
    train_data[col] = label_encoder.transform(train_data[col])

In [7]:
train_data = train_data.transpose()

In [8]:
p=d=q=range(0, 3)
pdq=list(itertools.product(p,d,q))
min_a = 9999999

for param in pdq :
    try :
        model = sm.tsa.statespace.SARIMAX(train_data[5:][7945], order=param)
        model_fit = model.fit(trend='c')
        if model_fit.aic < min_a :
            min_a = model_fit.aic
            min_param = param
    except:
        continue

model = sm.tsa.statespace.SARIMAX(train_data[5:][0], order=min_param)
model_fit = model.fit(trend='nc')
aa = model_fit.forecast(21)

In [9]:
p=d=q=range(0, 3)
pdq=list(itertools.product(p,d,q))

for idx in tqdm(range(7946, train_data.shape[1])) :
    min_a = 9999999
    for param in pdq :
        try :
            model = sm.tsa.statespace.SARIMAX(train_data[5:][idx], order=param)
            model_fit = model.fit(trend='nc')
            if model_fit.aic < min_a :
                min_a = model_fit.aic
                min_param = param
        except:
            continue
    model = sm.tsa.statespace.SARIMAX(train_data[5:][idx], order=min_param)
    model_fit = model.fit(trend='nc')
    bb = model_fit.forecast(21)
    aa = pd.concat([aa, bb], axis = 1)

  0%|          | 0/7944 [00:00<?, ?it/s]

In [10]:
aa

Unnamed: 0,predicted_mean,predicted_mean.1,predicted_mean.2,predicted_mean.3,predicted_mean.4,predicted_mean.5,predicted_mean.6,predicted_mean.7,predicted_mean.8,predicted_mean.9,...,predicted_mean.10,predicted_mean.11,predicted_mean.12,predicted_mean.13,predicted_mean.14,predicted_mean.15,predicted_mean.16,predicted_mean.17,predicted_mean.18,predicted_mean.19
2023-04-05,3.4454089999999996e-288,3.552675,1.218174e-41,4.098883,0.675789,4.509632,17.953752,27.728679,11.316113,2.770051,...,0.055409,2.09109,0.571789,0.515015,-0.004277097,3.920677e-08,4.560657,4.644589e-11,1.918563,0.0
2023-04-06,1.221217e-288,3.699119,1.7214819999999998e-41,3.752895,0.571943,4.229972,15.092655,25.195826,12.398003,2.531997,...,2.844015,0.953943,0.205211,-0.423524,-0.005820573,3.73983e-08,5.638609,5.436199e-11,1.37169,0.0
2023-04-07,1.87308e-288,3.158503,1.607588e-41,3.454261,0.500763,4.040284,13.069576,21.515484,15.89659,2.485922,...,0.079925,1.827141,0.098812,0.219681,-0.001832575,1.186158e-08,6.331151,5.571119e-11,0.861996,0.0
2023-04-08,1.958362e-288,3.215777,1.501228e-41,3.252207,0.465818,3.911622,11.639059,20.320862,18.444501,2.477005,...,2.819715,1.071305,0.457369,-0.200956,-0.001246435,3.762124e-09,6.754238,5.594114e-11,0.675677,0.0
2023-04-09,1.754983e-288,2.800392,1.401906e-41,3.107172,0.446301,3.824354,10.627542,18.129004,22.417151,2.475279,...,0.104011,1.69388,0.230255,0.090501,-0.0005295982,1.193229e-09,7.002847,5.598033e-11,0.797479,0.0
2023-04-10,1.8668699999999997e-288,2.801335,1.309155e-41,3.004653,0.435962,3.765161,9.912299,17.615632,26.079757,2.474945,...,2.79584,1.170641,0.161188,-0.097519,-0.000295019,3.784552e-10,7.144247,5.598701e-11,1.03634,0.0
2023-04-11,1.840918e-288,2.47764,1.22254e-41,2.931901,0.430368,3.725012,9.406551,16.288501,30.641983,2.47488,...,0.127676,1.606785,0.386009,0.034958,-0.0001396925,1.200343e-10,7.222358,5.598815e-11,1.2061,0.0
2023-04-12,1.832212e-288,2.444319,1.1416559999999999e-41,2.880326,0.427367,3.69778,9.048938,16.110185,35.206755,2.474867,...,2.772384,1.242015,0.245326,-0.048738,-7.276661e-05,3.807114e-11,7.264327,5.598835e-11,1.233384,0.0
2023-04-13,1.8437369999999998e-288,2.188493,1.0661229999999999e-41,2.843753,0.425752,3.679309,8.79607,15.292079,40.386835,2.474865,...,0.150926,1.546673,0.200556,0.011756,-3.581196e-05,1.207498e-11,7.286257,5.598838e-11,1.157235,0.0
2023-04-14,1.838278e-288,2.135604,9.955879e-42,2.817821,0.424883,3.66678,8.617267,15.269357,45.715295,2.474864,...,2.749338,1.292078,0.341509,-0.02526,-1.82259e-05,3.82981e-12,7.297383,5.598838e-11,1.062584,0.0


In [16]:
cols = list(range(7945, train_data.shape[1]))
aa.columns = cols
result = aa.transpose()

In [17]:
result

Unnamed: 0,2023-04-05,2023-04-06,2023-04-07,2023-04-08,2023-04-09,2023-04-10,2023-04-11,2023-04-12,2023-04-13,2023-04-14,...,2023-04-16,2023-04-17,2023-04-18,2023-04-19,2023-04-20,2023-04-21,2023-04-22,2023-04-23,2023-04-24,2023-04-25
7945,3.445409e-288,1.221217e-288,1.873080e-288,1.958362e-288,1.754983e-288,1.866870e-288,1.840918e-288,1.832212e-288,1.843737e-288,1.838278e-288,...,1.839840e-288,1.839209e-288,1.839466e-288,1.839443e-288,1.839399e-288,1.839433e-288,1.839421e-288,1.839421e-288,1.839424e-288,1.839422e-288
7946,3.552675e+00,3.699119e+00,3.158503e+00,3.215777e+00,2.800392e+00,2.801335e+00,2.477640e+00,2.444319e+00,2.188493e+00,2.135604e+00,...,1.867828e+00,1.701439e+00,1.634985e+00,1.498292e+00,1.432110e+00,1.318594e+00,1.255062e+00,1.159892e+00,1.100356e+00,1.019907e+00
7947,1.218174e-41,1.721482e-41,1.607588e-41,1.501228e-41,1.401906e-41,1.309155e-41,1.222540e-41,1.141656e-41,1.066123e-41,9.955879e-42,...,8.682081e-42,8.107669e-42,7.571259e-42,7.070339e-42,6.602561e-42,6.165730e-42,5.757801e-42,5.376861e-42,5.021124e-42,4.688923e-42
7948,4.098883e+00,3.752895e+00,3.454261e+00,3.252207e+00,3.107172e+00,3.004653e+00,2.931901e+00,2.880326e+00,2.843753e+00,2.817821e+00,...,2.786395e+00,2.777150e+00,2.770595e+00,2.765946e+00,2.762650e+00,2.760313e+00,2.758656e+00,2.757481e+00,2.756648e+00,2.756057e+00
7949,6.757886e-01,5.719432e-01,5.007629e-01,4.658184e-01,4.463007e-01,4.359621e-01,4.303680e-01,4.273670e-01,4.257515e-01,4.248831e-01,...,4.241648e-01,4.240297e-01,4.239571e-01,4.239180e-01,4.238970e-01,4.238857e-01,4.238796e-01,4.238764e-01,4.238746e-01,4.238737e-01
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15885,3.920677e-08,3.739830e-08,1.186158e-08,3.762124e-09,1.193229e-09,3.784552e-10,1.200343e-10,3.807114e-11,1.207498e-11,3.829810e-12,...,3.852641e-13,1.221938e-13,3.875609e-14,1.229223e-14,3.898713e-15,1.236551e-15,3.921955e-16,1.243922e-16,3.945336e-17,1.251338e-17
15886,4.560657e+00,5.638609e+00,6.331151e+00,6.754238e+00,7.002847e+00,7.144247e+00,7.222358e+00,7.264327e+00,7.286257e+00,7.297383e+00,...,7.305412e+00,7.306561e+00,7.307035e+00,7.307206e+00,7.307250e+00,7.307247e+00,7.307231e+00,7.307215e+00,7.307203e+00,7.307194e+00
15887,4.644589e-11,5.436199e-11,5.571119e-11,5.594114e-11,5.598033e-11,5.598701e-11,5.598815e-11,5.598835e-11,5.598838e-11,5.598838e-11,...,5.598838e-11,5.598838e-11,5.598838e-11,5.598838e-11,5.598838e-11,5.598838e-11,5.598838e-11,5.598838e-11,5.598838e-11,5.598838e-11
15888,1.918563e+00,1.371690e+00,8.619959e-01,6.756770e-01,7.974788e-01,1.036340e+00,1.206100e+00,1.233384e+00,1.157235e+00,1.062584e+00,...,1.019092e+00,1.056687e+00,1.091063e+00,1.103193e+00,1.094531e+00,1.078232e+00,1.066862e+00,1.065220e+00,1.070510e+00,1.076929e+00


In [18]:
submit = pd.read_csv('./sample_submission.csv')
submit.head()

Unnamed: 0,ID,2023-04-05,2023-04-06,2023-04-07,2023-04-08,2023-04-09,2023-04-10,2023-04-11,2023-04-12,2023-04-13,...,2023-04-16,2023-04-17,2023-04-18,2023-04-19,2023-04-20,2023-04-21,2023-04-22,2023-04-23,2023-04-24,2023-04-25
0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,2,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,3,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,4,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [23]:
submit.iloc[:,7945:] = result
submit.iloc[7945]

ID             7.945000e+03
2023-04-05    3.445409e-288
2023-04-06    1.221217e-288
2023-04-07    1.873080e-288
2023-04-08    1.958362e-288
2023-04-09    1.754983e-288
2023-04-10    1.866870e-288
2023-04-11    1.840918e-288
2023-04-12    1.832212e-288
2023-04-13    1.843737e-288
2023-04-14    1.838278e-288
2023-04-15    1.839180e-288
2023-04-16    1.839840e-288
2023-04-17    1.839209e-288
2023-04-18    1.839466e-288
2023-04-19    1.839443e-288
2023-04-20    1.839399e-288
2023-04-21    1.839433e-288
2023-04-22    1.839421e-288
2023-04-23    1.839421e-288
2023-04-24    1.839424e-288
2023-04-25    1.839422e-288
Name: 7945, dtype: float64

In [24]:
submit.to_csv('./arima_front_0_to_7944.csv', index=False)

In [28]:
result.to_csv('./arima_7945_to_end.csv', index = False)

In [29]:
pd.read_csv('./arima_7945_to_end.csv')

Unnamed: 0,2023-04-05,2023-04-06,2023-04-07,2023-04-08,2023-04-09,2023-04-10,2023-04-11,2023-04-12,2023-04-13,2023-04-14,...,2023-04-16,2023-04-17,2023-04-18,2023-04-19,2023-04-20,2023-04-21,2023-04-22,2023-04-23,2023-04-24,2023-04-25
0,3.445409e-288,1.221217e-288,1.873080e-288,1.958362e-288,1.754983e-288,1.866870e-288,1.840918e-288,1.832212e-288,1.843737e-288,1.838278e-288,...,1.839840e-288,1.839209e-288,1.839466e-288,1.839443e-288,1.839399e-288,1.839433e-288,1.839421e-288,1.839421e-288,1.839424e-288,1.839422e-288
1,3.552675e+00,3.699119e+00,3.158503e+00,3.215777e+00,2.800392e+00,2.801335e+00,2.477640e+00,2.444319e+00,2.188493e+00,2.135604e+00,...,1.867828e+00,1.701439e+00,1.634985e+00,1.498292e+00,1.432110e+00,1.318594e+00,1.255062e+00,1.159892e+00,1.100356e+00,1.019907e+00
2,1.218174e-41,1.721482e-41,1.607588e-41,1.501228e-41,1.401906e-41,1.309155e-41,1.222540e-41,1.141656e-41,1.066123e-41,9.955879e-42,...,8.682081e-42,8.107669e-42,7.571259e-42,7.070339e-42,6.602561e-42,6.165730e-42,5.757801e-42,5.376861e-42,5.021124e-42,4.688923e-42
3,4.098883e+00,3.752895e+00,3.454261e+00,3.252207e+00,3.107172e+00,3.004653e+00,2.931901e+00,2.880326e+00,2.843753e+00,2.817821e+00,...,2.786395e+00,2.777150e+00,2.770595e+00,2.765946e+00,2.762650e+00,2.760313e+00,2.758656e+00,2.757481e+00,2.756648e+00,2.756057e+00
4,6.757886e-01,5.719432e-01,5.007629e-01,4.658184e-01,4.463007e-01,4.359621e-01,4.303680e-01,4.273670e-01,4.257515e-01,4.248831e-01,...,4.241648e-01,4.240297e-01,4.239571e-01,4.239180e-01,4.238970e-01,4.238857e-01,4.238796e-01,4.238764e-01,4.238746e-01,4.238737e-01
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7940,3.920677e-08,3.739830e-08,1.186158e-08,3.762124e-09,1.193229e-09,3.784552e-10,1.200343e-10,3.807114e-11,1.207498e-11,3.829810e-12,...,3.852641e-13,1.221938e-13,3.875609e-14,1.229223e-14,3.898713e-15,1.236551e-15,3.921955e-16,1.243922e-16,3.945336e-17,1.251338e-17
7941,4.560657e+00,5.638609e+00,6.331151e+00,6.754238e+00,7.002847e+00,7.144247e+00,7.222358e+00,7.264327e+00,7.286257e+00,7.297383e+00,...,7.305412e+00,7.306561e+00,7.307035e+00,7.307206e+00,7.307250e+00,7.307247e+00,7.307231e+00,7.307215e+00,7.307203e+00,7.307194e+00
7942,4.644589e-11,5.436199e-11,5.571119e-11,5.594114e-11,5.598033e-11,5.598701e-11,5.598815e-11,5.598835e-11,5.598838e-11,5.598838e-11,...,5.598838e-11,5.598838e-11,5.598838e-11,5.598838e-11,5.598838e-11,5.598838e-11,5.598838e-11,5.598838e-11,5.598838e-11,5.598838e-11
7943,1.918563e+00,1.371690e+00,8.619959e-01,6.756770e-01,7.974788e-01,1.036340e+00,1.206100e+00,1.233384e+00,1.157235e+00,1.062584e+00,...,1.019092e+00,1.056687e+00,1.091063e+00,1.103193e+00,1.094531e+00,1.078232e+00,1.066862e+00,1.065220e+00,1.070510e+00,1.076929e+00
