In [None]:
#| default_exp mstl

# MSTL model

In [None]:
#| hide
from nbdev.showdoc import add_docs, show_doc

In [None]:
#| export
from typing import Dict, List, Optional, Union

import numpy as np
import pandas as pd
import statsmodels.api as sm

In [None]:
#| export
def mstl(
        x: np.ndarray, # time series
        period: Union[int, List[int]], # season length
        blambda: Optional[float] = None, # box-cox transform
        iterate: int = 2, # number of iterations
        s_window: Optional[np.ndarray] = None, # seasonal window
        stl_kwargs: Dict = dict(),
    ):
    if s_window is None:
        s_window = 7 + 4 * np.arange(1, 7)
    origx = x
    n = len(x)
    if isinstance(period, int):
        msts = [period]
    else:
        msts = sorted(period)
    if len(msts) == 1:
        iterate = 1
    if x.ndim == 2:
        x = x[:, 0]
    if np.isnan(x).any():
        raise Exception(
            '`mstl` cannot handle missing values. '
            'Please raise an issue to include this feature.'
        ) # we should interpolate here
    if blambda is not None:
        raise Exception(
            '`blambda` not implemented yet. ' 
            'Please rise an issue to include this feature.'
        )
    stl_kwargs = {'seasonal_deg': 0, **stl_kwargs}
    if msts[0] > 1:
        seas = np.zeros((len(msts), n))
        deseas = np.copy(x)
        if len(s_window) == 1:
            s_window = np.repeat(s_window, len(msts))
        for j in range(iterate):
            for i, seas_ in enumerate(msts, start=0):
                deseas = deseas + seas[i]
                fit = sm.tsa.STL(deseas, period=seas_, seasonal=s_window[i], **stl_kwargs).fit()
                seas[i] = fit.seasonal
                deseas = deseas - seas[i]
        trend = fit.trend
    else:
        try:
            from supersmoother import SuperSmoother
        except ImportError as e:
            print('supersmoother is required for mstl with period=1')
            raise e
        deseas = x
        t = 1 + np.arange(n)
        trend = SuperSmoother().fit(t, x).predict(t)
    deseas[np.isnan(origx)] = np.nan
    remainder = deseas - trend
    output = {'data': origx, 'trend': trend}
    if msts is not None and msts[0] > 1:
        if len(msts) == 1:
            output['seasonal'] = seas[0]
        else:
            for i, seas_ in enumerate(msts, start=0):
                output[f'seasonal{seas_}'] = seas[i]
    output['remainder'] = remainder
    return pd.DataFrame(output)

In [None]:
#| hide
x = np.arange(1, 11)
mstl(x, 12)

In [None]:
#| hide
from statsforecast.utils import AirPassengers as ap

In [None]:
#| hide
decomposition = mstl(ap, 12)
decomposition.plot()

In [None]:
#| hide
decomposition_stl_trend = mstl(ap, 12, stl_kwargs={'trend': 27})
decomposition_stl_trend.plot()

In [None]:
#| hide
decomposition_trend = mstl(ap, 1)
decomposition_trend.plot()

In [None]:
#| hide
url = "https://raw.githubusercontent.com/tidyverts/tsibbledata/master/data-raw/vic_elec/VIC2015/demand.csv"
df = pd.read_csv(url)
df["Date"] = df["Date"].apply(
    lambda x: pd.Timestamp("1899-12-30") + pd.Timedelta(x, unit="days")
)
df["ds"] = df["Date"] + pd.to_timedelta((df["Period"] - 1) * 30, unit="m")
timeseries = df[["ds", "OperationalLessIndustrial"]]
timeseries.columns = [
    "ds",
    "y",
]  # Rename to OperationalLessIndustrial to y for simplicity.

# Filter for first 149 days of 2012.
start_date = pd.to_datetime("2012-01-01")
end_date = start_date + pd.Timedelta("149D")
mask = (timeseries["ds"] >= start_date) & (timeseries["ds"] < end_date)
timeseries = timeseries[mask]

# Resample to hourly
timeseries = timeseries.set_index("ds").resample("H").sum()
timeseries.head()

# decomposition
decomposition = mstl(timeseries['y'].values, [24, 24 * 7]).tail(24 * 7 * 4)
decomposition.plot()