# Libraries

In [1]:
import pandas as pd
import polars as pl
import numpy as np
import gc
from matplotlib import pyplot as plt
import matplotlib.cm as cm
from sklearn.model_selection import StratifiedGroupKFold

# Configurations

In [2]:
class CONFIG:
    target_col = "responder_6"
    lag_cols_original = ["date_id", "symbol_id"] + [f"responder_{idx}" for idx in range(9)]
    lag_cols_rename = { f"responder_{idx}" : f"responder_{idx}_lag_1" for idx in range(9)}
    valid_ratio = 0.05
    start_dt = 1100
# 设置数据路径
DATA_PATH = "/Users/liwito/OneDrive-HKUSTConnect/Create/kaggle/local/jane-street-real-time-market-data-forecasting"

In [13]:
train = pl.scan_parquet(
    f"{DATA_PATH}/train.parquet"
)

# train = pl.scan_parquet(
#     f"{DATA_PATH}/train.parquet"
# ).select(
#     pl.col("date_id").unique().sort()
# ).collect().to_series().to_list()


# print(f"最小日期: {min(train)}")
# print(f"最大日期: {max(train)}")
# print(f"日期总数: {len(train)}")
# print(f"日期列表: {train}")  # 如果想看所有具体值

# result = train.select(
#     pl.col("date_id").unique().sort()
# ).with_columns(
#     pl.col("date_id").diff().alias("gap")  # 计算相邻日期的间隔
# ).collect()

# print("日期间隔统计：")
# print(result)  # 可以看到日期是否连续

# 查看过滤前后的数据量
before_filter = train.select(
    pl.col("date_id").n_unique().alias("total_dates"),
    pl.len().alias("total_records")
).collect()

after_filter = train.filter(
    pl.col("date_id").gt(1100)
).select(
    pl.col("date_id").n_unique().alias("remaining_dates"),
    pl.len().alias("remaining_records")
).collect()

# 计算每个日期的平均记录数
records_per_date = train.group_by("date_id").agg(
    pl.len().alias("records")
).collect()
print(records_per_date)
# print(records_per_date.shape)
print(after_filter)
# print(after_filter.shape)
print(before_filter)
# print(before_filter.shape)


shape: (1_699, 2)
┌─────────┬─────────┐
│ date_id ┆ records │
│ ---     ┆ ---     │
│ i16     ┆ u32     │
╞═════════╪═════════╡
│ 271     ┆ 16980   │
│ 161     ┆ 11037   │
│ 792     ┆ 29040   │
│ 432     ┆ 16980   │
│ 1355    ┆ 37752   │
│ …       ┆ …       │
│ 1548    ┆ 37752   │
│ 110     ┆ 11886   │
│ 917     ┆ 30008   │
│ 89      ┆ 12735   │
│ 896     ┆ 30008   │
└─────────┴─────────┘
shape: (1, 2)
┌─────────────────┬───────────────────┐
│ remaining_dates ┆ remaining_records │
│ ---             ┆ ---               │
│ u32             ┆ u32               │
╞═════════════════╪═══════════════════╡
│ 598             ┆ 22104280          │
└─────────────────┴───────────────────┘
shape: (1, 2)
┌─────────────┬───────────────┐
│ total_dates ┆ total_records │
│ ---         ┆ ---           │
│ u32         ┆ u32           │
╞═════════════╪═══════════════╡
│ 1699        ┆ 47127338      │
└─────────────┴───────────────┘


# Load training data

In [None]:
# Use last 2 parquets
train = pl.scan_parquet(
    f"{DATA_PATH}/train.parquet"
).select(
    pl.int_range(pl.len(), dtype=pl.UInt32).alias("id"),
    pl.all(),
).with_columns(
    (pl.col(CONFIG.target_col)*2).cast(pl.Int32).alias("label"),
).filter(
    pl.col("date_id").gt(CONFIG.start_dt)
)

# Create Lags data from training data

In [4]:
lags = train.select(pl.col(CONFIG.lag_cols_original))
lags = lags.rename(CONFIG.lag_cols_rename)
lags = lags.with_columns(
    date_id = pl.col('date_id') + 1,  # lagged by 1 day
    )
lags = lags.group_by(["date_id", "symbol_id"], maintain_order=True).last()  # pick up last record of previous date
lags

# Merge training data and lags data

In [5]:
train = train.join(lags, on=["date_id", "symbol_id"],  how="left")
train

# Split training data and validation data

In [6]:
len_train   = train.select(pl.col("date_id")).collect().shape[0]
valid_records = int(len_train * CONFIG.valid_ratio)
len_ofl_mdl = len_train - valid_records
last_tr_dt  = train.select(pl.col("date_id")).collect().row(len_ofl_mdl)[0]

print(f"\n len_train = {len_train}")
print(f"\n len_ofl_mdl = {len_ofl_mdl}")
print(f"\n---> Last offline train date = {last_tr_dt}\n")

training_data = train.filter(pl.col("date_id").le(last_tr_dt))
validation_data   = train.filter(pl.col("date_id").gt(last_tr_dt))


 len_train = 22104280

 len_ofl_mdl = 20999066

---> Last offline train date = 1669



In [7]:
validation_data

# Save data as parquets

In [8]:
training_data.collect().\
write_parquet(
    f"training.parquet", partition_by = "date_id",
)

In [9]:
validation_data.collect().\
write_parquet(
    "validation.parquet", partition_by = "date_id",
)