In [141]:
from sklearn.base import BaseEstimator, RegressorMixin
import statsmodels.api as sm
from statsmodels.tsa.statespace.sarimax import SARIMAX


import pandas as pd
import numpy as np
import warnings

https://www.aritro.in/post/exponential-smoothing-using-scikit-learn-wrapper-statsmodels/

https://www.scikit-yb.org/en/latest/_modules/yellowbrick/contrib/statsmodels/base.html#StatsModelsWrapper

In [142]:
class Sarimax(BaseEstimator, RegressorMixin):
    """
    A universal sklearn-style wrapper for statsmodels SARIMAX.
    """

    def __init__(
        self,
        order = (1, 0, 0),
        seasonal_order = (0, 0, 0, 0),
        trend = None,
        measurement_error = False,
        time_varying_regression = False,
        mle_regression = True,
        simple_differencing = False,
        enforce_stationarity = True,
        enforce_invertibility = True,
        hamilton_representation = False,
        concentrate_scale = False,
        trend_offset = 1,
        use_exact_diffuse = False,
        dates = None,
        freq = None,
        missing = 'none',
        validate_specification = True,
        method = 'lbfgs',
        maxiter = 50,
        start_params = None,
        disp = False,
        **kwargs
    ):

        # values = vars()
        # print(values)
        # for i in inspect.getfullargspec(values['self'].__init__).args[1:]:
        #     setattr(values['self'], i, values[i])

        arguments = vars()
        arguments = {**arguments, **arguments['kwargs']}
        del arguments['self']
        del arguments['kwargs']
        for k, v in arguments.items():
            setattr(self, k, v)

        self.sarimax = None
        self.sarimax_res = None
        self.training_index = None

        self._dummy_create_fit_sarimax()


    def _create_sarimax(self, y, exog):
        """
        Create a new statsmodel.SARIMAX.
        """
        self.sarimax = SARIMAX(endog=y, exog=exog, **self.__dict__)

        return
    
    def _dummy_create_fit_sarimax(self):
        """
        Create a new statsmodel.SARIMAX and fit it with empty pandas.Series.
        """
        kwargs_default = {
            'method': self.method,
            'maxiter': self.maxiter,
            'start_params': self.start_params,
            'disp': self.disp,
        }
        
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            self._create_sarimax(y=pd.Series([], dtype=float), exog=None)
            self.sarimax_res = self.sarimax.fit(**kwargs_default)
            self.training_index = pd.RangeIndex(start=0, stop=0, step=1)

        return


    def fit(self, y, exog=None, **kwargs):
        """
        """
        kwargs_default = {
            'method': self.method,
            'maxiter': self.maxiter,
            'start_params': self.start_params,
            'disp': self.disp,
        }

        # User provided kwargs in the fit have preference over the default ones
        if kwargs:
            kwargs_default.update(kwargs)

        self._create_sarimax(y=y, exog=exog)
        self.sarimax_res = self.sarimax.fit(**kwargs_default)
        self.training_index = y.index

        return 


    def predict(self, steps, exog=None):
        """
        """
        predictions = self.sarimax_res.forecast(steps=steps, exog=exog)

        return predictions
    

    def predict_interval(self, steps, exog=None, alpha=0.05, **kwargs):
        """
        """
        predictions = self.sarimax_res.get_forecast(
                        steps           = steps,
                        exog            = exog,
                        return_conf_int = False,
                        alpha           = 0.05,
                        **kwargs
                      )
        
        predictions = pd.concat((
                        predictions.predicted_mean.rename("pred"),
                        predictions.conf_int(alpha=alpha)),
                        axis = 1
                     )
        predictions.columns = ['pred', 'lower_bound', 'upper_bound']

        return predictions
    
    def extend():
        """
        """
        pass
    

    def set_params(self, params):
        """
        
        """
        params = {k:v for k,v in params.items() if k in self.__dict__}
        for key, value in params.items():
            setattr(self, key, value)

        self._create_sarimax(
            y = pd.Series(data=self.sarimax.endog.ravel(), index=self.training_index),
            exog = self.sarimax.exog
        )
            

    def __repr__(self):
        p, d, q = self.order
        P, D, Q, m = self.seasonal_order

        return f"Sarimax({p},{d},{q})({P},{D},{Q})[{m}]"
        


In [143]:
sarimax = Sarimax()
sarimax.fit(y=pd.Series(np.random.normal(size=100)))
sarimax

In [144]:
sarimax.predict(steps=4)

100   -0.063707
101   -0.003126
102   -0.000153
103   -0.000008
Name: predicted_mean, dtype: float64

In [145]:
sarimax.predict_interval(steps=4)

Unnamed: 0,pred,lower_bound,upper_bound
100,-0.063707,-2.05285,1.925436
101,-0.003126,-1.994662,1.98841
102,-0.000153,-1.991695,1.991389
103,-8e-06,-1.99155,1.991535


In [146]:
sarimax.set_params({'order': (1, 0, 110)})
sarimax

In [147]:
sarimax = Sarimax()
sarimax.set_params({'order': (1, 0, 99)})
sarimax