In [None]:
#| default_exp distributed.multiprocess

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

# MultiprocessBackend

> The computational efficiency of `StatsForecast` can be tracked to its two core components:<br>1. Its `models` written in NumBa that optimizes Python code to reach C speeds.<br>2. Its `core.StatsForecast` class that enables distributed computing.<br>This is a low-level class enabling other distribution methods.<br><br>

In [None]:
#| hide
from fastcore.test import test_eq
from nbdev.showdoc import add_docs, show_doc

In [None]:
#| export
from typing import Any

from statsforecast.core import _StatsForecast, ParallelBackend

In [None]:
#| export

# This parent class holds common `forecast` and `cross_validation` methods 
# from `core.StatsForecast` to enable the `FugueBackend` and the `RayBackend`.

# This Parent class is inherited by [FugueBakend](https://nixtla.github.io/statsforecast/distributed.fugue.html) 
# and [RayBackend](https://nixtla.github.io/statsforecast/distributed.ray.html).

class MultiprocessBackend(ParallelBackend):
    """MultiprocessBackend Parent Class for Distributed Computation.

    Parameters
    ----------
    n_jobs : int
        Number of jobs used in the parallel processing, use -1 for all cores.
    """
    def __init__(self, n_jobs: int) -> None:
        self.n_jobs = n_jobs
        super().__init__()

    def forecast(self, df, models, freq, fallback_model=None, **kwargs: Any) -> Any:
        model = _StatsForecast(models=models, freq=freq, 
                              fallback_model=fallback_model, n_jobs=self.n_jobs)
        return model.forecast(df=df, **kwargs)

    def cross_validation(self, df, models, freq, fallback_model=None, **kwargs: Any) -> Any:
        model = _StatsForecast(models=models, freq=freq, 
                              fallback_model=fallback_model, n_jobs=self.n_jobs)
        return model.cross_validation(df=df, **kwargs)

In [None]:
show_doc(MultiprocessBackend, title_level=3)

In [None]:
#| hide
from statsforecast import StatsForecast
from statsforecast.models import Naive
from statsforecast.utils import generate_series

In [None]:
#| hide
df = generate_series(10).reset_index()
df['unique_id'] = df['unique_id'].astype(str)

class FailNaive:
    def forecast(self):
        pass
    def __repr__(self):
        return 'Naive'

def test_mp_back(n_jobs=1):
    backend = MultiprocessBackend(n_jobs=n_jobs)
    #forecast
    fcst = backend.forecast(df, models=[Naive()], freq='D', h=12)
    fcst_stats = StatsForecast(models=[Naive()], freq='D').forecast(df=df, h=12)
    test_eq(fcst, fcst_stats)
    #crossvalidation
    fcst = backend.cross_validation(df, models=[Naive()], freq='D', h=12)
    fcst_stats = StatsForecast(models=[Naive()], freq='D').cross_validation(df=df, h=12)
    test_eq(fcst, fcst_stats)
    # fallback model
    fcst = backend.forecast(df, models=[FailNaive()], freq='D', fallback_model=Naive(), h=12)
    fcst_stats = StatsForecast(models=[Naive()], freq='D').forecast(df=df, h=12)
    test_eq(fcst, fcst_stats)
    
    #cross validation
    fcst_fugue = backend.cross_validation(df, models=[FailNaive()], freq='D', fallback_model=Naive(), h=12)
    fcst_stats = StatsForecast(models=[Naive()], freq='D').cross_validation(df=df, h=12)
    test_eq(fcst_fugue, fcst_stats)

test_mp_back()

In [None]:
#| hide
#| eval: false
test_mp_back(n_jobs=10)