## Libraries

In [1]:
%pip install -qqq mlforecast

Note: you may need to restart the kernel to use updated packages.


In [2]:
from pathlib import Path

import lightgbm as lgb
import mlforecast
import numpy as np
import polars as pl
from mlforecast import MLForecast
from mlforecast.lag_transforms import ExpandingMean, RollingMean, SeasonalRollingMean

In [3]:
mlforecast.__version__

'1.0.1'

In [4]:
pl.enable_string_cache()

## Data setup

In [5]:
input_path = Path('m5-forecasting-accuracy/')

### Calendar

In [6]:
cal_dtypes = {
    'date': pl.Datetime,
    'd': pl.Categorical,
    'wm_yr_wk': pl.Int32,
    'event_name_1': pl.Categorical,
    'event_type_1': pl.Categorical,
    'event_name_2': pl.Categorical,
    'event_type_2': pl.Categorical,
    'snap_CA': pl.Int32,
    'snap_TX': pl.Int32,
    'snap_WI': pl.Int32,
}
cal = pl.read_csv(input_path / 'calendar.csv', schema_overrides=cal_dtypes, columns=list(cal_dtypes.keys()))
event_cols = [k for k in cal_dtypes if k.startswith('event')]
cal = cal.with_columns(pl.col(event_cols).fill_null('nan'))

### Prices

In [7]:
prices_dtypes = {
    'store_id': pl.Categorical,
    'item_id': pl.Categorical,
    'wm_yr_wk': pl.Int32,
    'sell_price': pl.Float32,
}
prices = pl.read_csv(input_path / 'sell_prices.csv', schema_overrides=prices_dtypes)

### Sales

In [8]:
sales_dtypes = {
    'id': pl.Categorical,
    'item_id': pl.Categorical,
    'dept_id': pl.Categorical,
    'cat_id': pl.Categorical,
    'store_id': pl.Categorical,
    'state_id': pl.Categorical,
    **{f'd_{i}': pl.Float32 for i in range(1942)}
}
sales = pl.read_csv(input_path / 'sales_train_evaluation.csv', schema_overrides=sales_dtypes)

In [9]:
import polars as pl

id_vars = ['id', 'item_id', 'dept_id', 'cat_id', 'store_id', 'state_id']
value_vars = [col for col in sales.columns if col.startswith("d_")]

long = (
    sales.lazy()
    .melt(id_vars=id_vars, value_vars=value_vars, variable_name="d", value_name="y")
    .collect()
)


  .melt(id_vars=id_vars, value_vars=value_vars, variable_name="d", value_name="y")


In [10]:
%%time
print(long.shape[0])
long = long.with_columns(pl.col('d').cast(pl.Categorical))
long = long.join(cal, on=['d'])
dates = sorted(long['date'].unique())
long = long.sort(['id', 'date'])
without_leading_zeros = pl.col('y').gt(0).cast(pl.Int64).cum_max().over('id').cast(pl.Boolean)
above_min_date = pl.col('date') >= dates[-400]
keep_mask = without_leading_zeros & above_min_date
long = long.filter(keep_mask)
print(long.shape[0])

59181090
12159132
CPU times: total: 22.8 s
Wall time: 2.76 s


In [11]:
long = long.join(prices, on=['store_id', 'item_id', 'wm_yr_wk'])


In [59]:
import polars as pl

categoricals2 = [
    'id', 'item_id', 'dept_id', 'cat_id', 'store_id', 'state_id',
    'event_name_1', 'event_type_1', 'event_name_2', 'event_type_2'
]

mapping_dict = {}

# Step 1: Cast columns to categorical
df = long.with_columns([
    pl.col(col).cast(pl.Categorical).alias(col) for col in categoricals2
])

# Step 2: Build mapping and convert to physical ints
for col in categoricals2:
    # Extract unique categories
    categories = df.select(pl.col(col)).unique().sort(by=col).to_series().to_list()
    mapping_dict[col] = {cat: i for i, cat in enumerate(categories)}

    # Convert to physical integer codes
    df = df.with_columns(pl.col(col).to_physical().alias(col))


In [207]:
import datetime

In [209]:
# Get max date from long (original train + prices merged)
last_date_all = df['date'].max()
valid_horizon = 28

valid_start_date = last_date_all - datetime.timedelta(days=valid_horizon)
train_end = valid_start_date - datetime.timedelta(days=valid_horizon)

# Split long into train and validation (just like time-based CV)
train_long = df.filter(pl.col("date") < valid_start_date)
valid_long = df.filter(pl.col("date") >= valid_start_date)

# Build calendar + prices (X_df) for the validation period only
future_cal = cal.filter((pl.col('date') >= valid_start_date) & (pl.col('date') <= last_date_all))
future_prices = prices.filter(pl.col('wm_yr_wk').is_in(future_cal['wm_yr_wk'].unique()))

# Create IDs to match training IDs
future_prices = future_prices.with_columns(
    id = (pl.col('item_id') + '_' + pl.col('store_id') + '_evaluation')
)

X_df = future_prices.join(future_cal, on='wm_yr_wk')
X_df = X_df.drop(['store_id', 'item_id', 'wm_yr_wk', 'd'])
X_df = X_df.with_columns(pl.col('id').cast(pl.Categorical))


Please use `implode` to return to previous behavior.

See https://github.com/pola-rs/polars/issues/22149 for more information.
  future_prices = prices.filter(pl.col('wm_yr_wk').is_in(future_cal['wm_yr_wk'].unique()))


In [61]:
for h in range(0, 28 + 1):
    train_long = train_long.with_columns(
        pl.col('y').shift(-h).over('id').alias(f'y_t+{h}')
    )

In [62]:
df_clean = train_long.drop_nulls()

In [63]:
df_clean = df_clean.with_columns(
    pl.arange(0, df_clean.height).alias("row_number")
)

In [64]:
y_cols = [f'y_t+{h}' for h in range(0, 28 + 1)]
X_train = df_clean.drop(y_cols + ['date', 'id'])
y_train = df_clean

## Training

Since at the time of making this LightGBM can't handle polars dataframes with categorical features we'll build the features as numpy arrays as described [here](https://nixtla.github.io/mlforecast/docs/how-to-guides/training_with_numpy.html#preprocess-method).

In [91]:
fcst = MLForecast(
    models=[],
    freq='1d',
    lags=[7 * (i+1) for i in range(8)],
    lag_transforms = {
        1 :  [ExpandingMean()],
        7 :  [RollingMean(7), RollingMean(14), RollingMean(28), SeasonalRollingMean(7, 4)],
        14:  [RollingMean(7), RollingMean(14), RollingMean(28), SeasonalRollingMean(7, 4)],
        28:  [RollingMean(7), RollingMean(14), RollingMean(28), SeasonalRollingMean(7, 4)],
    },
    date_features=['year', 'month', 'day', 'weekday', 'quarter', 'week'],    
    num_threads=4,
)

In [123]:
%%time
categoricals = ['id', 'item_id', 'dept_id', 'cat_id', 'store_id', 'state_id']
X_train, y = fcst.preprocess(
    train_long.drop(["d", "wm_yr_wk"]),
    id_col='id',
    time_col='date',
    target_col='y',
    static_features=categoricals,
    return_X_y=True,    
    as_numpy=False,
)

CPU times: total: 3.97 s
Wall time: 867 ms


In [124]:
%%time
categoricals = ['id', 'item_id', 'dept_id', 'cat_id', 'store_id', 'state_id']
_, y = fcst.preprocess(
    train_long.drop(["d", "wm_yr_wk"]+y_cols).to_pandas(),
    id_col='id',
    time_col='date',
    target_col='y',
    static_features=categoricals,
    return_X_y=True,    
    as_numpy=False,
)

CPU times: total: 3 s
Wall time: 2.41 s


In [125]:
X_train

id,item_id,dept_id,cat_id,store_id,state_id,event_name_1,event_type_1,event_name_2,event_type_2,snap_CA,snap_TX,snap_WI,sell_price,y_t+0,y_t+1,y_t+2,y_t+3,y_t+4,y_t+5,y_t+6,y_t+7,y_t+8,y_t+9,y_t+10,y_t+11,y_t+12,y_t+13,y_t+14,y_t+15,y_t+16,y_t+17,y_t+18,y_t+19,y_t+20,y_t+21,y_t+22,y_t+23,y_t+24,y_t+25,y_t+26,y_t+27,y_t+28,lag7,lag14,lag21,lag28,lag35,lag42,lag49,lag56,expanding_mean_lag1,rolling_mean_lag7_window_size7,rolling_mean_lag7_window_size14,rolling_mean_lag7_window_size28,seasonal_rolling_mean_lag7_season_length7_window_size4,rolling_mean_lag14_window_size7,rolling_mean_lag14_window_size14,rolling_mean_lag14_window_size28,seasonal_rolling_mean_lag14_season_length7_window_size4,rolling_mean_lag28_window_size7,rolling_mean_lag28_window_size14,rolling_mean_lag28_window_size28,seasonal_rolling_mean_lag28_season_length7_window_size4,year,month,day,weekday,quarter,week
u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,i32,i32,i32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i32,i8,i8,i8,i8,i8
5063,2100,5158,5159,2004,5160,2003,2003,2003,2003,0,0,1,13.98,0.0,2.0,1.0,1.0,1.0,0.0,0.0,1.0,0.0,2.0,0.0,2.0,0.0,0.0,2.0,2.0,2.0,0.0,1.0,2.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0,2.0,5.0,0.0,0.0,3.0,1.0,4.0,0.0,0.0,2.0,0.732143,0.714286,0.714286,0.857143,1.0,0.714286,1.0,0.928571,2.0,0.714286,0.857143,0.678571,1.25,2015,6,14,7,2,24
5063,2100,5158,5159,2004,5160,2003,2003,2003,2003,0,1,1,13.98,2.0,1.0,1.0,1.0,0.0,0.0,1.0,0.0,2.0,0.0,2.0,0.0,0.0,2.0,2.0,2.0,0.0,1.0,2.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0,2.0,5.0,1.0,0.0,1.0,2.0,1.0,0.0,1.0,1.0,0.0,0.719298,0.571429,0.571429,0.857143,1.0,0.571429,1.0,0.928571,1.0,0.857143,0.857143,0.714286,0.75,2015,6,15,1,2,25
5063,2100,5158,5159,2004,5160,1538,1540,2003,2003,0,0,0,13.98,1.0,1.0,1.0,0.0,0.0,1.0,0.0,2.0,0.0,2.0,0.0,0.0,2.0,2.0,2.0,0.0,1.0,2.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0,2.0,5.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,2.0,0.0,0.741379,0.428571,0.571429,0.821429,0.25,0.714286,1.071429,0.964286,0.5,0.714286,0.857143,0.714286,0.75,2015,6,16,2,2,25
5063,2100,5158,5159,2004,5160,2003,2003,2003,2003,0,0,0,13.98,1.0,1.0,0.0,0.0,1.0,0.0,2.0,0.0,2.0,0.0,0.0,2.0,2.0,2.0,0.0,1.0,2.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0,2.0,5.0,1.0,0.0,2.0,1.0,1.0,0.0,0.0,1.0,1.0,0.0,0.0,0.745763,0.428571,0.642857,0.821429,0.5,0.857143,1.142857,0.964286,0.5,0.571429,0.785714,0.714286,0.5,2015,6,17,3,2,25
5063,2100,5158,5159,2004,5160,1539,1536,2003,2003,0,0,0,13.98,1.0,0.0,0.0,1.0,0.0,2.0,0.0,2.0,0.0,0.0,2.0,2.0,2.0,0.0,1.0,2.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0,2.0,5.0,1.0,0.0,2.0,2.0,0.0,0.0,0.0,0.0,2.0,0.0,1.0,0.0,0.75,0.428571,0.642857,0.75,0.0,0.857143,1.142857,0.964286,0.5,0.285714,0.785714,0.714286,0.75,2015,6,18,4,2,25
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
35565,5000,7262,6691,5062,26311,2003,2003,2003,2003,0,0,0,4.98,0.0,1.0,0.0,0.0,0.0,,,,,,,,,,,,,,,,,,,,,,,,,2.0,3.0,0.0,1.0,3.0,3.0,0.0,0.0,0.437158,0.714286,0.714286,0.678571,1.5,0.714286,0.785714,0.75,1.75,0.428571,0.714286,0.678571,1.75,2016,4,19,2,2,16
35565,5000,7262,6691,5062,26311,2003,2003,2003,2003,0,0,0,4.98,1.0,0.0,0.0,0.0,,,,,,,,,,,,,,,,,,,,,,,,,,0.0,2.0,1.0,0.0,0.0,1.0,0.0,1.0,0.435967,0.428571,0.642857,0.678571,0.75,0.857143,0.928571,0.785714,0.75,0.428571,0.642857,0.642857,0.25,2016,4,20,3,2,16
35565,5000,7262,6691,5062,26311,2003,2003,2003,2003,0,0,0,4.98,0.0,0.0,0.0,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,0.0,0.0,2.0,0.0,0.0,0.0,0.0,0.4375,0.571429,0.714286,0.714286,0.75,0.857143,0.785714,0.785714,0.5,0.714286,0.785714,0.714286,0.5,2016,4,21,4,2,16
35565,5000,7262,6691,5062,26311,2003,2003,2003,2003,0,0,0,4.98,0.0,0.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,2.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.436314,0.857143,0.857143,0.785714,0.75,0.857143,0.714286,0.785714,0.25,0.857143,0.857143,0.75,0.25,2016,4,22,5,2,16


In [None]:
X_train = X_train.drop_nulls()
y_train = X_train.select(y_cols)
X_train = X_train.drop(y_cols)


In [128]:
X_train

id,item_id,dept_id,cat_id,store_id,state_id,event_name_1,event_type_1,event_name_2,event_type_2,snap_CA,snap_TX,snap_WI,sell_price,lag7,lag14,lag21,lag28,lag35,lag42,lag49,lag56,expanding_mean_lag1,rolling_mean_lag7_window_size7,rolling_mean_lag7_window_size14,rolling_mean_lag7_window_size28,seasonal_rolling_mean_lag7_season_length7_window_size4,rolling_mean_lag14_window_size7,rolling_mean_lag14_window_size14,rolling_mean_lag14_window_size28,seasonal_rolling_mean_lag14_season_length7_window_size4,rolling_mean_lag28_window_size7,rolling_mean_lag28_window_size14,rolling_mean_lag28_window_size28,seasonal_rolling_mean_lag28_season_length7_window_size4,year,month,day,weekday,quarter,week
u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,i32,i32,i32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i32,i8,i8,i8,i8,i8
5063,2100,5158,5159,2004,5160,2003,2003,2003,2003,0,0,1,13.98,0.0,0.0,3.0,1.0,4.0,0.0,0.0,2.0,0.732143,0.714286,0.714286,0.857143,1.0,0.714286,1.0,0.928571,2.0,0.714286,0.857143,0.678571,1.25,2015,6,14,7,2,24
5063,2100,5158,5159,2004,5160,2003,2003,2003,2003,0,1,1,13.98,0.0,1.0,2.0,1.0,0.0,1.0,1.0,0.0,0.719298,0.571429,0.571429,0.857143,1.0,0.571429,1.0,0.928571,1.0,0.857143,0.857143,0.714286,0.75,2015,6,15,1,2,25
5063,2100,5158,5159,2004,5160,1538,1540,2003,2003,0,0,0,13.98,0.0,1.0,0.0,0.0,1.0,0.0,2.0,0.0,0.741379,0.428571,0.571429,0.821429,0.25,0.714286,1.071429,0.964286,0.5,0.714286,0.857143,0.714286,0.75,2015,6,16,2,2,25
5063,2100,5158,5159,2004,5160,2003,2003,2003,2003,0,0,0,13.98,1.0,1.0,0.0,0.0,1.0,1.0,0.0,0.0,0.745763,0.428571,0.642857,0.821429,0.5,0.857143,1.142857,0.964286,0.5,0.571429,0.785714,0.714286,0.5,2015,6,17,3,2,25
5063,2100,5158,5159,2004,5160,1539,1536,2003,2003,0,0,0,13.98,0.0,0.0,0.0,0.0,2.0,0.0,1.0,0.0,0.75,0.428571,0.642857,0.75,0.0,0.857143,1.142857,0.964286,0.5,0.285714,0.785714,0.714286,0.75,2015,6,18,4,2,25
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
35565,5000,7262,6691,5062,26311,2003,2003,2003,2003,0,0,0,4.98,3.0,3.0,0.0,0.0,1.0,1.0,0.0,0.0,0.414201,1.0,0.857143,0.642857,1.5,0.714286,0.642857,0.642857,1.0,0.285714,0.642857,0.5,0.5,2016,3,22,2,1,12
35565,5000,7262,6691,5062,26311,2003,2003,2003,2003,0,0,0,4.98,0.0,1.0,0.0,1.0,0.0,2.0,0.0,0.0,0.415929,0.857143,0.857143,0.642857,0.5,0.857143,0.642857,0.607143,0.5,0.428571,0.571429,0.535714,0.75,2016,3,23,3,1,12
35565,5000,7262,6691,5062,26311,1527,1536,2003,2003,0,0,0,4.98,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.414706,0.857143,0.857143,0.642857,0.0,0.857143,0.642857,0.607143,0.0,0.428571,0.571429,0.535714,0.0,2016,3,24,4,1,12
35565,5000,7262,6691,5062,26311,2003,2003,2003,2003,0,0,0,4.98,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,0.419355,0.857143,0.857143,0.642857,0.0,0.857143,0.642857,0.535714,0.0,0.428571,0.428571,0.535714,0.5,2016,3,25,5,1,12


In [129]:
from sklearn.multioutput import MultiOutputRegressor
from sklearn.base import clone
from joblib import Parallel, delayed

class MultiOutputLGBMRegressor(MultiOutputRegressor):
    def __init__(self, estimator, n_jobs=None, feature_name=None, categorical_feature=None):
        super().__init__(estimator, n_jobs=n_jobs)
        self.feature_name = feature_name
        self.categorical_feature = categorical_feature

    def fit(self, X, Y, **fit_params):
        if len(Y.shape) != 2:
            raise ValueError("Y must be 2D for multi-output regression.")
        
        self.estimators_ = Parallel(n_jobs=self.n_jobs)(
            delayed(self._fit_single)(
                i, X, Y[:, i]
            ) for i in range(Y.shape[1])
        )
        return self

    def _fit_single(self, i, X, y):
        estimator = clone(self.estimator)
        estimator.fit(
            X, y,
            feature_name=self.feature_name,
            categorical_feature=self.categorical_feature
        )
        return estimator

    def predict(self, X):
        return super().predict(X)


In [130]:
fcst.ts.features_order_

['id',
 'item_id',
 'dept_id',
 'cat_id',
 'store_id',
 'state_id',
 'event_name_1',
 'event_type_1',
 'event_name_2',
 'event_type_2',
 'snap_CA',
 'snap_TX',
 'snap_WI',
 'sell_price',
 'lag7',
 'lag14',
 'lag21',
 'lag28',
 'lag35',
 'lag42',
 'lag49',
 'lag56',
 'expanding_mean_lag1',
 'rolling_mean_lag7_window_size7',
 'rolling_mean_lag7_window_size14',
 'rolling_mean_lag7_window_size28',
 'seasonal_rolling_mean_lag7_season_length7_window_size4',
 'rolling_mean_lag14_window_size7',
 'rolling_mean_lag14_window_size14',
 'rolling_mean_lag14_window_size28',
 'seasonal_rolling_mean_lag14_season_length7_window_size4',
 'rolling_mean_lag28_window_size7',
 'rolling_mean_lag28_window_size14',
 'rolling_mean_lag28_window_size28',
 'seasonal_rolling_mean_lag28_season_length7_window_size4',
 'year',
 'month',
 'day',
 'weekday',
 'quarter',
 'week']

In [131]:
X_train

id,item_id,dept_id,cat_id,store_id,state_id,event_name_1,event_type_1,event_name_2,event_type_2,snap_CA,snap_TX,snap_WI,sell_price,lag7,lag14,lag21,lag28,lag35,lag42,lag49,lag56,expanding_mean_lag1,rolling_mean_lag7_window_size7,rolling_mean_lag7_window_size14,rolling_mean_lag7_window_size28,seasonal_rolling_mean_lag7_season_length7_window_size4,rolling_mean_lag14_window_size7,rolling_mean_lag14_window_size14,rolling_mean_lag14_window_size28,seasonal_rolling_mean_lag14_season_length7_window_size4,rolling_mean_lag28_window_size7,rolling_mean_lag28_window_size14,rolling_mean_lag28_window_size28,seasonal_rolling_mean_lag28_season_length7_window_size4,year,month,day,weekday,quarter,week
u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,i32,i32,i32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i32,i8,i8,i8,i8,i8
5063,2100,5158,5159,2004,5160,2003,2003,2003,2003,0,0,1,13.98,0.0,0.0,3.0,1.0,4.0,0.0,0.0,2.0,0.732143,0.714286,0.714286,0.857143,1.0,0.714286,1.0,0.928571,2.0,0.714286,0.857143,0.678571,1.25,2015,6,14,7,2,24
5063,2100,5158,5159,2004,5160,2003,2003,2003,2003,0,1,1,13.98,0.0,1.0,2.0,1.0,0.0,1.0,1.0,0.0,0.719298,0.571429,0.571429,0.857143,1.0,0.571429,1.0,0.928571,1.0,0.857143,0.857143,0.714286,0.75,2015,6,15,1,2,25
5063,2100,5158,5159,2004,5160,1538,1540,2003,2003,0,0,0,13.98,0.0,1.0,0.0,0.0,1.0,0.0,2.0,0.0,0.741379,0.428571,0.571429,0.821429,0.25,0.714286,1.071429,0.964286,0.5,0.714286,0.857143,0.714286,0.75,2015,6,16,2,2,25
5063,2100,5158,5159,2004,5160,2003,2003,2003,2003,0,0,0,13.98,1.0,1.0,0.0,0.0,1.0,1.0,0.0,0.0,0.745763,0.428571,0.642857,0.821429,0.5,0.857143,1.142857,0.964286,0.5,0.571429,0.785714,0.714286,0.5,2015,6,17,3,2,25
5063,2100,5158,5159,2004,5160,1539,1536,2003,2003,0,0,0,13.98,0.0,0.0,0.0,0.0,2.0,0.0,1.0,0.0,0.75,0.428571,0.642857,0.75,0.0,0.857143,1.142857,0.964286,0.5,0.285714,0.785714,0.714286,0.75,2015,6,18,4,2,25
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
35565,5000,7262,6691,5062,26311,2003,2003,2003,2003,0,0,0,4.98,3.0,3.0,0.0,0.0,1.0,1.0,0.0,0.0,0.414201,1.0,0.857143,0.642857,1.5,0.714286,0.642857,0.642857,1.0,0.285714,0.642857,0.5,0.5,2016,3,22,2,1,12
35565,5000,7262,6691,5062,26311,2003,2003,2003,2003,0,0,0,4.98,0.0,1.0,0.0,1.0,0.0,2.0,0.0,0.0,0.415929,0.857143,0.857143,0.642857,0.5,0.857143,0.642857,0.607143,0.5,0.428571,0.571429,0.535714,0.75,2016,3,23,3,1,12
35565,5000,7262,6691,5062,26311,1527,1536,2003,2003,0,0,0,4.98,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.414706,0.857143,0.857143,0.642857,0.0,0.857143,0.642857,0.607143,0.0,0.428571,0.571429,0.535714,0.0,2016,3,24,4,1,12
35565,5000,7262,6691,5062,26311,2003,2003,2003,2003,0,0,0,4.98,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,0.419355,0.857143,0.857143,0.642857,0.0,0.857143,0.642857,0.535714,0.0,0.428571,0.428571,0.535714,0.5,2016,3,25,5,1,12


In [134]:
import lightgbm as lgb

model_params = {
    'verbose': -1,
    'force_col_wise': True,
    'num_leaves': 256,
    'n_estimators': 50,
}

base_model = lgb.LGBMRegressor(**model_params)

multi_model = MultiOutputLGBMRegressor(
    estimator=base_model,
    feature_name=fcst.ts.features_order_,
    categorical_feature=categoricals,
    n_jobs=4
)

%time multi_model.fit(X_train.to_pandas(), y_train.to_numpy())


CPU times: total: 16.3 s
Wall time: 5min 4s


In [135]:
fcst.models_ = {'LGBMRegressor': multi_model}

We'll manually train the model here, which allows us to specify which features should be treated as categorical.

## Forecasting

We now override the `models_` attribute to generate predictions, as described [here](https://nixtla.github.io/mlforecast/docs/how-to-guides/custom_training.html#custom-training).

In [137]:
from datetime import datetime, timedelta


In [136]:
X_df

sell_price,id,date,event_name_1,event_type_1,event_name_2,event_type_2,snap_CA,snap_TX,snap_WI
f32,cat,datetime[μs],cat,cat,cat,cat,i32,i32,i32
8.38,"""HOBBIES_1_001_CA_1_evaluation""",2016-04-24 00:00:00,"""nan""","""nan""","""nan""","""nan""",0,0,0
8.38,"""HOBBIES_1_001_CA_1_evaluation""",2016-04-25 00:00:00,"""nan""","""nan""","""nan""","""nan""",0,0,0
8.38,"""HOBBIES_1_001_CA_1_evaluation""",2016-04-26 00:00:00,"""nan""","""nan""","""nan""","""nan""",0,0,0
8.38,"""HOBBIES_1_001_CA_1_evaluation""",2016-04-27 00:00:00,"""nan""","""nan""","""nan""","""nan""",0,0,0
8.38,"""HOBBIES_1_001_CA_1_evaluation""",2016-04-28 00:00:00,"""nan""","""nan""","""nan""","""nan""",0,0,0
…,…,…,…,…,…,…,…,…,…
1.0,"""FOODS_3_827_WI_3_evaluation""",2016-05-18 00:00:00,"""nan""","""nan""","""nan""","""nan""",0,0,0
1.0,"""FOODS_3_827_WI_3_evaluation""",2016-05-19 00:00:00,"""nan""","""nan""","""nan""","""nan""",0,0,0
1.0,"""FOODS_3_827_WI_3_evaluation""",2016-05-20 00:00:00,"""nan""","""nan""","""nan""","""nan""",0,0,0
1.0,"""FOODS_3_827_WI_3_evaluation""",2016-05-21 00:00:00,"""nan""","""nan""","""nan""","""nan""",0,0,0


In [140]:
X_df

id,date,event_name_1,event_type_1,event_name_2,event_type_2,snap_CA,snap_TX,snap_WI,sell_price
u32,datetime[μs],u32,u32,u32,u32,i32,i32,i32,f32
5063,2016-04-24 00:00:00,2003,2003,2003,2003,0,0,0,13.98
5063,2016-04-25 00:00:00,2003,2003,2003,2003,0,0,0,13.98
5063,2016-04-26 00:00:00,2003,2003,2003,2003,0,0,0,13.98
5063,2016-04-27 00:00:00,2003,2003,2003,2003,0,0,0,13.98
5063,2016-04-28 00:00:00,2003,2003,2003,2003,0,0,0,13.98
…,…,…,…,…,…,…,…,…,…
35565,2016-05-18 00:00:00,2003,2003,2003,2003,0,0,0,4.98
35565,2016-05-19 00:00:00,2003,2003,2003,2003,0,0,0,4.98
35565,2016-05-20 00:00:00,2003,2003,2003,2003,0,0,0,4.98
35565,2016-05-21 00:00:00,2003,2003,2003,2003,0,0,0,4.98


In [173]:
current_X

row_idx,id,item_id,dept_id,cat_id,store_id,state_id,y,date,event_name_1,event_type_1,event_name_2,event_type_2,snap_CA,snap_TX,snap_WI,sell_price
u32,u32,u32,u32,u32,u32,u32,f32,datetime[μs],u32,u32,u32,u32,i32,i32,i32,f32
0,5063,2100,5158,5159,2004,5160,2.0,2015-04-19 00:00:00,2003,2003,2003,2003,0,0,0,13.98
1,5063,2100,5158,5159,2004,5160,0.0,2015-04-20 00:00:00,2003,2003,2003,2003,0,0,0,13.98
2,5063,2100,5158,5159,2004,5160,0.0,2015-04-21 00:00:00,2003,2003,2003,2003,0,0,0,13.98
3,5063,2100,5158,5159,2004,5160,0.0,2015-04-22 00:00:00,2003,2003,2003,2003,0,0,0,13.98
4,5063,2100,5158,5159,2004,5160,0.0,2015-04-23 00:00:00,2003,2003,2003,2003,0,0,0,13.98
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
11305407,35565,5000,7262,6691,5062,26311,1.0,2016-04-20 00:00:00,2003,2003,2003,2003,0,0,0,4.98
11305408,35565,5000,7262,6691,5062,26311,0.0,2016-04-21 00:00:00,2003,2003,2003,2003,0,0,0,4.98
11305409,35565,5000,7262,6691,5062,26311,0.0,2016-04-22 00:00:00,2003,2003,2003,2003,0,0,0,4.98
11305410,35565,5000,7262,6691,5062,26311,0.0,2016-04-23 00:00:00,2003,2003,2003,2003,0,0,0,4.98


In [195]:
current_date = forecast_start + pl.duration(days=h)


In [196]:
current_date

In [203]:
last_date_all

datetime.datetime(2016, 5, 22, 0, 0)

In [211]:
from datetime import timedelta
import numpy as np

forecast_start = valid_start_date
horizon = 28

all_preds = []

for h in range(horizon):
    current_date = forecast_start + datetime.timedelta(days=h)

    # Step 1: Historical data
    current_X = df.filter(pl.col("date") <= current_date).drop(["d", "wm_yr_wk"])

    # Step 2: Add row number as index before preprocessing
    current_X = current_X.with_row_count("row_idx")
    date_lookup = current_X.select(["row_idx", "id", "date"])

    # Step 3: Preprocess (drop only wm_yr_wk, keep real date)
    X_h_pl, _ = fcst.preprocess(
        current_X,
        id_col='id',
        time_col='date',
        target_col='y',
        static_features=categoricals,
        return_X_y=True,
        as_numpy=False,
    )

    # Step 4: Merge to restore date after preprocessing
    X_h_pl = X_h_pl.join(date_lookup, on="row_idx", how="left")

    # Step 5: Filter only for current_date
    X_today = X_h_pl.filter(pl.col("date") == current_date)

    ids = X_today["id"].to_numpy()
    X_np = X_today.drop(["id", "row_idx", "date"]).to_numpy().astype(np.float32)

    # Step 6: Predict
    y_pred_h = multi_model.estimators_[h].predict(X_np)

    # Step 7: Store results
    preds_df = pl.DataFrame({
        "id": ids,
        "date": [current_date] * len(ids),
        "y_pred": y_pred_h
    })

    all_preds.append(preds_df)

# Step 8: Concatenate all predictions
final_preds = pl.concat(all_preds)


  current_X = current_X.with_row_count("row_idx")
  current_X = current_X.with_row_count("row_idx")
  current_X = current_X.with_row_count("row_idx")
  current_X = current_X.with_row_count("row_idx")
  current_X = current_X.with_row_count("row_idx")
  current_X = current_X.with_row_count("row_idx")
  current_X = current_X.with_row_count("row_idx")
  current_X = current_X.with_row_count("row_idx")
  current_X = current_X.with_row_count("row_idx")
  current_X = current_X.with_row_count("row_idx")
  current_X = current_X.with_row_count("row_idx")
  current_X = current_X.with_row_count("row_idx")
  current_X = current_X.with_row_count("row_idx")
  current_X = current_X.with_row_count("row_idx")
  current_X = current_X.with_row_count("row_idx")
  current_X = current_X.with_row_count("row_idx")
  current_X = current_X.with_row_count("row_idx")
  current_X = current_X.with_row_count("row_idx")
  current_X = current_X.with_row_count("row_idx")
  current_X = current_X.with_row_count("row_idx")


In [212]:
from sklearn.metrics import mean_squared_error
import numpy as np


import datetime


# Step 1: Prepare actual values
actuals = valid_long.select(["id", "date", "y"])

# Step 2: Join predictions with actuals
merged = final_preds.join(actuals, on=["id", "date"], how="inner")

# Step 3: Compute RMSE
rmse = mean_squared_error(merged["y"], merged["y_pred"], squared=False)
print(f"Validation RMSE: {rmse:.4f}")


Validation RMSE: 2.9999


