### Paths and imports

In [14]:
import os
import polars as pl
import shutil

from IPython.display import HTML, display

In [15]:
# Note that there are 1698 days
test_date_start = 1690
test_date_end = 1698

latest_fill = test_date_start - 1

In [None]:
# TODO: Variable is scored. 
# TODO: Lags aren't actually lagged

: 

In [16]:
# Setup paths relative to this notebook's location
NOTEBOOK_DIR = os.getcwd()  # submissions folder
PROJ_DIR = os.path.dirname(NOTEBOOK_DIR)  # main project folder
DATA_DIR = os.path.join(PROJ_DIR, "jane-street-real-time-market-data-forecasting")
LOCAL_TEST_DIR = os.path.join(NOTEBOOK_DIR, "local_test_data")

# Create local test directory if it doesn't exist
os.makedirs(LOCAL_TEST_DIR, exist_ok=True)

### Grab a sample of train to turn into test data

In [17]:
def get_test_and_lag_dates(train_df: pl.LazyFrame, test_start_date: int, test_end_date: int = None) -> tuple[pl.Series, pl.Series]:

    # Validate test_start_date exists in data
    min_date = train_df.select(pl.col('date_id').min()).collect().item()
    max_date = train_df.select(pl.col('date_id').max()).collect().item()
    
    if test_start_date < min_date:
        raise ValueError(f"test_start_date ({test_start_date}) must be >= minimum date in dataset ({min_date})")
    
    # If no end date specified, use max date
    if test_end_date is None:
        test_end_date = max_date
    elif test_end_date > max_date:
        raise ValueError(f"test_end_date ({test_end_date}) must be <= maximum date in dataset ({max_date})")
    
    # Generate date ranges
    test_dates = pl.Series(range(test_start_date, test_end_date + 1))
    # Lag dates are shifted back by one day but maintain same length
    lag_dates = pl.Series(range(test_start_date, test_end_date + 1))
    
    print(f"Created date ranges:")
    print(f"Test dates: {test_dates.min()} to {test_dates.max()} (n={len(test_dates)})")
    print(f"Lag dates: {lag_dates.min()} to {lag_dates.max()} (n={len(lag_dates)})")
    
    return test_dates, lag_dates

In [18]:
train = pl.scan_parquet(os.path.join(DATA_DIR, "train.parquet"))

test_dates, lag_dates = get_test_and_lag_dates(train, test_start_date=test_date_start, test_end_date=test_date_end)

# Get all data for both test and lags
test_data = train.filter(
    pl.col('date_id').is_in(test_dates)
).collect()

lag_data = train.filter(
    pl.col('date_id').is_in(lag_dates)
).collect()

Created date ranges:
Test dates: 1690 to 1698 (n=9)
Lag dates: 1690 to 1698 (n=9)


### Create test and lags from this sample

In [19]:
# Create test data matching competition format
local_test_formatted = test_data.select([
    pl.int_range(0, pl.len()).cast(pl.UInt64).alias('id'),
    pl.int_range(0, pl.len()).cast(pl.Int64).alias('row_id'),
    # Shift date_ids to start at 0 while preserving order
    (pl.col('date_id') - test_data['date_id'].min()).cast(pl.Int16).alias('date_id'),
    pl.col('time_id').cast(pl.Int16),
    pl.col('symbol_id').cast(pl.Int8),
    pl.col('weight').cast(pl.Float32),
    pl.lit(True).alias('is_scored'),  # All rows scored in our local test
    
    # Get all feature columns in order
    *[pl.col(f'feature_{i:02d}').cast(pl.Float32) for i in range(79)],
    # Keep responder_6 for scoring
    pl.col('responder_6').cast(pl.Float32)
])

# Create lags data matching competition format
local_lags_formatted = lag_data.select([
    pl.int_range(0, pl.len()).cast(pl.UInt64).alias('id'),
    # Shift date_ids to start at 0 while preserving order
    (pl.col('date_id') - lag_data['date_id'].min()).cast(pl.Int16).alias('date_id'),
    pl.col('time_id').cast(pl.Int16),
    pl.col('symbol_id').cast(pl.Int8),
    # Get all responders with _lag_1 suffix
    *[pl.col(f'responder_{i}').cast(pl.Float32).alias(f'responder_{i}_lag_1') 
      for i in range(9)]
])

### Have a peak

In [20]:
def create_title(title):
    return HTML(f"""
    <h3>{title}</h3>
    """)

# Look at the data before saving
with pl.Config(tbl_rows=4, tbl_cols=-1):
    display(create_title("First rows of our formatted test data"))
    display(local_test_formatted)

# Look at the data before saving
with pl.Config(tbl_rows=4, tbl_cols=-1):
    display(create_title("First rows of our formatted lags data"))
    display(local_lags_formatted)

id,row_id,date_id,time_id,symbol_id,weight,is_scored,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,feature_32,feature_33,feature_34,feature_35,feature_36,feature_37,feature_38,feature_39,feature_40,feature_41,feature_42,feature_43,feature_44,feature_45,feature_46,feature_47,feature_48,feature_49,feature_50,feature_51,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78,responder_6
u64,i64,i16,i16,i8,f32,bool,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,0,0,0,0,3.427295,true,2.835166,-0.854095,3.444795,3.097683,2.432463,-0.296623,-0.003247,0.174131,0.113434,11.0,7.0,76.0,-1.054923,1.92637,-0.196895,,-0.767655,,-0.627067,-1.335889,1.073516,-0.081029,1.935252,0.845424,1.367642,1.223538,0.96692,1.280783,0.581346,-0.73954,-0.984937,-0.04214,,,1.058765,1.729322,2.006166,0.053617,0.141686,,0.484563,,,-0.469076,,-1.02392,1.38806,-0.226523,0.097061,0.27471,,0.105249,,,-0.007427,,-2.226455,1.170089,,-0.698984,-0.204521,0.454603,-0.531477,-0.449067,-0.575404,-1.744474,-1.873604,-0.652655,1.372647,-0.134405,-1.111272,0.807675,-0.423163,,,-0.054287,-0.063634,-0.289314,-0.231896,-0.03525
1,1,0,0,1,3.114638,true,2.80876,-0.345852,2.843017,2.771165,2.587666,-0.294217,-0.002645,0.190838,0.122362,11.0,7.0,76.0,-1.056737,0.667128,-0.385839,,-0.772138,,-1.143342,-1.639617,0.821671,0.054117,0.870084,0.740512,1.704525,0.601606,-1.778843,-0.244074,0.93686,-0.569216,-0.754085,0.047782,,,0.03495,0.231156,-2.1351,0.014636,0.0781,,0.562146,,,-1.134097,,-1.460074,1.109892,-0.457376,-1.192078,-0.965832,,0.40512,,,-0.589856,,-2.119151,1.22752,,-0.20853,-0.092294,0.454603,-0.251779,-0.329865,-0.332368,-1.31082,-2.794437,-0.622243,0.317424,-0.357538,-0.793062,0.456985,-0.486649,,,-0.07044,-0.054276,-0.207491,-0.261169,-1.332828
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
338798,338798,8,967,37,1.243116,true,2.663298,-0.889112,2.313155,3.101428,0.324454,0.618944,1.185663,1.599724,0.319719,34.0,4.0,214.0,0.759314,0.284057,0.41716,-0.611075,-0.513717,-0.891423,1.84994,0.406756,-1.608196,-0.252663,-0.271574,-0.051405,0.098146,-0.653961,0.173676,-0.016497,-0.404509,-0.577262,-0.731429,-0.21646,3.018564,-0.472061,3.13922,3.065858,0.842925,0.053283,-0.074403,0.500129,0.08263,0.336223,0.643934,-0.422367,-0.418195,0.203037,-0.702278,0.543305,-0.195764,0.693364,0.953293,0.352567,0.471775,1.876459,-0.143377,0.845516,0.301135,-0.395703,0.738038,-0.04124,1.270645,-1.101531,-0.358106,-0.141883,-0.255192,2.489247,0.537652,0.982107,-0.158009,0.137389,0.478357,0.782692,0.581421,-0.106056,-0.111017,0.163867,0.169331,-0.037563,-0.029483,-0.148711
338799,338799,8,967,38,3.193685,true,2.728506,-0.745238,2.788789,2.343393,0.454731,0.862839,0.964795,2.089673,0.344931,50.0,1.0,522.0,0.406531,0.618247,1.01327,-0.952069,-0.679168,-0.597603,0.375125,1.97537,-0.440974,-0.072018,1.741353,1.380735,-0.110494,-0.874806,0.553424,0.532243,0.263214,-0.757856,-0.869204,-0.062955,3.619233,-0.386316,3.54456,3.120631,-1.443649,-0.257411,-0.309567,1.366358,-0.220885,0.029798,1.094489,-0.051078,-0.114243,0.517313,0.852201,0.522199,-0.027275,0.471593,1.213111,0.263278,0.915804,1.862022,0.503819,1.310126,0.662521,1.654948,1.090367,0.535922,0.653011,-1.101531,-0.622853,-0.363631,-0.395652,-0.016812,2.016734,0.241486,0.253229,0.228745,0.462717,0.799635,0.706102,-0.376377,-0.286764,-0.359046,-0.246135,-0.288941,-0.247774,-0.138548


id,date_id,time_id,symbol_id,responder_0_lag_1,responder_1_lag_1,responder_2_lag_1,responder_3_lag_1,responder_4_lag_1,responder_5_lag_1,responder_6_lag_1,responder_7_lag_1,responder_8_lag_1
u64,i16,i16,i8,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,0,0,0,-0.005415,-0.23267,-0.130127,-0.028247,-0.098344,-0.231422,-0.03525,-0.06153,-0.430392
1,0,0,1,-0.020811,-0.067131,0.13835,-0.729628,-0.660754,-1.335486,-1.332828,-0.508476,-2.208436
…,…,…,…,…,…,…,…,…,…,…,…,…
338798,8,967,37,1.925987,0.479394,3.621867,-0.107114,-0.063599,1.204755,-0.148711,-0.026583,-0.256395
338799,8,967,38,1.228778,0.512562,-0.050865,0.160883,0.080756,-0.078237,-0.138548,-0.038771,-0.21194


In [24]:
# Define paths
test_dir = os.path.join(LOCAL_TEST_DIR, "test.parquet")

# Remove existing test directory if it exists
if os.path.exists(test_dir):
    shutil.rmtree(test_dir)

# Create main directory
os.makedirs(test_dir)

# Get max date_id
max_date = local_test_formatted['date_id'].max()

# Save data by date_id
for date_id in range(max_date + 1):
    # Create date subdirectory
    date_dir = os.path.join(test_dir, f"date_id={date_id}")
    os.makedirs(date_dir)
    
    # Filter data for this date and save
    local_test_formatted.filter(
        pl.col('date_id') == date_id
    ).write_parquet(
        os.path.join(date_dir, "part-0.parquet")
    )

print("Directory structure created:")
for root, dirs, files in os.walk(test_dir):
    print(f"Directory: {root}")
    for file in files:
        print(f"  File: {file}")

Directory structure created:
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/test.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/test.parquet/date_id=4
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/test.parquet/date_id=6
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/test.parquet/date_id=1
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/test.parquet/date_id=8
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/test.parquet/date_id=0
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/test.parquet/date_id=5
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/lo

In [25]:
# Define paths for lags
lags_dir = os.path.join(LOCAL_TEST_DIR, "lags.parquet")

# Remove existing lags directory if it exists
if os.path.exists(lags_dir):
    shutil.rmtree(lags_dir)

# Create main directory
os.makedirs(lags_dir)

# Save lags data by date_id
for date_id in range(max_date + 1):
    # Create date subdirectory
    date_dir = os.path.join(lags_dir, f"date_id={date_id}")
    os.makedirs(date_dir)
    
    # For each date_id in test, we want the previous day's responders
    # For date_id 0, we'll use the earliest data we have
    source_date = local_test_formatted['date_id'].min() if date_id == 0 else date_id - 1
    
    # Filter data for this date and save
    local_lags_formatted.filter(
        pl.col('date_id') == source_date
    ).write_parquet(
        os.path.join(date_dir, "part-0.parquet")
    )

print("\nLags directory structure created:")
for root, dirs, files in os.walk(lags_dir):
    print(f"Directory: {root}")
    for file in files:
        print(f"  File: {file}")


Lags directory structure created:
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/lags.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/lags.parquet/date_id=4
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/lags.parquet/date_id=6
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/lags.parquet/date_id=1
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/lags.parquet/date_id=8
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/lags.parquet/date_id=0
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_testing/local_test_data/lags.parquet/date_id=5
  File: part-0.parquet
Directory: /monfs01/projects/ys68/JaneStreet-Kaggle/submission_test

In [26]:
def get_last_known_values_at_date(cutoff_date: int) -> dict:
    """Get last known values for all features efficiently"""
    # Load data up to cutoff
    historical_data = pl.scan_parquet(os.path.join(DATA_DIR, "train.parquet")).\
        filter(pl.col("date_id") <= cutoff_date)
    
    # Get last non-null value for each feature grouped by symbol
    last_values = historical_data.group_by('symbol_id').agg([
        pl.col(f'feature_{i:02d}').filter(~pl.col(f'feature_{i:02d}').is_null())
        .last()
        .alias(f'feature_{i:02d}')
        for i in range(79)
    ]).collect()
    
    # Convert to dictionary format
    value_dict = {}
    for row in last_values.iter_rows(named=True):
        sym = row['symbol_id']
        for i in range(79):
            feat = f'feature_{i:02d}'
            value_dict[(sym, feat)] = row[feat] if row[feat] is not None else 0.0
    
    # Print in copyable format
    print("last_known_values = {")
    for (sym, feat), val in sorted(value_dict.items()):
        print(f"    ({sym}, '{feat}'): {val:.6f},")
    print("}")
    
    return value_dict

In [27]:
# Run ths as 1698 to get the dictionary for kaggle upload.
last_values = get_last_known_values_at_date(1698)

#last_known_values = get_last_known_values_at_date(latest_fill)

last_known_values = {
    (0, 'feature_00'): 2.827219,
    (0, 'feature_01'): -0.875262,
    (0, 'feature_02'): 2.520329,
    (0, 'feature_03'): 1.967673,
    (0, 'feature_04'): 0.516955,
    (0, 'feature_05'): 0.528624,
    (0, 'feature_06'): 0.759274,
    (0, 'feature_07'): 1.303269,
    (0, 'feature_08'): 0.172209,
    (0, 'feature_09'): 11.000000,
    (0, 'feature_10'): 7.000000,
    (0, 'feature_11'): 76.000000,
    (0, 'feature_12'): 0.422208,
    (0, 'feature_13'): 0.544176,
    (0, 'feature_14'): 0.088627,
    (0, 'feature_15'): -0.838364,
    (0, 'feature_16'): -0.385189,
    (0, 'feature_17'): -0.794437,
    (0, 'feature_18'): 1.732842,
    (0, 'feature_19'): 0.257682,
    (0, 'feature_20'): 1.515887,
    (0, 'feature_21'): -0.202102,
    (0, 'feature_22'): 1.285443,
    (0, 'feature_23'): 0.787999,
    (0, 'feature_24'): 1.921844,
    (0, 'feature_25'): 1.281180,
    (0, 'feature_26'): 1.223592,
    (0, 'feature_27'): 1.002706,
    (0, 'feature_28'): 0.587736,
    (0, 'featu

: 