### Paths and imports

In [37]:
import os
import polars as pl
import shutil

from IPython.display import HTML, display

In [38]:
# TODO: Variable is scored. 
# TODO: Lags aren't actually lagged

In [39]:
# Setup paths relative to this notebook's location
NOTEBOOK_DIR = os.getcwd()  # submissions folder
PROJ_DIR = os.path.dirname(NOTEBOOK_DIR)  # main project folder
DATA_DIR = os.path.join(PROJ_DIR, "jane-street-real-time-market-data-forecasting")
LOCAL_TEST_DIR = os.path.join(NOTEBOOK_DIR, "local_test_data")

# Create local test directory if it doesn't exist
os.makedirs(LOCAL_TEST_DIR, exist_ok=True)

### Grab a sample of train to turn into test data

In [43]:
def get_test_and_lag_dates(train_df: pl.LazyFrame, test_start_date: int, test_end_date: int = None) -> tuple[pl.Series, pl.Series]:

    # Validate test_start_date exists in data
    min_date = train_df.select(pl.col('date_id').min()).collect().item()
    max_date = train_df.select(pl.col('date_id').max()).collect().item()
    
    if test_start_date < min_date:
        raise ValueError(f"test_start_date ({test_start_date}) must be >= minimum date in dataset ({min_date})")
    
    # If no end date specified, use max date
    if test_end_date is None:
        test_end_date = max_date
    elif test_end_date > max_date:
        raise ValueError(f"test_end_date ({test_end_date}) must be <= maximum date in dataset ({max_date})")
    
    # Generate date ranges
    test_dates = pl.Series(range(test_start_date, test_end_date + 1))
    # Lag dates are shifted back by one day but maintain same length
    lag_dates = pl.Series(range(test_start_date - 1, test_end_date))
    
    print(f"Created date ranges:")
    print(f"Test dates: {test_dates.min()} to {test_dates.max()} (n={len(test_dates)})")
    print(f"Lag dates: {lag_dates.min()} to {lag_dates.max()} (n={len(lag_dates)})")
    
    return test_dates, lag_dates

Created date ranges:
Test dates: 1694 to 1698 (n=5)
Lag dates: 1693 to 1697 (n=5)

Test data shape: (187792, 93)
Lag data shape: (187792, 93)


In [None]:
train = pl.scan_parquet(os.path.join(DATA_DIR, "train.parquet"))

# Note that there are 1698 days
start = 1650
end = 1670
test_dates, lag_dates = get_test_and_lag_dates(train, test_start_date=start, test_end_date=end)

# Get all data for both test and lags
test_data = train.filter(
    pl.col('date_id').is_in(test_dates)
).collect()

lag_data = train.filter(
    pl.col('date_id').is_in(lag_dates)
).collect()

### Create test and lags from this sample

In [41]:
# Create test data matching competition format
local_test_formatted = test_data.select([
    pl.int_range(0, pl.len()).cast(pl.UInt64).alias('id'),
    pl.int_range(0, pl.len()).cast(pl.Int64).alias('row_id'),
    # Shift date_ids to start at 0 while preserving order
    (pl.col('date_id') - test_data['date_id'].min()).cast(pl.Int16).alias('date_id'),
    pl.col('time_id').cast(pl.Int16),
    pl.col('symbol_id').cast(pl.Int8),
    pl.col('weight').cast(pl.Float32),
    pl.lit(True).alias('is_scored'),  # All rows scored in our local test
    
    # Get all feature columns in order
    *[pl.col(f'feature_{i:02d}').cast(pl.Float32) for i in range(79)],
    # Keep responder_6 for scoring
    pl.col('responder_6').cast(pl.Float32)
])

# Create lags data matching competition format
local_lags_formatted = lag_data.select([
    pl.int_range(0, pl.len()).cast(pl.UInt64).alias('id'),
    # Shift date_ids to start at 0 while preserving order
    (pl.col('date_id') - lag_data['date_id'].min()).cast(pl.Int16).alias('date_id'),
    pl.col('time_id').cast(pl.Int16),
    pl.col('symbol_id').cast(pl.Int8),
    # Get all responders with _lag_1 suffix
    *[pl.col(f'responder_{i}').cast(pl.Float32).alias(f'responder_{i}_lag_1') 
      for i in range(9)]
])

### Have a peak

In [42]:
def create_title(title):
    return HTML(f"""
    <h3>{title}</h3>
    """)

# Look at the data before saving
with pl.Config(tbl_rows=4, tbl_cols=-1):
    display(create_title("First rows of our formatted test data"))
    display(local_test_formatted)

# Look at the data before saving
with pl.Config(tbl_rows=4, tbl_cols=-1):
    display(create_title("First rows of our formatted lags data"))
    display(local_lags_formatted)

id,row_id,date_id,time_id,symbol_id,weight,is_scored,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,feature_32,feature_33,feature_34,feature_35,feature_36,feature_37,feature_38,feature_39,feature_40,feature_41,feature_42,feature_43,feature_44,feature_45,feature_46,feature_47,feature_48,feature_49,feature_50,feature_51,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78,responder_6
u64,i64,i16,i16,i8,f32,bool,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,0,0,0,0,3.152647,true,2.919276,1.92291,2.582785,3.105847,2.293566,-0.370599,0.138311,-0.130688,0.275721,11.0,7.0,76.0,-0.94706,0.664533,-0.269485,,-0.996777,,-1.604461,-1.644228,1.278864,-0.155247,1.313421,0.492202,1.66719,1.436057,1.102688,0.757871,0.637915,-0.465332,-1.108307,-0.170706,,,-0.593546,0.044052,1.378604,0.040623,0.129284,,0.640934,,,-0.220958,,-1.448401,1.199859,0.081077,-0.166701,-0.009289,,0.374428,,,-0.566921,,-2.329654,1.716035,,-0.765468,-0.347655,1.137164,-0.322298,-0.455329,-0.303883,-1.375988,-1.668673,-0.82594,0.803257,-0.310364,-1.189274,0.368912,-0.591277,,,-0.269808,-0.241873,-0.312287,-0.272453,-0.256121
1,1,0,0,1,2.89644,true,2.87712,1.995329,2.278081,3.169037,2.582665,-0.38225,0.212546,-0.17336,0.36353,11.0,7.0,76.0,-0.970736,2.009739,-0.195846,,-0.379707,,-1.482726,-2.315325,1.266528,0.008514,0.886489,0.875512,1.193528,0.457954,-1.116456,-0.068779,0.815845,-0.758722,-0.658705,0.007348,,,-0.342805,0.707429,1.829124,0.102788,0.132547,,0.014285,,,-0.846931,,-1.391555,0.952372,0.250118,-1.121867,-0.064084,,0.388397,,,-1.082998,,-1.733198,0.761915,,-1.50701,-0.539962,1.137164,-0.469725,-0.346583,-0.347141,-1.382208,-2.353191,-0.749504,1.993429,-0.056408,-0.916642,0.388281,-0.458708,,,-0.155287,-0.129516,-0.307254,-0.279362,-0.154576
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
187790,187790,4,967,37,1.243116,true,2.663298,-0.889112,2.313155,3.101428,0.324454,0.618944,1.185663,1.599724,0.319719,34.0,4.0,214.0,0.759314,0.284057,0.41716,-0.611075,-0.513717,-0.891423,1.84994,0.406756,-1.608196,-0.252663,-0.271574,-0.051405,0.098146,-0.653961,0.173676,-0.016497,-0.404509,-0.577262,-0.731429,-0.21646,3.018564,-0.472061,3.13922,3.065858,0.842925,0.053283,-0.074403,0.500129,0.08263,0.336223,0.643934,-0.422367,-0.418195,0.203037,-0.702278,0.543305,-0.195764,0.693364,0.953293,0.352567,0.471775,1.876459,-0.143377,0.845516,0.301135,-0.395703,0.738038,-0.04124,1.270645,-1.101531,-0.358106,-0.141883,-0.255192,2.489247,0.537652,0.982107,-0.158009,0.137389,0.478357,0.782692,0.581421,-0.106056,-0.111017,0.163867,0.169331,-0.037563,-0.029483,-0.148711
187791,187791,4,967,38,3.193685,true,2.728506,-0.745238,2.788789,2.343393,0.454731,0.862839,0.964795,2.089673,0.344931,50.0,1.0,522.0,0.406531,0.618247,1.01327,-0.952069,-0.679168,-0.597603,0.375125,1.97537,-0.440974,-0.072018,1.741353,1.380735,-0.110494,-0.874806,0.553424,0.532243,0.263214,-0.757856,-0.869204,-0.062955,3.619233,-0.386316,3.54456,3.120631,-1.443649,-0.257411,-0.309567,1.366358,-0.220885,0.029798,1.094489,-0.051078,-0.114243,0.517313,0.852201,0.522199,-0.027275,0.471593,1.213111,0.263278,0.915804,1.862022,0.503819,1.310126,0.662521,1.654948,1.090367,0.535922,0.653011,-1.101531,-0.622853,-0.363631,-0.395652,-0.016812,2.016734,0.241486,0.253229,0.228745,0.462717,0.799635,0.706102,-0.376377,-0.286764,-0.359046,-0.246135,-0.288941,-0.247774,-0.138548


id,date_id,time_id,symbol_id,responder_0_lag_1,responder_1_lag_1,responder_2_lag_1,responder_3_lag_1,responder_4_lag_1,responder_5_lag_1,responder_6_lag_1,responder_7_lag_1,responder_8_lag_1
u64,i16,i16,i8,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,0,0,0,-0.05731,-0.081993,-0.204956,-1.030227,0.700867,0.588168,-1.891638,1.016729,0.702558
1,0,0,1,-0.01258,-0.027038,-0.004415,-0.282159,0.236547,0.108119,-0.671197,0.361489,0.356874
…,…,…,…,…,…,…,…,…,…,…,…,…
187790,4,967,37,0.23585,0.556479,0.618944,-0.243765,-0.108361,-0.260777,-0.486923,-0.275566,-1.020708
187791,4,967,38,0.542563,0.513193,0.814393,0.032767,0.025435,0.311465,-0.044797,0.011133,-0.0793


In [35]:
# Define paths
test_dir = os.path.join(LOCAL_TEST_DIR, "test.parquet")

# Remove existing test directory if it exists
if os.path.exists(test_dir):
    shutil.rmtree(test_dir)

# Create main directory
os.makedirs(test_dir)

# Get max date_id
max_date = local_test_formatted['date_id'].max()

# Save data by date_id
for date_id in range(max_date + 1):
    # Create date subdirectory
    date_dir = os.path.join(test_dir, f"date_id={date_id}")
    os.makedirs(date_dir)
    
    # Filter data for this date and save
    local_test_formatted.filter(
        pl.col('date_id') == date_id
    ).write_parquet(
        os.path.join(date_dir, "part-0.parquet")
    )

print("Directory structure created:")
for root, dirs, files in os.walk(test_dir):
    print(f"Directory: {root}")
    for file in files:
        print(f"  File: {file}")

Directory structure created:
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/test.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/test.parquet/date_id=1
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/test.parquet/date_id=0
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/test.parquet/date_id=3
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/test.parquet/date_id=2
  File: part-0.parquet


In [36]:
# Define paths for lags
lags_dir = os.path.join(LOCAL_TEST_DIR, "lags.parquet")

# Remove existing lags directory if it exists
if os.path.exists(lags_dir):
    shutil.rmtree(lags_dir)

# Create main directory
os.makedirs(lags_dir)

# Save lags data by date_id
for date_id in range(max_date + 1):
    # Create date subdirectory
    date_dir = os.path.join(lags_dir, f"date_id={date_id}")
    os.makedirs(date_dir)
    
    # For each date_id in test, we want the previous day's responders
    # For date_id 0, we'll use the earliest data we have
    source_date = local_test_formatted['date_id'].min() if date_id == 0 else date_id - 1
    
    # Filter data for this date and save
    local_lags_formatted.filter(
        pl.col('date_id') == source_date
    ).write_parquet(
        os.path.join(date_dir, "part-0.parquet")
    )

print("\nLags directory structure created:")
for root, dirs, files in os.walk(lags_dir):
    print(f"Directory: {root}")
    for file in files:
        print(f"  File: {file}")


Lags directory structure created:
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/lags.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/lags.parquet/date_id=1
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/lags.parquet/date_id=0
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/lags.parquet/date_id=3
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/lags.parquet/date_id=2
  File: part-0.parquet
