In [1]:
!pip -q install rtdl_num_embeddings delu rtdl_revisiting_models 

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import rtdl_num_embeddings
from rtdl_num_embeddings import compute_bins
from torch.utils.data import TensorDataset, DataLoader, Dataset, ConcatDataset

from sklearn.model_selection import train_test_split

from sklearn.metrics import r2_score
import pandas as pd
import math
import numpy as np
import delu
from tqdm import tqdm
import polars as pl
from collections import OrderedDict
import sys

from tanm_reference import Model, make_parameter_groups


from torch import Tensor
from typing import List, Callable, Union, Any, TypeVar, Tuple

import joblib

import gc

# Load data

In [3]:

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

feature_train_list = [f"feature_{idx:02d}" for idx in range(79)] 


target_col = "responder_6"

feature_train = feature_train_list \
                + [f"responder_{idx}_lag_1" for idx in range(9)] 

start_dt = 1100
end_dt = 1599

feature_cat = ["feature_09", "feature_10", "feature_11" , 'symbol_id']
feature_cont = [item for item in feature_train if item not in feature_cat]
std_feature = [i for i in feature_train_list if i not in feature_cat] + [f"responder_{idx}_lag_1" for idx in range(9)]

# batch_size = 2048
batch_size = 8192
num_epochs = 30


# data_stats = joblib.load("/kaggle/input/jane-street-data-preprocessing/data_stats.pkl")
data_stats = joblib.load('/kaggle/input/my-own-js/data_stats.pkl')
means = data_stats['mean']
stds = data_stats['std']

def standardize(df, feature_cols, means, stds):
    return df.with_columns([
        ((pl.col(col) - means[col]) / stds[col]).alias(col) for col in feature_cols
    ])

In [4]:
motono0223_train_original = pl.scan_parquet("/kaggle/input/js24-preprocessing-create-lags/training.parquet")
motono0223_valid_original = pl.scan_parquet("/kaggle/input/js24-preprocessing-create-lags/validation.parquet")
motono0223_original = pl.concat([motono0223_train_original,motono0223_valid_original])

train_original = motono0223_original.filter((pl.col("date_id") >= start_dt) & (pl.col("date_id") <= end_dt))
valid_original = motono0223_original.filter(pl.col("date_id") > end_dt)


In [5]:
# train_original = pl.scan_parquet("/kaggle/input/jane-street-data-preprocessing/training.parquet")
# valid_original = pl.scan_parquet("/kaggle/input/jane-street-data-preprocessing/validation.parquet")

# def get_category_mapping(df, column):
#     unique_values = df.select([column]).unique().collect().to_series()
#     return {cat: idx for idx, cat in enumerate(unique_values)}

# category_mappings = {col: get_category_mapping(all_original, col) for col in feature_cat + ['symbol_id']}

category_mappings = {'feature_09': {2: 0, 4: 1, 9: 2, 11: 3, 12: 4, 14: 5, 15: 6, 25: 7, 26: 8, 30: 9, 34: 10, 42: 11, 44: 12, 46: 13, 49: 14, 50: 15, 57: 16, 64: 17, 68: 18, 70: 19, 81: 20, 82: 21},
 'feature_10': {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 10: 7, 12: 8},
 'feature_11': {9: 0, 11: 1, 13: 2, 16: 3, 24: 4, 25: 5, 34: 6, 40: 7, 48: 8, 50: 9, 59: 10, 62: 11, 63: 12, 66: 13,
  76: 14, 150: 15, 158: 16, 159: 17, 171: 18, 195: 19, 214: 20, 230: 21, 261: 22, 297: 23, 336: 24, 376: 25, 388: 26, 410: 27, 522: 28, 534: 29, 539: 30},
 'symbol_id': {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 10: 10, 11: 11, 12: 12, 13: 13, 14: 14, 15: 15, 16: 16, 17: 17, 18: 18, 19: 19,
  20: 20, 21: 21, 22: 22, 23: 23, 24: 24, 25: 25, 26: 26, 27: 27, 28: 28, 29: 29, 30: 30, 31: 31, 32: 32, 33: 33, 34: 34, 35: 35, 36: 36, 37: 37, 38: 38},
 'time_id' : {i : i for i in range(968)}}


def encode_column(df, column, mapping):
    def encode_category(category):
        return mapping.get(category, -1)  
    
    return df.with_columns(
        pl.col(column).map_elements(encode_category, return_dtype=pl.Int16).alias(column)
    )

for col in feature_cat:
    train_original = encode_column(train_original, col, category_mappings[col])
    valid_original = encode_column(valid_original, col, category_mappings[col])

In [6]:
train_original = standardize(train_original, std_feature, means, stds)
valid_original = standardize(valid_original, std_feature, means, stds)

In [7]:
# train_data = train_original \
#              .filter((pl.col("date_id") >= start_dt) & (pl.col("date_id") <= end_dt)) \
#             .sort(['date_id', 'time_id']) \
#              .select(feature_train + [target_col, 'weight', 'symbol_id', 'time_id'])

train_data = train_original.filter(pl.col("date_id").ge(start_dt))\
                            .filter(pl.col("date_id").le(end_dt))\
                            # .select(feature_train + [target_col, 'weight', 'symbol_id', 'time_id'])

valid_data = valid_original \
             .filter(pl.col("date_id") > end_dt)\
             .sort(['date_id', 'time_id'])\
             # .select(feature_train + [target_col, 'weight', 'symbol_id', 'time_id'])


In [8]:
train_numpy = train_data.collect().to_pandas()[feature_train + [target_col, 'weight', 'symbol_id', 'time_id']].values
valid_numpy = valid_data.collect().to_pandas()[feature_train + [target_col, 'weight', 'symbol_id', 'time_id']].values
train_numpy.shape,valid_numpy.shape

((18424912, 92), (3679368, 92))

In [9]:
pd.set_option('display.max_columns',100)
nan_stats = pd.DataFrame(np.isnan(train_numpy).sum(axis=0)/len(train_numpy)).T
nan_stats.columns=[feature_train + [target_col, 'weight', 'symbol_id', 'time_id']]
nan_stats

Unnamed: 0,feature_00,feature_01,feature_02,feature_03,feature_04,feature_05,feature_06,feature_07,feature_08,feature_09,feature_10,feature_11,feature_12,feature_13,feature_14,feature_15,feature_16,feature_17,feature_18,feature_19,feature_20,feature_21,feature_22,feature_23,feature_24,feature_25,feature_26,feature_27,feature_28,feature_29,feature_30,feature_31,feature_32,feature_33,feature_34,feature_35,feature_36,feature_37,feature_38,feature_39,feature_40,feature_41,feature_42,feature_43,feature_44,feature_45,feature_46,feature_47,feature_48,feature_49,feature_50,feature_51,feature_52,feature_53,feature_54,feature_55,feature_56,feature_57,feature_58,feature_59,feature_60,feature_61,feature_62,feature_63,feature_64,feature_65,feature_66,feature_67,feature_68,feature_69,feature_70,feature_71,feature_72,feature_73,feature_74,feature_75,feature_76,feature_77,feature_78,responder_0_lag_1,responder_1_lag_1,responder_2_lag_1,responder_3_lag_1,responder_4_lag_1,responder_5_lag_1,responder_6_lag_1,responder_7_lag_1,responder_8_lag_1,responder_6,weight,symbol_id,time_id
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.013187,0.0,0.0,0.0,0.0,0.0,0.0,0.024793,5.427434e-08,0.004132,0.0,0.0,0.0,0.001576,0.0,0.0,0.0,0.0,0.001576,0.001576,0.0,0.0,0.0,0.001576,0.009858,0.009858,0.0,0.0,0.0,0.0,0.0,0.070248,2.170974e-07,0.018595,0.070248,2.170974e-07,0.018595,0.000148,0.000148,0.0,0.0,0.0,0.070248,0.0,0.018595,0.070248,0.0,0.018595,0.0,0.0,0.009858,0.0,0.0,0.0,2e-05,5e-06,6e-06,0.000148,0.000148,0.0,0.0,0.0,0.0,0.0,0.0,0.009978,0.009978,0.001646,0.001646,0.000374,0.000374,0.018598,0.018598,0.018598,0.018598,0.018598,0.018598,0.018598,0.018598,0.018598,0.0,0.0,0.0,0.0


In [10]:
# fill na
train_numpy = np.nan_to_num(train_numpy,nan=0)
valid_numpy = np.nan_to_num(valid_numpy,nan=0)

In [11]:
%%time

# y_valid_data = y_valid.collect().to_numpy().squeeze(-1)
# w_valid_data = w_valid.collect().to_numpy().squeeze(-1)

train_data_tensor = torch.tensor(train_numpy, dtype=torch.float32)
valid_data_tensor = torch.tensor(valid_numpy, dtype=torch.float32)

train_ds = TensorDataset(train_data_tensor)
train_dl = DataLoader(train_ds, batch_size=batch_size, num_workers=4, pin_memory=True, shuffle=True)

valid_ds = TensorDataset(valid_data_tensor)
valid_dl = DataLoader(valid_ds, batch_size=batch_size, num_workers=4, pin_memory=True, shuffle=False)

# valid_data2_tensor = torch.tensor(valid_data_2.collect().to_numpy(), dtype=torch.float32)
# valid2_ds = TensorDataset(valid_data2_tensor)
# valid2_dl = DataLoader(valid2_ds, batch_size=batch_size, num_workers=4, pin_memory=True, shuffle=False)

all_data = False
if all_data:
    train_ds = ConcatDataset([train_ds, valid_ds])
    train_dl = DataLoader(train_ds, batch_size=batch_size, num_workers=4, pin_memory=True, shuffle=True)

CPU times: user 1.86 s, sys: 4.51 s, total: 6.36 s
Wall time: 3.21 s


In [12]:
del train_numpy,valid_numpy
gc.collect()

0

# Define Model

In [13]:
n_cont_features = 85
# n_cont_features = 89
n_cat_features = 4
n_classes = None
# cat_cardinalities = [83, 13, 540, 40]
cat_cardinalities = [23, 10, 32, 40]

## TabM Model

In [14]:
class LogCoshLoss(nn.Module):
    def __init__(self):
        super(LogCoshLoss, self).__init__()

    def forward(self, y_pred, y_true):
        loss = torch.log(torch.cosh(y_pred - y_true))
        return torch.mean(loss)

In [15]:
%%time

# TabM
arch_type = 'tabm'
bins = None

# TabM-mini with the piecewise-linear embeddings.
# arch_type = 'tabm-mini'
# bins_input = train_data_tensor[:, :-4][:, [col for col in range(train_data_tensor[:, :-4].shape[1]) if col not in [9, 10, 11]]]
# bins = compute_bins(bins_input[torch.randperm(len(bins_input))[:1000000]], ...)

# del bins_input
# gc.collect()

k = 16 # 集成输出的数量
model = Model(
    n_num_features=n_cont_features,
    cat_cardinalities=cat_cardinalities,
    n_classes=n_classes,
    backbone={
        'type': 'MLP',
        'n_blocks': 3 ,
        'd_block': 256,
        'dropout': 0.15,
    },
    bins=bins,
    # num_embeddings=(
    #     None
    #     if bins is None
    #     else {
    #         'type': 'PiecewiseLinearEmbeddings',
    #         'd_embedding': 16,
    #         'activation': True,
    #         'version': 'B',
    #     }
    # ),
    num_embeddings=(
        None
        # {
        #     'type': 'PeriodicEmbeddings',
        #     'd_embedding': 16,
        #     'lite':True,
        # }
    ),
    arch_type=arch_type,
    k=k,
).to(device)

optimizer = torch.optim.AdamW(
    # Instead of model.parameters(),
    make_parameter_groups(model),
    lr=5e-4,
    weight_decay=1e-4 ,
)

# loss_fn = nn.MSELoss()
class R2Loss(nn.Module):
    def __init__(self):
        super(R2Loss, self).__init__()

    def forward(self, y_pred, y_true):
        mse_loss = torch.sum((y_pred - y_true) ** 2)
        var_y = torch.sum(y_true ** 2)
        loss = mse_loss / (var_y + 1e-38)
        return loss

# loss_fn = nn.HuberLoss(delta=0.2)
# loss_fn = R2Loss()
loss_fn = torch.nn.MSELoss()

CPU times: user 488 ms, sys: 225 ms, total: 713 ms
Wall time: 1.28 s


# Training

In [16]:
timer = delu.tools.Timer()
patience = 5
early_stopping = delu.tools.EarlyStopping(patience, mode="max")
best = {
    "val": -math.inf,
    "epoch": -1,
}
timer.run()

In [None]:
def r2_val(y_true, y_pred, sample_weight):
    residuals = sample_weight * (y_true - y_pred) ** 2
    weighted_residual_sum = np.sum(residuals)

    # Calculate weighted sum of squared true values (denominator)
    weighted_true_sum = np.sum(sample_weight * (y_true) ** 2)

    # Calculate weighted R2
    r2 = 1 - weighted_residual_sum / weighted_true_sum

    return r2

for epoch in range(num_epochs):
    model.train()

    # Training
    train_pred_list = []
    with tqdm(train_dl, total=len(train_dl), leave=True) as phar:
        for train_tensor in phar:
            optimizer.zero_grad()
            X_input = train_tensor[0][:, :-4].to(device)
            y_input = train_tensor[0][:, -4].to(device)
            w_input = train_tensor[0][:, -3].to(device)

            
            symbol_input = train_tensor[0][:, -2].to(device)
            time_input = train_tensor[0][:, -1].to(device)
                
            x_cont_input = X_input[:, [col for col in range(X_input.shape[1]) if col not in [9, 10, 11]]]
            x_cont_input = x_cont_input + torch.randn_like(x_cont_input) * 0.035
            
            x_cat_input = X_input[:, [9, 10, 11]]
            # x_cat_input = (torch.concat([x_cat_input, symbol_input.unsqueeze(-1), time_input.unsqueeze(-1)], axis=1)).to(torch.int64)
            x_cat_input = (torch.concat([x_cat_input, symbol_input.unsqueeze(-1)], axis=1)).to(torch.int64)

            

            output = model(x_cont_input, x_cat_input).squeeze(-1)
            loss = loss_fn(output.flatten(0, 1), y_input.repeat_interleave(k))

            train_pred_list.append((output.mean(1), y_input, w_input))
        
            loss.backward()
            optimizer.step()

            phar.set_postfix(
                OrderedDict(
                    epoch=f'{epoch+1}/{num_epochs}',
                    loss=f'{loss.item():.6f}',
                    lr=f'{optimizer.param_groups[0]["lr"]:.3e}'
                )
            )
            phar.update(1)

    weights_train = torch.cat([x[2] for x in train_pred_list]).cpu().numpy()
    y_train = torch.cat([x[1] for x in train_pred_list]).cpu().numpy()
    prob_train = torch.cat([x[0] for x in train_pred_list]).detach().cpu().numpy()
    train_r2 = r2_val(y_train, prob_train, weights_train)
    
    
    model.eval()
    valid_loss_list = []
    valid_pred_list = []
    for valid_tensor in tqdm(valid_dl):
        X_valid = valid_tensor[0][:, :-4].to(device)
        y_valid = valid_tensor[0][:, -4].to(device)
        w_valid = valid_tensor[0][:, -3].to(device)
        symbol_valid = valid_tensor[0][:, -2].to(device)
        time_valid = valid_tensor[0][:, -1].to(device)
        
        x_cont_valid = X_valid[:, [col for col in range(X_valid.shape[1]) if col not in [9, 10, 11]]]
        x_cont_valid = x_cont_valid + torch.randn_like(x_cont_valid) * 0.035
        
        x_cat_valid = X_valid[:, [9, 10, 11]]
        # x_cat_valid = (torch.concat([x_cat_valid, symbol_valid.unsqueeze(-1),time_valid.unsqueeze(-1)], axis=1)).to(torch.int64)
        x_cat_valid = (torch.concat([x_cat_valid, symbol_valid.unsqueeze(-1)], axis=1)).to(torch.int64)

        with torch.no_grad():
            y_pred = model(x_cont_valid, x_cat_valid).squeeze(-1)
    
        # val_loss = loss_fn(y_pred.squeeze(-1).squeeze(-1).cpu().detach(), y_valid)
        # print(y_pred.flatten(0, 1),y_valid.repeat_interleave(k))
        val_loss = loss_fn(y_pred.flatten(0, 1), y_valid.repeat_interleave(k))
        valid_loss_list.append(val_loss)
        valid_pred_list.append((y_pred.mean(1), y_valid, w_valid))
    
    valid_loss_mean = sum(valid_loss_list) / len(valid_loss_list)
    # val_r2 = r2_score(y_valid_data, torch.cat(valid_pred_list).numpy(), sample_weight=w_valid_data)

    weights_eval = torch.cat([x[2] for x in valid_pred_list]).cpu().numpy()
    y_eval = torch.cat([x[1] for x in valid_pred_list]).cpu().numpy()
    prob_eval = torch.cat([x[0] for x in valid_pred_list]).cpu().numpy()
    val_r2 = r2_val(y_eval, prob_eval, weights_eval)


    
    print(f"Epoch {epoch + 1}: train_r2 = {train_r2:.6f}, val_loss_mean={valid_loss_mean:.6f}, val_r2={val_r2:.6f}, [time] {timer}")


    
    
    
    if val_r2 > best["val"]:
        print("🌸 New best epoch! 🌸")
        best = {"val": val_r2, "epoch": epoch}
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'r2': val_r2,
        }
        torch.save(checkpoint, f'epoch{epoch}_r2_{val_r2}.pt')
    print()
    
    early_stopping.update(val_r2)
    if early_stopping.should_stop():
        print("Early stop")
        break


checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    # 'r2': val_r2,
}

torch.save(checkpoint, f'last_tabm.pt')

100%|██████████| 2250/2250 [02:33<00:00, 14.70it/s, epoch=1/30, loss=0.720350, lr=5.000e-04]
100%|██████████| 450/450 [00:18<00:00, 24.18it/s]


Epoch 1: train_r2 = 0.018012, val_loss_mean=0.641858, val_r2=0.007528, [time] 0:02:52.135034
🌸 New best epoch! 🌸



 18%|█▊        | 409/2250 [00:29<02:06, 14.50it/s, epoch=2/30, loss=0.738875, lr=5.000e-04]