In [None]:
!pip install tab-transformer-pytorch -q --no-index --find-links=/kaggle/input/jane-street-import/tab-transformer-pytorch
!pip install hyper_connections -q --no-index --find-links=/kaggle/input/jane-street-import/hyper_connections

In [None]:
import os,gc
import pickle
import polars as pl
import numpy as np
import pandas as pd
import joblib

In [None]:
class CONFIG():
    def __init__(self):
        self.train_data_path = '/kaggle/input/data-create-create-lags/training.parquet'
        self.valid_data_path = '/kaggle/input/data-create-create-lags/validation.parquet'
        self.feature_names = [f"feature_{i:02d}" for i in range(79)] + [f"responder_{idx}_lag_1" for idx in range(9)]
        self.label_name = 'responder_6'
        self.weight_name = 'weight'
        self.feature_cat = ["feature_09", "feature_10", "feature_11"]
        self.feature_cont = [item for item in self.feature_names if item not in self.feature_cat]
        self.train_start_dt = 1100
my_config = CONFIG()

In [None]:
data_stats = joblib.load('/kaggle/input/my-own-js/data_stats.pkl')
means = data_stats['mean']
stds = data_stats['std']

def standardize(df, feature_cols, means, stds):
    return df.with_columns([
        ((pl.col(col) - means[col]) / stds[col]).alias(col) for col in feature_cols
    ])

category_mappings = {'feature_09': {2: 0, 4: 1, 9: 2, 11: 3, 12: 4, 14: 5, 15: 6, 25: 7, 26: 8, 30: 9, 34: 10, 42: 11, 44: 12, 46: 13, 49: 14, 50: 15, 57: 16, 64: 17, 68: 18, 70: 19, 81: 20, 82: 21},
 'feature_10': {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 10: 7, 12: 8},
 'feature_11': {9: 0, 11: 1, 13: 2, 16: 3, 24: 4, 25: 5, 34: 6, 40: 7, 48: 8, 50: 9, 59: 10, 62: 11, 63: 12, 66: 13,
  76: 14, 150: 15, 158: 16, 159: 17, 171: 18, 195: 19, 214: 20, 230: 21, 261: 22, 297: 23, 336: 24, 376: 25, 388: 26, 410: 27, 522: 28, 534: 29, 539: 30},
}

def encode_column(df, column, mapping):
    def encode_category(category):
        return mapping.get(category, -1)  
    
    return df.with_columns(
        pl.col(column).map_elements(encode_category, return_dtype=pl.Int16).alias(column)
    )

In [None]:
import os
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
# import pytorch_lightning as pl
from pytorch_lightning import (LightningDataModule, LightningModule, Trainer)
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, Timer
from pytorch_lightning.loggers import WandbLogger
import wandb
import pandas as pd
import numpy as np
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from tab_transformer_pytorch import FTTransformer


In [None]:
class R2Loss(nn.Module):
    def __init__(self):
        super(R2Loss, self).__init__()

    def forward(self, y_pred, y_true):
        mse_loss = torch.sum((y_pred - y_true) ** 2)
        var_y = torch.sum(y_true ** 2)
        loss = mse_loss / (var_y + 1e-38)
        return loss

# Custom R2 metric for validation
def r2_val(y_true, y_pred, sample_weight):
    r2 = 1 - np.average((y_pred - y_true) ** 2, weights=sample_weight) / (np.average((y_true) ** 2, weights=sample_weight) + 1e-38)
    return r2


class FTTransformerModel(LightningModule):
    def __init__(self, n_cont_features, cat_cardinalities, lr, weight_decay):
        super().__init__()
        self.save_hyperparameters()
        self.model = FTTransformer(
                categories = cat_cardinalities,      # tuple containing the number of unique values within each category
                num_continuous = n_cont_features,                # number of continuous values
                dim = 4,                           # dimension, paper set at 32
                dim_out = 1,                        # binary prediction, but could be anything
                depth = 3,                          # depth, paper recommended 6
                heads = 2,                          # heads, paper recommends 8
                attn_dropout = 0.2,                 # post-attention dropout
                ff_dropout = 0.2                    # feed forward dropout
            )
        self.lr = lr
        self.weight_decay = weight_decay
        self.training_step_outputs = []
        self.validation_step_outputs = []
        # self.loss_fn = F.mse_loss()
        # self.loss_fn = R2Loss()
        # self.loss_fn = weighted_mse_loss

    def forward(self, x_cont, x_cat):
        return self.model(x_cat, x_cont).squeeze(-1)
        # return self.model(x_cont, x_cat).squeeze(-1)

    def training_step(self, batch):
        x_cont, x_cat, y, w = batch
        # x_cont = x_cont + torch.randn_like(x_cont) * 0.01
        y_hat = self(x_cont, x_cat)
        # loss = self.loss_fn(y_hat.flatten(0, 1), y.repeat_interleave(self.k), w_y.repeat_interleave(self.k))
        # loss = self.loss_fn(y_hat.flatten(0, 1), y.repeat_interleave(self.k))
        
        loss = F.mse_loss(y_hat, y, reduction='none') * w  #
        loss = loss.mean()
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=x_cont.size(0))
        self.training_step_outputs.append((y_hat , y, w))
        return loss

    def validation_step(self, batch):
        x_cont, x_cat, y, w = batch
        # x_cont = x_cont + torch.randn_like(x_cont)
        y_hat = self(x_cont, x_cat)
        # loss = self.loss_fn(y_hat.flatten(0, 1), y.repeat_interleave(self.k), w_y.repeat_interleave(self.k))
        # loss = self.loss_fn(y_hat.flatten(0, 1), y.repeat_interleave(self.k))
        loss = F.mse_loss(y_hat, y, reduction='none') * w  #
        loss = loss.mean()
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=x_cont.size(0))
        self.validation_step_outputs.append((y_hat , y, w))
        return loss

    def on_validation_epoch_end(self):
        """Calculate validation WRMSE at the end of the epoch."""
        y = torch.cat([x[1] for x in self.validation_step_outputs]).cpu().numpy()
        if self.trainer.sanity_checking:
            prob = torch.cat([x[0] for x in self.validation_step_outputs]).cpu().numpy()
        else:
            prob = torch.cat([x[0] for x in self.validation_step_outputs]).cpu().numpy()
            weights = torch.cat([x[2] for x in self.validation_step_outputs]).cpu().numpy()
            # r2_val
            val_r_square = r2_val(y, prob, weights)
            self.log("val_r_square", val_r_square, prog_bar=True, on_step=False, on_epoch=True)
        self.validation_step_outputs.clear()

    # def my_configure_optimizers(self):
    #     optimizer = torch.optim.AdamW(make_parameter_groups(self.model), lr=self.lr, weight_decay=self.weight_decay)
    #     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5,
    #                                                            verbose=True)
    #     return {
    #         'optimizer': optimizer,
    #         'lr_scheduler': {
    #             'scheduler': scheduler,
    #             'monitor': 'val_r_square',
    #         }
    #     }

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5,
                                                               verbose=True)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_loss',
            }
        }

    def on_train_epoch_end(self):
        if self.trainer.sanity_checking:
            return

        y = torch.cat([x[1] for x in self.training_step_outputs]).cpu().numpy()
        prob = torch.cat([x[0] for x in self.training_step_outputs]).detach().cpu().numpy()
        weights = torch.cat([x[2] for x in self.training_step_outputs]).cpu().numpy()
        # r2_training
        train_r_square = r2_val(y, prob, weights)
        self.log("train_r_square", train_r_square, prog_bar=True, on_step=False, on_epoch=True)
        self.training_step_outputs.clear()

        epoch = self.trainer.current_epoch
        metrics = {k: v.item() if isinstance(v, torch.Tensor) else v for k, v in self.trainer.logged_metrics.items()}
        formatted_metrics = {k: f"{v:.5f}" for k, v in metrics.items()}
        print(f"Epoch {epoch}: {formatted_metrics}")

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model_path = '/kaggle/input/ft-transformer-model/nn_0.model.ckpt'
model = FTTransformerModel.load_from_checkpoint(model_path).to(device)

In [None]:
lags_ : pl.DataFrame | None = None
lags_last : pl.DataFrame | None = None

def predict(test: pl.DataFrame, lags: pl.DataFrame | None) -> pl.DataFrame | pd.DataFrame:
    global lags_
    global lags_last

    if lags is not None:
        lags_ = lags
    
    predictions = test.select(
        'row_id',
        pl.lit(0.0).alias('responder_6'),
    )
        
    for col in my_config.feature_cat:
        test = encode_column(test, col, category_mappings[col])
    
    if lags is not None:
        lags_ = lags
        lags_last = lags.group_by(["date_id", "symbol_id"], maintain_order=True).last() # pick up last record of previous date
    
    id_column_types = {
        'date_id': pl.Int32,
        'time_id': pl.Int32,
        'symbol_id': pl.Int32
    }
    test = test.cast(id_column_types)
    lags_last = lags_last.cast(id_column_types)

    test = test.join(lags_last, on=["date_id", "symbol_id"],  how="left")

    # 先标准化在fillna(0)
    test = standardize(test, my_config.feature_cont, means, stds)

    test = test.with_columns([
        pl.col(col).fill_null(0).alias(col) for col in my_config.feature_names
    ])
    
    # X_test = test[my_config.feature_names].to_numpy()
    # X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)
    # X_cat = X_test_tensor[:, [9, 10, 11]]
    # X_cont = X_test_tensor[:, [i for i in range(X_test_tensor.shape[1]) if i not in [9, 10, 11]]]
    # # X_cat = (torch.concat([X_cat, symbol_tensor.unsqueeze(-1), time_tensor.unsqueeze(-1)], axis=1)).to(torch.int64)
    # X_cat = (torch.concat([X_cat, symbol_tensor.unsqueeze(-1)], axis=1)).to(torch.int64)

    X_cont = test[my_config.feature_cont].to_numpy()
    X_cont = torch.tensor(X_cont, dtype=torch.float32).to(device)
    X_cat = test[my_config.feature_cat].to_numpy()
    X_cat = torch.tensor(X_cat, dtype=torch.int64).to(device)
    # print(X_cont.shape,X_cat.shape)
    model.eval()
    with torch.no_grad():
        outputs = model(X_cont, X_cat)
        # Assuming the model outputs a tensor of shape (batch_size, 1)
        preds = outputs.squeeze(-1).cpu().numpy()
        # print(preds.shape)
        # preds = preds.mean(1)
    
    predictions = \
    test.select('row_id').\
    with_columns(
        pl.Series(
            name   = 'responder_6', 
            values = np.clip(preds, a_min = -5, a_max = 5),
            dtype  = pl.Float64,
        )
    )


    # The predict function must return a DataFrame
    assert isinstance(predictions, pl.DataFrame | pd.DataFrame)
    # with columns 'row_id', 'responer_6'
    assert list(predictions.columns) == ['row_id', 'responder_6']
    # and as many rows as the test data.
    assert len(predictions) == len(test)

    return predictions

In [None]:
import kaggle_evaluation.jane_street_inference_server

inference_server = kaggle_evaluation.jane_street_inference_server.JSInferenceServer(predict)
if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    inference_server.serve()
else:
    inference_server.run_local_gateway(
        (
            '/kaggle/input/jane-street-real-time-market-data-forecasting/test.parquet',
            '/kaggle/input/jane-street-real-time-market-data-forecasting/lags.parquet',
        )
    )