### Paths and imports

In [32]:
import os
import polars as pl
import shutil

from IPython.display import HTML, display

In [33]:
# TODO: Variable is scored. 
# TODO: Lags aren't actually lagged

In [34]:
# Setup paths relative to this notebook's location
NOTEBOOK_DIR = os.getcwd()  # submissions folder
PROJ_DIR = os.path.dirname(NOTEBOOK_DIR)  # main project folder
DATA_DIR = os.path.join(PROJ_DIR, "jane-street-real-time-market-data-forecasting")
LOCAL_TEST_DIR = os.path.join(NOTEBOOK_DIR, "local_test_data")

# Create local test directory if it doesn't exist
os.makedirs(LOCAL_TEST_DIR, exist_ok=True)

### Grab a sample of train to turn into test data

In [35]:
def get_test_and_lag_dates(train_df: pl.LazyFrame, test_start_date: int, test_end_date: int = None) -> tuple[pl.Series, pl.Series]:

    # Validate test_start_date exists in data
    min_date = train_df.select(pl.col('date_id').min()).collect().item()
    max_date = train_df.select(pl.col('date_id').max()).collect().item()
    
    if test_start_date < min_date:
        raise ValueError(f"test_start_date ({test_start_date}) must be >= minimum date in dataset ({min_date})")
    
    # If no end date specified, use max date
    if test_end_date is None:
        test_end_date = max_date
    elif test_end_date > max_date:
        raise ValueError(f"test_end_date ({test_end_date}) must be <= maximum date in dataset ({max_date})")
    
    # Generate date ranges
    test_dates = pl.Series(range(test_start_date, test_end_date + 1))
    # Lag dates are shifted back by one day but maintain same length
    lag_dates = pl.Series(range(test_start_date - 1, test_end_date))
    
    print(f"Created date ranges:")
    print(f"Test dates: {test_dates.min()} to {test_dates.max()} (n={len(test_dates)})")
    print(f"Lag dates: {lag_dates.min()} to {lag_dates.max()} (n={len(lag_dates)})")
    
    return test_dates, lag_dates

In [36]:
train = pl.scan_parquet(os.path.join(DATA_DIR, "train.parquet"))

# Note that there are 1698 days
start = 1670
end = 1690
test_dates, lag_dates = get_test_and_lag_dates(train, test_start_date=start, test_end_date=end)

# Get all data for both test and lags
test_data = train.filter(
    pl.col('date_id').is_in(test_dates)
).collect()

lag_data = train.filter(
    pl.col('date_id').is_in(lag_dates)
).collect()

Created date ranges:
Test dates: 1670 to 1690 (n=21)
Lag dates: 1669 to 1689 (n=21)


### Create test and lags from this sample

In [37]:
# Create test data matching competition format
local_test_formatted = test_data.select([
    pl.int_range(0, pl.len()).cast(pl.UInt64).alias('id'),
    pl.int_range(0, pl.len()).cast(pl.Int64).alias('row_id'),
    # Shift date_ids to start at 0 while preserving order
    (pl.col('date_id') - test_data['date_id'].min()).cast(pl.Int16).alias('date_id'),
    pl.col('time_id').cast(pl.Int16),
    pl.col('symbol_id').cast(pl.Int8),
    pl.col('weight').cast(pl.Float32),
    pl.lit(True).alias('is_scored'),  # All rows scored in our local test
    
    # Get all feature columns in order
    *[pl.col(f'feature_{i:02d}').cast(pl.Float32) for i in range(79)],
    # Keep responder_6 for scoring
    pl.col('responder_6').cast(pl.Float32)
])

# Create lags data matching competition format
local_lags_formatted = lag_data.select([
    pl.int_range(0, pl.len()).cast(pl.UInt64).alias('id'),
    # Shift date_ids to start at 0 while preserving order
    (pl.col('date_id') - lag_data['date_id'].min()).cast(pl.Int16).alias('date_id'),
    pl.col('time_id').cast(pl.Int16),
    pl.col('symbol_id').cast(pl.Int8),
    # Get all responders with _lag_1 suffix
    *[pl.col(f'responder_{i}').cast(pl.Float32).alias(f'responder_{i}_lag_1') 
      for i in range(9)]
])

### Have a peak

In [38]:
def create_title(title):
    return HTML(f"""
    <h3>{title}</h3>
    """)

# Look at the data before saving
with pl.Config(tbl_rows=4, tbl_cols=-1):
    display(create_title("First rows of our formatted test data"))
    display(local_test_formatted)

# Look at the data before saving
with pl.Config(tbl_rows=4, tbl_cols=-1):
    display(create_title("First rows of our formatted lags data"))
    display(local_lags_formatted)

id,row_id,date_id,time_id,symbol_id,weight,is_scored,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,feature_32,feature_33,feature_34,feature_35,feature_36,feature_37,feature_38,feature_39,feature_40,feature_41,feature_42,feature_43,feature_44,feature_45,feature_46,feature_47,feature_48,feature_49,feature_50,feature_51,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78,responder_6
u64,i64,i16,i16,i8,f32,bool,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,0,0,0,0,3.405333,true,0.310848,1.131483,0.745982,0.776556,-0.798888,0.189267,-0.661057,-0.29037,-0.228528,11.0,7.0,76.0,-0.721815,0.178626,-0.625948,,-0.60002,,-1.773468,-1.426673,0.752113,-0.128356,1.485508,0.798556,0.939105,0.531187,1.102988,0.510059,0.409313,-0.711646,-0.902749,-0.162899,,,0.619876,0.466504,-1.488351,0.000466,-0.099937,,0.716869,,,0.10863,,-0.957244,1.638018,-0.578405,0.482135,0.009865,,0.25855,,,-0.701604,,-1.752637,2.125332,,0.074168,0.018479,1.027483,-0.525323,-0.239548,-0.619589,-1.70162,-1.398736,-0.939831,0.29814,-0.321387,-0.737133,0.032996,-0.614959,,,0.811054,0.708597,-0.001946,0.004254,-0.119038
1,1,0,0,1,2.619899,true,0.382185,0.458687,0.544353,0.09255,-0.911843,0.130969,-0.905152,-0.301003,-0.187945,11.0,7.0,76.0,-0.920349,2.437098,-0.075225,,-0.504462,,-1.68278,-1.309567,1.357302,-0.007474,1.107871,0.374563,1.059325,1.151365,-1.985968,-0.465673,0.825128,-0.756643,-0.735251,-0.006619,,,0.495858,0.553447,-1.584693,-0.05839,0.086217,,0.514809,,,-0.652876,,-1.598468,1.125729,-0.049231,0.421088,-0.185787,,1.567681,,,-1.137877,,-1.826585,1.230211,,1.252662,0.437808,1.027483,-0.342516,-0.421323,-0.424025,-1.053866,-1.873317,-0.709916,1.826658,-0.019934,-0.572781,2.331671,-0.2095,,,0.859362,0.641228,-0.051073,-0.032598,-0.339291
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
781174,781174,20,967,37,1.55614,true,2.911813,-0.242471,3.700644,3.222936,0.124478,0.883587,2.292406,1.394078,-0.878478,34.0,4.0,214.0,-0.314272,0.124768,-0.041822,-0.825399,-0.342005,-0.531658,-2.057258,0.486214,-1.707943,-0.21099,-0.417635,-0.361594,-0.059514,-0.713784,0.348272,-0.298016,-0.690916,-0.659161,-0.50128,-0.180412,0.869804,-0.370386,0.830234,0.948416,0.238697,-1.055938,-1.072071,0.820401,0.1911,1.930064,0.433063,-0.826805,0.059136,0.765181,-0.456113,0.448141,0.950471,1.14333,0.769916,-0.603411,1.092517,1.124749,0.02841,0.866227,1.507799,-0.433767,0.516783,1.255785,1.952839,0.454603,-0.529127,-0.285267,-0.215432,-1.768409,0.843704,-0.253711,0.093359,-0.056299,-0.51414,0.172416,-0.000188,0.017093,0.088445,0.24878,0.617374,0.155965,0.217655,0.197556
781175,781175,20,967,38,3.576271,true,3.382692,-0.261147,2.76647,2.566119,0.185446,0.974221,2.207694,1.890429,-0.72314,50.0,1.0,522.0,0.175068,0.782206,0.33322,-0.588666,-0.535963,-0.758741,-1.017514,2.328029,-0.324293,-0.112635,1.542221,0.837426,0.203225,-0.386168,0.463996,0.343767,0.129188,-0.747565,-0.995199,-0.136413,0.851929,0.011313,0.929717,0.958114,-0.656523,-0.728048,-0.653399,0.434715,0.826769,0.387983,0.53527,-0.057372,0.106603,0.521531,1.616243,0.414272,0.546593,0.527728,0.170201,-0.197693,-0.097086,1.389461,0.085452,0.545139,0.633178,0.87364,0.668379,0.314732,0.408428,0.454603,-0.43304,-0.288851,-0.542692,-1.597191,2.074712,-0.036793,0.223281,-0.055106,0.350647,1.695994,0.864276,-0.188388,-0.256035,-0.132781,-0.111854,-0.157093,-0.135594,-0.17002


id,date_id,time_id,symbol_id,responder_0_lag_1,responder_1_lag_1,responder_2_lag_1,responder_3_lag_1,responder_4_lag_1,responder_5_lag_1,responder_6_lag_1,responder_7_lag_1,responder_8_lag_1
u64,i16,i16,i8,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,0,0,0,0.035418,0.034741,0.414574,2.384467,0.784588,0.952633,2.203616,1.399664,1.403749
1,0,0,1,-0.173574,0.141018,-0.341351,1.997512,0.825174,0.093344,2.674313,1.074391,0.439501
…,…,…,…,…,…,…,…,…,…,…,…,…
780206,20,967,37,-2.508144,-1.549502,-2.070124,0.155362,0.088527,0.903197,0.761249,0.355074,1.63809
780207,20,967,38,-0.717013,-0.048132,-1.166597,-1.290214,-0.498964,-0.248329,-0.090587,-0.010463,-0.136916


In [39]:
# Define paths
test_dir = os.path.join(LOCAL_TEST_DIR, "test.parquet")

# Remove existing test directory if it exists
if os.path.exists(test_dir):
    shutil.rmtree(test_dir)

# Create main directory
os.makedirs(test_dir)

# Get max date_id
max_date = local_test_formatted['date_id'].max()

# Save data by date_id
for date_id in range(max_date + 1):
    # Create date subdirectory
    date_dir = os.path.join(test_dir, f"date_id={date_id}")
    os.makedirs(date_dir)
    
    # Filter data for this date and save
    local_test_formatted.filter(
        pl.col('date_id') == date_id
    ).write_parquet(
        os.path.join(date_dir, "part-0.parquet")
    )

print("Directory structure created:")
for root, dirs, files in os.walk(test_dir):
    print(f"Directory: {root}")
    for file in files:
        print(f"  File: {file}")

Directory structure created:
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/test.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/test.parquet/date_id=15
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/test.parquet/date_id=4
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/test.parquet/date_id=6
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/test.parquet/date_id=1
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/test.parquet/date_id=10
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/test.parquet/date_id=8
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/

In [40]:
# Define paths for lags
lags_dir = os.path.join(LOCAL_TEST_DIR, "lags.parquet")

# Remove existing lags directory if it exists
if os.path.exists(lags_dir):
    shutil.rmtree(lags_dir)

# Create main directory
os.makedirs(lags_dir)

# Save lags data by date_id
for date_id in range(max_date + 1):
    # Create date subdirectory
    date_dir = os.path.join(lags_dir, f"date_id={date_id}")
    os.makedirs(date_dir)
    
    # For each date_id in test, we want the previous day's responders
    # For date_id 0, we'll use the earliest data we have
    source_date = local_test_formatted['date_id'].min() if date_id == 0 else date_id - 1
    
    # Filter data for this date and save
    local_lags_formatted.filter(
        pl.col('date_id') == source_date
    ).write_parquet(
        os.path.join(date_dir, "part-0.parquet")
    )

print("\nLags directory structure created:")
for root, dirs, files in os.walk(lags_dir):
    print(f"Directory: {root}")
    for file in files:
        print(f"  File: {file}")


Lags directory structure created:
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/lags.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/lags.parquet/date_id=15
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/lags.parquet/date_id=4
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/lags.parquet/date_id=6
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/lags.parquet/date_id=1
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/lags.parquet/date_id=10
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/lags.parquet/date_id=8
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_te

In [41]:
def get_last_known_values_at_date(cutoff_date: int) -> dict:
    """Get last known values for all features efficiently"""
    # Load data up to cutoff
    historical_data = pl.scan_parquet(os.path.join(DATA_DIR, "train.parquet")).\
        filter(pl.col("date_id") <= cutoff_date)
    
    # Get last non-null value for each feature grouped by symbol
    last_values = historical_data.group_by('symbol_id').agg([
        pl.col(f'feature_{i:02d}').filter(~pl.col(f'feature_{i:02d}').is_null())
        .last()
        .alias(f'feature_{i:02d}')
        for i in range(79)
    ]).collect()
    
    # Convert to dictionary format
    value_dict = {}
    for row in last_values.iter_rows(named=True):
        sym = row['symbol_id']
        for i in range(79):
            feat = f'feature_{i:02d}'
            value_dict[(sym, feat)] = row[feat] if row[feat] is not None else 0.0
    
    # Print in copyable format
    print("last_known_values = {")
    for (sym, feat), val in sorted(value_dict.items()):
        print(f"    ({sym}, '{feat}'): {val:.6f},")
    print("}")
    
    return value_dict

In [42]:
# Run ths as 1698 to get the dictionary for kaggle upload.
# last_values = get_last_known_values_at_date(1698)

last_known_values = get_last_known_values_at_date(start - 1)

last_known_values = {
    (0, 'feature_00'): 0.606275,
    (0, 'feature_01'): -1.111470,
    (0, 'feature_02'): 0.957191,
    (0, 'feature_03'): 0.785588,
    (0, 'feature_04'): -1.686526,
    (0, 'feature_05'): -1.608237,
    (0, 'feature_06'): -0.264967,
    (0, 'feature_07'): -1.092005,
    (0, 'feature_08'): -0.221835,
    (0, 'feature_09'): 11.000000,
    (0, 'feature_10'): 7.000000,
    (0, 'feature_11'): 76.000000,
    (0, 'feature_12'): 0.940446,
    (0, 'feature_13'): 0.597973,
    (0, 'feature_14'): 1.052861,
    (0, 'feature_15'): -1.091804,
    (0, 'feature_16'): -0.432034,
    (0, 'feature_17'): -0.575417,
    (0, 'feature_18'): 5.115178,
    (0, 'feature_19'): -0.046730,
    (0, 'feature_20'): 0.714703,
    (0, 'feature_21'): -0.116039,
    (0, 'feature_22'): 1.779095,
    (0, 'feature_23'): 0.820886,
    (0, 'feature_24'): 0.532221,
    (0, 'feature_25'): 0.298325,
    (0, 'feature_26'): 1.257741,
    (0, 'feature_27'): 0.838528,
    (0, 'feature_28'): 0.552220,
    (0, 