### Jane Street Real-Time Market Data Forecasting baseline with LightGBM

Link to the competition: https://www.kaggle.com/competitions/jane-street-real-time-market-data-forecasting/overview


Important information

- Lags: Values of responder_{0...8} lagged by one date_id. The evaluation API serves the entirety of the lagged responders for a date_id on that date_id's first time_id. In other words, all of the previous date's responders will be served at the first time step of the succeeding date.

- The symbol_id column contains encrypted identifiers. Each symbol_id is not guaranteed to appear in all time_id and date_id combinations. Additionally, new symbol_id values may appear in future test sets.

We will use the zero-mean R-squared function as the loss and customize the evaluation metric.

The zero-mean R-squared function is:

$$ 1 - \frac{\sum_{i=1}^n w_i (y_i - \hat{y}_i)^2}{\sum_{i=1}^n w_i y_i^2} $$

So the loss function is:

$$ \text{Loss} = \sum_{i=1}^n w_i (y_i - \hat{y}_i)^2 $$

To incorporate the zero-mean R-squared into the training loss in LightGBM, we need to calculate the gradient and hessian, which are:

$$ \frac{\partial \text{Loss}}{\partial \hat{y}_i} = -2 w_i (y_i - \hat{y}_i) $$


$$ \frac{\partial^2 \text{Loss}}{\partial \hat{y}_i^2} = 2 w_i $$

In this notebook, we don't use lags at the moment. For more information about using lags data, check this [notebook](https://www.kaggle.com/code/motono0223/js24-preprocessing-create-lags).

In [1]:
import numpy as np
import lightgbm as lgb
import polars as pl
import plotly.express as px
from pathlib import Path

In [2]:
data_path = "/home/yang/kaggle/jane/data"

In [3]:
training_data = pl.read_parquet(Path(data_path, "train.parquet", "partition_id=0", "part-0.parquet"))
training_data.head(10)

date_id,time_id,symbol_id,weight,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,feature_32,…,feature_51,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78,responder_0,responder_1,responder_2,responder_3,responder_4,responder_5,responder_6,responder_7,responder_8
i16,i16,i8,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i8,i8,i16,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,0,1,3.889038,,,,,,0.851033,0.242971,0.2634,-0.891687,11,7,76,-0.883028,0.003067,-0.744703,,-0.169586,,-1.335938,-1.707803,0.91013,,1.636431,1.522133,-1.551398,-0.229627,,,1.378301,-0.283712,0.123196,,,…,0.204797,,,-0.808103,,-2.037683,0.727661,,-0.989118,-0.345213,-1.36224,,,,,,-1.251104,-0.110252,-0.491157,-1.02269,0.152241,-0.659864,,,-0.261412,-0.211486,-0.335556,-0.281498,0.738489,-0.069556,1.380875,2.005353,0.186018,1.218368,0.775981,0.346999,0.095504
0,0,7,1.370613,,,,,,0.676961,0.151984,0.192465,-0.521729,11,7,76,-0.865307,-0.225629,-0.582163,,0.317467,,-1.250016,-1.682929,1.412757,,0.520378,0.744132,-0.788658,0.641776,,,0.2272,0.580907,1.128879,,,…,1.172836,,,-1.625862,,-1.410017,1.063013,,0.888355,0.467994,-1.36224,,,,,,-1.065759,0.013322,-0.592855,-1.052685,-0.393726,-0.741603,,,-0.281207,-0.182894,-0.245565,-0.302441,2.965889,1.190077,-0.523998,3.849921,2.626981,5.0,0.703665,0.216683,0.778639
0,0,9,2.285698,,,,,,1.056285,0.187227,0.249901,-0.77305,11,7,76,-0.675719,-0.199404,-0.586798,,-0.814909,,-1.296782,-2.040234,0.639589,,1.597359,0.657514,-1.350148,0.364215,,,-0.017751,-0.317361,-0.122379,,,…,0.535897,,,-0.72542,,-2.29417,1.764551,,-0.120789,-0.063458,-1.36224,,,,,,-0.882604,-0.072482,-0.617934,-0.86323,-0.241892,-0.709919,,,0.377131,0.300724,-0.106842,-0.096792,-0.864488,-0.280303,-0.326697,0.375781,1.271291,0.099793,2.109352,0.670881,0.772828
0,0,10,0.690606,,,,,,1.139366,0.273328,0.306549,-1.262223,42,5,150,-0.694008,3.004091,0.114809,,-0.251882,,-1.902009,-0.979447,0.241165,,-0.392359,-0.224699,-2.129397,-0.855287,,,0.404142,-0.578156,0.105702,,,…,2.413415,,,1.313203,,-0.810125,2.939022,,3.988801,1.834661,-1.36224,,,,,,-0.697595,1.074309,-0.206929,-0.530602,4.765215,0.571554,,,-0.226891,-0.251412,-0.215522,-0.296244,0.408499,0.223992,2.294888,1.097444,1.225872,1.225376,1.114137,0.775199,-1.379516
0,0,14,0.44057,,,,,,0.9552,0.262404,0.344457,-0.613813,44,3,16,-0.947351,-0.030018,-0.502379,,0.646086,,-1.844685,-1.58656,-0.182024,,-0.969949,-0.673813,-1.282132,-1.399894,,,0.043815,-0.320225,-0.031713,,,…,1.253902,,,0.476195,,-0.771732,2.843421,,1.379815,0.411827,-1.36224,,,,,,-0.948601,-0.136814,-0.447704,-1.141761,0.099631,-0.661928,,,3.678076,2.793581,2.61825,3.418133,-0.373387,-0.502764,-0.348021,-3.928148,-1.591366,-5.0,-3.57282,-1.089123,-5.0
0,0,16,1.118269,,,,,,1.092428,0.241437,0.309494,-1.047909,11,7,76,-0.375681,0.195831,-0.408775,,0.592456,,-0.998185,-1.444947,0.748643,,0.36953,0.748203,-1.476237,-0.098337,,,0.276082,0.108393,0.762983,,,…,0.781335,,,-0.852696,,-1.823927,1.409013,,-1.46437,-0.784301,-1.36224,,,,,,-0.81232,0.501891,-0.32754,-0.522513,-0.163335,-0.706811,,,-0.272425,-0.253911,-0.346411,-0.189784,0.505199,0.035095,-1.976849,1.33572,-1.215543,-1.593503,0.84055,0.246794,-0.101013
0,0,19,2.456331,,,,,,0.747231,0.168143,0.193987,-0.602202,4,3,11,-1.062734,0.795971,-0.237297,,0.388065,,-0.913943,-1.914018,0.105461,,0.946351,1.344925,-0.573883,0.964037,,,-0.616086,0.142924,0.266898,,,…,0.881099,,,-2.042819,,-1.802318,1.041058,,-1.487282,-0.671137,-1.36224,,,,,,-0.77469,1.127827,-0.201782,-0.990042,1.306702,-0.356026,,,-0.330015,-0.249928,-0.210434,-0.187166,0.096462,-0.252154,-0.002623,-1.002823,-0.632378,0.112727,-0.807892,-0.704062,0.184303
0,0,33,1.663408,,,,,,1.182569,0.206299,0.253336,-0.946356,11,7,76,-0.766347,-0.406701,-1.033553,,1.628502,,-1.145392,-1.682898,0.679352,,0.767303,0.740793,-1.215067,-1.124142,,,-0.46222,0.263958,1.038144,,,…,1.122132,,,-1.721597,,-0.916624,1.604584,,1.223971,0.653575,-1.36224,,,,,,-0.856799,-0.313983,-0.510727,-0.710232,-0.222114,-0.823108,,,-0.347611,-0.286403,-0.399971,-0.171215,0.209253,0.182766,0.010843,1.409293,0.002821,0.147636,0.965387,-0.263765,0.280629
0,1,1,3.889038,,,,,,0.917613,-0.401377,0.066465,-0.724457,11,7,76,-0.754305,-0.010117,-0.589252,,-0.2128,,-1.07089,-1.121607,0.91013,,1.636431,1.522133,-1.551398,-0.229627,,,1.378301,-0.283712,0.123196,,,…,0.041353,,,-1.055839,,-1.715698,1.856922,,-0.352075,-0.250698,-1.36224,,,,,,-0.766033,-0.177765,-0.470256,-1.287999,0.196725,-0.319822,,,-0.211037,-0.235291,-0.269429,-0.301006,0.80866,-0.10582,1.493438,1.387199,0.095372,1.163139,0.751976,0.348024,0.041855
0,1,7,1.370613,,,,,,0.877172,-0.254713,0.04721,-0.673954,11,7,76,-1.070783,0.749826,-0.530852,,1.098499,,-1.333668,-1.410048,1.412757,,0.520378,0.744132,-0.788658,0.641776,,,0.2272,0.580907,1.128879,,,…,-0.064571,,,-0.938621,,-2.335762,1.684062,,-0.049584,0.194288,-1.36224,,,,,,-0.974702,1.749538,-0.214511,-0.965937,-0.192011,-0.498838,,,-0.293913,-0.240564,-0.208652,-0.294585,3.119074,1.52468,-0.080976,4.015586,2.337222,5.0,0.635277,0.009769,0.251455


In [19]:
training_data.estimated_size("gb")

0.6481935195624828

In [4]:
# convert all null to 0
training_data = training_data.fill_null(0)
training_data.head(5)

date_id,time_id,symbol_id,weight,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,feature_32,…,feature_51,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78,responder_0,responder_1,responder_2,responder_3,responder_4,responder_5,responder_6,responder_7,responder_8
i16,i16,i8,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i8,i8,i16,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,0,1,3.889038,0.0,0.0,0.0,0.0,0.0,0.851033,0.242971,0.2634,-0.891687,11,7,76,-0.883028,0.003067,-0.744703,0.0,-0.169586,0.0,-1.335938,-1.707803,0.91013,0.0,1.636431,1.522133,-1.551398,-0.229627,0.0,0.0,1.378301,-0.283712,0.123196,0.0,0.0,…,0.204797,0.0,0.0,-0.808103,0.0,-2.037683,0.727661,0.0,-0.989118,-0.345213,-1.36224,0.0,0.0,0.0,0.0,0.0,-1.251104,-0.110252,-0.491157,-1.02269,0.152241,-0.659864,0.0,0.0,-0.261412,-0.211486,-0.335556,-0.281498,0.738489,-0.069556,1.380875,2.005353,0.186018,1.218368,0.775981,0.346999,0.095504
0,0,7,1.370613,0.0,0.0,0.0,0.0,0.0,0.676961,0.151984,0.192465,-0.521729,11,7,76,-0.865307,-0.225629,-0.582163,0.0,0.317467,0.0,-1.250016,-1.682929,1.412757,0.0,0.520378,0.744132,-0.788658,0.641776,0.0,0.0,0.2272,0.580907,1.128879,0.0,0.0,…,1.172836,0.0,0.0,-1.625862,0.0,-1.410017,1.063013,0.0,0.888355,0.467994,-1.36224,0.0,0.0,0.0,0.0,0.0,-1.065759,0.013322,-0.592855,-1.052685,-0.393726,-0.741603,0.0,0.0,-0.281207,-0.182894,-0.245565,-0.302441,2.965889,1.190077,-0.523998,3.849921,2.626981,5.0,0.703665,0.216683,0.778639
0,0,9,2.285698,0.0,0.0,0.0,0.0,0.0,1.056285,0.187227,0.249901,-0.77305,11,7,76,-0.675719,-0.199404,-0.586798,0.0,-0.814909,0.0,-1.296782,-2.040234,0.639589,0.0,1.597359,0.657514,-1.350148,0.364215,0.0,0.0,-0.017751,-0.317361,-0.122379,0.0,0.0,…,0.535897,0.0,0.0,-0.72542,0.0,-2.29417,1.764551,0.0,-0.120789,-0.063458,-1.36224,0.0,0.0,0.0,0.0,0.0,-0.882604,-0.072482,-0.617934,-0.86323,-0.241892,-0.709919,0.0,0.0,0.377131,0.300724,-0.106842,-0.096792,-0.864488,-0.280303,-0.326697,0.375781,1.271291,0.099793,2.109352,0.670881,0.772828
0,0,10,0.690606,0.0,0.0,0.0,0.0,0.0,1.139366,0.273328,0.306549,-1.262223,42,5,150,-0.694008,3.004091,0.114809,0.0,-0.251882,0.0,-1.902009,-0.979447,0.241165,0.0,-0.392359,-0.224699,-2.129397,-0.855287,0.0,0.0,0.404142,-0.578156,0.105702,0.0,0.0,…,2.413415,0.0,0.0,1.313203,0.0,-0.810125,2.939022,0.0,3.988801,1.834661,-1.36224,0.0,0.0,0.0,0.0,0.0,-0.697595,1.074309,-0.206929,-0.530602,4.765215,0.571554,0.0,0.0,-0.226891,-0.251412,-0.215522,-0.296244,0.408499,0.223992,2.294888,1.097444,1.225872,1.225376,1.114137,0.775199,-1.379516
0,0,14,0.44057,0.0,0.0,0.0,0.0,0.0,0.9552,0.262404,0.344457,-0.613813,44,3,16,-0.947351,-0.030018,-0.502379,0.0,0.646086,0.0,-1.844685,-1.58656,-0.182024,0.0,-0.969949,-0.673813,-1.282132,-1.399894,0.0,0.0,0.043815,-0.320225,-0.031713,0.0,0.0,…,1.253902,0.0,0.0,0.476195,0.0,-0.771732,2.843421,0.0,1.379815,0.411827,-1.36224,0.0,0.0,0.0,0.0,0.0,-0.948601,-0.136814,-0.447704,-1.141761,0.099631,-0.661928,0.0,0.0,3.678076,2.793581,2.61825,3.418133,-0.373387,-0.502764,-0.348021,-3.928148,-1.591366,-5.0,-3.57282,-1.089123,-5.0


In [5]:
train_feature_list = ["time_id", "symbol_id"] + [f"feature_{idx:02d}" for idx in range(79)]

In [6]:
weight = training_data["weight"]

In [7]:
# subset with only features
training_data_subset = training_data.select([col for col in training_data.columns if col in train_feature_list])
training_data_subset.head(10)

time_id,symbol_id,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,feature_32,feature_33,feature_34,…,feature_42,feature_43,feature_44,feature_45,feature_46,feature_47,feature_48,feature_49,feature_50,feature_51,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78
i16,i8,f32,f32,f32,f32,f32,f32,f32,f32,f32,i8,i8,i16,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,1,0.0,0.0,0.0,0.0,0.0,0.851033,0.242971,0.2634,-0.891687,11,7,76,-0.883028,0.003067,-0.744703,0.0,-0.169586,0.0,-1.335938,-1.707803,0.91013,0.0,1.636431,1.522133,-1.551398,-0.229627,0.0,0.0,1.378301,-0.283712,0.123196,0.0,0.0,0.0,0.28118,…,0.0,-0.181716,0.0,0.0,0.0,0.564021,2.088506,0.832022,0.0,0.204797,0.0,0.0,-0.808103,0.0,-2.037683,0.727661,0.0,-0.989118,-0.345213,-1.36224,0.0,0.0,0.0,0.0,0.0,-1.251104,-0.110252,-0.491157,-1.02269,0.152241,-0.659864,0.0,0.0,-0.261412,-0.211486,-0.335556,-0.281498
0,7,0.0,0.0,0.0,0.0,0.0,0.676961,0.151984,0.192465,-0.521729,11,7,76,-0.865307,-0.225629,-0.582163,0.0,0.317467,0.0,-1.250016,-1.682929,1.412757,0.0,0.520378,0.744132,-0.788658,0.641776,0.0,0.0,0.2272,0.580907,1.128879,0.0,0.0,0.0,-1.512286,…,0.0,0.0,0.0,0.0,0.0,-10.835207,-0.002704,-0.621836,0.0,1.172836,0.0,0.0,-1.625862,0.0,-1.410017,1.063013,0.0,0.888355,0.467994,-1.36224,0.0,0.0,0.0,0.0,0.0,-1.065759,0.013322,-0.592855,-1.052685,-0.393726,-0.741603,0.0,0.0,-0.281207,-0.182894,-0.245565,-0.302441
0,9,0.0,0.0,0.0,0.0,0.0,1.056285,0.187227,0.249901,-0.77305,11,7,76,-0.675719,-0.199404,-0.586798,0.0,-0.814909,0.0,-1.296782,-2.040234,0.639589,0.0,1.597359,0.657514,-1.350148,0.364215,0.0,0.0,-0.017751,-0.317361,-0.122379,0.0,0.0,0.0,-0.320921,…,0.0,0.0,0.0,0.0,0.0,-1.420632,-3.515137,-4.67776,0.0,0.535897,0.0,0.0,-0.72542,0.0,-2.29417,1.764551,0.0,-0.120789,-0.063458,-1.36224,0.0,0.0,0.0,0.0,0.0,-0.882604,-0.072482,-0.617934,-0.86323,-0.241892,-0.709919,0.0,0.0,0.377131,0.300724,-0.106842,-0.096792
0,10,0.0,0.0,0.0,0.0,0.0,1.139366,0.273328,0.306549,-1.262223,42,5,150,-0.694008,3.004091,0.114809,0.0,-0.251882,0.0,-1.902009,-0.979447,0.241165,0.0,-0.392359,-0.224699,-2.129397,-0.855287,0.0,0.0,0.404142,-0.578156,0.105702,0.0,0.0,0.0,0.544138,…,0.0,0.0,0.0,0.0,0.0,0.382074,2.669135,0.611711,0.0,2.413415,0.0,0.0,1.313203,0.0,-0.810125,2.939022,0.0,3.988801,1.834661,-1.36224,0.0,0.0,0.0,0.0,0.0,-0.697595,1.074309,-0.206929,-0.530602,4.765215,0.571554,0.0,0.0,-0.226891,-0.251412,-0.215522,-0.296244
0,14,0.0,0.0,0.0,0.0,0.0,0.9552,0.262404,0.344457,-0.613813,44,3,16,-0.947351,-0.030018,-0.502379,0.0,0.646086,0.0,-1.844685,-1.58656,-0.182024,0.0,-0.969949,-0.673813,-1.282132,-1.399894,0.0,0.0,0.043815,-0.320225,-0.031713,0.0,0.0,0.0,-0.08842,…,0.0,0.0,0.0,0.0,0.0,-2.0146,-2.321076,-3.711265,0.0,1.253902,0.0,0.0,0.476195,0.0,-0.771732,2.843421,0.0,1.379815,0.411827,-1.36224,0.0,0.0,0.0,0.0,0.0,-0.948601,-0.136814,-0.447704,-1.141761,0.099631,-0.661928,0.0,0.0,3.678076,2.793581,2.61825,3.418133
0,16,0.0,0.0,0.0,0.0,0.0,1.092428,0.241437,0.309494,-1.047909,11,7,76,-0.375681,0.195831,-0.408775,0.0,0.592456,0.0,-0.998185,-1.444947,0.748643,0.0,0.36953,0.748203,-1.476237,-0.098337,0.0,0.0,0.276082,0.108393,0.762983,0.0,0.0,0.0,0.041396,…,0.0,0.0,0.0,0.0,0.0,4.809584,1.041688,3.471813,0.0,0.781335,0.0,0.0,-0.852696,0.0,-1.823927,1.409013,0.0,-1.46437,-0.784301,-1.36224,0.0,0.0,0.0,0.0,0.0,-0.81232,0.501891,-0.32754,-0.522513,-0.163335,-0.706811,0.0,0.0,-0.272425,-0.253911,-0.346411,-0.189784
0,19,0.0,0.0,0.0,0.0,0.0,0.747231,0.168143,0.193987,-0.602202,4,3,11,-1.062734,0.795971,-0.237297,0.0,0.388065,0.0,-0.913943,-1.914018,0.105461,0.0,0.946351,1.344925,-0.573883,0.964037,0.0,0.0,-0.616086,0.142924,0.266898,0.0,0.0,0.0,-0.071354,…,0.0,0.0,0.0,0.0,0.0,1.070634,-0.217551,2.093741,0.0,0.881099,0.0,0.0,-2.042819,0.0,-1.802318,1.041058,0.0,-1.487282,-0.671137,-1.36224,0.0,0.0,0.0,0.0,0.0,-0.77469,1.127827,-0.201782,-0.990042,1.306702,-0.356026,0.0,0.0,-0.330015,-0.249928,-0.210434,-0.187166
0,33,0.0,0.0,0.0,0.0,0.0,1.182569,0.206299,0.253336,-0.946356,11,7,76,-0.766347,-0.406701,-1.033553,0.0,1.628502,0.0,-1.145392,-1.682898,0.679352,0.0,0.767303,0.740793,-1.215067,-1.124142,0.0,0.0,-0.46222,0.263958,1.038144,0.0,0.0,0.0,-0.222309,…,0.0,0.0,0.0,0.0,0.0,0.927966,1.264635,1.45311,0.0,1.122132,0.0,0.0,-1.721597,0.0,-0.916624,1.604584,0.0,1.223971,0.653575,-1.36224,0.0,0.0,0.0,0.0,0.0,-0.856799,-0.313983,-0.510727,-0.710232,-0.222114,-0.823108,0.0,0.0,-0.347611,-0.286403,-0.399971,-0.171215
1,1,0.0,0.0,0.0,0.0,0.0,0.917613,-0.401377,0.066465,-0.724457,11,7,76,-0.754305,-0.010117,-0.589252,0.0,-0.2128,0.0,-1.07089,-1.121607,0.91013,0.0,1.636431,1.522133,-1.551398,-0.229627,0.0,0.0,1.378301,-0.283712,0.123196,0.0,0.0,0.0,0.303496,…,0.0,-0.244077,0.0,0.0,0.0,0.452334,1.012832,0.762938,0.0,0.041353,0.0,0.0,-1.055839,0.0,-1.715698,1.856922,0.0,-0.352075,-0.250698,-1.36224,0.0,0.0,0.0,0.0,0.0,-0.766033,-0.177765,-0.470256,-1.287999,0.196725,-0.319822,0.0,0.0,-0.211037,-0.235291,-0.269429,-0.301006
1,7,0.0,0.0,0.0,0.0,0.0,0.877172,-0.254713,0.04721,-0.673954,11,7,76,-1.070783,0.749826,-0.530852,0.0,1.098499,0.0,-1.333668,-1.410048,1.412757,0.0,0.520378,0.744132,-0.788658,0.641776,0.0,0.0,0.2272,0.580907,1.128879,0.0,0.0,0.0,-1.831976,…,0.0,0.0,0.0,0.0,0.0,-7.63369,-0.004959,-0.497161,0.0,-0.064571,0.0,0.0,-0.938621,0.0,-2.335762,1.684062,0.0,-0.049584,0.194288,-1.36224,0.0,0.0,0.0,0.0,0.0,-0.974702,1.749538,-0.214511,-0.965937,-0.192011,-0.498838,0.0,0.0,-0.293913,-0.240564,-0.208652,-0.294585


In [8]:
training_data_loader = lgb.Dataset(training_data_subset, label=training_data.select(pl.col("responder_6")).to_numpy(), weight=training_data.select(pl.col("weight")).to_numpy())

# Customize the loss function

In [9]:
def zero_mean_r2_objective(pred, train):
    """
    Custom zero-mean R-squared objective for LightGBM.

    Args:
        y_true: Array of true values.
        y_pred: Array of predicted values.
        weight: Array of sample weights.

    Returns:
        grad: Gradient.
        hess: Hessian.
    """

    # Ensure weights are valid
    weight = train.get_weight() if train.get_weight() is not None else np.ones_like(pred)
    
    # Gradient (negative derivative of the loss)
    grad = -2 * weight * (train.get_label() - pred)
    
    # Hessian (second derivative of the loss)
    hess = 2 * weight
    
    return grad, hess

## Train LightGBM model to predict responder 6

In [10]:
# Set parameters for LightGBM
# params = {
#     'objective': 'regression',
#     'metric': 'rmse',
#     'boosting_type': 'gbdt',
#     'num_leaves': 31,
#     'learning_rate': 0.05,
#     'feature_fraction': 0.9
# }

params = {
    "objective": zero_mean_r2_objective,  # Disable default objectives
    "metric": "None",     # Disable default metrics
    'boosting_type': 'gbdt',
    'num_leaves': 62,
    'learning_rate': 0.05,
    'feature_fraction': 0.9,
}

In [12]:
# Train with custom objective and evaluation metric
model = lgb.train(
    params,
    training_data_loader,
    #valid_sets=[data],  # Add validation sets if available
    #feval=zero_mean_r2_score,  # Optional: include the custom metric for evaluation
    num_boost_round=1000,
    #early_stopping_rounds=10,
)



[LightGBM] [Info] Using self-defined objective function
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.198448 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 17308
[LightGBM] [Info] Number of data points in the train set: 1944210, number of used features: 72
[LightGBM] [Info] Using self-defined objective function


In [13]:
model.save_model('jane_lgbm_null_to_0_custom_loss.txt')

<lightgbm.basic.Booster at 0x7f252277f5c0>

## Model evaluation

In [14]:
# load example test data
test_data = pl.read_parquet(Path(data_path, "test.parquet", "date_id=0", "part-0.parquet"))
test_data.head(5)

row_id,date_id,time_id,symbol_id,weight,is_scored,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,…,feature_42,feature_43,feature_44,feature_45,feature_46,feature_47,feature_48,feature_49,feature_50,feature_51,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78
i64,i16,i16,i8,f32,bool,f32,f32,f32,f32,f32,f32,f32,f32,f32,f64,f64,f64,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,0,0,0,3.169998,True,0.0,0.0,0.0,0.0,0.0,-0.0,-0.0,-0.0,0.0,0.0,0.0,0.0,-0.0,0.0,0.0,,-0.0,,-0.0,-0.0,0.0,-0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.0,-0.0,…,,-0.0,,-0.0,0.0,-0.0,0.0,0.0,,0.0,,,-0.0,,-0.0,0.0,,0.0,0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.0,-0.0,-0.0,0.0,0.0,,,0.0,0.0,-0.0,-0.0
1,0,0,1,2.165993,True,0.0,-0.0,0.0,0.0,0.0,-0.0,-0.0,-0.0,0.0,0.0,0.0,0.0,-0.0,0.0,-0.0,,-0.0,,-0.0,-0.0,0.0,-0.0,0.0,0.0,0.0,0.0,-0.0,0.0,0.0,-0.0,-0.0,…,,-0.0,,-0.0,0.0,0.0,0.0,0.0,,0.0,,,-0.0,,-0.0,0.0,,0.0,0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.0,-0.0,-0.0,0.0,-0.0,,,0.0,0.0,0.0,0.0
2,0,0,2,3.06555,True,0.0,-0.0,0.0,0.0,0.0,-0.0,-0.0,-0.0,0.0,0.0,0.0,0.0,-0.0,0.0,0.0,,-0.0,,-0.0,-0.0,0.0,-0.0,0.0,0.0,0.0,-0.0,0.0,0.0,0.0,-0.0,-0.0,…,,-0.0,,-0.0,0.0,-0.0,-0.0,-0.0,,0.0,,,-0.0,,-0.0,0.0,,-0.0,-0.0,-0.0,0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0,,,0.0,0.0,-0.0,-0.0
3,0,0,3,2.698642,True,0.0,0.0,0.0,0.0,0.0,-0.0,-0.0,-0.0,0.0,0.0,0.0,0.0,-0.0,0.0,0.0,,-0.0,,-0.0,-0.0,-0.0,-0.0,0.0,0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,…,,-0.0,,-0.0,0.0,-0.0,0.0,-0.0,,-0.0,,,-0.0,,-0.0,0.0,,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0,,,0.0,0.0,-0.0,-0.0
4,0,0,4,1.80333,True,0.0,-0.0,0.0,0.0,0.0,-0.0,-0.0,-0.0,0.0,0.0,0.0,0.0,-0.0,0.0,-0.0,,-0.0,,-0.0,-0.0,-0.0,0.0,0.0,0.0,0.0,0.0,-0.0,-0.0,-0.0,-0.0,-0.0,…,,-0.0,,-0.0,0.0,0.0,0.0,0.0,,0.0,,,-0.0,,-0.0,0.0,,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.0,-0.0,-0.0,0.0,-0.0,,,0.0,0.0,0.0,0.0


In [15]:
test_data_subset = test_data.select([pl.col(column) for column in train_feature_list])
test_data_subset.head(5)

time_id,symbol_id,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,feature_32,feature_33,feature_34,…,feature_42,feature_43,feature_44,feature_45,feature_46,feature_47,feature_48,feature_49,feature_50,feature_51,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78
i16,i8,f32,f32,f32,f32,f32,f32,f32,f32,f32,f64,f64,f64,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,0,0.0,0.0,0.0,0.0,0.0,-0.0,-0.0,-0.0,0.0,0.0,0.0,0.0,-0.0,0.0,0.0,,-0.0,,-0.0,-0.0,0.0,-0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.0,-0.0,-0.0,,,0.0,…,,-0.0,,-0.0,0.0,-0.0,0.0,0.0,,0.0,,,-0.0,,-0.0,0.0,,0.0,0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.0,-0.0,-0.0,0.0,0.0,,,0.0,0.0,-0.0,-0.0
0,1,0.0,-0.0,0.0,0.0,0.0,-0.0,-0.0,-0.0,0.0,0.0,0.0,0.0,-0.0,0.0,-0.0,,-0.0,,-0.0,-0.0,0.0,-0.0,0.0,0.0,0.0,0.0,-0.0,0.0,0.0,-0.0,-0.0,-0.0,,,0.0,…,,-0.0,,-0.0,0.0,0.0,0.0,0.0,,0.0,,,-0.0,,-0.0,0.0,,0.0,0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.0,-0.0,-0.0,0.0,-0.0,,,0.0,0.0,0.0,0.0
0,2,0.0,-0.0,0.0,0.0,0.0,-0.0,-0.0,-0.0,0.0,0.0,0.0,0.0,-0.0,0.0,0.0,,-0.0,,-0.0,-0.0,0.0,-0.0,0.0,0.0,0.0,-0.0,0.0,0.0,0.0,-0.0,-0.0,-0.0,,,0.0,…,,-0.0,,-0.0,0.0,-0.0,-0.0,-0.0,,0.0,,,-0.0,,-0.0,0.0,,-0.0,-0.0,-0.0,0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0,,,0.0,0.0,-0.0,-0.0
0,3,0.0,0.0,0.0,0.0,0.0,-0.0,-0.0,-0.0,0.0,0.0,0.0,0.0,-0.0,0.0,0.0,,-0.0,,-0.0,-0.0,-0.0,-0.0,0.0,0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,,,0.0,…,,-0.0,,-0.0,0.0,-0.0,0.0,-0.0,,-0.0,,,-0.0,,-0.0,0.0,,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0,,,0.0,0.0,-0.0,-0.0
0,4,0.0,-0.0,0.0,0.0,0.0,-0.0,-0.0,-0.0,0.0,0.0,0.0,0.0,-0.0,0.0,-0.0,,-0.0,,-0.0,-0.0,-0.0,0.0,0.0,0.0,0.0,0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.0,,,0.0,…,,-0.0,,-0.0,0.0,0.0,0.0,0.0,,0.0,,,-0.0,,-0.0,0.0,,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.0,-0.0,-0.0,0.0,-0.0,,,0.0,0.0,0.0,0.0


In [16]:
# load saved model to make predictions
model = lgb.Booster(model_file='jane_lgbm_null_to_0_custom_loss.txt')

In [17]:
y_pred = model.predict(test_data_subset)
y_pred



array([0.21535119, 0.21431521, 0.21431521, 0.21431521, 0.21431521,
       0.21431521, 0.21431521, 0.21431521, 0.21300842, 0.21680481,
       0.18400499, 0.18329191, 0.18329191, 0.18329191, 0.18329191,
       0.18329191, 0.18329191, 0.18329191, 0.18329191, 0.18329191,
       0.18329191, 0.18329191, 0.18329191, 0.18329191, 0.18329191,
       0.18329191, 0.18329191, 0.18329191, 0.18329191, 0.18329191,
       0.18329191, 0.18329191, 0.18329191, 0.18329191, 0.18329191,
       0.18329191, 0.18329191, 0.18329191, 0.18329191])

In [18]:
# put predictions into a dataframe with row id
predictions = test_data.select(pl.col("row_id"))
predictions = predictions.with_columns(pl.Series("responder_6", y_pred))
predictions

row_id,responder_6
i64,f64
0,0.215351
1,0.214315
2,0.214315
3,0.214315
4,0.214315
…,…
34,0.183292
35,0.183292
36,0.183292
37,0.183292
