In [None]:
#default_exp distributed.models

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#hide
import warnings

warnings.simplefilter('ignore', UserWarning)

# Distributed models

> Wrappers around the dask-based distributed training implementations of LightGBM and XGBoost

In [None]:
#export
from typing import Callable, Dict, List, Optional, Tuple

import dask.dataframe as dd
import lightgbm as lgb
import xgboost as xgb
from dask.distributed import Client, default_client, Future

from mlforecast.core import predictions_flow

In [None]:
from window_ops.expanding import expanding_mean
from window_ops.rolling import rolling_mean, rolling_std

from mlforecast.distributed.core import distributed_preprocess
from mlforecast.utils import generate_daily_series

In [None]:
#exporti
class BaseDistributedModel:
    
    def __init__(self, model, client: Optional[Client] = None):
        self.model = model
        self.client = client or default_client()
        
    def fit(self, X: dd.DataFrame, y: dd.Series, **kwargs):
        self.model.fit(X, y, **kwargs)
        return self
    
    @property
    def model_(self):
        raise NotImplementedError
        
    def predict(self,
                series: List[Future],
                horizon: int,
                divisions: Optional[Tuple] = None,
                predict_fn: Optional[Callable] = predictions_flow) -> dd.DataFrame:
        model_future = self.client.scatter(self.model_, broadcast=True)
        predictions_futures = self.client.map(predict_fn,
                                              series,
                                              model=model_future,
                                              horizon=horizon)
        meta = self.client.submit(lambda x: x.head(), predictions_futures[0]).result()
        return dd.from_delayed(predictions_futures, meta=meta, divisions=divisions)
    
    def __repr__(self) -> str:
        return self.model.__repr__()

In [None]:
#export
class LGBMForecast(BaseDistributedModel):
    
    def __init__(self, params: Dict = {}, client: Optional[Client] = None):
        super().__init__(lgb.DaskLGBMRegressor(**params), client)
        
    @property
    def model_(self):
        return self.model.booster_

In [None]:
#export
class XGBForecast(BaseDistributedModel):
    
    def __init__(self, params: Dict = {}, client: Optional[Client] = None):
        super().__init__(xgb.dask.DaskXGBRegressor(**params), client)
        
    @property
    def model_(self):
        return self.model.get_booster()

In [None]:
client = Client(n_workers=2)

In [None]:
series = generate_daily_series(100, n_static_features=2)
partitioned_series = dd.from_pandas(series, npartitions=2)

In [None]:
flow_config = dict(
    freq='D',
    lags=[7, 14],
    lag_transforms={
        1: [
            expanding_mean
        ],
        7: [
            (rolling_mean, 7), 
            (rolling_std, 7),
        ]
    },
    date_features=['dayofweek', 'month', 'year'],
    num_threads=2,
)

In [None]:
ts_futures, train_ddf = distributed_preprocess(partitioned_series, flow_config)
X, y = train_ddf.drop(columns=['ds', 'y']), train_ddf.y

In [None]:
model = LGBMForecast().fit(X, y)
assert model.predict(ts_futures, 7).compute().equals(model.predict(ts_futures, 7).compute())

In [None]:
for col in series.select_dtypes(include='category'):
    series[col] = series[col].cat.codes
partitioned_series = dd.from_pandas(series, npartitions=2)
ts_futures, train_ddf = distributed_preprocess(partitioned_series, flow_config)
X, y = train_ddf.drop(columns=['ds', 'y']), train_ddf.y

In [None]:
model = XGBForecast().fit(X, y)
assert model.predict(ts_futures, 7).compute().equals(model.predict(ts_futures, 7).compute())

In [None]:
client.close()