# Neural Network Training Notebook For Jane Street Real-Time Market Data Forecasting

In [1]:
import os
import joblib 

import pandas as pd
import polars as pl
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as lightning
from pytorch_lightning import (LightningDataModule, LightningModule, Trainer)
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, Timer
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split

In [2]:
target="responder_6"
start_date_id = 1448 #use last 250 days
lags_cols = ["date_id", "symbol_id"] + [f"responder_{idx}" for idx in range(9)]
validation_ratio=0.05
features = ["symbol_id", "time_id"] + [f"feature_{idx:02d}" for idx in range(79)] + [f"responder_{idx}_lag_1" for idx in range(9)]
SEED=42

In [3]:
train = pl.scan_parquet(
    f"/kaggle/input/jane-street-real-time-market-data-forecasting/train.parquet"
).select(
    pl.int_range(pl.len(), dtype=pl.UInt32).alias("id"),
    pl.all(),
).with_columns(
    (pl.col(target)*2).cast(pl.Int32).alias("label"),
).filter(
    pl.col("date_id").gt(start_date_id)
)

In [4]:
lags = train.select(pl.col(lags_cols))
lags_cols_rename = { f"responder_{idx}" : f"responder_{idx}_lag_1" for idx in range(9)}
lags = lags.rename(lags_cols_rename)
lags = lags.with_columns(
    date_id = pl.col('date_id') + 1,  # 1 day lag
    )
lags = lags.group_by(["date_id", "symbol_id"], maintain_order=True).last()  # pick up last record of previous date

In [5]:
train = train.join(lags, on=["date_id", "symbol_id"],  how="left")

In [6]:
len_set   = train.select(pl.col("date_id")).collect().shape[0]
len_validation = int(len_set * validation_ratio)
len_train = len_set - len_validation
last_train_date  = train.select(pl.col("date_id")).collect().row(len_train)[0]

print(f"{len_set=}")
print(f"{len_train=}")
print(f"Last offline train date = {last_train_date}\n")

training_data = train.filter(pl.col("date_id").le(last_train_date)).collect()
validation_data = train.filter(pl.col("date_id").gt(last_train_date)).collect()

len_set=9217296
len_train=8756432
Last offline train date = 1686



In [7]:
train_df=training_data.to_pandas()
val_df=validation_data.to_pandas()

In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [9]:
def r2_val(y_true, y_pred, sample_weight):
    r2 = 1 - np.average((y_pred - y_true) ** 2, weights=sample_weight) / (np.average((y_true) ** 2, weights=sample_weight) + 1e-38)
    return r2

In [10]:
class JaneStreetDataset(Dataset):
    def __init__(self, df):
        self.features = torch.FloatTensor(df[features].values)
        self.labels = torch.FloatTensor(df[target].values)
        self.weights = torch.FloatTensor(df["weight"].values)
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        x = self.features[idx]
        y = self.labels[idx]
        w = self.weights[idx]
        return x, y, w

class DataModule(LightningDataModule):
    def __init__(self, train_df, batch_size, val_df):
        super().__init__()
        self.df = train_df.copy()
        self.batch_size = batch_size
        self.dates = self.df['date_id'].unique()
        self.train_dataset = None
        self.val_df = val_df.copy()
        self.val_dataset = None

    def setup_folds(self, fold=0, N_fold=5):
        selected_dates = [date for ii, date in enumerate(self.dates) if ii % N_fold != fold]
        df_train = self.df.loc[self.df['date_id'].isin(selected_dates)]
        self.train_dataset = JaneStreetDataset(df_train)
        if self.val_df is not None:
            self.val_dataset = JaneStreetDataset(self.val_df)

    def setup_single(self):
        self.train_dataset = JaneStreetDataset(self.df)
        if self.val_df is not None:
            self.val_dataset = JaneStreetDataset(self.val_df)
            
    def train_dataloader(self, n_workers=0):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=n_workers)

    def val_dataloader(self, n_workers=0):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=n_workers)

In [11]:
class NN(LightningModule):
    def __init__(self, input_dim, hidden_dims, dropouts, lr, weight_decay):
        super().__init__()
        self.save_hyperparameters()
        layers = []
        in_dim = input_dim
        for i, hidden_dim in enumerate(hidden_dims):
            layers.append(nn.BatchNorm1d(in_dim))
            if i > 0:
                layers.append(nn.ReLU())
            if i < len(dropouts):
                layers.append(nn.Dropout(dropouts[i]))
            layers.append(nn.Linear(in_dim, hidden_dim))
            in_dim = hidden_dim
        layers.append(nn.Linear(in_dim, 1)) 
        layers.append(nn.Tanh())
        self.model = nn.Sequential(*layers)
        self.lr = lr
        self.weight_decay = weight_decay
        self.validation_step_outputs = []

    def forward(self, x):
        return 5 * self.model(x).squeeze(-1)  

    def training_step(self, batch):
        x, y, w = batch
        y_pred = self(x)
        loss = F.mse_loss(y_pred, y, reduction='none') * w  
        loss = loss.mean()
        self.log('train_loss', loss, on_step=False, on_epoch=True, batch_size=x.size(0))
        return loss

    def validation_step(self, batch):
        x, y, w = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y, reduction='none') * w
        loss = loss.mean()
        self.log('val_loss', loss, on_step=False, on_epoch=True, batch_size=x.size(0))
        self.validation_step_outputs.append((y_hat, y, w))
        return loss

    def on_validation_epoch_end(self):
        """Calculate validation WRMSE at the end of the epoch."""
        y = torch.cat([x[1] for x in self.validation_step_outputs]).cpu().numpy()
        if self.trainer.sanity_checking:
            prob = torch.cat([x[0] for x in self.validation_step_outputs]).cpu().numpy()
        else:
            prob = torch.cat([x[0] for x in self.validation_step_outputs]).cpu().numpy()
            weights = torch.cat([x[2] for x in self.validation_step_outputs]).cpu().numpy()
            val_r_square = r2_val(y, prob, weights)
            self.log("val_r_square", val_r_square, prog_bar=True, on_step=False, on_epoch=True)
        self.validation_step_outputs.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5,
                                                               verbose=True)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_loss',
            }
        }

    def on_train_epoch_end(self):
        if self.trainer.sanity_checking:
            return
        epoch = self.trainer.current_epoch
        metrics = {k: v.item() if isinstance(v, torch.Tensor) else v for k, v in self.trainer.logged_metrics.items()}
        formatted_metrics = {k: f"{v:.5f}" for k, v in metrics.items()}
        print(f"Epoch {epoch}: {formatted_metrics}")

In [12]:
train_df[features] = train_df[features].fillna(method = 'ffill').fillna(0)
val_df[features] = val_df[features].fillna(method = 'ffill').fillna(0)
data_module = DataModule(train_df, batch_size=8192, val_df=val_df)

  train_df[features] = train_df[features].fillna(method = 'ffill').fillna(0)
  val_df[features] = val_df[features].fillna(method = 'ffill').fillna(0)


In [13]:
lightning.seed_everything(42)
RUN_CROSSVAL=False
if RUN_CROSSVAL:
    for fold in range(5):
        data_module.setup_folds(fold, 5)
        # Obtain input dimension
        input_dim = data_module.train_dataset.features.shape[1]
        # Initialize Model
        model = NN(
            input_dim=input_dim,
            hidden_dims=[512, 512, 256],
            dropouts=[0.1, 0.1],
            lr=1e-3,
            weight_decay=5e-4
        )
        # Initialize Callbacks
        early_stopping = EarlyStopping('val_loss', patience=25, mode='min', verbose=False)
        checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min', save_top_k=1, verbose=False, filename=f"/kaggle/working/models/nn_{fold}.model") 
        timer = Timer()
        # Initialize Trainer
        trainer = Trainer(
            max_epochs=2000,
            accelerator=device,
            #devices=[0] if torch.cuda.is_available() else None,
            logger=None,
            callbacks=[early_stopping, checkpoint_callback, timer],
            enable_progress_bar=True
        )
        # Start Training
        trainer.fit(model, data_module.train_dataloader(4), data_module.val_dataloader(4))
        print(f'Fold-{fold} Training completed in {timer.time_elapsed("train"):.2f}s')

In [14]:
RUN_SINGLE=True
if RUN_SINGLE:
    data_module.setup_single()
    input_dim = data_module.train_dataset.features.shape[1]
    model = NN(
        input_dim=input_dim,
        hidden_dims=[512, 512, 256],
        dropouts=[0.1, 0.1],
        lr=1e-3,
        weight_decay=5e-4
    )
    early_stopping = EarlyStopping('val_loss', patience=25, mode='min', verbose=False)
    checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min', save_top_k=1, verbose=False, 
                                          filename="/kaggle/working/single_model/model") 
    timer = Timer()
    trainer = Trainer(
        max_epochs=2000,
        accelerator=device,
        logger=None,
        callbacks=[early_stopping, checkpoint_callback, timer],
        enable_progress_bar=False
    )

    trainer.fit(model, data_module.train_dataloader(4), data_module.val_dataloader(4))
    print(f'Training completed in {timer.time_elapsed("train"):.2f}s')

Epoch 0: {'val_loss': '1.05061', 'val_r_square': '-0.00204', 'train_loss': '1.57411'}
Epoch 1: {'val_loss': '1.05146', 'val_r_square': '-0.00285', 'train_loss': '1.51174'}
Epoch 2: {'val_loss': '1.04941', 'val_r_square': '-0.00090', 'train_loss': '1.50733'}
Epoch 3: {'val_loss': '1.04678', 'val_r_square': '0.00161', 'train_loss': '1.50604'}
Epoch 4: {'val_loss': '1.04722', 'val_r_square': '0.00120', 'train_loss': '1.50341'}
Epoch 5: {'val_loss': '1.04750', 'val_r_square': '0.00093', 'train_loss': '1.49982'}
Epoch 6: {'val_loss': '1.04600', 'val_r_square': '0.00236', 'train_loss': '1.49641'}
Epoch 7: {'val_loss': '1.04478', 'val_r_square': '0.00352', 'train_loss': '1.49425'}
Epoch 8: {'val_loss': '1.04578', 'val_r_square': '0.00257', 'train_loss': '1.49223'}
Epoch 9: {'val_loss': '1.04598', 'val_r_square': '0.00237', 'train_loss': '1.49020'}
Epoch 10: {'val_loss': '1.04855', 'val_r_square': '-0.00007', 'train_loss': '1.48968'}
Epoch 11: {'val_loss': '1.04565', 'val_r_square': '0.00269',