In [39]:
# ライブラリーの読み込み
import numpy as np
import pandas as pd
import pmdarima as pm
from pmdarima import utils
from pmdarima import arima
from pmdarima import model_selection
from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_percentage_error
from matplotlib import pyplot as plt

# グラフのスタイルとサイズ
plt.style.use('ggplot')
plt.rcParams['figure.figsize'] = [12, 9]

train_df = pd.read_csv('data/train.csv')
test_df = pd.read_csv('data/test.csv')

# 訓練データにしかない野菜は除く
kinds = test_df['kind'].unique()
train_df = train_df[train_df['kind'].isin(kinds)]

# 2005年は中途半端な月から始まっているので除く
train_df = train_df.query('20160101 <= date').reset_index(drop=True)
train_df = train_df.query('kind == "だいこん"')


In [40]:
def rmspe(y_true, y_pred):
    return np.sqrt(np.mean(np.square((y_true - y_pred) / y_true)))

# 学習データとテストデータ（直近12ヶ月間）に分割
# df_train, df_test = model_selection.train_test_split(df, test_size=30)
df_train = train_df.query('date <= 20220930').reset_index(drop=True)['mode_price']
df_test = train_df.query('20221001 <= date').reset_index(drop=True)['mode_price']

print(df_train.shape)
print(df_test.shape)

(1735,)
(22,)


In [41]:
# モデル構築（Auto ARIMA）
arima_model = pm.auto_arima(df_train, 
                            seasonal=True,
                            m=365,
                            trace=True,
                            n_jobs=-1,
                            maxiter=10)

Performing stepwise search to minimize aic




KeyboardInterrupt: 

In [None]:
# 予測
##学習データの期間の予測値
train_pred = arima_model.predict_in_sample()
##テストデータの期間の予測値
test_pred, test_pred_ci = arima_model.predict(
    n_periods=df_test.shape[0], 
    return_conf_int=True
)
# テストデータで精度検証
print(rmspe(df_test, test_pred))

0.45851813928454144


In [None]:
test_pred

1727    1297.800315
1728    1297.271456
1729    1283.511978
1730    1283.508116
1731    1283.064630
1732    1283.485251
1733    1304.219487
1734    1296.856194
1735    1304.190857
1736    1304.219355
1737    1303.608884
1738    1303.331937
1739    1303.333924
1740    1303.333924
1741    1303.333924
1742    1303.333924
1743    1303.333924
1744    1303.333924
1745    1303.333924
1746    1303.333924
1747    1303.333924
1748    1303.333924
1749    1303.333924
1750    1303.333924
1751    1303.333924
1752    1303.333924
1753    1303.333924
1754    1303.333924
1755    1303.333924
1756    1303.333924
dtype: float64