# Useful notebooks:

- Preprocessing : https://www.kaggle.com/code/motono0223/js24-preprocessing-create-lags
- Training (XGB) : https://www.kaggle.com/code/motono0223/js24-train-gbdt-model-with-lags-singlemodel
  - trained XGB model : https://www.kaggle.com/datasets/motono0223/js24-trained-gbdt-model
- Training (NN): **this notebook** https://www.kaggle.com/code/voix97/jane-street-rmf-training-nn
  - trained NN model : https://www.kaggle.com/datasets/voix97/js-xs-nn-trained-model
- Inference of NN : https://www.kaggle.com/code/voix97/jane-street-rmf-nn-with-pytorch-lightning
- Inference of NN+XGB:  https://www.kaggle.com/code/voix97/jane-street-rmf-nn-xgb
- EDA(1) : https://www.kaggle.com/code/motono0223/eda-jane-street-real-time-market-data-forecasting
- EDA(2) : https://www.kaggle.com/code/motono0223/eda-v2-jane-street-real-time-market-forecasting

# Training Neural Networks (MLP) with PyTorch Lightning

In [1]:
import os
import pickle
import polars as pl
import numpy as np
import pandas as pd

# Load Data

In [None]:
# --- Ultimate Memory-Optimized Data Loading Strategy ---
print("Starting ultimate memory-optimized data loading...")

input_path = './input_df' if os.path.exists('./input_df') else '/kaggle/input/js24-preprocessing-create-lags'
feature_names = [f"feature_{i:02d}" for i in range(79)] + [f"responder_{idx}_lag_1" for idx in range(9)]
label_name = 'responder_6'
weight_name = 'weight'

# 1. Load ONLY the validation set. It's small enough to be kept in memory for all folds.
print("Loading and preprocessing validation data (to keep in memory)...")
valid_processed_lazy = (
    pl.scan_parquet(f"{input_path}/validation.parquet")
    .cast({col: pl.Float32 for col, dtype in pl.scan_parquet(f"{input_path}/validation.parquet").schema.items() if dtype == pl.Float64})
    .with_columns(pl.col(feature_names).forward_fill().fill_null(0))
)
valid = valid_processed_lazy.collect().to_pandas()
print(f"Validation data loaded. Shape: {valid.shape}")

# 2. Get all unique date_ids from the combined dataset for cross-validation splitting.
# This is a very fast and low-memory operation as it only scans one column.
print("Scanning for all unique date_ids for fold splitting...")
all_dates_train = pl.scan_parquet(f"{input_path}/training.parquet").select('date_id').unique()
all_dates_valid = pl.scan_parquet(f"{input_path}/validation.parquet").select('date_id').unique()
all_dates = pl.concat([all_dates_train, all_dates_valid]).unique().collect().to_series().to_list()
all_dates.sort()
print(f"Found {len(all_dates)} unique dates for cross-validation.")

  schema = train_lazy.schema


Loading and preprocessing validation data...
Validation data shape: (1082224, 104)
Loading, preprocessing, and concatenating training data...
Combined training data shape: (22104280, 104)
Data loading and preprocessing complete.


# Training Configurations

In [None]:
import os
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning
from pytorch_lightning import (LightningDataModule, LightningModule, Trainer)
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, Timer
from pytorch_lightning.loggers import WandbLogger
import wandb
import pandas as pd
import numpy as np
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader


class custom_args():
    def __init__(self):
        self.usegpu = True
        self.gpuid = 0
        self.seed = 42
        self.model = 'nn'
        self.use_wandb = False
        self.project = 'js-xs-nn-with-lags'
        self.dname = "./input_df/"
        self.loader_workers = 4
        self.bs = 8192
        self.lr = 1e-3
        self.weight_decay = 5e-4
        self.dropouts = [0.1, 0.1]
        self.n_hidden = [512, 512, 256]
        self.patience = 25
        self.max_epochs = 2000
        self.N_fold = 5


my_args = custom_args()

# PyTorch Data Module Definition

In [None]:
# --- Ultimate Memory-Optimized PyTorch Data Module ---
import os

# Define a worker_init_fn to print worker PIDs for verification
def worker_init_fn(worker_id):
    pid = os.getpid()
    # This print statement will appear in the logs for each worker process
    print(f"SUCCESS: DataLoader Worker with ID {worker_id} started with PID: {pid}")

class CustomDataset(Dataset):
    """Memory-efficient Dataset using Just-in-Time tensor conversion."""
    def __init__(self, df):
        self.features = df[feature_names].values
        self.labels = df[label_name].values
        self.weights = df[weight_name].values
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        x = torch.tensor(self.features[idx], dtype=torch.float32)
        y = torch.tensor(self.labels[idx], dtype=torch.float32)
        w = torch.tensor(self.weights[idx], dtype=torch.float32)
        return x, y, w

class DataModule(LightningDataModule):
    """This DataModule loads data fold-by-fold from disk to minimize memory usage."""
    def __init__(self, batch_size, valid_df, all_dates):
        super().__init__()
        self.batch_size = batch_size
        self.valid_df = valid_df
        self.all_dates = all_dates
        self.train_dataset = None
        self.val_dataset = None
        self.input_path = './input_df' if os.path.exists('./input_df') else '/kaggle/input/js24-preprocessing-create-lags'

    def setup(self, fold=0, N_fold=5, stage=None):
        print(f"Setting up data for fold {fold}...")
        train_dates = [date for i, date in enumerate(self.all_dates) if i % N_fold != fold]
        
        print(f"Lazily preparing training data for {len(train_dates)} dates for fold {fold}...")
        def process_lazy(lazy_frame):
            schema = lazy_frame.schema
            dtype_map = {col: pl.Float32 for col, dtype in schema.items() if dtype == pl.Float64}
            return (
                lazy_frame.filter(pl.col('date_id').is_in(train_dates))
                .cast(dtype_map).with_columns(pl.col(feature_names).forward_fill().fill_null(0))
            )

        train_lazy_fold = process_lazy(pl.scan_parquet(f"{self.input_path}/training.parquet"))
        valid_lazy_fold = process_lazy(pl.scan_parquet(f"{self.input_path}/validation.parquet"))
        
        df_train = pl.concat([train_lazy_fold, valid_lazy_fold]).collect().to_pandas()
        print(f"Fold {fold} training data loaded into memory. Shape: {df_train.shape}")
        
        self.train_dataset = CustomDataset(df_train)
        
        if self.val_dataset is None:
            print("Creating validation dataset (this happens only once)...")
            self.val_dataset = CustomDataset(self.valid_df)

    def train_dataloader(self, n_workers=0):
        return DataLoader(
            self.train_dataset, 
            batch_size=self.batch_size, 
            shuffle=True, 
            num_workers=n_workers,
            worker_init_fn=worker_init_fn # Add this to verify workers
        )

    def val_dataloader(self, n_workers=0):
        return DataLoader(
            self.val_dataset, 
            batch_size=self.batch_size, 
            shuffle=False, 
            num_workers=n_workers,
            worker_init_fn=worker_init_fn # Add this to verify workers
        )

# NN Model Definition

In [None]:
import time

# Custom R2 metric for validation
def r2_val(y_true, y_pred, sample_weight):
    r2 = 1 - np.average((y_pred - y_true) ** 2, weights=sample_weight) / (np.average((y_true) ** 2, weights=sample_weight) + 1e-38)
    return r2


class NN(LightningModule):
    def __init__(self, input_dim, hidden_dims, dropouts, lr, weight_decay):
        super().__init__()
        self.save_hyperparameters()
        layers = []
        in_dim = input_dim
        for i, hidden_dim in enumerate(hidden_dims):
            layers.append(nn.BatchNorm1d(in_dim))
            if i > 0:
                layers.append(nn.SiLU())
            if i < len(dropouts):
                layers.append(nn.Dropout(dropouts[i]))
            layers.append(nn.Linear(in_dim, hidden_dim))
            # layers.append(nn.ReLU())
            in_dim = hidden_dim
        layers.append(nn.Linear(in_dim, 1)) 
        layers.append(nn.Tanh())
        self.model = nn.Sequential(*layers)
        self.lr = lr
        self.weight_decay = weight_decay
        self.validation_step_outputs = []
        
        # Add a variable to store the start time of each epoch
        self._epoch_start_time = None

    def forward(self, x):
        return 5 * self.model(x).squeeze(-1)  

    def training_step(self, batch):
        x, y, w = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y, reduction='none') * w  #
        loss = loss.mean()
        self.log('train_loss', loss, on_step=False, on_epoch=True, batch_size=x.size(0))
        return loss

    def validation_step(self, batch):
        x, y, w = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y, reduction='none') * w
        loss = loss.mean()
        self.log('val_loss', loss, on_step=False, on_epoch=True, batch_size=x.size(0))
        self.validation_step_outputs.append((y_hat, y, w))
        return loss

    def on_validation_epoch_end(self):
        """Calculate validation WRMSE at the end of the epoch."""
        y = torch.cat([x[1] for x in self.validation_step_outputs]).cpu().numpy()
        if self.trainer.sanity_checking:
            prob = torch.cat([x[0] for x in self.validation_step_outputs]).cpu().numpy()
        else:
            prob = torch.cat([x[0] for x in self.validation_step_outputs]).cpu().numpy()
            weights = torch.cat([x[2] for x in self.validation_step_outputs]).cpu().numpy()
            # r2_val
            val_r_square = r2_val(y, prob, weights)
            self.log("val_r_square", val_r_square, prog_bar=True, on_step=False, on_epoch=True)
        self.validation_step_outputs.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5,
                                                               verbose=True)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_loss',
            }
        }

    def on_train_epoch_start(self):
        """Record the start time at the beginning of each training epoch."""
        self._epoch_start_time = time.time()

    def on_train_epoch_end(self):
        """Calculate and print the duration at the end of each training epoch."""
        if self.trainer.sanity_checking:
            return
        
        # Calculate epoch duration
        epoch_end_time = time.time()
        epoch_duration = epoch_end_time - self._epoch_start_time
        
        epoch = self.trainer.current_epoch
        metrics = {k: v.item() if isinstance(v, torch.Tensor) else v for k, v in self.trainer.logged_metrics.items()}
        formatted_metrics = {k: f"{v:.5f}" for k, v in metrics.items()}
        
        # Add epoch duration to the printed output
        print(f"Epoch {epoch}: {formatted_metrics} -- Duration: {epoch_duration:.2f}s")

# Create PyTorch Data Module

In [None]:
args = my_args

# checking device
device = torch.device(f'cuda:{args.gpuid}' if torch.cuda.is_available() and args.usegpu else 'cpu')
accelerator = 'gpu' if torch.cuda.is_available() and args.usegpu else 'cpu'

# Initialize Data Module with the new fold-by-fold loading strategy.
# It now takes the pre-loaded 'valid' dataframe and the list of 'all_dates'.
print("Initializing Data Module with the new fold-by-fold loading strategy...")
data_module = DataModule(batch_size=args.bs, valid_df=valid, all_dates=all_dates)
print("Data Module initialized.")

Initializing Data Module with preprocessed data...
Data Module initialized.


# Create Model and Training

In [None]:
# =========================================================
# FOLD CONTROL
# ---------------------------------------------------------
# This is the only variable you need to change for each run.
# Set this value from 0 to 4 to train each fold separately.
# For your first run, keep it at 0.
# For your second run, change it to 1, and so on.
# =========================================================
FOLD_TO_TRAIN = 0


In [None]:
import gc
import os

# 确保保存模型的目录存在
os.makedirs('./models', exist_ok=True)

# --- Memory-Optimized Single-Fold Training ---

# This training block now runs for only ONE FOLD, controlled by FOLD_TO_TRAIN.
# This makes each Kaggle run short and ensures the model output is saved correctly.

pytorch_lightning.seed_everything(args.seed)

# Set the current fold to the one specified in the control cell
fold = FOLD_TO_TRAIN
print(f"\n{'='*20} STARTING TRAINING FOR FOLD {fold} {'='*20}")

# The setup call will trigger memory-efficient loading for this specific fold.
data_module.setup(fold, args.N_fold)

input_dim = data_module.train_dataset.features.shape[1]
model = NN(
    input_dim=input_dim,
    hidden_dims=args.n_hidden,
    dropouts=args.dropouts,
    lr=args.lr,
    weight_decay=args.weight_decay
)

if args.use_wandb:
    wandb_run = wandb.init(project=args.project, config=vars(args), reinit=True)
    logger = WandbLogger(experiment=wandb_run)
else:
    logger = None
    
early_stopping = EarlyStopping('val_loss', patience=args.patience, mode='min', verbose=True)
# The filename now correctly uses our FOLD_TO_TRAIN variable to save the right model
checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min', save_top_k=1, verbose=True, filename=f"./models/nn_{fold}.model") 
timer = Timer()

# I've also added the speed optimizations we discussed earlier.
# Note the comma after enable_progress_bar=True is now correctly added.
trainer = Trainer(
    max_epochs=args.max_epochs,
    accelerator=accelerator,
    devices=[args.gpuid] if args.usegpu else None,
    logger=logger,
    callbacks=[early_stopping, checkpoint_callback, timer],
    enable_progress_bar=True,
    precision='16-mixed' # Use mixed precision for speed
)

# Use multiple workers for faster data loading
N_WORKERS = 4
print(f"Starting training for fold {fold} with num_workers={N_WORKERS}...")
trainer.fit(
    model, 
    data_module.train_dataloader(n_workers=N_WORKERS), 
    data_module.val_dataloader(n_workers=N_WORKERS)
)

print(f"Fold-{fold} Training completed in {timer.time_elapsed('train'):.2f}s")

# Clean up memory at the end of the run
del model, trainer, checkpoint_callback, early_stopping, timer, data_module.train_dataset
gc.collect()

print(f"\n{'='*20} SUCCESSFULLY FINISHED FOLD {fold} {'='*20}")