In [None]:
#| default_exp distributed.multiprocess

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

# Multiprocess Backend

In [None]:
#| hide
from fastcore.test import test_eq

In [None]:
#| export
from typing import Any

from statsforecast.core import StatsForecast
from statsforecast.distributed.core import ParallelBackend

In [None]:
#| export
class MultiprocessBackend(ParallelBackend):
    def __init__(self, n_jobs: int) -> None:
        self.n_jobs = n_jobs
        super().__init__()

    def forecast(self, df, models, freq, **kwargs: Any) -> Any:
        model = StatsForecast(df=df, models=models, freq=freq, n_jobs=self.n_jobs)
        return model.forecast(**kwargs)

    def cross_validation(self, df, models, freq, **kwargs: Any) -> Any:
        model = StatsForecast(df=df, models=models, freq=freq, n_jobs=self.n_jobs)
        return model.cross_validation(**kwargs)

In [None]:
#| hide
from statsforecast.models import Naive
from statsforecast.utils import generate_series
try: from nbdev.imports import IN_NOTEBOOK
except: IN_NOTEBOOK=False

df = generate_series(10).reset_index()
df['unique_id'] = df['unique_id'].astype(str)

def test_mp_back(n_jobs=1):
    backend = MultiprocessBackend(n_jobs=n_jobs)
    #forecast
    fcst = backend.forecast(df, models=[Naive()], freq='D', h=12)
    fcst_stats = StatsForecast(models=[Naive()], freq='D').forecast(df=df, h=12)
    test_eq(fcst, fcst_stats)
    #crossvalidation
    fcst = backend.cross_validation(df, models=[Naive()], freq='D', h=12)
    fcst_stats = StatsForecast(models=[Naive()], freq='D').cross_validation(df=df, h=12)
    test_eq(fcst, fcst_stats)
    
test_mp_back()
if __name__=="__main__" and not IN_NOTEBOOK:
    test_mp_back(n_jobs=2)