In [34]:
from sklearn.base import BaseEstimator, RegressorMixin
import statsmodels.api as sm
from statsmodels.tsa.statespace.sarimax import SARIMAX

from typing import Optional, Union, Tuple, List, Dict


import pandas as pd
import numpy as np
import warnings
import inspect

https://www.aritro.in/post/exponential-smoothing-using-scikit-learn-wrapper-statsmodels/

https://www.scikit-yb.org/en/latest/_modules/yellowbrick/contrib/statsmodels/base.html#StatsModelsWrapper

In [55]:
from typing import Optional, Union, Tuple, List, Dict



class Sarimax(BaseEstimator, RegressorMixin):
    """
    A universal sklearn-style wrapper for statsmodels SARIMAX.

    Parameters
    ----------


    Attributes
    ----------

    """

    def __init__(
        self,
        order:tuple=(1, 0, 0),
        seasonal_order:tuple=(0, 0, 0, 0),
        trend:str=None,
        measurement_error:bool=False,
        time_varying_regression:bool=False,
        mle_regression:bool=True,
        simple_differencing:bool=False,
        enforce_stationarity:bool=True,
        enforce_invertibility:bool=True,
        hamilton_representation:bool=False,
        concentrate_scale:bool=False,
        trend_offset:int=1,
        use_exact_diffuse:bool=False,
        dates = None,
        freq = None,
        missing = 'none',
        validate_specification:bool=True,
        method:str='lbfgs',
        maxiter:int=50,
        start_params = None,
        disp:bool= False,
        fit_kwargs: Optional[dict]={'disp':False},
        predict_kwargs: Optional[dict]={}
) -> None:


        self.order = order
        self.seasonal_order = seasonal_order
        self.trend = trend
        self.measurement_error = measurement_error
        self.time_varying_regression = time_varying_regression
        self.mle_regression = mle_regression
        self.simple_differencing = simple_differencing
        self.enforce_stationarity = enforce_stationarity
        self.enforce_invertibility = enforce_invertibility
        self.hamilton_representation = hamilton_representation
        self.concentrate_scale = concentrate_scale
        self.trend_offset = trend_offset
        self.use_exact_diffuse = use_exact_diffuse
        self.dates = dates
        self.freq = freq
        self.missing = missing
        self.validate_specification = validate_specification
        self.method = method
        self.maxiter = maxiter
        self.start_params = start_params
        self.disp = disp
        self.fit_kwargs = fit_kwargs
        self.predict_kwargs = predict_kwargs
        
        self.sarimax = None
        self.sarimax_res = None
        self.training_index = None
        self._dummy_create_fit_sarimax()

        # Check remove from fit_kwargs the parameters that are not in the fit method
        # of the statsmodels.SARIMAX
        fit_kwargs_keys = inspect.signature(self.sarimax.fit).parameters.keys()
        self.fit_kwargs = {k:v for k,v in self.fit_kwargs.items() if k in fit_kwargs_keys}
        # Check remove from predict_kwargs the parameters that are not in the predict method
        # of the statsmodels.SARIMAX.RESULTS
        predict_kwargs_keys = inspect.signature(self.sarimax_res.get_forecast).parameters.keys()
        self.predict_kwargs = {k:v for k,v in self.predict_kwargs.items() if k in predict_kwargs_keys}
        


    def _create_sarimax(
        self,
        y: pd.Series,
        exog: Optional[Union[pd.Series, pd.DataFrame]] = None
        ) -> None:
        """
        A helper function to create a new statsmodel.SARIMAX.

        Parameters
        ----------
        y : pandas.Series
            The endogenous variable.
        exog : pandas.DataFrame
            The exogenous variables.
        
        Returns
        -------
        None

        """
        self.sarimax = SARIMAX(endog=y, exog=exog, **self.__dict__)

        return
    

    def _dummy_create_fit_sarimax(self):
        """
        A helper function to create a dummy SARIMAX and fit it to an empty
        series.

        Parameters
        ----------
        None

        Returns
        -------
        None
        """
        kwargs_default = {
            'method': self.method,
            'maxiter': self.maxiter,
            'start_params': self.start_params,
            'disp': self.disp,
        }
        
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            self._create_sarimax(y=pd.Series([], dtype=float), exog=None)
            self.sarimax_res = self.sarimax.fit(**kwargs_default)
            self.training_index = pd.RangeIndex(start=0, stop=0, step=1)

        return


    def fit(
        self,
        y: pd.Series,
        exog: Optional[Union[pd.Series, pd.DataFrame]] = None
    ) -> None:
        """
        Fit the model to the data.

        Parameters
        ----------
        y : pandas Series
            Training time series.
        exog : pandas Series, pandas DataFrame, default `None`
            Exogenous variable/s included as predictor/s. Must have the same
            number of observations as `y` and their indexes must be aligned so
            that y[i] is regressed on exog[i].

        Returns
        -------
        None

        """
        
        kwargs_default = {
            'method': self.method,
            'maxiter': self.maxiter,
            'start_params': self.start_params,
            'disp': self.disp,
        }

        # User provided fit_kwargs in the fit have preference over the default ones
        if self.fit_kwargs:
            kwargs_default.update(fit_kwargs)

        self._create_sarimax(y=y, exog=exog)
        self.sarimax_res = self.sarimax.fit(**self.fit_kwargs)
        self.training_index = y.index

        return 


    def predict(
        self,
        steps: int,
        last_window: Optional[pd.Series]=None,
        exog: Optional[Union[pd.Series, pd.DataFrame]]=None
    ):
        """
        Predict n steps ahead. It is an recursive process in which, each prediction,
        is used as a predictor for the next step.

        Parameters
        ----------
        steps : int
            Number of future steps predicted.
        last_window : pandas Series, default `None`
            Series values used to create the predictors (lags) needed in the 
            first iteration of the prediction (t + 1).
            If `last_window = None`, the values stored in `self.last_window` are
            used to calculate the initial predictors, and the predictions start
            right after training data.
        exog : pandas Series, pandas DataFrame, default `None`
            Exogenous variable/s included as predictor/s.

        Returns
        -------
        predictions : pandas Series
            Predicted values.
        
        """

        predictions = self.sarimax_res.forecast(steps=steps, exog=exog)

        return predictions
    

    def predict_interval(self, steps, exog=None, alpha=0.05, **kwargs):
        """
        """
        predictions = self.sarimax_res.get_forecast(
                        steps           = steps,
                        exog            = exog,
                        return_conf_int = False,
                        alpha           = 0.05,
                        **kwargs
                      )
        
        predictions = pd.concat((
                        predictions.predicted_mean.rename("pred"),
                        predictions.conf_int(alpha=alpha)),
                        axis = 1
                     )
        predictions.columns = ['pred', 'lower_bound', 'upper_bound']

        return predictions
    
    def extend():
        """
        """
        pass
    

    def set_params(self, params):
        """
        
        """
        params = {k:v for k,v in params.items() if k in self.__dict__}
        for key, value in params.items():
            setattr(self, key, value)

        self._create_sarimax(
            y = pd.Series(data=self.sarimax.endog.ravel(), index=self.training_index),
            exog = self.sarimax.exog
        )
            

    def __repr__(self):
        p, d, q = self.order
        P, D, Q, m = self.seasonal_order

        return f"Sarimax({p},{d},{q})({P},{D},{Q})[{m}]"


In [56]:
sarimax = Sarimax(order=(1, 1, 1))
sarimax.fit(y=pd.Series(np.random.normal(size=100)))
sarimax



In [57]:
sarimax.predict(steps=4)

100    0.089448
101    0.094491
102    0.094429
103    0.094430
Name: predicted_mean, dtype: float64

In [58]:
sarimax.predict_interval(steps=4)



Unnamed: 0,pred,lower_bound,upper_bound
100,0.089448,-1.962973,2.14187
101,0.094491,-1.957836,2.146818
102,0.094429,-1.957901,2.146759
103,0.09443,-1.9579,2.14676


In [54]:
sarimax.set_params({'order': (1, 0, 110)})
sarimax



In [49]:
sarimax = Sarimax()
sarimax.set_params({'order': (1, 0, 99)})
sarimax

