This notebook is used to train NN with autoencoder, this is an implementation of https://www.kaggle.com/code/gogo827jz/jane-street-supervised-autoencoder-mlp  
dataset is constructed by https://www.kaggle.com/code/motono0223/js24-preprocessing-create-lags

# Training Neural Networks (MLP) with PyTorch Lightning

In [2]:
import os
import pickle
import polars as pl
import numpy as np
import pandas as pd

# Load Data

In [3]:
#This block load data. Dataset is different for running locally on kaggle platform and server
LOCAL_TRAINING = True
feature_names = [f"feature_{i:02d}" for i in range(79)] + [f"responder_{idx}_lag_1" for idx in range(9)]
label_name = 'responder_6'
weight_name = 'weight'
train_loc = os.path.join("/kaggle/input/test_training_dataset.pickle")
valid_loc = os.path.join("/kaggle/input/test_validation_dataset.pickle")
if LOCAL_TRAINING:
    df = pl.scan_parquet(
    f"/kaggle/input/k/qinyanghe/js24-preprocessing-create-lags/test_training.parquet"
).collect().to_pandas()
    valid = pl.scan_parquet(
    f"/kaggle/input/k/qinyanghe/js24-preprocessing-create-lags/test_validation.parquet"
).collect().to_pandas()
    df = pd.concat([df, valid]).reset_index(drop=True)# A trick to boost LB from 0.0045->0.005

In [4]:
X_train = df[ feature_names ]
y_train = df[ label_name ]
w_train = df[ "weight" ]
X_valid = valid[ feature_names ]
y_valid = valid[ label_name ]
w_valid = valid[ "weight" ]

X_train.shape, y_train.shape, w_train.shape, X_valid.shape, y_valid.shape, w_valid.shape

((7361640, 88), (7361640,), (7361640,), (338800, 88), (338800,), (338800,))

In [4]:
X_train['responder_6_lag_1'][X_train["responder_6_lag_1"].notnull()].

SyntaxError: invalid syntax (2376126448.py, line 1)

# Training Configurations

In [5]:
import os
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning import (LightningDataModule, LightningModule, Trainer)
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, Timer
from pytorch_lightning.loggers import WandbLogger
from torchvision.transforms.v2 import GaussianNoise
import wandb
import pandas as pd
import numpy as np
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader


# PyTorch Data Module Definition

In [6]:
class CustomDataset(Dataset):
    def __init__(self, df, accelerator):
        self.features = torch.FloatTensor(df[feature_names].values).to(accelerator)
        self.labels = torch.FloatTensor(df[label_name].values).to(accelerator)
        self.weights = torch.FloatTensor(df[weight_name].values).to(accelerator)
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        x = self.features[idx]
        y = self.labels[idx]
        w = self.weights[idx]
        return x, y, w


class DataModule(LightningDataModule):
    def __init__(self, train_df, batch_size, valid_df=None, accelerator='cpu'):
        super().__init__()
        self.df = train_df
        self.batch_size = batch_size
        self.dates = self.df['date_id'].unique()
        self.accelerator = accelerator
        self.train_dataset = None
        self.valid_df = None
        if valid_df is not None:
            self.valid_df = valid_df
        self.val_dataset = None

    def setup(self, fold=0, N_fold=5, stage=None):
        # Split dataset
        selected_dates = [date for ii, date in enumerate(self.dates) if ii % N_fold != fold]
        df_train = self.df.loc[self.df['date_id'].isin(selected_dates)]
        self.train_dataset = CustomDataset(df_train, self.accelerator)
        if self.valid_df is not None:
            df_valid = self.valid_df
            self.val_dataset = CustomDataset(df_valid, self.accelerator)

    def train_dataloader(self, n_workers=0):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=n_workers)

    def val_dataloader(self, n_workers=0):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=n_workers)
    


# Autoencoder NN Model Definition

In [21]:
# Custom R2 metric for validation
def r2_val(y_true, y_pred, sample_weight):
    r2 = 1 - np.average((y_pred - y_true) ** 2, weights=sample_weight) / (np.average((y_true) ** 2, weights=sample_weight) + 1e-38)
    return r2


class AE_NN(LightningModule):
    def __init__(self, input_dim,ae_dims, nn_hidden_dims, ae_dropouts, nn_dropouts, noise_std,loss_weights, lr, weight_decay):
        super().__init__()
        self.save_hyperparameters()
        in_dim = input_dim
        self.ae_1 = nn.BatchNorm1d(in_dim)
        #gaussian noise at this step
        self.ae_2 = nn.Linear(in_dim,ae_dims[0])
        self.ae_3 = nn.BatchNorm1d(ae_dims[0])
        self.ae_4 = nn.SiLU()
        #SiLU as the final output of encoder?
        self.ae_5 = nn.Dropout(ae_dropouts[0])
        self.ae_6 = nn.Linear(ae_dims[0],input_dim)
        #output of ae_6 is recovered features
        #the following is for autoencoder regression
        self.ae_7 = nn.Linear(input_dim,ae_dims[1])
        self.ae_8 = nn.BatchNorm1d(ae_dims[1])
        self.ae_9 = nn.SiLU()
        self.ae_10 = nn.Dropout(ae_dropouts[1])
        self.ae_11 = nn.Linear(ae_dims[1],1)
        self.MLP_layers = nn.ModuleList()
        self.input_dim = input_dim
        in_dim = ae_dims[0] + self.input_dim
        for i, hidden_dim in enumerate(nn_hidden_dims):
            self.MLP_layers.append(nn.BatchNorm1d(in_dim))
            if i > 0:
                self.MLP_layers.append(nn.SiLU())
            if i < len(nn_dropouts):
                self.MLP_layers.append(nn.Dropout(nn_dropouts[i]))
            self.MLP_layers.append(nn.Linear(in_dim, hidden_dim))
            # layers.append(nn.ReLU())
            in_dim = hidden_dim
        self.MLP_layers.append(nn.Linear(in_dim, 1))  # 输出层
        #self.MLP_layers.append(nn.Tanh())
        #why output activation is Tanh?
        self.lr = lr
        self.weight_decay = weight_decay
        self.validation_step_outputs = []
        self.noise_std = noise_std
        self.loss_weights = loss_weights
        self.val_r_square_history = []
        
    def forward(self,inp):
        x = self.ae_1(inp)
        x = x+ self.noise_std * torch.randn(self.input_dim).to(inp.device)
        x = self.ae_2(x)
        x = self.ae_3(x)
        encoder = self.ae_4(x)
        x = self.ae_5(encoder)
        decoder = self.ae_6(x)
        x = self.ae_7(decoder)
        x = self.ae_8(x)
        x = self.ae_9(x)
        x = self.ae_10(x)
        out_ae = self.ae_11(x)

        x = torch.cat((encoder,inp),1)
        for layer in self.MLP_layers:
            x = layer(x)
        return decoder, out_ae, x
    
    def training_step(self,batch):
        x,y,w = batch
        decoder_hat, out_ae_hat,y_hat = self(x)
        decoder_loss = F.mse_loss(decoder_hat,x)  
        y = y.view(-1,1)
        out_ae_loss = F.mse_loss(out_ae_hat,y, reduction = 'none') * w
        y_loss = F.mse_loss(y_hat,y, reduction = 'none') * w
        decoder_loss = decoder_loss.mean()
        out_ae_loss = out_ae_loss.mean()
        y_loss = y_loss.mean()
        
        # y_np = y.clone().detach().cpu().numpy()
        # y_hat_np = y_hat.clone().detach().cpu().numpy()
        # w_np = w.clone().detach().cpu().numpy()
        # val_r_square = r2_val(y_np, y_hat_np,
        #                       w_np)
        r_square = self.r2_val_torch(y,y_hat, w)
        self.log('decoder_loss', decoder_loss, on_step=False, on_epoch=True, batch_size=x.size(0))
        self.log('out_ae_loss', out_ae_loss, on_step=False, on_epoch=True, batch_size=x.size(0))
        self.log('y_loss', y_loss, on_step=False, on_epoch=True, batch_size=x.size(0))
        self.log('r_square', r_square, on_step=False, on_epoch=True, batch_size=x.size(0))
        w_1,w_2,w_3 = self.loss_weights[0],self.loss_weights[1],self.loss_weights[2]
        loss = w_1 * decoder_loss + w_2 * out_ae_loss + w_3 * y_loss
        return loss
        
    def validation_step(self, batch, batch_idx):
        if self.trainer.sanity_checking:
            return  # Skip during sanity check
        x, y, w = batch
        decoder_hat, out_ae_hat, y_hat = self(x)
        
        # Compute validation losses
        decoder_loss = F.mse_loss(decoder_hat, x)
        y = y.view(-1, 1)
        out_ae_loss = F.mse_loss(out_ae_hat, y, reduction='none') * w
        y_loss = F.mse_loss(y_hat, y, reduction='none') * w
        
        # Compute r_square for this batch
        r_square = self.r2_val_torch(y, y_hat, w)
        
        # Log metrics
        self.log('val_decoder_loss', decoder_loss, on_step=False, on_epoch=True, batch_size=x.size(0))
        self.log('val_out_ae_loss', out_ae_loss.mean(), on_step=False, on_epoch=True, batch_size=x.size(0))
        self.log('val_y_loss', y_loss.mean(), on_step=False, on_epoch=True, batch_size=x.size(0))
        self.log('val_r_square', r_square, on_step=False, on_epoch=True, batch_size=x.size(0))
        
        # Store only necessary tensors with detached copies
        self.validation_step_outputs.append((
            y_hat.detach(),  # Detach to free computational graph
            y.detach(),      # Detach to free computational graph
            w.detach()       # Detach to free computational graph
        ))

    def on_validation_epoch_end(self):
        if not self.trainer.sanity_checking:
            try:
                # # Efficient concatenation with proper memory management
                # y_true = torch.cat([x[1] for x in self.validation_step_outputs])
                # y_pred = torch.cat([x[0] for x in self.validation_step_outputs])
                # weights = torch.cat([x[2] for x in self.validation_step_outputs])
                
                # # Compute epoch-level r_square
                # epoch_r_square = self.r2_val_torch(y_true, y_pred, weights)
                
                # # Log and print metrics
                # self.log("val_r_square_epoch", epoch_r_square, prog_bar=True)
                # self.val_r_square_history.append(epoch_r_square.item())
                # print(f"\nValidation R² Score (Epoch {self.trainer.current_epoch}): {epoch_r_square:.4f}")
                current_epoch = self.trainer.current_epoch
                val_decoder_loss = trainer.logged_metrics['val_decoder_loss']
                val_out_ae_loss = trainer.logged_metrics['val_out_ae_loss']
                val_y_loss = trainer.logged_metrics['val_y_loss']
                val_r_square = trainer.logged_metrics['val_r_square']
                self.val_r_square_history.append(val_r_square.item())
                formatted_metrics = {'val_decoder_loss':f"{val_decoder_loss:.5f}",'val_out_ae_loss':f"{val_out_ae_loss:.5f}",
                            'val_y_loss':f"{val_y_loss:.5f}",'val_r_square':f"{val_r_square:.5f}"}
                print(f"Epoch {current_epoch}: {formatted_metrics}")
                if current_epoch!= 0 and current_epoch % 10 == 0:  # Plot every 5 epochs
                    self.plot_r_square_history()
                
            finally:
                # Ensure cleanup happens even if there's an error
                self.validation_step_outputs.clear()

        
    def r2_val_torch(self,y_true: torch.Tensor, y_pred: torch.Tensor, sample_weight: torch.Tensor) -> torch.Tensor:
        """
        Calculate weighted R² score using PyTorch tensors.
        
        Args:
            y_true: Ground truth values
            y_pred: Predicted values
            sample_weight: Weights for each sample
            
        Returns:
            Weighted R² score
        """
        # Ensure all inputs are tensors and on the same device
        if not torch.is_tensor(sample_weight):
            sample_weight = torch.tensor(sample_weight, device=y_true.device)
        
        # Calculate weighted MSE (numerator)
        weighted_mse = torch.sum(sample_weight * (y_pred - y_true) ** 2) / torch.sum(sample_weight)
        
        # Calculate weighted variance (denominator)
        weighted_var = torch.sum(sample_weight * y_true ** 2) / torch.sum(sample_weight) + 1e-38
        
        # Calculate R²
        r2 = 1 - weighted_mse / weighted_var
        
        return r2
    def plot_r_square_history(self):
        """
        Plot the history of validation R² scores
        """
        try:
            import matplotlib.pyplot as plt
            
            plt.figure(figsize=(10, 6))
            plt.plot(self.val_r_square_history, '-o')
            plt.title('Validation R² Score History')
            plt.xlabel('Epoch')
            plt.ylabel('R² Score')
            plt.grid(True)
            plt.savefig(f'r2_history_epoch_{self.trainer.current_epoch}.png')
            plt.close()
        except ImportError:
            print("Matplotlib not available for plotting")
            
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5,
                                                               verbose=True)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'y_loss',
            }
        }

    def on_train_epoch_end(self):
        if self.trainer.sanity_checking:
            return
        epoch = self.trainer.current_epoch
        w_1,w_2,w_3 = self.loss_weights[0],self.loss_weights[1],self.loss_weights[2]
        decoder_loss = trainer.logged_metrics['decoder_loss']
        out_ae_loss = trainer.logged_metrics['out_ae_loss']
        r_square = trainer.logged_metrics['r_square']
        y_loss = trainer.logged_metrics['y_loss']
        # metrics = {k: v.item() if isinstance(v, torch.Tensor) else v for k, v in self.trainer.logged_metrics.items()}
        # formatted_metrics = {k: f"{v:.5f}" for k, v in metrics.items()}
        formatted_metrics = {'decoder_loss':f"{decoder_loss:.5f}",'out_ae_loss':f"{out_ae_loss:.5f}",
                            'y_loss':f"{y_loss:.5f}",'r_square':f"{r_square:.5f}"}
        print(f"Epoch {epoch}: {formatted_metrics}")

# Create PyTorch Data Module

In [15]:
class custom_args():
    def __init__(self):
        self.usegpu = True
        self.gpuid = 0
        self.seed = 42
        self.model = 'nn'
        self.use_wandb = False
        self.project = 'js-xs-nn-with-lags'
        self.dname = "./input_df/"
        self.loader_workers = 4
        self.bs = 8192
        self.lr = 1e-3
        self.weight_decay = 5e-4
        self.ae_dropouts = [0.1,0.1]
        self.dropouts = [0.1, 0.1]
        self.ae_dims = [40,120]
        self.n_hidden = [512, 512, 256]
        self.noise_std = 1
        self.loss_weights = [0.1,1,1]
        self.patience = 10
        self.max_epochs = 100
        self.N_fold = 5



args = custom_args()

# checking device
device = torch.device(f'cuda:{args.gpuid}' if torch.cuda.is_available() and args.usegpu else 'cpu')
accelerator = 'gpu' if torch.cuda.is_available() and args.usegpu else 'cpu'
loader_device = 'cpu'


# Initialize Data Module

df[feature_names] = df[feature_names].fillna(method = 'ffill').fillna(0)
valid[feature_names] = valid[feature_names].fillna(method = 'ffill').fillna(0)
data_module = DataModule(df, batch_size=args.bs, valid_df=valid, accelerator=loader_device)

NameError: name 'df' is not defined

# Create Model and Training

In [22]:
import gc
try:
    del df
except NameError:
    pass 
gc.collect()
pl.seed_everything(args.seed)
for fold in range(args.N_fold):
    data_module.setup(fold, args.N_fold)
    # Obtain input dimension
    input_dim = data_module.train_dataset.features.shape[1]
    # Initialize Model
    model = AE_NN(
        input_dim=input_dim,
        ae_dims = args.ae_dims,
        nn_hidden_dims=args.n_hidden,
        ae_dropouts = args.ae_dropouts,
        nn_dropouts=args.dropouts,
        noise_std = args.noise_std,
        loss_weights = args.loss_weights,
        lr=args.lr,
        weight_decay=args.weight_decay
    )
    # Initialize Logger
    if args.use_wandb:
        wandb_run = wandb.init(project=args.project, config=vars(args), reinit=True)
        logger = WandbLogger(experiment=wandb_run)
    else:
        logger = None
    # Initialize Callbacks
    early_stopping = EarlyStopping('y_loss', patience=args.patience, mode='min', verbose=False)
    checkpoint_callback = ModelCheckpoint(monitor='y_loss', mode='min', save_top_k=1, verbose=False, filename=f"./models/nn_{fold}.model") 
    timer = Timer()
    # Initialize Trainer
    trainer = Trainer(
        max_epochs=args.max_epochs,
        accelerator=accelerator,
        devices=[args.gpuid] if args.usegpu else None,
        logger=logger,
        callbacks=[early_stopping, checkpoint_callback, timer],
        enable_progress_bar=True,
        num_sanity_val_steps=0
    )
    # Start Training
    trainer.fit(model, data_module.train_dataloader(args.loader_workers), data_module.val_dataloader(args.loader_workers))
    # You can find trained best model in your local path
    print(f'Fold-{fold} Training completed in {timer.time_elapsed("train"):.2f}s')


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


Training: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0: {'val_decoder_loss': '296.75375', 'val_out_ae_loss': '1.21437', 'val_y_loss': '1.29764', 'val_r_square': '-0.07140'}
Epoch 0: {'decoder_loss': '494.49680', 'out_ae_loss': '1.50280', 'y_loss': '1.48437', 'r_square': '0.00958'}


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1: {'val_decoder_loss': '125.58650', 'val_out_ae_loss': '1.21347', 'val_y_loss': '3.11844', 'val_r_square': '-1.08761'}
Epoch 1: {'decoder_loss': '158.46252', 'out_ae_loss': '1.49631', 'y_loss': '1.44562', 'r_square': '0.03529'}


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2: {'val_decoder_loss': '124.58168', 'val_out_ae_loss': '1.22663', 'val_y_loss': '5.28300', 'val_r_square': '-3.32389'}
Epoch 2: {'decoder_loss': '20.99688', 'out_ae_loss': '1.49445', 'y_loss': '1.43332', 'r_square': '0.04347'}


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3: {'val_decoder_loss': '123.97749', 'val_out_ae_loss': '1.24730', 'val_y_loss': '30.71189', 'val_r_square': '-17.51278'}
Epoch 3: {'decoder_loss': '5.08809', 'out_ae_loss': '1.49281', 'y_loss': '1.42359', 'r_square': '0.04995'}


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4: {'val_decoder_loss': '151.54704', 'val_out_ae_loss': '1.29259', 'val_y_loss': '11.21717', 'val_r_square': '-12.45285'}
Epoch 4: {'decoder_loss': '4.22250', 'out_ae_loss': '1.49194', 'y_loss': '1.41813', 'r_square': '0.05362'}


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5: {'val_decoder_loss': '136.65465', 'val_out_ae_loss': '2.64616', 'val_y_loss': '64.88599', 'val_r_square': '-53.48793'}
Epoch 5: {'decoder_loss': '4.04589', 'out_ae_loss': '1.48910', 'y_loss': '1.41483', 'r_square': '0.05582'}


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6: {'val_decoder_loss': '103.50160', 'val_out_ae_loss': '5.11906', 'val_y_loss': '50.50990', 'val_r_square': '-40.62472'}
Epoch 6: {'decoder_loss': '3.90974', 'out_ae_loss': '1.48367', 'y_loss': '1.41235', 'r_square': '0.05746'}


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7: {'val_decoder_loss': '122.42485', 'val_out_ae_loss': '6.86431', 'val_y_loss': '66.36824', 'val_r_square': '-66.17456'}
Epoch 7: {'decoder_loss': '3.79798', 'out_ae_loss': '1.47956', 'y_loss': '1.41017', 'r_square': '0.05893'}


Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x781259593e80>>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(


NameError: name 'exit' is not defined

KeyboardInterrupt: 
Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x781259593e80>>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 


In [23]:
data_module.train_dataset.features.shape

torch.Size([5882536, 88])