In [1]:
import pandas as pd
import numpy as np
import polars as pl
import gc
from matplotlib import pyplot as plt
import matplotlib.cm as cm
from sklearn.model_selection import StratifiedGroupKFold


In [2]:
class CONFIG:
    target_col = "responder_6"
    lag_cols_original = ["date_id", "symbol_id"] + [f"responder_{idx}" for idx in range(9)]
    lag_cols_rename = { f"responder_{idx}" : f"responder_{idx}_lag_1" for idx in range(9)}
    valid_ratio = 0.05
    start_dt = 900

In [3]:
train = pl.scan_parquet(f"./train.parquet").select(
    pl.int_range(pl.len(), dtype=pl.UInt32).alias("id"),
    pl.all(),
).with_columns(
    (pl.col(CONFIG.target_col)*2).cast(pl.Int32).alias("label"),
).filter(pl.col("date_id") >= CONFIG.start_dt)
train.collect()

id,date_id,time_id,symbol_id,weight,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,…,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78,responder_0,responder_1,responder_2,responder_3,responder_4,responder_5,responder_6,responder_7,responder_8,partition_id,label
u32,i16,i16,i8,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i8,i8,i16,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i64,i32
18340954,900,0,0,2.371006,-0.362016,-0.696294,-0.019394,0.334306,2.253981,-0.439977,1.573309,0.352227,-0.044843,11,7,76,-1.099241,0.309228,-0.635026,,0.283982,,-1.131046,-1.035899,-0.278812,-0.20908,0.703597,1.203631,-0.339501,-0.719582,1.039576,1.153007,0.741938,0.226457,0.173083,-0.184055,…,,-1.597145,,-1.013786,1.302025,,0.668544,0.306663,-1.095944,-0.271329,-0.03082,-0.0478,-1.575859,-1.533297,-1.003211,0.300566,-0.365393,-0.894165,0.097043,-0.709909,,,-0.167952,-0.297251,-0.272094,-0.252341,-0.112829,-0.075709,0.541234,-0.566407,-0.626564,-0.746171,-0.716941,-0.455068,-1.3275,5,-1
18340955,900,0,1,3.687028,0.293689,-0.603608,-0.674807,0.024626,1.889744,-0.53366,1.414847,0.289346,-0.053265,11,7,76,-0.914842,0.195611,-0.587876,,-0.285317,,-0.842046,-0.823002,-0.059361,0.087735,1.852966,1.646581,0.009873,-0.733066,0.425621,1.664557,2.13886,-0.236449,-0.290679,0.084217,…,,-1.292938,,-1.349268,1.767661,,2.272888,1.085229,-1.095944,-0.223109,-0.179489,-0.277511,-1.639956,-1.501274,-0.977827,0.123375,-0.45289,-1.133464,0.257559,-0.640283,,,-0.208362,-0.267925,-0.177973,-0.295702,-0.367251,-0.221843,-1.160344,-0.190356,-0.352678,-1.381038,-0.193624,-0.209931,-0.81559,5,0
18340956,900,0,2,1.78284,0.058466,-0.616436,0.16103,0.153395,2.325349,-0.320207,1.460199,0.300257,-0.076107,81,2,59,-0.961296,1.546081,-0.047038,,0.028425,,-1.470194,-1.869943,-0.783728,-0.207264,0.04208,-0.31905,-0.44346,-0.782183,0.349521,1.803328,1.483842,-0.595795,-0.419538,-0.241272,…,,-0.274939,,-0.981101,2.188366,,2.931268,1.082697,-1.095944,0.015759,-0.175433,-0.106617,-1.93917,-2.069989,-0.866594,1.045068,-0.377948,-0.967123,2.72307,-0.110954,,,0.873847,0.790283,-0.040086,-0.036944,0.737472,0.160162,0.733051,-0.372516,-0.768763,-1.387485,-1.048492,-1.776193,-2.540738,5,-2
18340957,900,0,3,1.547719,-0.363907,-0.387501,0.061037,-0.205486,1.717424,-0.534499,1.302091,0.308245,-0.043578,4,3,11,-0.888497,0.19357,-0.663478,,0.167434,,-1.538333,-1.286979,0.369831,0.156803,-0.473849,-1.029982,0.385592,-0.38753,-0.386601,-0.874078,-0.781389,-0.526571,-0.850001,0.172943,…,,-0.802743,,-1.348928,1.5755,,1.633589,0.821374,-1.095944,-0.140429,0.26251,-0.012567,-1.236598,-1.376425,-0.819258,-0.125511,-0.536088,-0.662439,0.331441,-0.272971,,,9.073547,7.78659,2.52084,5.492901,-0.062704,-0.086862,-0.02496,1.408547,-0.392927,-0.129132,1.960046,-0.474515,-0.122464,5,3
18340958,900,0,7,2.320256,-0.254111,-0.727723,-0.168957,0.264305,2.023745,-0.53854,1.686796,0.384101,-0.059297,11,7,76,-1.186049,-0.520586,-0.607425,,1.152275,,-0.919511,-1.271852,0.721635,0.035307,0.375869,2.085817,-0.115594,-0.832953,-0.981918,0.761649,0.603729,1.783795,1.488107,0.04045,…,,-1.75769,,-1.784724,0.627529,,0.374991,0.281838,-1.095944,-0.161465,0.28516,0.159623,-2.008134,-1.425322,-0.859137,-0.265254,-0.689399,-1.158431,-0.357917,-0.6256,,,-0.246888,-0.213367,-0.299885,-0.223796,-0.041207,-0.50557,0.181291,-1.067922,-0.632572,-1.787657,-0.755364,-0.257752,-3.729681,5,-1
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
47127333,1698,967,34,3.242493,2.52516,-0.721981,2.544025,2.477615,0.417557,0.785812,1.117796,2.199436,0.415427,42,5,150,0.804403,1.157257,1.031543,-0.671189,-0.3286,-0.486132,1.730176,-0.006173,-0.001144,-0.213062,0.932618,1.367338,-0.238197,-0.692615,-0.121163,1.090798,1.444294,-0.675626,-1.013264,-0.242888,…,1.377051,-0.396358,0.520262,1.179617,1.127657,2.231928,0.614652,2.412886,-1.101531,-0.384833,-0.275818,-0.40804,2.427115,-0.108427,0.739734,0.830205,0.366287,1.33325,1.075499,1.798264,-0.183443,-0.190222,0.234211,0.347142,-0.044463,0.016936,0.243475,0.166927,0.38494,-0.174297,-0.066046,-0.038767,-0.132337,-0.022426,-0.252461,9,0
47127334,1698,967,35,1.079139,1.857906,-0.790646,2.745439,2.339877,0.845065,0.65137,1.180301,1.966379,0.321543,25,7,195,-0.075294,-0.152726,-0.20417,-0.421137,0.21708,-0.258775,1.874978,0.19988,-0.199219,-0.125619,-1.004547,-0.051933,0.450905,0.009246,0.164127,-0.939974,-1.143421,-0.320071,-0.379835,-0.142429,…,0.687755,-1.189577,0.180146,-0.175486,-1.60435,-0.209283,0.249847,0.288816,-1.101531,-0.343868,-0.253991,-0.278832,2.050639,-0.059506,-0.029396,-0.101381,-0.187759,-0.180839,-0.0861,-0.153405,-0.196077,-0.175292,1.04578,0.739733,0.03372,0.05086,0.850152,0.909382,1.015314,0.235962,0.122539,0.099559,-0.249584,-0.123571,-0.46063,9,0
47127335,1698,967,36,1.033172,2.515527,-0.672298,2.28925,2.521592,0.255077,0.919892,1.172018,2.180496,0.24846,49,7,297,1.026715,-0.096892,0.224309,-0.528109,-0.704952,-0.704818,2.312482,0.32804,-0.108193,,-0.945684,-0.244173,0.205989,-0.357343,,,-1.11075,-0.580242,-0.400568,,…,0.933568,0.032978,-0.519118,-0.290343,-0.806786,0.106295,0.183461,1.830421,-1.101531,-0.341991,-0.249132,-0.34365,2.251358,0.601888,1.035051,-0.283241,0.107244,0.86016,0.024223,0.374852,-0.220933,-0.161584,0.032771,0.036888,0.168908,0.152333,0.395684,-0.292574,-3.215846,-0.535129,-0.178484,-1.80815,-0.065355,-0.000367,-0.12517,9,0
47127336,1698,967,37,1.243116,2.663298,-0.889112,2.313155,3.101428,0.324454,0.618944,1.185663,1.599724,0.319719,34,4,214,0.759314,0.284057,0.41716,-0.611075,-0.513717,-0.891423,1.84994,0.406756,-1.608196,-0.252663,-0.271574,-0.051405,0.098146,-0.653961,0.173676,-0.016497,-0.404509,-0.577262,-0.731429,-0.21646,…,1.876459,-0.143377,0.845516,0.301135,-0.395703,0.738038,-0.04124,1.270645,-1.101531,-0.358106,-0.141883,-0.255192,2.489247,0.537652,0.982107,-0.158009,0.137389,0.478357,0.782692,0.581421,-0.106056,-0.111017,0.163867,0.169331,-0.037563,-0.029483,1.925987,0.479394,3.621867,-0.107114,-0.063599,1.204755,-0.148711,-0.026583,-0.256395,9,0


In [4]:
train.select(pl.col("label")).collect().describe()

statistic,label
str,f64
"""count""",28786384.0
"""null_count""",0.0
"""mean""",0.022574
"""std""",1.448318
"""min""",-10.0
"""25%""",0.0
"""50%""",0.0
"""75%""",0.0
"""max""",10.0


In [5]:
lags = train.select(pl.col(CONFIG.lag_cols_original)).rename(CONFIG.lag_cols_rename).with_columns(
    date_id = pl.col('date_id') + 1
)
lags = lags.group_by(["date_id", "symbol_id"], maintain_order=True).last()


In [6]:
train = train.join(lags, on=["date_id", "symbol_id"], how="left")
train.collect()

id,date_id,time_id,symbol_id,weight,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,…,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78,responder_0,responder_1,responder_2,responder_3,responder_4,responder_5,responder_6,responder_7,responder_8,partition_id,label,responder_0_lag_1,responder_1_lag_1,responder_2_lag_1,responder_3_lag_1,responder_4_lag_1,responder_5_lag_1,responder_6_lag_1,responder_7_lag_1,responder_8_lag_1
u32,i16,i16,i8,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i8,i8,i16,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i64,i32,f32,f32,f32,f32,f32,f32,f32,f32,f32
18340954,900,0,0,2.371006,-0.362016,-0.696294,-0.019394,0.334306,2.253981,-0.439977,1.573309,0.352227,-0.044843,11,7,76,-1.099241,0.309228,-0.635026,,0.283982,,-1.131046,-1.035899,-0.278812,-0.20908,0.703597,1.203631,-0.339501,-0.719582,1.039576,1.153007,0.741938,0.226457,0.173083,-0.184055,…,-0.271329,-0.03082,-0.0478,-1.575859,-1.533297,-1.003211,0.300566,-0.365393,-0.894165,0.097043,-0.709909,,,-0.167952,-0.297251,-0.272094,-0.252341,-0.112829,-0.075709,0.541234,-0.566407,-0.626564,-0.746171,-0.716941,-0.455068,-1.3275,5,-1,,,,,,,,,
18340955,900,0,1,3.687028,0.293689,-0.603608,-0.674807,0.024626,1.889744,-0.53366,1.414847,0.289346,-0.053265,11,7,76,-0.914842,0.195611,-0.587876,,-0.285317,,-0.842046,-0.823002,-0.059361,0.087735,1.852966,1.646581,0.009873,-0.733066,0.425621,1.664557,2.13886,-0.236449,-0.290679,0.084217,…,-0.223109,-0.179489,-0.277511,-1.639956,-1.501274,-0.977827,0.123375,-0.45289,-1.133464,0.257559,-0.640283,,,-0.208362,-0.267925,-0.177973,-0.295702,-0.367251,-0.221843,-1.160344,-0.190356,-0.352678,-1.381038,-0.193624,-0.209931,-0.81559,5,0,,,,,,,,,
18340956,900,0,2,1.78284,0.058466,-0.616436,0.16103,0.153395,2.325349,-0.320207,1.460199,0.300257,-0.076107,81,2,59,-0.961296,1.546081,-0.047038,,0.028425,,-1.470194,-1.869943,-0.783728,-0.207264,0.04208,-0.31905,-0.44346,-0.782183,0.349521,1.803328,1.483842,-0.595795,-0.419538,-0.241272,…,0.015759,-0.175433,-0.106617,-1.93917,-2.069989,-0.866594,1.045068,-0.377948,-0.967123,2.72307,-0.110954,,,0.873847,0.790283,-0.040086,-0.036944,0.737472,0.160162,0.733051,-0.372516,-0.768763,-1.387485,-1.048492,-1.776193,-2.540738,5,-2,,,,,,,,,
18340957,900,0,3,1.547719,-0.363907,-0.387501,0.061037,-0.205486,1.717424,-0.534499,1.302091,0.308245,-0.043578,4,3,11,-0.888497,0.19357,-0.663478,,0.167434,,-1.538333,-1.286979,0.369831,0.156803,-0.473849,-1.029982,0.385592,-0.38753,-0.386601,-0.874078,-0.781389,-0.526571,-0.850001,0.172943,…,-0.140429,0.26251,-0.012567,-1.236598,-1.376425,-0.819258,-0.125511,-0.536088,-0.662439,0.331441,-0.272971,,,9.073547,7.78659,2.52084,5.492901,-0.062704,-0.086862,-0.02496,1.408547,-0.392927,-0.129132,1.960046,-0.474515,-0.122464,5,3,,,,,,,,,
18340958,900,0,7,2.320256,-0.254111,-0.727723,-0.168957,0.264305,2.023745,-0.53854,1.686796,0.384101,-0.059297,11,7,76,-1.186049,-0.520586,-0.607425,,1.152275,,-0.919511,-1.271852,0.721635,0.035307,0.375869,2.085817,-0.115594,-0.832953,-0.981918,0.761649,0.603729,1.783795,1.488107,0.04045,…,-0.161465,0.28516,0.159623,-2.008134,-1.425322,-0.859137,-0.265254,-0.689399,-1.158431,-0.357917,-0.6256,,,-0.246888,-0.213367,-0.299885,-0.223796,-0.041207,-0.50557,0.181291,-1.067922,-0.632572,-1.787657,-0.755364,-0.257752,-3.729681,5,-1,,,,,,,,,
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
47127333,1698,967,34,3.242493,2.52516,-0.721981,2.544025,2.477615,0.417557,0.785812,1.117796,2.199436,0.415427,42,5,150,0.804403,1.157257,1.031543,-0.671189,-0.3286,-0.486132,1.730176,-0.006173,-0.001144,-0.213062,0.932618,1.367338,-0.238197,-0.692615,-0.121163,1.090798,1.444294,-0.675626,-1.013264,-0.242888,…,-0.384833,-0.275818,-0.40804,2.427115,-0.108427,0.739734,0.830205,0.366287,1.33325,1.075499,1.798264,-0.183443,-0.190222,0.234211,0.347142,-0.044463,0.016936,0.243475,0.166927,0.38494,-0.174297,-0.066046,-0.038767,-0.132337,-0.022426,-0.252461,9,0,0.501321,0.905332,-0.819582,-0.564046,-0.223018,-0.283954,-0.045938,0.009797,-0.102538
47127334,1698,967,35,1.079139,1.857906,-0.790646,2.745439,2.339877,0.845065,0.65137,1.180301,1.966379,0.321543,25,7,195,-0.075294,-0.152726,-0.20417,-0.421137,0.21708,-0.258775,1.874978,0.19988,-0.199219,-0.125619,-1.004547,-0.051933,0.450905,0.009246,0.164127,-0.939974,-1.143421,-0.320071,-0.379835,-0.142429,…,-0.343868,-0.253991,-0.278832,2.050639,-0.059506,-0.029396,-0.101381,-0.187759,-0.180839,-0.0861,-0.153405,-0.196077,-0.175292,1.04578,0.739733,0.03372,0.05086,0.850152,0.909382,1.015314,0.235962,0.122539,0.099559,-0.249584,-0.123571,-0.46063,9,0,-1.113053,0.69719,-1.619031,-1.222743,-0.706082,-0.291133,0.167733,0.099704,0.32461
47127335,1698,967,36,1.033172,2.515527,-0.672298,2.28925,2.521592,0.255077,0.919892,1.172018,2.180496,0.24846,49,7,297,1.026715,-0.096892,0.224309,-0.528109,-0.704952,-0.704818,2.312482,0.32804,-0.108193,,-0.945684,-0.244173,0.205989,-0.357343,,,-1.11075,-0.580242,-0.400568,,…,-0.341991,-0.249132,-0.34365,2.251358,0.601888,1.035051,-0.283241,0.107244,0.86016,0.024223,0.374852,-0.220933,-0.161584,0.032771,0.036888,0.168908,0.152333,0.395684,-0.292574,-3.215846,-0.535129,-0.178484,-1.80815,-0.065355,-0.000367,-0.12517,9,0,-1.019353,-0.460962,-2.026678,-0.848606,-0.305448,-1.256913,-0.109359,-0.027474,-0.253956
47127336,1698,967,37,1.243116,2.663298,-0.889112,2.313155,3.101428,0.324454,0.618944,1.185663,1.599724,0.319719,34,4,214,0.759314,0.284057,0.41716,-0.611075,-0.513717,-0.891423,1.84994,0.406756,-1.608196,-0.252663,-0.271574,-0.051405,0.098146,-0.653961,0.173676,-0.016497,-0.404509,-0.577262,-0.731429,-0.21646,…,-0.358106,-0.141883,-0.255192,2.489247,0.537652,0.982107,-0.158009,0.137389,0.478357,0.782692,0.581421,-0.106056,-0.111017,0.163867,0.169331,-0.037563,-0.029483,1.925987,0.479394,3.621867,-0.107114,-0.063599,1.204755,-0.148711,-0.026583,-0.256395,9,0,0.23585,0.556479,0.618944,-0.243765,-0.108361,-0.260777,-0.486923,-0.275566,-1.020708


In [7]:
len_train = train.select(pl.col("date_id")).collect().shape[0]
valid_records = int(len_train * CONFIG.valid_ratio)
len_ofl_mdl = len_train - valid_records
last_tr_dt = train.select(pl.col("date_id")).collect().row(len_ofl_mdl)[0]

print(f"len_train: {len_train}, len_ofl_mdl: {len_ofl_mdl}, last_tr_dt: {last_tr_dt}")

training_data = train.filter(pl.col("date_id") <= last_tr_dt)
validation_data = train.filter(pl.col("date_id") > last_tr_dt)

len_train: 28786384, len_ofl_mdl: 27347065, last_tr_dt: 1660


In [8]:
training_data.collect().write_parquet(f"./training_data.parquet", partition_by= "date_id")

In [9]:

validation_data.collect().write_parquet(f"./validation_data.parquet", partition_by= "date_id")