In [32]:
import numpy as np
import pandas as pd
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tsa.stattools import arma_order_select_ic
from statsmodels.tsa.stattools import adfuller



class ARIMAModel:
    def __init__(self, data : pd.DataFrame):
        self.model = None
        self.result = None
        self.data = data

    def best_order(self, data, d_val=0, max_d=2) -> tuple:
        """
        :param data: pd.Series
        :param d_val: current differencing count
        :param max_d: maximum differencing (default = 2)
        :return: optimal d (0, 1, 2)
        """
        # data will be a DataFrame by pandas
        check = adfuller(data.dropna()) # dropna() to remove NaN values
        p_value = check[1]

        # 0.05 is our base
        if p_value <= 0.05 or d_val == max_d:
            # if stationary or reached max differencing
            order_result : dict = arma_order_select_ic(data.dropna(), ic='aic', trend='n')
            best_aic : list = order_result.aic_min_order

            best_p = best_aic[0]
            best_q = best_aic[1]
            return best_p, d_val, best_q
        
        else:
            # if not stationary:
            data_diff = data.diff().dropna()
            return self.best_order(data_diff, d_val + 1)
        
    def fit(self) -> None:
        """
        :param data: pd.DataFrame
        """
        self.model = ARIMA(self.data, order=self.best_order(self.data))
        self.result = self.model.fit(method_kwargs={'maxiter':300})

    def forecast(self, steps : int = 10, ) -> np.ndarray:
        forecast_obj = self.result.get_forecast(steps=steps)
        return forecast_obj
    
non_stat = np.cumsum(np.random.normal(size=1000))
data= pd.DataFrame(non_stat)
print(data)
model = ARIMAModel(data)
model.fit()

forecast_10 = model.forecast(steps=50)
print(forecast_10.conf_int())

             0
0    -0.682360
1     0.533310
2     0.900554
3    -0.494607
4    -0.066645
..         ...
995 -34.546508
996 -33.853121
997 -32.392818
998 -32.093583
999 -34.913598

[1000 rows x 1 columns]


  warn('Non-stationary starting autoregressive parameters'
  warn('Non-invertible starting MA parameters found.'


        lower y    upper y
1000 -37.164270 -33.053784
1001 -38.003582 -32.302002
1002 -38.466070 -31.601097
1003 -38.732528 -30.894284
1004 -38.960669 -30.231228
1005 -39.273066 -29.683076
1006 -39.722973 -29.290960
1007 -40.282457 -29.041133
1008 -40.863890 -28.869448
1009 -41.365094 -28.689872
1010 -41.716856 -28.432914
1011 -41.912134 -28.075607
1012 -42.006168 -27.648901
1013 -42.089632 -27.219768
1014 -42.247720 -26.857921
1015 -42.523377 -26.603626
1016 -42.900897 -26.450867
1017 -43.316604 -26.352041
1018 -43.690446 -26.240849
1019 -43.962849 -26.062635
1020 -44.119983 -25.798335
1021 -44.197237 -25.471270
1022 -44.260952 -25.134264
1023 -44.377588 -24.844045
1024 -44.584380 -24.635414
1025 -44.874415 -24.506801
1026 -45.202251 -24.422798
1027 -45.506520 -24.331535
1028 -45.738316 -24.188568
1029 -45.882247 -23.976385
1030 -45.961349 -23.710905
1031 -46.024762 -23.432528
1032 -46.124586 -23.186598
1033 -46.292628 -23.002755
1034 -46.527362 -22.882519
1035 -46.796464 -22.800079
1