## Libraries

In [1]:
%pip install -qqq mlforecast

Note: you may need to restart the kernel to use updated packages.


In [2]:
from pathlib import Path

import lightgbm as lgb
import mlforecast
import numpy as np
import polars as pl
from mlforecast import MLForecast
from mlforecast.lag_transforms import ExpandingMean, RollingMean, SeasonalRollingMean

In [3]:
mlforecast.__version__

'1.0.0'

In [4]:
pl.enable_string_cache()

## Data setup

In [5]:
input_path = Path('m5-forecasting-accuracy/')

### Calendar

In [6]:
cal_dtypes = {
    'date': pl.Datetime,
    'd': pl.Categorical,
    'wm_yr_wk': pl.Int32,
    'event_name_1': pl.Categorical,
    'event_type_1': pl.Categorical,
    'event_name_2': pl.Categorical,
    'event_type_2': pl.Categorical,
    'snap_CA': pl.Int32,
    'snap_TX': pl.Int32,
    'snap_WI': pl.Int32,
}
cal = pl.read_csv(input_path / 'calendar.csv', schema_overrides=cal_dtypes, columns=list(cal_dtypes.keys()))
event_cols = [k for k in cal_dtypes if k.startswith('event')]
cal = cal.with_columns(pl.col(event_cols).fill_null('nan'))

### Prices

In [7]:
prices_dtypes = {
    'store_id': pl.Categorical,
    'item_id': pl.Categorical,
    'wm_yr_wk': pl.Int32,
    'sell_price': pl.Float32,
}
prices = pl.read_csv(input_path / 'sell_prices.csv', schema_overrides=prices_dtypes)

### Sales

In [8]:
sales_dtypes = {
    'id': pl.Categorical,
    'item_id': pl.Categorical,
    'dept_id': pl.Categorical,
    'cat_id': pl.Categorical,
    'store_id': pl.Categorical,
    'state_id': pl.Categorical,
    **{f'd_{i}': pl.Float32 for i in range(1942)}
}
sales = pl.read_csv(input_path / 'sales_train_evaluation.csv', schema_overrides=sales_dtypes)

In [28]:
sales_val = pl.read_csv(input_path / 'sales_train_validation.csv', schema_overrides=sales_dtypes)

In [9]:
long = sales.unpivot(
    index=['id', 'item_id', 'dept_id', 'cat_id', 'store_id', 'state_id'],
    variable_name='d',
    value_name='y',
)

In [10]:
%%time
print(long.shape[0])
long = long.with_columns(pl.col('d').cast(pl.Categorical))
long = long.join(cal, on=['d'])
dates = sorted(long['date'].unique())
long = long.sort(['id', 'date'])
without_leading_zeros = pl.col('y').gt(0).cast(pl.Int64).cum_max().over('id').cast(pl.Boolean)
above_min_date = pl.col('date') >= dates[-400]
keep_mask = without_leading_zeros & above_min_date
long = long.filter(keep_mask)
print(long.shape[0])

59181090
12159132
CPU times: total: 23.5 s
Wall time: 3.18 s


In [40]:
# Get max date from long (original train + prices merged)
last_date_all = long['date'].max()
valid_horizon = 28
valid_start_date = last_date_all - pl.duration(days=valid_horizon)

# Split long into train and validation (just like time-based CV)
train_long = long.filter(pl.col("date") < valid_start_date)
valid_long = long.filter(pl.col("date") >= valid_start_date)

# Build calendar + prices (X_df) for the validation period only
future_cal = cal.filter((pl.col('date') >= valid_start_date) & (pl.col('date') <= last_date_all))
future_prices = prices.filter(pl.col('wm_yr_wk').is_in(future_cal['wm_yr_wk'].unique()))

# Create IDs to match training IDs
future_prices = future_prices.with_columns(
    id = (pl.col('item_id') + '_' + pl.col('store_id') + '_evaluation')
)

X_df = future_prices.join(future_cal, on='wm_yr_wk')
X_df = X_df.drop(['store_id', 'item_id', 'wm_yr_wk', 'd'])
X_df = X_df.with_columns(pl.col('id').cast(pl.Categorical))


Please use `implode` to return to previous behavior.

See https://github.com/pola-rs/polars/issues/22149 for more information.
  future_prices = prices.filter(pl.col('wm_yr_wk').is_in(future_cal['wm_yr_wk'].unique()))


## Training

Since at the time of making this LightGBM can't handle polars dataframes with categorical features we'll build the features as numpy arrays as described [here](https://nixtla.github.io/mlforecast/docs/how-to-guides/training_with_numpy.html#preprocess-method).

In [12]:
fcst = MLForecast(
    models=[],
    freq='1d',
    lags=[7 * (i+1) for i in range(8)],
    lag_transforms = {
        1 :  [ExpandingMean()],
        7 :  [RollingMean(7), RollingMean(14), RollingMean(28), SeasonalRollingMean(7, 4)],
        14:  [RollingMean(7), RollingMean(14), RollingMean(28), SeasonalRollingMean(7, 4)],
        28:  [RollingMean(7), RollingMean(14), RollingMean(28), SeasonalRollingMean(7, 4)],
    },
    date_features=['year', 'month', 'day', 'weekday', 'quarter', 'week'],    
    num_threads=4,
)

In [20]:
long["event_name_1"].value_counts()

event_name_1,count
cat,u32
"""Purim End""",30490
"""EidAlAdha""",30404
"""MartinLutherKingDay""",30477
"""StPatricksDay""",30490
"""NBAFinalsStart""",30303
…,…
"""LaborDay""",30392
"""Mother's day""",60574
"""OrthodoxChristmas""",30474
"""nan""",11186101


In [41]:
%%time
categoricals = ['id', 'item_id', 'dept_id', 'cat_id', 'store_id', 'state_id']
X, y = fcst.preprocess(
    train_long,
    id_col='id',
    time_col='date',
    target_col='y',
    static_features=categoricals,
    return_X_y=True,    
    as_numpy=True,
)

CPU times: total: 18.6 s
Wall time: 1.85 s


We'll manually train the model here, which allows us to specify which features should be treated as categorical.

In [42]:
model_params = {
    'verbose': -1,
    'force_col_wise': True,
    'num_leaves': 256,
    'n_estimators': 50,
}
model = lgb.LGBMRegressor(**model_params)
%time model.fit(X, y, feature_name=fcst.ts.features_order_, categorical_feature=categoricals)

CPU times: total: 3min 55s
Wall time: 17.2 s


## Forecasting

We now override the `models_` attribute to generate predictions, as described [here](https://nixtla.github.io/mlforecast/docs/how-to-guides/custom_training.html#custom-training).

In [43]:
fcst.models_ = {'LGBMRegressor': model}
%time preds = fcst.predict(valid_horizon, X_df=X_df)

CPU times: total: 18.9 s
Wall time: 1.49 s


In [56]:
valid_long

id,item_id,dept_id,cat_id,store_id,state_id,y,date,event_name_1,event_type_1,event_name_2,event_type_2,snap_CA,snap_TX,snap_WI,sell_price
cat,cat,cat,cat,cat,cat,f32,datetime[μs],cat,cat,cat,cat,i32,i32,i32,f32
"""HOBBIES_1_001_CA_1_evaluation""","""HOBBIES_1_001""","""HOBBIES_1""","""HOBBIES""","""CA_1""","""CA""",1.0,2016-04-24 00:00:00,"""nan""","""nan""","""nan""","""nan""",0,0,0,8.38
"""HOBBIES_1_001_CA_1_evaluation""","""HOBBIES_1_001""","""HOBBIES_1""","""HOBBIES""","""CA_1""","""CA""",0.0,2016-04-25 00:00:00,"""nan""","""nan""","""nan""","""nan""",0,0,0,8.38
"""HOBBIES_1_001_CA_1_evaluation""","""HOBBIES_1_001""","""HOBBIES_1""","""HOBBIES""","""CA_1""","""CA""",0.0,2016-04-26 00:00:00,"""nan""","""nan""","""nan""","""nan""",0,0,0,8.38
"""HOBBIES_1_001_CA_1_evaluation""","""HOBBIES_1_001""","""HOBBIES_1""","""HOBBIES""","""CA_1""","""CA""",0.0,2016-04-27 00:00:00,"""nan""","""nan""","""nan""","""nan""",0,0,0,8.38
"""HOBBIES_1_001_CA_1_evaluation""","""HOBBIES_1_001""","""HOBBIES_1""","""HOBBIES""","""CA_1""","""CA""",2.0,2016-04-28 00:00:00,"""nan""","""nan""","""nan""","""nan""",0,0,0,8.38
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""FOODS_3_583_WI_3_evaluation""","""FOODS_3_583""","""FOODS_3""","""FOODS""","""WI_3""","""WI""",0.0,2016-05-18 00:00:00,"""nan""","""nan""","""nan""","""nan""",0,0,0,1.45
"""FOODS_3_583_WI_3_evaluation""","""FOODS_3_583""","""FOODS_3""","""FOODS""","""WI_3""","""WI""",1.0,2016-05-19 00:00:00,"""nan""","""nan""","""nan""","""nan""",0,0,0,1.45
"""FOODS_3_583_WI_3_evaluation""","""FOODS_3_583""","""FOODS_3""","""FOODS""","""WI_3""","""WI""",2.0,2016-05-20 00:00:00,"""nan""","""nan""","""nan""","""nan""",0,0,0,1.45
"""FOODS_3_583_WI_3_evaluation""","""FOODS_3_583""","""FOODS_3""","""FOODS""","""WI_3""","""WI""",1.0,2016-05-21 00:00:00,"""nan""","""nan""","""nan""","""nan""",0,0,0,1.45


In [57]:
preds.join(valid_long["id","date","y"],how="inner",on=["id","date"])

id,date,LGBMRegressor,y
cat,datetime[μs],f64,f32
"""HOBBIES_1_001_CA_1_evaluation""",2016-04-24 00:00:00,1.004026,1.0
"""HOBBIES_1_001_CA_1_evaluation""",2016-04-25 00:00:00,0.725626,0.0
"""HOBBIES_1_001_CA_1_evaluation""",2016-04-26 00:00:00,0.733981,0.0
"""HOBBIES_1_001_CA_1_evaluation""",2016-04-27 00:00:00,0.694137,0.0
"""HOBBIES_1_001_CA_1_evaluation""",2016-04-28 00:00:00,0.857808,2.0
…,…,…,…
"""FOODS_3_583_WI_3_evaluation""",2016-05-17 00:00:00,0.841036,0.0
"""FOODS_3_583_WI_3_evaluation""",2016-05-18 00:00:00,0.908861,0.0
"""FOODS_3_583_WI_3_evaluation""",2016-05-19 00:00:00,0.835177,1.0
"""FOODS_3_583_WI_3_evaluation""",2016-05-20 00:00:00,0.890161,2.0


In [None]:
preds.join(eval_long)

id,date,LGBMRegressor
cat,datetime[μs],f64
"""HOBBIES_1_001_CA_1_evaluation""",2016-04-24 00:00:00,1.004026
"""HOBBIES_1_001_CA_1_evaluation""",2016-04-25 00:00:00,0.725626
"""HOBBIES_1_001_CA_1_evaluation""",2016-04-26 00:00:00,0.733981
"""HOBBIES_1_001_CA_1_evaluation""",2016-04-27 00:00:00,0.694137
"""HOBBIES_1_001_CA_1_evaluation""",2016-04-28 00:00:00,0.857808
…,…,…
"""FOODS_3_583_WI_3_evaluation""",2016-05-17 00:00:00,0.841036
"""FOODS_3_583_WI_3_evaluation""",2016-05-18 00:00:00,0.908861
"""FOODS_3_583_WI_3_evaluation""",2016-05-19 00:00:00,0.835177
"""FOODS_3_583_WI_3_evaluation""",2016-05-20 00:00:00,0.890161


In [66]:
from utilsforecast.evaluation import evaluate
from utilsforecast.losses import mse, mae, rmse, mape

# Join predictions and ground truth
df_eval = preds.join(
    valid_long.select(["id", "date", "y"]),
    how="inner",
    on=["id", "date"]
).rename({"id": "unique_id","date":"ds"})  # required by utilsforecast

# Evaluate using Nixtla's metrics
metrics = evaluate(
    df_eval,
    metrics=[rmse, mae, mse, mape]
)

print(metrics)


shape: (121_960, 3)
┌───────────────────────────────┬────────┬───────────────┐
│ unique_id                     ┆ metric ┆ LGBMRegressor │
│ ---                           ┆ ---    ┆ ---           │
│ cat                           ┆ str    ┆ f64           │
╞═══════════════════════════════╪════════╪═══════════════╡
│ HOBBIES_1_001_CA_1_evaluation ┆ rmse   ┆ 1.365652      │
│ HOBBIES_1_002_CA_1_evaluation ┆ rmse   ┆ 0.506787      │
│ HOBBIES_1_003_CA_1_evaluation ┆ rmse   ┆ 0.933877      │
│ HOBBIES_1_004_CA_1_evaluation ┆ rmse   ┆ 1.562539      │
│ HOBBIES_1_005_CA_1_evaluation ┆ rmse   ┆ 1.217481      │
│ …                             ┆ …      ┆ …             │
│ FOODS_3_579_WI_3_evaluation   ┆ mape   ┆ 0.681225      │
│ FOODS_3_580_WI_3_evaluation   ┆ mape   ┆ 0.794133      │
│ FOODS_3_581_WI_3_evaluation   ┆ mape   ┆ 0.672644      │
│ FOODS_3_582_WI_3_evaluation   ┆ mape   ┆ 0.608619      │
│ FOODS_3_583_WI_3_evaluation   ┆ mape   ┆ 0.232426      │
└───────────────────────────────┴───

In [67]:
metrics.describe()

statistic,unique_id,metric,LGBMRegressor
str,str,str,f64
"""count""","""121960""","""121960""",121113.0
"""null_count""","""0""","""0""",847.0
"""mean""",,,1.868217
"""std""",,,13.424169
"""min""",,"""mae""",0.016103
"""25%""",,,0.448177
"""50%""",,,0.715504
"""75%""",,,1.209434
"""max""",,"""rmse""",1565.380037


In [62]:
df_eval

unique_id,date,LGBMRegressor,y
cat,datetime[μs],f64,f32
"""HOBBIES_1_001_CA_1_evaluation""",2016-04-24 00:00:00,1.004026,1.0
"""HOBBIES_1_001_CA_1_evaluation""",2016-04-25 00:00:00,0.725626,0.0
"""HOBBIES_1_001_CA_1_evaluation""",2016-04-26 00:00:00,0.733981,0.0
"""HOBBIES_1_001_CA_1_evaluation""",2016-04-27 00:00:00,0.694137,0.0
"""HOBBIES_1_001_CA_1_evaluation""",2016-04-28 00:00:00,0.857808,2.0
…,…,…,…
"""FOODS_3_583_WI_3_evaluation""",2016-05-17 00:00:00,0.841036,0.0
"""FOODS_3_583_WI_3_evaluation""",2016-05-18 00:00:00,0.908861,0.0
"""FOODS_3_583_WI_3_evaluation""",2016-05-19 00:00:00,0.835177,1.0
"""FOODS_3_583_WI_3_evaluation""",2016-05-20 00:00:00,0.890161,2.0


In [51]:
rmse

2.1082259474118947

## Submission

In [16]:
wide = preds.pivot(values='LGBMRegressor', index='id', on='date')
wide.columns = ['id'] + [f'F{i+1}' for i in range(28)]
wide = wide.with_columns(pl.col('id').cast(pl.Utf8))
wide

id,F1,F2,F3,F4,F5,F6,F7,F8,F9,F10,F11,F12,F13,F14,F15,F16,F17,F18,F19,F20,F21,F22,F23,F24,F25,F26,F27,F28
str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""HOBBIES_1_001_CA_1_evaluation""",0.833897,0.840057,0.797003,0.981697,1.001909,1.354826,1.353887,0.81619,0.921808,0.951556,0.967182,1.01913,1.361474,1.169812,0.863993,0.881578,0.860041,0.922818,0.996496,1.349537,1.227884,0.80713,0.778652,0.80238,0.89425,0.927711,1.159154,1.094949
"""HOBBIES_1_002_CA_1_evaluation""",0.301608,0.337076,0.30109,0.29595,0.351041,0.404711,0.380121,0.224811,0.2927,0.307221,0.290396,0.361029,0.403829,0.398394,0.328379,0.325858,0.296514,0.300414,0.358873,0.405415,0.415314,0.328379,0.321958,0.3319,0.328,0.377666,0.433479,0.392115
"""HOBBIES_1_003_CA_1_evaluation""",0.525617,0.507423,0.523173,0.575457,0.773517,0.806967,0.852567,0.534406,0.596685,0.653209,0.601299,0.727181,0.842056,0.906656,0.570054,0.582903,0.600682,0.627493,0.716215,0.792254,0.843897,0.553248,0.562246,0.550727,0.592842,0.717597,0.769499,0.806967
"""HOBBIES_1_004_CA_1_evaluation""",1.669314,1.408619,1.46908,1.55035,1.562228,2.026032,2.472441,1.166942,1.599797,1.691126,1.662757,1.92003,2.279371,2.510155,1.762801,1.578394,1.597008,1.621594,1.846485,2.145817,2.458724,1.855348,1.596231,1.636001,1.588729,1.724641,2.110413,2.31892
"""HOBBIES_1_005_CA_1_evaluation""",1.243967,1.175488,1.119334,1.101344,1.186044,1.419581,1.408517,0.953076,1.022751,1.09324,1.108721,1.256097,1.40743,1.544462,1.133089,1.114874,1.114592,1.161211,1.190928,1.363148,1.458089,1.107447,1.066974,1.027533,1.004561,1.1528,1.285675,1.333709
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""FOODS_3_579_WI_3_evaluation""",0.240898,0.256534,0.23278,0.215711,0.227391,0.241789,0.241789,0.159824,0.215166,0.257338,0.242633,0.303502,0.266146,0.34995,0.328606,0.250371,0.230068,0.308539,0.289797,0.349088,0.349088,0.298956,0.271631,0.292644,0.242351,0.246009,0.264338,0.258524
"""FOODS_3_580_WI_3_evaluation""",6.80407,5.391587,4.988874,5.805819,5.531085,6.623079,9.018246,6.10223,4.628618,9.278724,11.11929,9.949909,8.079201,13.477679,10.683496,8.373447,9.313737,11.042063,6.592893,12.691774,10.939248,7.593142,9.59427,10.841804,5.633552,6.039898,7.073972,7.313361
"""FOODS_3_581_WI_3_evaluation""",0.220488,0.20723,0.20723,0.21237,0.230425,0.238449,0.205814,0.135448,0.185589,0.265597,0.234152,0.300215,0.257665,0.341469,0.308196,0.248857,0.237482,0.283318,0.268566,0.340607,0.324712,0.290475,0.254222,0.300058,0.217975,0.253423,0.255857,0.234148
"""FOODS_3_582_WI_3_evaluation""",2.700242,2.632899,2.294907,2.426135,2.789663,3.302588,3.811948,2.499263,2.589231,2.481936,2.678291,3.274891,3.226891,4.092828,2.964795,2.805928,2.488012,2.94113,2.928715,3.645687,4.292225,2.639021,2.853514,2.595108,2.54386,2.783921,3.286068,3.22004


In [17]:
subm = pl.concat([wide, wide.with_columns(pl.col('id').str.replace('evaluation', 'validation'))])
subm.write_csv('submission.csv')