In [None]:
#|default_exp distributed.core

In [None]:
#|hide
%load_ext autoreload
%autoreload 2

# Distributed core

> Building blocks for the distributed pipeline.

In [None]:
#|export
import operator
from typing import Callable, List, Optional

import dask.dataframe as dd
import pandas as pd
from dask.distributed import Client, default_client, futures_of, wait

from mlforecast.core import TimeSeries

In [None]:
import pandas as pd
from nbdev import show_doc
from window_ops.rolling import rolling_mean

from mlforecast.utils import generate_daily_series

In [None]:
client = Client(n_workers=2, threads_per_worker=1)

In [None]:
#|exporti
def _fit_transform(ts, data, id_col, time_col, target_col, static_features, dropna, keep_last_n):
    df = ts.fit_transform(data, id_col, time_col, target_col, static_features, dropna, keep_last_n)
    return ts, df


def _predict(ts, model, horizon, dynamic_dfs, predict_fn, **predict_fn_kwargs):
    return ts.predict(model, horizon, dynamic_dfs, predict_fn, **predict_fn_kwargs)

In [None]:
#|export
class DistributedTimeSeries:
    """TimeSeries for distributed forecasting."""
    
    def __init__(
        self,
        ts: TimeSeries,
        client: Optional[Client] = None,
    ):
        self._base_ts = ts
        self.client = client or default_client()
        
    def fit_transform(
        self,
        data: dd.DataFrame,
        id_col: str = 'index',
        time_col: str = 'ds',
        target_col: str = 'y',        
        static_features: Optional[List[str]] = None,
        dropna: bool = True,
        keep_last_n: Optional[int] = None,         
    ) -> dd.DataFrame:
        """Applies the transformations to each partition of `data`."""
        self.data_divisions = data.divisions
        data = self.client.persist(data)
        wait(data)
        partition_futures = futures_of(data)
        self.ts = []
        df_futures = []
        for part_future in partition_futures:
            future = self.client.submit(
                _fit_transform,
                self._base_ts,
                part_future,                
                id_col,
                time_col,
                target_col,
                static_features=static_features,
                dropna=dropna,
                keep_last_n=keep_last_n,
                pure=False,
            )
            ts_future = self.client.submit(operator.itemgetter(0), future)
            df_future = self.client.submit(operator.itemgetter(1), future)
            self.ts.append(ts_future)
            df_futures.append(df_future)
        meta = self.client.submit(lambda x: x.head(0), df_futures[0]).result()
        ret = dd.from_delayed(df_futures, meta=meta)
        assert not isinstance(ret, dd.Series)  # mypy
        return ret
    
    def predict(
        self,
        models,
        horizon: int,
        dynamic_dfs: Optional[List[pd.DataFrame]] = None,
        predict_fn: Optional[Callable] = None,
        **predict_fn_kwargs,
    ) -> dd.DataFrame:
        """Broadcasts `models` across all workers and computes the next `horizon` timesteps.
        
        `predict_fn(model, new_x, features_order, **predict_fn_kwargs)` is called on each timestep.
        """
        if not isinstance(models, list):
            models = [models]
        models_future = self.client.scatter(models, broadcast=True)
        if dynamic_dfs is not None:
            dynamic_dfs_futures = self.client.scatter(dynamic_dfs, broadcast=True)
        else:
            dynamic_dfs_futures = None
        predictions_futures = [
            self.client.submit(
                _predict,
                ts_future,
                models_future,
                horizon,
                dynamic_dfs=dynamic_dfs_futures,
                predict_fn=predict_fn,
                **predict_fn_kwargs,
            )
            for ts_future in self.ts
        ]
        meta = self.client.submit(lambda x: x.head(), predictions_futures[0]).result()
        ret = dd.from_delayed(
            predictions_futures, meta=meta, divisions=self.data_divisions
        )
        assert not isinstance(ret, dd.Series)  # mypy
        return ret

    def __repr__(self):
        ts_repr = repr(self._base_ts)
        return f'Distributed{ts_repr}'

The `DistributedTimeSeries` class takes a `TimeSeries` object which specifies the desired features. If you have more partitions than workers it's recommended to set `num_threads=1` to avoid colliding with dask's parallelism (dask could schedule several tasks on each worker).

In [None]:
config = dict(
    freq='D',
    lags=[7, 14],
    lag_transforms={
        7 : [(rolling_mean, 7)],
        14: [(rolling_mean, 7)],
    },
    date_features=['dayofweek'],
    num_threads=1,
)
ts = TimeSeries(**config)
dts = DistributedTimeSeries(ts)
dts

DistributedTimeSeries(freq=<Day>, transforms=['lag-7', 'lag-14', 'rolling_mean_lag-7_window_size-7', 'rolling_mean_lag-14_window_size-7'], date_features=['dayofweek'], num_threads=1)

In [None]:
series = generate_daily_series(100, n_static_features=2)
series

Unnamed: 0_level_0,ds,y,static_0,static_1
unique_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
id_00,2000-01-01,39.811983,79,45
id_00,2000-01-02,103.274013,79,45
id_00,2000-01-03,176.574744,79,45
id_00,2000-01-04,258.987900,79,45
id_00,2000-01-05,344.940404,79,45
...,...,...,...,...
id_99,2000-06-25,453.400509,69,35
id_99,2000-06-26,30.229478,69,35
id_99,2000-06-27,101.313713,69,35
id_99,2000-06-28,145.724335,69,35


In [None]:
partitioned_series = dd.from_pandas(series, npartitions=6)
partitioned_series

Unnamed: 0_level_0,ds,y,static_0,static_1
npartitions=6,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
id_00,datetime64[ns],float64,category[known],category[known]
id_18,...,...,...,...
...,...,...,...,...
id_84,...,...,...,...
id_99,...,...,...,...


In [None]:
show_doc(DistributedTimeSeries.fit_transform)

---

### DistributedTimeSeries.fit_transform

>      DistributedTimeSeries.fit_transform (data:dask.dataframe.core.DataFrame,
>                                           id_col:str='unique_id',
>                                           time_col:str='ds',
>                                           target_col:str='y', static_features:
>                                           Optional[List[str]]=None,
>                                           dropna:bool=True,
>                                           keep_last_n:Optional[int]=None)

Applies the transformations to each partition of `data`.

In [None]:
train_ddf = dts.fit_transform(partitioned_series).compute()

local_df = ts.fit_transform(series)
pd.testing.assert_frame_equal(train_ddf, local_df)

In [None]:
#|hide
next_feats_futures = client.map(lambda ts: ts._update_features(), dts.ts)
next_feats = pd.concat(client.gather(next_feats_futures))
local_upd = ts._update_features()
pd.testing.assert_frame_equal(next_feats, local_upd)

In [None]:
show_doc(DistributedTimeSeries.predict)

---

### DistributedTimeSeries.predict

>      DistributedTimeSeries.predict (models, horizon:int,
>                                     dynamic_dfs:Optional[List[pandas.core.fram
>                                     e.DataFrame]]=None,
>                                     predict_fn:Optional[Callable]=None,
>                                     **predict_fn_kwargs)

Broadcasts `models` across all workers and computes the next `horizon` timesteps.

`predict_fn(model, new_x, features_order, **predict_fn_kwargs)` is called on each timestep.

In [None]:
class DummyModel:
    def predict(self, X):
        return X['lag-7'].values
    
horizon = 7
models = [DummyModel(), DummyModel()]
preds = dts.predict(models, horizon).compute()

ts = TimeSeries(**config)
ts.fit_transform(series)
local_preds = ts.predict(models, horizon)

pd.testing.assert_frame_equal(preds, local_preds)

In [None]:
client.close()