# Inference

In this notebook I've experimented an approach to do online retraining of a simple lightgbm model.

- EDA: https://www.kaggle.com/code/simonedegasperis/starter-eda
- train: https://www.kaggle.com/code/simonedegasperis/lgbm-model-training

Kudos to https://www.kaggle.com/code/motono0223/js24-preprocessing-create-lags for historical data with added lags feature that I've used to train the initial version of the model.

For retraining I've sampled from each date of historical data a random 1% of the data and I've added online data. I tried to retrain the model each 100 batches by gradually increasing the cache which was defined as global variable and I've then averaged the solution of the initial model and the new retrained model.

The purpose of choosing a small fraction of data for retraining was to try to stay within 1 minute limit between 2 consecutive batches.
I hope you will find the solution helpfull to build a better model.

In [1]:
# imports
import os
import glob
import numpy as np
import pandas as pd
import polars as pl
import lightgbm as lgb
import xgboost as xgb
import pickle
import kaggle_evaluation.jane_street_inference_server

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from pytorch_lightning import (LightningDataModule, LightningModule, Trainer)
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, Timer
from statsmodels.tsa.arima.model import ARIMA

In [3]:
class CONFIG:
    seed = 42
    target_col = "responder_6"
    all_cols = ["date_id", "symbol_id", "time_id", "weight"] + [f"feature_{idx:02d}" for idx in range(79)]+ [f"responder_{idx}_lag_1" for idx in range(9)] + [target_col]
    test_cols = ["row_id", "date_id", "symbol_id", "time_id"] + [f"feature_{idx:02d}" for idx in range(79)]+ [f"responder_{idx}_lag_1" for idx in range(9)] + [target_col]
    feature_cols = ["symbol_id", "time_id"] + [f"feature_{idx:02d}" for idx in range(79)]+ [f"responder_{idx}_lag_1" for idx in range(9)]
    only_features = ["row_id", "date_id", "symbol_id", "time_id"] + [f"feature_{idx:02d}" for idx in range(79)]
    only_lags = ["row_id", "date_id", "symbol_id", "time_id"] + [f"responder_{idx}_lag_1" for idx in range(9)] 
    data_paths = ["/kaggle/input/lgbm-model-training/lgbm_model_0.json","/kaggle/input/js24-preprocessing-create-lags/training.parquet/"]
    

In [4]:
files = glob.glob(os.path.join(CONFIG.data_paths[1], "*/*parquet"))

In [5]:
files.sort()

In [6]:
pl_train = pl.concat([pl.read_parquet(_f, columns=CONFIG.all_cols).sample(fraction=0.01) for _f in files])

In [7]:
pl_train = pl_train.sort(["date_id", "time_id"])
pl_train = pl_train.with_row_count(name="row_id")
pl_train = pl_train.with_columns(pl.col("row_id").cast(pl.Int64))  # Ensure row_id is uint32


  pl_train = pl_train.with_row_count(name="row_id")


In [8]:
# load model
lgbm_model = lgb.Booster(model_file=CONFIG.data_paths[0])

In [9]:
# Params used to retrain
input_params = {"num_leaves": 31, "feature_fraction": 0.8, "n_estimators": 50, "learning_rate": 0.1}

In [10]:
# Define Parameters
params = {
    'objective': 'regression',
    'metric': 'rmse',                                      # Root Mean Squared Error
    'boosting_type': 'gbdt',                               # Gradient Boosted Decision Trees
    'num_leaves': input_params['num_leaves'],
    'learning_rate': input_params['learning_rate'],
    'feature_fraction': input_params['feature_fraction'],
    'n_estimators': input_params['n_estimators']      
}

# Inference

In [11]:
# we use historical data with new data hold in a cache to retrain the model
pl_train.head()

row_id,date_id,symbol_id,time_id,weight,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,…,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78,responder_0_lag_1,responder_1_lag_1,responder_2_lag_1,responder_3_lag_1,responder_4_lag_1,responder_5_lag_1,responder_6_lag_1,responder_7_lag_1,responder_8_lag_1,responder_6
i64,i16,i8,i16,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i8,i8,i16,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,1101,14,4,1.500934,-0.215145,1.179754,0.368541,0.038441,2.425909,-0.666875,-0.349467,-0.167756,0.757481,44,3,16,-0.875412,0.787212,0.033569,,-0.341584,-0.484018,-2.238651,-1.171494,-0.038426,-0.194049,-0.384562,-0.559253,-0.755148,-0.12496,-0.27222,0.089473,0.053283,-0.810824,-0.437539,-0.240939,…,,,0.540745,,-0.680334,2.469463,,1.278344,0.748299,0.659961,-0.502885,-0.205055,-0.369538,-1.549924,-0.962902,-0.889017,0.174719,-0.21597,-0.604603,1.113999,0.40423,,,-0.155046,-0.141096,-0.197576,-0.202803,,,,,,,,,,0.484299
1,1101,7,8,2.401405,0.093593,0.997003,0.40583,0.129569,2.524339,-0.971238,-0.659492,-0.630072,0.359726,11,7,76,-0.586538,1.09278,0.064694,,1.88143,1.767255,-0.660556,-0.788979,0.415083,0.194781,0.787271,1.810799,-0.971642,-1.003009,-0.459541,-0.053719,-0.055635,1.282601,2.557373,0.18462,…,,,-2.608151,,-1.203008,-0.18091,,-0.75804,-0.281572,0.659961,0.107453,-0.237982,0.162273,-1.2511,-1.219487,-0.234532,1.713841,0.373077,-0.503657,0.248767,-0.404681,,,-0.28696,-0.285293,-0.210075,-0.221053,,,,,,,,,,0.139638
2,1101,8,9,1.935522,-0.380708,0.912805,0.29319,0.444959,2.26886,-1.130615,-0.180222,-0.367418,0.406533,15,1,62,-0.951315,-0.064544,-0.155454,,0.198297,0.09163,-0.593237,-1.17119,-0.215801,-0.106486,0.189968,0.358131,-0.339604,0.419127,0.520018,0.018472,-0.332786,0.005121,0.107359,-0.078926,…,,,-0.617989,,-0.663403,0.731777,,0.163533,-0.233739,0.659961,-0.336044,-0.284863,-0.372323,-1.379376,-1.203877,-0.400485,-0.065636,-0.070379,-0.938926,-0.016331,-0.210375,,,-0.265744,-0.203676,-0.290017,-0.244613,,,,,,,,,,1.242788
3,1101,29,16,1.611318,0.068022,1.369779,-0.121312,0.021666,2.044464,-0.452192,-0.14105,-0.142251,0.41963,57,1,336,-0.392691,1.63144,0.711133,,-0.853772,-0.674342,-0.68171,-0.760426,0.458043,0.079794,-0.067636,-0.696882,0.634566,1.346142,0.440314,-1.532042,-1.723571,-0.607855,-0.980738,0.083502,…,,,0.113628,,-2.539725,1.587884,-0.212709,-0.22311,-0.504644,0.659961,-0.387672,-0.298322,-0.3092,-0.862436,-1.649611,-0.491142,1.954723,0.344242,-0.497607,0.564821,0.517118,-0.208764,-0.218946,-0.066378,-0.093635,-0.128404,-0.157404,,,,,,,,,,1.183075
4,1101,31,22,1.05595,0.283426,1.454438,-0.1975,-0.124206,2.06975,-0.630104,0.068969,-0.033901,0.464361,44,3,24,0.091977,1.596405,1.848182,,-0.325091,-0.588025,-0.145013,-0.714127,-0.41438,-0.039237,-0.778567,-0.642517,0.240914,1.054543,-1.100804,-1.319936,-0.728654,-0.675175,-0.598646,-0.064911,…,-1.526856,,-1.261106,-2.338275,-2.559795,0.090784,-1.327383,-3.561138,-1.713163,0.659961,-0.170258,-0.311777,-0.258067,-0.927648,-1.393083,0.304107,2.576701,2.759911,-0.241112,0.438135,1.25564,0.071643,0.044431,0.706335,0.602444,0.248825,0.216596,,,,,,,,,,3.424516


In [12]:
# Initialize an empty DataFrame as the global cache
cache = pl.DataFrame()
batch_count = 1
new_lgbm_model = None
# hist_data = pl.DataFrame()
# train = pl.DataFrame()

In [13]:
lags_ : pl.DataFrame | None = None
# Replace this function with your inference code.
# You can return either a Pandas or Polars dataframe, though Polars is recommended.
# Each batch of predictions (except the very first) must be returned within 1 minute of the batch features being provided.
def predict(test: pl.DataFrame, lags: pl.DataFrame | None) -> pl.DataFrame | pd.DataFrame:
    """Make a prediction."""
    # All the responders from the previous day are passed in at time_id == 0. We save them in a global variable for access at every time_id.
    # Use them as extra features, if you like.
    global cache          # Declare the global cache
    global batch_count
    global new_lgbm_model
    # global train
    # global hist_data

    # Replace this section with your own predictions
    predictions = test.select(
        'row_id',
        pl.lit(0.0).alias('responder_6'),
    )

    if not lags is None:
        lags = lags.group_by(["date_id", "symbol_id"], maintain_order=True).last() # pick up last record of previous date
        lags = lags.drop(["time_id"])
        test = test.join(lags, on=["date_id", "symbol_id"],  how="left")
    else:
        test = test.with_columns(
            ( pl.lit(0.0).alias(f'responder_{idx}_lag_1') for idx in range(9) )
        )

    if lags is not None:
        print(f"Filling cache for batch count {batch_count}")
        print("cache")
        print(cache.shape)
        # print(cache.columns)
        print("test")
        print(test.shape)
        # print(test.columns)
        # Update the global cache with new rows from test
        cache = pl.concat([cache, test], rechunk=True)
        

    # initialize preds
    preds = np.zeros((test.shape[0],))

    # lightgbm model
    X = test[CONFIG.feature_cols].to_numpy()
  
    # re-train a model on the fly every 30 batches
    if batch_count % 100 == 0 and batch_count>=100:
        print("---------------------------------------------------------------------------------------------")
        print("Using cache data to retrain the model")
        labels = cache[['date_id', 'symbol_id', 'responder_6_lag_1']]
        labels = labels.group_by(["date_id", "symbol_id"], maintain_order=True).last()  # pick up last record of previous date
        lag_cols_rename = {"responder_6_lag_1": "responder_6"}
        labels = labels.rename(lag_cols_rename)
        # I shift 1 day back because we know that responder_6_lag_1 correspond to the last recrd of the previous day
        labels = labels.with_columns(
            date_id = pl.col('date_id') - 1,  # lagged by 1 day
        )
        train = cache.group_by(["date_id", "symbol_id"], maintain_order=True).last()  # pick up last record of previous date
        train = train.join(labels, on=["date_id", "symbol_id"],  how="left")
        # merge with historic data before retraining
        # print("Shape before merging with historic data")
        # print("New data")
        train = train.drop(["is_scored","weight"])
        # print(train.columns)
        # print(train.shape)
        # print(train.dtypes)
        # print("Historical data")
        hist_data = pl_train.select(train.columns)
        # print(hist_data.columns)
        # print(hist_data.shape)
        # print(hist_data.dtypes)

        # Recasting columns of df1 to match the column types of df2
        train = train.select([
            pl.col(col).cast(hist_data.schema[col]) for col in hist_data.columns
        ])

        train = pl.concat([train, hist_data], rechunk=True)
        print("Shape after merging with historic data")  
        print(train.shape)
        # after this process we will obtain the labels
        X_train = train[CONFIG.feature_cols].to_numpy()
        y_train = train.select(CONFIG.target_col).to_numpy().flatten()

        print("shape train data")
        print(X_train.shape)

        print("shape labels")
        print(y_train.shape)       

        train_data = lgb.Dataset(X_train, label=y_train)

        # Re-train the model
        new_lgbm_model = lgb.train(
            params,
            train_data,
            num_boost_round=10
        )

    if new_lgbm_model:
        # lightgbm model
        y_pred1 = new_lgbm_model.predict(X, num_iteration=lgbm_model.best_iteration)
        y_pred2 = lgbm_model.predict(X, num_iteration=lgbm_model.best_iteration)
        y_pred = (y_pred1+y_pred2)/2
    else:
        # lightgbm model
        y_pred = lgbm_model.predict(X, num_iteration=lgbm_model.best_iteration)

    preds = y_pred
    print(f"predict> preds.shape =", preds.shape)
    
    predictions = \
    test.select('row_id').\
    with_columns(
        pl.Series(
            name   = 'responder_6', 
            values = np.clip(preds, a_min = -5, a_max = 5),
            dtype  = pl.Float64,
        )
    )

    if isinstance(predictions, pl.DataFrame):
        assert predictions.columns == ['row_id', 'responder_6']
    elif isinstance(predictions, pd.DataFrame):
        assert (predictions.columns == ['row_id', 'responder_6']).all()
    else:
        raise TypeError('The predict function must return a DataFrame')
    # Confirm has as many rows as the test data.
    assert len(predictions) == len(test)

    batch_count+=1

    return predictions

In [14]:
# test = pl.read_parquet('/kaggle/input/jane-street-real-time-market-data-forecasting/test.parquet')
# lags = pl.read_parquet('/kaggle/input/jane-street-real-time-market-data-forecasting/lags.parquet')
# predict(test, lags)

In [15]:
inference_server = kaggle_evaluation.jane_street_inference_server.JSInferenceServer(predict)

if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    inference_server.serve()
else:
    inference_server.run_local_gateway(
        (
            '/kaggle/input/jane-street-real-time-market-data-forecasting/test.parquet',
            '/kaggle/input/jane-street-real-time-market-data-forecasting/lags.parquet',
        )
    )

Filling cache for batch count 1
cache
(0, 0)
test
(39, 94)
predict> preds.shape = (39,)
