In [None]:
#all_distributed
#default_exp distributed.core

In [None]:
#hide
%load_ext autoreload
%autoreload 2

# Distributed core

> Building blocks for the distributed pipeline.

In [None]:
#export
import operator
from typing import Callable, List, Optional

import dask.dataframe as dd
from dask.distributed import Client, default_client, futures_of, wait

from mlforecast.core import TimeSeries, simple_predict


In [None]:
import pandas as pd
from nbdev import show_doc
from window_ops.rolling import rolling_mean

from mlforecast.utils import generate_daily_series

In [None]:
client = Client(n_workers=2)

In [None]:
#exporti
def _fit_transform(ts, data, **kwargs):
    df = ts.fit_transform(data, **kwargs)
    return ts, df


def _predict(ts, model, horizon, predict_fn, **predict_fn_kwargs):
    return ts.predict(model, horizon, predict_fn, **predict_fn_kwargs)


In [None]:
#export
class DistributedTimeSeries:
    """TimeSeries for distributed forecasting."""
    
    def __init__(
        self,
        ts: TimeSeries,
        client: Optional[Client] = None,
    ):
        self._base_ts = ts
        self.client = client or default_client()
        
    def fit_transform(
        self,
        data: dd.DataFrame,
        static_features: Optional[List[str]] = None,
        dropna: bool = True,
        keep_last_n: Optional[int] = None,         
    ) -> dd.DataFrame:
        """Applies the transformations to each partition of `data`."""
        self.data_divisions = data.divisions
        data = self.client.persist(data)
        wait(data)
        partition_futures = futures_of(data)
        self.ts = []
        df_futures = []
        for part_future in partition_futures:
            future = self.client.submit(
                _fit_transform,
                self._base_ts,
                part_future,
                static_features=static_features,
                dropna=dropna,
                keep_last_n=keep_last_n,
                pure=False,
            )
            ts_future = self.client.submit(operator.itemgetter(0), future)
            df_future = self.client.submit(operator.itemgetter(1), future)
            self.ts.append(ts_future)
            df_futures.append(df_future)
        meta = self.client.submit(lambda x: x.head(0), df_futures[0]).result()
        return dd.from_delayed(df_futures, meta=meta)
    
    def predict(
        self,
        model,
        horizon: int,
        predict_fn: Callable = simple_predict,
        **predict_fn_kwargs,
    ) -> dd.DataFrame:
        """Broadcasts `model` across all workers and computes the next `horizon` timesteps.
        
        `predict_fn(model, new_x, features_order, **predict_fn_kwargs)` is called on each timestep.
        """
        model_future = self.client.scatter(model, broadcast=True)
        predictions_futures = [
            self.client.submit(
                _predict,
                ts_future,
                model_future,
                horizon,
                predict_fn=predict_fn,
                **predict_fn_kwargs,
            )
            for ts_future in self.ts
        ]
        meta = self.client.submit(lambda x: x.head(), predictions_futures[0]).result()
        return dd.from_delayed(
            predictions_futures, meta=meta, divisions=self.data_divisions
        )

    def __repr__(self):
        ts_repr = self._base_ts.__repr__()
        return f'Distributed{ts_repr}'


The `DistributedTimeSeries` class takes a `TimeSeries` object which specifies the desired features. If you have more partitions than workers it's recommended to set `num_threads=1` to avoid colliding with dask's parallelism (dask could schedule several tasks on each worker).

In [None]:
config = dict(
    freq='D',
    lags=[7, 14],
    lag_transforms={
        7 : [(rolling_mean, 7)],
        14: [(rolling_mean, 7)],
    },
    date_features=['dayofweek'],
    num_threads=1,
)
ts = TimeSeries(**config)
dts = DistributedTimeSeries(ts)
dts

In [None]:
series = generate_daily_series(100, n_static_features=2)
series

In [None]:
partitioned_series = dd.from_pandas(series, npartitions=6)
partitioned_series

In [None]:
show_doc(DistributedTimeSeries.fit_transform)

In [None]:
train_ddf = dts.fit_transform(partitioned_series).compute()

local_df = ts.fit_transform(series)
assert train_ddf.equals(local_df)

In [None]:
#hide
next_feats_futures = client.map(lambda ts: ts._update_features(), dts.ts)
next_feats = pd.concat(client.gather(next_feats_futures))
local_upd = ts._update_features()
assert next_feats.equals(local_upd)

In [None]:
show_doc(DistributedTimeSeries.predict)

In [None]:
class DummyModel:
    def predict(self, X):
        return X['lag-7'].values
    
horizon = 7
model = DummyModel()
preds = dts.predict(model, horizon).compute()

ts = TimeSeries(**config)
ts.fit_transform(series)
local_preds = ts.predict(model, horizon)

assert preds.equals(local_preds)

In [None]:
client.close()