In [None]:
#| default_exp distributed.fugue

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

# Fugue Backend

In [None]:
#| hide
from fastcore.test import test_eq

In [None]:
#| export
from typing import Any, Dict

import numpy as np
import pandas as pd
try:
    from fugue import transform
except ModuleNotFoundError as e:
    msg = (
        f'{e}. To use fugue you have to install it.'
        'Please run `pip install fugue`. '
    )
    raise ModuleNotFoundError(msg) from e
from statsforecast.core import StatsForecast
from statsforecast.distributed.core import ParallelBackend
from triad import Schema

In [None]:
#| export
class FugueBackend(ParallelBackend):
    def __init__(self, engine: Any = None, conf: Any = None, **transform_kwargs: Any):
        self._engine = engine
        self._conf = conf
        self._transform_kwargs = dict(transform_kwargs)

    def __getstate__(self) -> Dict[str, Any]:
        return {}

    def forecast(self, df, models, freq, **kwargs: Any) -> Any:
        schema = "*-y+" + str(self._get_output_schema(models))
        return transform(
            df,
            self._forecast_series,
            params=dict(models=models, freq=freq, kwargs=kwargs),
            schema=schema,
            partition={"by": "unique_id"},
            engine=self._engine,
            engine_conf=self._conf,
            **self._transform_kwargs,
        )

    def cross_validation(self, df, models, freq, **kwargs: Any) -> Any:
        schema = "*-y+" + str(self._get_output_schema(models, mode="cv"))
        return transform(
            df,
            self._cv,
            params=dict(models=models, freq=freq, kwargs=kwargs),
            schema=schema,
            partition={"by": "unique_id"},
            engine=self._engine,
            engine_conf=self._conf,
            **self._transform_kwargs,
        )

    def _forecast_series(self, df: pd.DataFrame, models, freq, kwargs) -> pd.DataFrame:
        tdf = df.set_index("unique_id")
        model = StatsForecast(df=tdf, models=models, freq=freq, n_jobs=1)
        return model.forecast(**kwargs).reset_index()

    def _cv(self, df: pd.DataFrame, models, freq, kwargs) -> pd.DataFrame:
        tdf = df.set_index("unique_id")
        model = StatsForecast(df=tdf, models=models, freq=freq, n_jobs=1)
        return model.cross_validation(**kwargs).reset_index()

    def _get_output_schema(self, models, mode="forecast") -> Schema:
        cols = [(repr(model), np.float32) for model in models]
        if mode == "cv":
            cols = [("cutoff", "datetime"), ("y", np.float32)] + cols
        return Schema(cols)

In [None]:
#| hide
try: from nbdev.imports import IN_NOTEBOOK
except: IN_NOTEBOOK=False
if __name__=="__main__" and not IN_NOTEBOOK:
    from statsforecast.models import Naive
    from statsforecast.utils import generate_series

    df = generate_series(10).reset_index()
    df['unique_id'] = df['unique_id'].astype(str)

    backend = FugueBackend()
    #forecast
    fcst_fugue = backend.forecast(df, models=[Naive()], freq='D', h=12)
    fcst_stats = StatsForecast(models=[Naive()], freq='D').forecast(df=df, h=12)
    test_eq(fcst_fugue, fcst_stats.reset_index())

    #cross validation
    fcst_fugue = backend.cross_validation(df, models=[Naive()], freq='D', h=12)
    fcst_stats = StatsForecast(models=[Naive()], freq='D').cross_validation(df=df, h=12)
    test_eq(fcst_fugue, fcst_stats.reset_index())