# Objective
The objetive of this notebook is to create models to forecast our cement production data and wrap them within a [Custom Python Model](https://mlflow.org/docs/latest/models.html#custom-python-models). As we saw previously the default statsmodels flavour did not provide as with confidence or prediction intervals. In this notebook we will develop custom Python Models to solve that issue.

# Description of the data
The dataset is quaterly cement production data.

# Imports, configuration and constants

In [11]:
import matplotlib.pyplot as plt
import mlflow
import pandas as pd
import statsmodels
import cloudpickle

from sys import version_info
from mlflow.models import infer_signature
from prophet import Prophet
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_absolute_percentage_error
from statsmodels.tsa.exponential_smoothing.ets import ETSModel
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tsa.statespace.sarimax import SARIMAX

import utils as ut

In [2]:
import importlib
importlib.reload(ut)

<module 'utils' from '/home/santiagopalmero/repos/fpp3package/python/mlflow_12_2/utils.py'>

In [3]:
plt.rc("figure", figsize=(16, 12))
plt.rc("font", size=13)

In [4]:
PLOT_TITLE = "AUS cement production"
PLOT_YLABEL = "Cement"
PLOT_XLABEL = "Quarter"

PLOT_KWARGS = {
    "title": PLOT_TITLE,
    "ylabel": PLOT_YLABEL,
    "xlabel": PLOT_XLABEL,
}

# Load data

In [5]:
ts_train = ut.read_csv_series("data/ts_train.csv")
ts_test = ut.read_csv_series("data/ts_test.csv")

In [6]:
ts_train.index = pd.to_datetime(ts_train.index)
ts_train = ts_train.asfreq('QS-OCT')

ts_test.index = pd.to_datetime(ts_test.index)
ts_test = ts_test.asfreq('QS-OCT')

# Set experiment

In [7]:
mlflow.set_experiment("Cement_Forecasting")

<Experiment: artifact_location='file:///home/santiagopalmero/repos/fpp3package/python/mlflow/mlruns/570366757150847423', creation_time=1701436591943, experiment_id='570366757150847423', last_update_time=1701436591943, lifecycle_stage='active', name='Cement_Forecasting', tags={'mlflow.note.content': 'Project about cement production forecasting. This '
                        'experiment contains several forecasting models.',
 'project_name': 'forecasting'}>

# Forecasting

## ARIMA

From the https://otexts.com/fpp3/arima-ets.html#comparing-arima-and-ets-on-seasonal-data section we know that the model that was used is `ARIMA(1,0,1)(2,1,1)[4] w/ drift`.

The drift concept is better explained in the previous version of the book https://otexts.com/fpp2/arima-r.html. In this link we can see an explanation about the drift in statsmodels https://stackoverflow.com/questions/66651360/arima-forecast-gives-different-results-with-new-python-statsmodels.

There are [differences in implementation](https://www.statsmodels.org/dev/examples/notebooks/generated/statespace_sarimax_faq.html#Differences-between-trend-and-exog-in-SARIMAX) between the class `ARIMA` and `SARIMAX`. As we have seen previously, `SARIMAX` is the one that corresponds to the book. 

In [23]:
order = (1,0,1)
seasonal_order = (2,1,1,4)

sarimax = SARIMAX(
    endog=ts_train, 
    order=order, 
    seasonal_order=seasonal_order, 
    trend="c",
)
res = sarimax.fit()

ts_sarimax_h = res.predict(start=ts_test.index[0], end=ts_test.index[-1])

rmse = mean_squared_error(ts_test, ts_sarimax_h, squared=False)
mae = mean_absolute_error(ts_test, ts_sarimax_h)
mape = mean_absolute_percentage_error(ts_test, ts_sarimax_h)

 This problem is unconstrained.


RUNNING THE L-BFGS-B CODE

           * * *

Machine precision = 2.220D-16
 N =            7     M =           10

At X0         0 variables are exactly at the bounds

At iterate    0    f=  5.93161D+00    |proj g|=  2.14605D-01

At iterate    5    f=  5.85403D+00    |proj g|=  3.80578D-02

At iterate   10    f=  5.85147D+00    |proj g|=  1.37646D-02

At iterate   15    f=  5.85027D+00    |proj g|=  3.39077D-02

At iterate   20    f=  5.82684D+00    |proj g|=  1.33501D-01

At iterate   25    f=  5.80155D+00    |proj g|=  5.88010D-03

At iterate   30    f=  5.79975D+00    |proj g|=  6.18598D-03

At iterate   35    f=  5.79736D+00    |proj g|=  3.61733D-03

At iterate   40    f=  5.79388D+00    |proj g|=  1.10395D-02

At iterate   45    f=  5.78620D+00    |proj g|=  2.96676D-02

At iterate   50    f=  5.78534D+00    |proj g|=  4.97323D-04

           * * *

Tit   = total number of iterations
Tnf   = total number of function evaluations
Tnint = total number of segments explored during Cau



In [24]:
sarimax_model_path = "sarimax.pickle"
artifacts = {"SARIMAX_model": sarimax_model_path}
res.save(sarimax_model_path)

In [25]:
PYTHON_VERSION = (
    f"{version_info.major}.{version_info.minor}.{version_info.micro}"
)

conda_env = {
    "channels": ["defaults"],
    "dependencies": [
        f"python={PYTHON_VERSION}",
        "pip",
        {
            "pip": [
                f"mlflow=={mlflow.__version__}",
                f"statsmodels=={statsmodels.__version__}",
                f"cloudpickle=={cloudpickle.__version__}",
            ],
        },
    ],
    "name": "sarimax_env",
}

In [40]:
class SARIMAXPythonModel(mlflow.pyfunc.PythonModel):
    
    def load_context(self, context):
        from statsmodels.tsa.statespace.sarimax import SARIMAXResults
        
        self.res = SARIMAXResults.load(context.artifacts["SARIMAX_model"])

    def predict(self, context, model_input, params=None):
        start = model_input["start"].iloc[0]
        end = model_input["end"].iloc[0]
        
        pred = self.res.get_prediction(
            start=start, 
            end=end,
        )
        
        return pred.summary_frame(alpha=0.05)

In [41]:
with mlflow.start_run(run_name="sarimax"):
    mlflow.set_tag(
        "custom", 
        "Testing model development custom MLflow features.",
    )

    mlflow.log_params(
        {
            "order": order,
            "seasonal_order": sarimax.seasonal_order,
            "trend": sarimax.trend,
        }
    )
    mlflow.log_params(res.params)
    mlflow.log_params({"summary": res.summary()})

    mlflow.log_metric("rmse", rmse)
    mlflow.log_metric("mae", mae)
    mlflow.log_metric("mape", mape)

    mlflow_pyfunc_model_path = "SARIMAX_pyfunc"
    mlflow.pyfunc.log_model(
        artifact_path=mlflow_pyfunc_model_path,
        python_model=SARIMAXPythonModel(),
        artifacts=artifacts,
        conda_env=conda_env,
    )

Downloading artifacts: 100%|██████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 70.40it/s]




In [42]:
sarimax_uri = "runs:/6a2982bcc8484c6f907a899674235727/SARIMAX_pyfunc"
sarimax_pyfunc = mlflow.pyfunc.load_model(model_uri=sarimax_uri)

In [43]:
start = pd.to_datetime(ts_test.index[0])
end = pd.to_datetime(ts_test.index[-1])

prediction_data = pd.DataFrame({"start": start, "end": end}, index=[0])

sarimax_pyfunc.predict(prediction_data)

y,mean,mean_se,mean_ci_lower,mean_ci_upper
2008-01-01,2316.867722,102.281542,2116.399583,2517.33586
2008-04-01,2489.894004,121.780633,2251.208349,2728.57966
2008-07-01,2531.558849,134.862533,2267.233141,2795.884556
2008-10-01,2482.209738,144.097215,2199.784387,2764.63509
2009-01-01,2290.261027,157.396005,1981.770527,2598.751527
2009-04-01,2473.175645,165.513306,2148.775528,2797.575763
2009-07-01,2502.579827,171.477033,2166.491018,2838.668637
2009-10-01,2451.309548,175.899221,2106.553411,2796.065685
2010-01-01,2255.029916,177.084376,1907.950917,2602.108915
2010-04-01,2455.212408,178.363884,2105.62562,2804.799196


## ETS
From the books example in https://otexts.com/fpp3/arima-ets.html we can see the type of ETS model finally selected is ETS(M,N,M). From the R documentation https://www.rdocumentation.org/packages/forecast/versions/8.21/topics/ets we know that:
- The first letter denotes the error type ("A", "M" or "Z");
- The second letter denotes the trend type ("N","A","M" or "Z")
- The third letter denotes the season type ("N","A","M" or "Z").
- In all cases, "N"=none, "A"=additive, "M"=multiplicative and "Z"=automatically selected

In [44]:
ets = ETSModel(ts_train, error="mul", trend=None, seasonal="mul")
res = ets.fit()

ts_ets_h = res.predict(start=ts_test.index[0], end=ts_test.index[-1])

rmse = mean_squared_error(ts_test, ts_ets_h, squared=False)
mae = mean_absolute_error(ts_test, ts_ets_h)
mape = mean_absolute_percentage_error(ts_test, ts_ets_h)

RUNNING THE L-BFGS-B CODE

           * * *

Machine precision = 2.220D-16
 N =            6     M =           10

At X0         0 variables are exactly at the bounds

At iterate    0    f=  6.91324D+00    |proj g|=  2.15029D+00

At iterate    1    f=  6.84699D+00    |proj g|=  1.65803D+00

At iterate    2    f=  6.42640D+00    |proj g|=  4.69153D-01

At iterate    3    f=  6.34069D+00    |proj g|=  4.86073D-01

At iterate    4    f=  6.26445D+00    |proj g|=  4.72329D-01

At iterate    5    f=  6.22885D+00    |proj g|=  3.81426D-01

At iterate    6    f=  6.18920D+00    |proj g|=  2.63642D-01

At iterate    7    f=  6.16398D+00    |proj g|=  4.70483D-01

At iterate    8    f=  6.15694D+00    |proj g|=  2.18301D-01

At iterate    9    f=  6.15418D+00    |proj g|=  9.72367D-02

At iterate   10    f=  6.15325D+00    |proj g|=  6.66425D-02

At iterate   11    f=  6.15227D+00    |proj g|=  6.25598D-02

At iterate   12    f=  6.14968D+00    |proj g|=  9.14247D-02

At iterate   13    f=  6.1

In [45]:
ets_model_path = "ets.pickle"
artifacts = {"ETS_model": ets_model_path}
res.save(ets_model_path)

In [47]:
PYTHON_VERSION = (
    f"{version_info.major}.{version_info.minor}.{version_info.micro}"
)

conda_env = {
    "channels": ["defaults"],
    "dependencies": [
        f"python={PYTHON_VERSION}",
        "pip",
        {
            "pip": [
                f"mlflow=={mlflow.__version__}",
                f"statsmodels=={statsmodels.__version__}",
                f"cloudpickle=={cloudpickle.__version__}",
            ],
        },
    ],
    "name": "ets_env",
}

In [48]:
class ETSPythonModel(mlflow.pyfunc.PythonModel):
    
    def load_context(self, context):
        from statsmodels.tsa.exponential_smoothing.ets import ETSResults
        
        self.res = ETSResults.load(context.artifacts["ETS_model"])

    def predict(self, context, model_input, params=None):
        start = model_input["start"].iloc[0]
        end = model_input["end"].iloc[0]
        
        pred = self.res.get_prediction(
            start=start, 
            end=end,
        )
        
        return pred.summary_frame(alpha=0.05)

In [49]:
with mlflow.start_run(run_name="ets"):
    mlflow.set_tag(
        "custom", 
        "Testing model development custom MLflow features.",
    )

    mlflow.log_params(
        {
            "error": "mul",
            "trend": None,
            "seasonal": "mul",
        }
    )
    mlflow.log_params({"summary": res.summary()})

    mlflow.log_metric("rmse", rmse)
    mlflow.log_metric("mae", mae)
    mlflow.log_metric("mape", mape)

    mlflow_pyfunc_model_path = "ETS_pyfunc"
    mlflow.pyfunc.log_model(
        artifact_path=mlflow_pyfunc_model_path,
        python_model=ETSPythonModel(),
        artifacts=artifacts,
        conda_env=conda_env,
    )

Downloading artifacts: 100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 373.09it/s]




In [50]:
ets_uri = "runs:/fca494d9bc8a4c6695c94bd5667fde6e/ETS_pyfunc"
ets_pyfunc = mlflow.pyfunc.load_model(model_uri=ets_uri)

In [51]:
start = pd.to_datetime(ts_test.index[0])
end = pd.to_datetime(ts_test.index[-1])

prediction_data = pd.DataFrame({"start": start, "end": end}, index=[0])

ets_pyfunc.predict(prediction_data)

Unnamed: 0,mean,mean_numerical,pi_lower,pi_upper
2008-01-01,2253.187112,2247.073401,2011.449717,2516.969333
2008-04-01,2490.747411,2486.9242,2156.871388,2851.129462
2008-07-01,2573.875726,2566.919399,2156.407212,3041.033789
2008-10-01,2540.758414,2535.871206,2087.157634,3022.11697
2009-01-01,2253.187112,2248.851617,1811.792721,2752.310268
2009-04-01,2490.747411,2487.515398,1967.328635,3142.807288
2009-07-01,2573.875726,2570.408218,1992.151426,3269.563901
2009-10-01,2540.758414,2539.05059,1923.144015,3245.158942
2010-01-01,2253.187112,2254.873833,1688.024659,2955.048658
2010-04-01,2490.747411,2482.405954,1831.927245,3271.167142
