# TabRet: Pre-trainable Transformer-based Model for Unseen Columns

This notebook shows how the [TabRet](https://arxiv.org/pdf/2303.15747.pdf) model can be used. TabRet is a pre-trainable Transformer-based model for tabular data and designed to work on a downstream task that contains columns not seen in pre-training. Unlike other methods, TabRet has an extra learning step before fine-tuning called retokenizing, which calibrates feature embeddings based on the masked autoencoding loss. TabRet is pretrained on a large collection of public health surveys (BRFSS dataset) and then fine-tuned on a classification task (Stroke prediction) in healthcare.

A key challenge in using pretrained models for tabular data is that each table for downstream task has a different set of columns, and it is difficult to know at the pre-training phase which columns will appear in the downstream task. To address the above issue, we propose TabRet, a pre-trainable Transformer network that can adapt to unseen columns in downstream tasks. The training and fine-tuning process for TabRet consists of the following steps:
1. Pretraining: First, TabRet is pre-trained based on the reconstruction loss with masking augmentation.
2. Retokenizing: when unseen columns appear in a downstream task, their tokenizers are trained through masked modeling while freezing the mixer before fine-tuning.
3. Finetuning: Finetuning the mixer for a specific target variable while freezing the rest of model parts.

<p align="center">
    <img src="https://user-images.githubusercontent.com/16665038/266691262-3514bd11-9dd4-4bad-a4a0-4afe4bfca8fd.png" alt="tabret" />
</p>

To execute this notebook successfully, please ensure you've completed the data preprocessing notebook first to prepare the necessary data for pretraining and finetuning.

Please be aware that the pretraining phase can be both time-consuming and resource-intensive. For your convenience, checkpoints for a pretrained model are included in this notebook. It's recommended to first explore the pretrained model for retokenizing and fine-tuning tasks before opting to run the full pretraining process.

The model's core implementation is built using PyTorch which can be found in the 'model' directory. To simplify the training process and enhance usability, high-level classes have been defined using PyTorch Lightning in this notebook. This allows for easier experiment tracking, distributed training, and other advanced features, making the training workflow more straightforward and efficient.

In [None]:
# imports
from os.path import join
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import optuna
import pandas as pd
import pytorch_lightning as pl
import torch
import xgboost as xgb
from model.tabret import TabRet
from model.tabret_cls import TabRetClassifier
from model.utils import get_diff_columns
from pl_bolts.optimizers import LinearWarmupCosineAnnealingLR
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
    RichProgressBar,
)
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, QuantileTransformer
from torch import Tensor, optim
from torch.utils.data import DataLoader, Dataset

In [None]:
# Constants
SEED = 4
PRETRAINED_DIR = "/ssd003/projects/aieng/public/ssl_bootcamp_resources/pretrained_ckpts/tabret"
CHECKPOINT_DIR = "./checkpoints"
LOG_DIR = "./logs"
DATA_DIR = "./datasets"

torch.set_float32_matmul_precision("medium")
pl.seed_everything(SEED)

## Dataset

### Dataset Class

The TabularDataset class serves as a customizable container for tabular data, compatible with PyTorch's Dataset API. The class allows for easy handling of continuous and categorical features, as well as target variables, simplifying the data preparation steps for model training. By providing an interface to return data in PyTorch-compatible formats, it seamlessly integrates with PyTorch's DataLoader for batch processing.

In [None]:
class TabularDataset(Dataset):
    def __init__(
        self,
        data: pd.DataFrame,
        task: str = "binary",
        continuous_columns: Optional[Sequence[str]] = None,
        categorical_columns: Optional[Sequence[str]] = None,
        target: Optional[Union[str, Sequence[str]]] = None,
    ) -> None:
        """Tabular dataset for tabular data.

        Args:
        ----
            data (pandas.DataFrame): DataFrame.
            task (str): One of "binary", "multiclass", "regression".
                Defaults to "binary".
            continuous_cols (sequence of str, optional): Sequence of names of
                continuous features (columns). Defaults to None.
            categorical_cols (sequence of str, optional): Sequence of names of
                categorical features (columns). Defaults to None.
            target (str, optional): If None, `np.zeros` is set as target.
                Defaults to None.
        """
        super().__init__()
        self.data = data
        self.task = task
        self.num = data.shape[0]
        self.continuous_columns = continuous_columns if continuous_columns else []
        self.categorical_columns = categorical_columns if categorical_columns else []

        if target:
            self.target = data[target].values
            if isinstance(target, str):
                self.target = self.target.reshape(-1, 1)
        else:
            self.target = np.zeros((self.num, 1))

    def __len__(self) -> int:
        return self.num

    def __getitem__(self, idx: int) -> Dict[str, Tensor]:
        """
        Args:
        ----
            idx (int): The index of the sample in the dataset.

        Returns
        -------
            dict[str, Tensor]:
                The returned dict has the keys {"target", "continuous", "categorical"}
                and its values. If no continuous/categorical features, the returned value is `[]`.
        """
        x = {
            "continuous": {key: torch.tensor(self.data[key].values[idx]).float() for key in self.continuous_columns}
            if self.continuous_columns
            else {},
            "categorical": {key: torch.tensor(self.data[key].values[idx]).long() for key in self.categorical_columns}
            if self.categorical_columns
            else {},
        }
        if self.task == "multiclass":
            x["target"] = torch.LongTensor(self.target[idx])
        elif self.task in ["binary", "regression"]:
            x["target"] = torch.tensor(self.target[idx])
        else:
            raise ValueError(f"task: {self.task} must be 'multiclass' or 'binary' or 'regression'")
        return x

### Pre-training Data: BRFSS

The DataFrame that houses the complete BRFSS data is created by executing the preprocessing notebook. This data consists of both numeric and categorical features. For numerical features, the Quantile Transformation method from the scikit-learn library is applied for transformation. Meanwhile, categorical features are processed using the Ordinal Encoder. It's worth noting that for baseline methods, only the categorical features undergo transformation via the Ordinal Encoder.

In [None]:
brfss_df = pd.read_csv(join(DATA_DIR, "brfss", "all.csv"))
brfss_df

In [None]:
pre_categorical_columns = list(
    brfss_df.loc[:, (brfss_df.dtypes == "object") | (brfss_df.dtypes == "int64")].drop("Diabetes", axis=1).columns
)
pre_continuous_columns = list(brfss_df.loc[:, brfss_df.dtypes == "float64"].columns)
pre_cat_cardinality_dict = {col: len(brfss_df[col].unique()) for col in pre_categorical_columns}
pre_target_columns = ["Diabetes"]

In [None]:
len(pre_categorical_columns)

In [None]:
len(pre_continuous_columns)

In [None]:
pre_cont_enc = QuantileTransformer(output_distribution="normal").fit(brfss_df[pre_continuous_columns])
pre_cate_enc = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1).fit(
    brfss_df[pre_categorical_columns]
)

In [None]:
# A copy of the original dataframe is saved for later use for retokenization
brfss_df_copy = brfss_df.copy()

In [None]:
brfss_df[pre_continuous_columns] = pre_cont_enc.transform(brfss_df[pre_continuous_columns])
brfss_df[pre_categorical_columns] = pre_cate_enc.transform(brfss_df[pre_categorical_columns])
brfss_df

In [None]:
pre_train_df, pre_val_df = train_test_split(brfss_df, test_size=0.24, random_state=SEED, stratify=brfss_df["Diabetes"])
pre_train_df.size

In [None]:
pre_train_dataset = TabularDataset(
    data=pre_train_df,
    categorical_columns=pre_categorical_columns,
    continuous_columns=pre_continuous_columns,
    target=pre_target_columns,
)

pre_val_dataset = TabularDataset(
    data=pre_val_df,
    categorical_columns=pre_categorical_columns,
    continuous_columns=pre_continuous_columns,
    target=pre_target_columns,
)

## TabRet Model


The model architecture is designed to effectively process and analyze tabular data through multiple specialized layers, each with its own functionality:

1.  Feature Tokenizer: This initial layer converts the input data, denoted as x, into an embedded form suitable for further processing.

2. Alignment Layer: Employing a linear layer, this component adjusts the token dimensions to make them compatible with the Encoder layer that follows. This is required because different dimensions for Feature Tokenizer and Encoder are employed for the model’s flexibility.

3. Random Masking: In the pre-training and retokenizing phases, some tokens are randomly masked by the Random Masking process. Specifically, a certain number of tokens, determined by a set mask ratio, are chosen uniformly at random from the set of tokens for each data point, and these chosen tokens are then replaced with a 'mask token'. If no tokens are initially selected for masking, the protocol is overridden to ensure that at least one token is replaced with a 'mask token'.

4. Encoder: This consists of an N-layer Transformer equipped with Pre-Normalization. It serves as the core processing unit, efficiently encoding the tokenized input.

5. Post-Encoder: Post-encoding is performed by adding a mask token, positional embedding, and an additional Transformer block to the output of the Encoder.

6. Projector: Finally, this layer maps the processed tokens back into the original column feature spaces using linear layers, completing the forward pass through the model.


<p align="center">
    <img width="422" alt="Screen Shot 2023-09-05 at 10 26 14 PM" src="https://user-images.githubusercontent.com/16665038/266691464-17d952d3-1c93-4809-8161-92f585965811.png">
</p>


We use the same model configuration that is outlined in the original paper. Specifically, the Encoder is designed with 6 block and includes FFN dropout. The Decoder, on the other hand, is streamlined and consists of a single block.


In [None]:
# TabRet architecture config
enc_transformer_config = {
    "n_blocks": 6,
    "residual_dropout": 0.0,
    "ffn_d_hidden": 512,
    "d_token": 384,
    "attention_dropout": 0.1,
    "ffn_dropout": 0.1,
    "attention_n_heads": 8,
    "attention_initialization": "kaiming",
    "ffn_activation": "ReGLU",
    "attention_normalization": "LayerNorm",
    "ffn_normalization": "LayerNorm",
    "prenormalization": True,
    "first_prenormalization": False,
    "last_layer_query_idx": None,
    "n_tokens": None,
    "kv_compression_ratio": None,
    "kv_compression_sharing": None,
}
dec_transformer_config = {
    "n_blocks": 1,
    "residual_dropout": 0.0,
    "ffn_d_hidden": 128,
    "d_token": 96,
    "attention_dropout": 0.1,
    "ffn_dropout": 0.0,
    "attention_n_heads": 8,
    "attention_initialization": "kaiming",
    "ffn_activation": "ReGLU",
    "attention_normalization": "LayerNorm",
    "ffn_normalization": "LayerNorm",
    "prenormalization": True,
    "first_prenormalization": False,
    "last_layer_query_idx": None,
    "n_tokens": None,
    "kv_compression_ratio": None,
    "kv_compression_sharing": None,
}

## Pre-training

In the pretraining phase, a masked autoencoder approach is employed to prepare the model for better generalization and performance in downstream tasks.

1. Mask Token Replacement: The first step involves replacing the embeddings of randomly selected columns with a specialized form of embedding known as a "mask token." This is akin to obscuring some of the information, forcing the model to learn to predict or "fill in the gaps" during training. This strategy is effective for making the model more robust and capable of handling unseen or missing data.

2. Shuffle Augmentation: Alongside the mask token replacement, another strategy called "Shuffle Augmentation" is used. This technique involves the column-wise permutation of data within a minibatch, but only for a randomly chosen subset of columns. Essentially, the columns are shuffled to create a new, augmented version of the data. This helps the model become invariant to the ordering of columns, increasing its flexibility and ability to generalize across different kinds of tabular data structures.

<p align="center">
    <img width="422" alt="Screen Shot 2023-09-05 at 10 26 14 PM" src="https://user-images.githubusercontent.com/16665038/266691464-17d952d3-1c93-4809-8161-92f585965811.png">
</p>


Below, the TabRetFTTrans class serves as a high-level interface for pretraining the TabRet model using PyTorch Lightning. This class simplifies the model training and validation process, making it more convenient and straightforward.

In [None]:
class TabRetFTTrans(pl.LightningModule):
    def __init__(
        self,
        continuous_columns: List[str],
        cat_cardinality_dict: Dict[str, int],
        enc_transformer_config: Dict[str, Any],
        dec_transformer_config: Dict[str, Any],
        mask_ratio: float = 0.7,
        col_shuffle: Optional[Dict[str, Any]] = None,
        epochs: int = 1000,
        lr: float = 1e-4,
        warmup_epochs: int = 10,
        warmup_start_lr: float = 1e-6,
        min_lr: float = 1e-6,
    ) -> None:
        super().__init__()
        self.model = TabRet.make(
            continuous_columns=continuous_columns,
            cat_cardinality_dict=cat_cardinality_dict,
            enc_transformer_config=enc_transformer_config,
            dec_transformer_config=dec_transformer_config,
        )
        self.mask_ratio = mask_ratio
        self.col_shuffle = col_shuffle
        self.epochs = epochs
        self.lr = lr
        self.warmup_epochs = warmup_epochs
        self.warmup_start_lr = warmup_start_lr
        self.min_lr = min_lr
        self.save_hyperparameters()

    def forward(self, inp: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        loss, preds, mask = self.model(
            x_num=inp["continuous"],
            x_cat=inp["categorical"],
            mask_ratio=self.mask_ratio,
            col_shuffle=dict(self.col_shuffle),  # type: ignore
        )
        return loss, preds, mask

    def training_step(
        self, batch: Dict[str, Tensor], batch_idx: int
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        loss, _, _ = self(batch)
        self.log("loss", loss)
        return loss

    def validation_step(self, batch: Dict[str, Tensor], batch_idx: int) -> None:
        loss, _, _ = self(batch)
        self.log("val_loss", loss, sync_dist=True)

    def configure_optimizers(
        self,
    ) -> Dict[str, Union[optim.Optimizer, optim.lr_scheduler._LRScheduler]]:
        optimizer = optim.AdamW(self.model.optimization_param_groups(), lr=self.lr)
        scheduler = LinearWarmupCosineAnnealingLR(
            optimizer,
            warmup_epochs=self.warmup_epochs,
            max_epochs=self.epochs,
            warmup_start_lr=self.warmup_start_lr,
            eta_min=self.min_lr,
        )
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

In [None]:
# Pretraining hyperparameters
pre_mask_ratio = 0.7
pre_epochs = 1000
pre_batch_size = 4096
pre_eval_batch_size = 4096
num_workers = 4
pre_lr = 1.5e-5
pre_warmup_epochs = 40
pre_warmup_start = 1e-6
pre_lr_min = 1e-6
column_shuffle = {
    "ratio": 0.1,
    "mode": "shuffle",
}

In [None]:
pre_model = TabRetFTTrans(
    continuous_columns=pre_continuous_columns,
    cat_cardinality_dict=pre_cat_cardinality_dict,
    enc_transformer_config=enc_transformer_config,
    dec_transformer_config=dec_transformer_config,
    mask_ratio=pre_mask_ratio,
    col_shuffle=column_shuffle,
    epochs=pre_epochs,
    lr=pre_lr,
    warmup_epochs=pre_warmup_epochs,
    warmup_start_lr=pre_warmup_start,
    min_lr=pre_lr_min,
)

In [None]:
pre_train_dataloader = DataLoader(
    pre_train_dataset,
    batch_size=pre_batch_size,
    num_workers=num_workers,
    pin_memory=True,
)

pre_val_dataloader = DataLoader(
    pre_val_dataset,
    batch_size=pre_eval_batch_size,
    num_workers=num_workers,
)

In [None]:
pre_trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    # strategy=DDPStrategy(find_unused_parameters=False),
    max_epochs=pre_epochs,
    precision=16,
    check_val_every_n_epoch=2,
    callbacks=[
        ModelCheckpoint(
            monitor="val_loss",
            mode="min",
            filename="best",
            save_top_k=1,
            save_last=True,
            verbose=True,
            dirpath=join(CHECKPOINT_DIR, "pretraining"),
        ),
        LearningRateMonitor(logging_interval="step"),
        RichProgressBar(),
        # logger=WandbLogger(
        #     project="tabret-pretraining",
        #     entity="vector-ssl-bootcamp",
        #     save_dir=join(LOG_DIR, "pretraining"),
        # ),
    ],
    log_every_n_steps=10,
)

In [None]:
pre_trainer.fit(
    model=pre_model,
    train_dataloaders=pre_train_dataloader,
    val_dataloaders=pre_val_dataloader,
)

## Retokenizing and Fine-tuning

### Stroke Dataset

Stroke prediction is used as a downstram task. 

Regarding the data transformation, for features that are common to both datasets, transformation models (like QuantileTransformer for continuous features and OrdinalEncoder for categorical ones) are first fitted using the pre-training dataset and then applied to the downstream dataset. On the other hand, for features unique to the downstream dataset, transformations are directly fitted and applied using that specific dataset. This way, the preprocessing aligns well with the pre-training scheme, allowing for a seamless transition to fine-tuning the model.

The dataset is divided into 80% training dataset and 20% test dataset. Of the training data, 100 samples are separated as a fine-tuning dataset.

In [None]:
stroke_df = pd.read_csv(join(DATA_DIR, "stroke", "stroke.csv"))
stroke_df

In [None]:
# keep a copy of the original dataframe for later use for the baseline model
stroke_df_copy = stroke_df.copy()

In [None]:
continuous_columns = [
    "avg_glucose_level",
    "_BMI5",
]

categorical_columns = [
    "hypertension",
    "heart_disease",
    "work_type",
    "Residence_type",
    "SEX",
    "_AGEG5YR",
    "MARITAL",
    "SMOKE100",
]

common_cont_columns = ["_BMI5"]

common_cate_columns = [
    "SEX",
    "_AGEG5YR",
    "MARITAL",
    "SMOKE100",
]

diff_cont_columns = [
    "avg_glucose_level",
]

diff_cate_columns = [
    "hypertension",
    "heart_disease",
    "work_type",
    "Residence_type",
]

target_columns = ["stroke"]

In [None]:
if common_cont_columns:
    common_cont_enc = QuantileTransformer(output_distribution="normal").fit(brfss_df_copy[common_cont_columns])

if common_cate_columns:
    common_cate_enc = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1).fit(
        brfss_df_copy[common_cate_columns]
    )

if diff_cont_columns:
    diff_cont_enc = QuantileTransformer(output_distribution="normal").fit(stroke_df[diff_cont_columns])

if diff_cate_columns:
    diff_cate_enc = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1).fit(
        stroke_df[diff_cate_columns]
    )

In [None]:
if common_cont_columns:
    stroke_df[common_cont_columns] = common_cont_enc.transform(stroke_df[common_cont_columns])

if common_cate_columns:
    stroke_df[common_cate_columns] = common_cate_enc.transform(stroke_df[common_cate_columns])

if diff_cont_columns:
    stroke_df[diff_cont_columns] = diff_cont_enc.transform(stroke_df[diff_cont_columns])

if diff_cate_columns:
    stroke_df[diff_cate_columns] = diff_cate_enc.transform(stroke_df[diff_cate_columns])

In [None]:
stroke_df

In [None]:
cat_cardinality_dict = {}


def get_cardinality_dict(enc: OrdinalEncoder) -> Dict[str, int]:
    return {key: len(cardinality) for key, cardinality in zip(enc.feature_names_in_, enc.categories_)}


if common_cate_columns:
    cat_cardinality_dict.update(get_cardinality_dict(common_cate_enc))
if diff_cate_columns:
    cat_cardinality_dict.update(get_cardinality_dict(diff_cate_enc))

cat_cardinality_dict = {key: cat_cardinality_dict[key] for key in categorical_columns}

In [None]:
train_df, test_df = train_test_split(stroke_df, test_size=0.20, random_state=SEED, stratify=stroke_df["stroke"])

fine_df, fval_df = train_test_split(
    train_df,
    train_size=100,
    random_state=SEED,
    stratify=train_df["stroke"],
)

In [None]:
train_dataset = TabularDataset(
    data=fine_df,
    categorical_columns=categorical_columns,
    continuous_columns=continuous_columns,
    target=target_columns,
)

val_dataset = TabularDataset(
    data=fval_df,
    categorical_columns=categorical_columns,
    continuous_columns=continuous_columns,
    target=target_columns,
)

test_dataset = TabularDataset(
    data=test_df,
    categorical_columns=categorical_columns,
    continuous_columns=continuous_columns,
    target=target_columns,
)

### Retokenizing

In the retokenizing step, several actions are taken to adapt the model for columns that were not seen during pre-training. First, new tokenizers are added specifically for these newly appearing columns to convert their data into a form that the model can understand. To ensure compatibility, parts of the decoder, specifically the Positional Embedding and Projector, are initialized to align with the fine-tuning table's requirements.

Once the new tokenizers and parts of the decoder are set up, the parameters for all existing components—old tokenizers, the encoder, and the decoder—are frozen. This means that these parts of the model are not updated during the retokenizing process, ensuring that the knowledge gained during pre-training is preserved.

The training of the new tokenizers employs the same masked modeling approach used in the initial pre-training. Essentially, the columns that the model hasn't seen before are treated as masked out. During this process, special 'mask' tokens are fed into the Post-Encoder as placeholders for these unseen columns. This allows the model to learn how to handle these new columns without disrupting the learned patterns for the existing ones.

<p align="center">
<img width="457" alt="Screen Shot 2023-09-05 at 10 26 26 PM" src="https://user-images.githubusercontent.com/16665038/266691732-ad9339ef-62ec-41f9-a12d-ea3e2f5e80aa.png">
</p>

Below, the TabRetokenize class serves as a high-level interface for retokenizing step using PyTorch Lightning.


In [None]:
class TabRetokenize(pl.LightningModule):
    def __init__(
        self,
        pre_continuous_columns: List[str],
        pre_cat_cardinality_dict: Dict[str, int],
        continuous_columns: List[str],
        cat_cardinality_dict: Dict[str, int],
        enc_transformer_config: Dict[str, Any],
        dec_transformer_config: Dict[str, Any],
        model_path: str,
        epochs: int = 200,
        batch_size: int = 32,
        lr: float = 1.5e-3,
        warmup_epochs: int = 10,
        warmup_start_lr: float = 1e-6,
        min_lr: float = 1e-6,
        mask_ratio: float = 0.5,
        para_freeze: bool = True,
        mask_token_freeze: bool = True,
        except_decoder: bool = False,
    ) -> None:
        super().__init__()
        self.epochs = epochs
        self.lr = lr
        self.warmup_epochs = warmup_epochs
        self.warmup_start_lr = warmup_start_lr
        self.min_lr = min_lr
        self.mask_ratio = mask_ratio
        self.batch_size = batch_size
        self.save_hyperparameters()
        self.model = TabRet.make(
            continuous_columns=pre_continuous_columns,
            cat_cardinality_dict=pre_cat_cardinality_dict,
            enc_transformer_config=enc_transformer_config,
            dec_transformer_config=dec_transformer_config,
        )
        state_dict = torch.load(model_path)["state_dict"]
        self.model.load_state_dict(state_dict, strict=False)
        diff_columns, continuous_columns, cat_cardinality_dict = get_diff_columns(
            continuous_columns,
            cat_cardinality_dict,
            pre_continuous_columns,
            pre_cat_cardinality_dict,
        )
        self.model.add_attribute(
            continuous_columns=continuous_columns,
            cat_cardinality_dict=cat_cardinality_dict,
        )
        if para_freeze:
            self.model.freeze_parameters_wo_specific_columns(diff_columns)

        if not mask_token_freeze:
            self.model.unfreeze_mask_token()

        if except_decoder:
            self.model.unfreeze_decoder()

        self.save_hyperparameters()

    def forward(self, inp: Dict[str, Any]) -> Tuple[Tensor, Tensor, Tensor]:
        loss, preds, mask = self.model(
            x_num=inp["continuous"],
            x_cat=inp["categorical"],
            mask_ratio=self.mask_ratio,
        )
        return loss, preds, mask

    def training_step(
        self, batch: Dict[str, Tensor], batch_idx: int
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        loss, _, _ = self(batch)
        self.log("loss", loss)
        return loss

    def validation_step(self, batch: Dict[str, Tensor], batch_idx: int) -> None:
        loss, _, _ = self(batch)
        self.log("val_loss", loss)

    def configure_optimizers(
        self,
    ) -> Dict[str, Union[optim.Optimizer, optim.lr_scheduler._LRScheduler]]:
        optimizer = optim.AdamW(self.model.optimization_param_groups(), lr=self.lr, weight_decay=1e-5)
        scheduler = LinearWarmupCosineAnnealingLR(
            optimizer,
            warmup_epochs=self.warmup_epochs,
            max_epochs=self.epochs,
            warmup_start_lr=self.warmup_start_lr,
            eta_min=self.min_lr,
        )
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

In [None]:
# Retokenize hyperparameters
ret_mask_ratio = 0.5
ret_epochs = 100
ret_batch_size = 32
ret_eval_batch_size = 32
num_workers = 4
ret_lr = 1.5e-3
ret_warmup_epochs = 10
ret_warmup_start = 1e-7
ret_lr_min = 1e-7
ret_patience = 50
para_freeze = True
mask_token_freeze = True
except_decoder = False

In [None]:
pretrained_model = join(PRETRAINED_DIR, "pretrained-epoch=1000.ckpt")

In [None]:
ret_model = TabRetokenize(
    pre_continuous_columns=pre_continuous_columns,
    pre_cat_cardinality_dict=pre_cat_cardinality_dict,
    continuous_columns=continuous_columns,
    cat_cardinality_dict=cat_cardinality_dict,
    enc_transformer_config=enc_transformer_config,
    dec_transformer_config=dec_transformer_config,
    model_path=pretrained_model,
    epochs=ret_epochs,
    batch_size=ret_batch_size,
    lr=ret_lr,
    warmup_epochs=ret_warmup_epochs,
    warmup_start_lr=ret_warmup_start,
    min_lr=ret_lr_min,
    mask_ratio=ret_mask_ratio,
    para_freeze=para_freeze,
    mask_token_freeze=mask_token_freeze,
    except_decoder=except_decoder,
)

In [None]:
ret_train_dataloader = DataLoader(
    train_dataset,
    batch_size=ret_batch_size,
    num_workers=num_workers,
    pin_memory=True,
)

ret_val_dataloader = DataLoader(
    val_dataset,
    batch_size=ret_eval_batch_size,
    num_workers=num_workers,
)

In [None]:
ret_trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    max_epochs=ret_epochs,
    precision=16,
    check_val_every_n_epoch=1,
    callbacks=[
        ModelCheckpoint(
            monitor="val_loss",
            mode="min",
            filename="best",
            save_top_k=1,
            save_last=True,
            verbose=True,
            dirpath=join(CHECKPOINT_DIR, "retokenize"),
        ),
        LearningRateMonitor(logging_interval="step"),
        EarlyStopping(
            monitor="val_loss",
            min_delta=0.00,
            patience=ret_patience,
            verbose=True,
            mode="min",
        ),
        RichProgressBar(),
        # logger=WandbLogger(
        #     project="tabret-retokenize",
        #     entity="vector-ssl-bootcamp",
        #     save_dir=join(LOG_DIR, "retokenize"),
        # ),
    ],
    log_every_n_steps=1,
)

In [None]:
ret_trainer.fit(
    model=ret_model,
    train_dataloaders=ret_train_dataloader,
    val_dataloaders=ret_val_dataloader,
)

### Fine-tuning

In the fine-tuning step, the target value 'y' for this task is introduced as a new column in the data and treated as a "masked" entry, similar to how new features were handled during the retokenizing phase.

Initially, all the parameters that the model learned during the pre-training and retokenizing phases are kept constant, or "frozen." This helps to maintain the general understanding the model has developed thus far.

The training then zeroes in on specific components, particularly the Positional Embedding and Projector, but now for the newly added target value column. This is similar to how these components were adapted during the retokenizing phase for newly seen columns. The goal is to enable the model to accurately predict or classify the target values based on its learned knowledge.

One of the key advantages of this step is efficiency. Since only specific components are trained, there's a significant reduction in the number of learning parameters. This not only speeds up the training process but also reduces the required sample size of the training dataset for the downstream task, making fine-tuning a more resource-efficient way to adapt the model for specific applications.

<p align="center">
<img width="488" alt="Screen Shot 2023-09-05 at 10 26 36 PM" src="https://user-images.githubusercontent.com/16665038/266691879-547996b7-9e3e-41f3-b266-a8865000703e.png">
</p>

Below, the TabRetFinetune class serves as a high-level interface for finetuning step using PyTorch Lightning.

In [None]:
class TabRetFinetune(pl.LightningModule):
    def __init__(
        self,
        pre_continuous_columns: List[str],
        pre_cat_cardinality_dict: Dict[str, int],
        continuous_columns: List[str],
        cat_cardinality_dict: Dict[str, int],
        enc_transformer_config: Dict[str, Any],
        dec_transformer_config: Dict[str, Any],
        model_path: str,
        epochs: int = 200,
        lr: float = 1.5e-3,
        warmup_epochs: int = 10,
        warmup_start_lr: float = 1e-6,
        min_lr: float = 1e-6,
        output_dim: int = 1,
    ) -> None:
        super().__init__()
        self.model = TabRet.make(
            continuous_columns=pre_continuous_columns,
            cat_cardinality_dict=pre_cat_cardinality_dict,
            enc_transformer_config=enc_transformer_config,
            dec_transformer_config=dec_transformer_config,
        )
        self.model.add_attribute(
            continuous_columns=continuous_columns,
            cat_cardinality_dict=cat_cardinality_dict,
        )

        self.epochs = epochs
        self.lr = lr
        self.warmup_epochs = warmup_epochs
        self.warmup_start_lr = warmup_start_lr
        self.min_lr = min_lr

        self.save_hyperparameters()
        state_dict = torch.load(model_path)["state_dict"]
        self.model.load_state_dict(state_dict, strict=False)
        self.model.freeze_parameters()
        output_dim += 1
        self.model = TabRetClassifier(self.model, output_dim)
        self.model.show_trainable_parameter()

    def forward(self, inp: Dict[str, Any]) -> Tuple[Tensor, Tensor]:
        logits = self.model(
            x_num=inp["continuous"],
            x_cat=inp["categorical"],
        )
        loss = self.model.loss_fn(logits, inp["target"].squeeze(-1))
        return loss, logits

    def training_step(
        self, batch: Dict[str, Tensor], batch_idx: int
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        loss, _ = self(batch)
        self.log("loss", loss)
        return loss

    def validation_step(self, batch: Dict[str, Tensor], batch_idx: int) -> None:
        loss, _ = self(batch)
        self.log("val_loss", loss, sync_dist=True)

    def test_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Dict[str, Tensor]:
        loss, logits = self(batch)
        preds = torch.softmax(logits, dim=-1)[:, 1]
        return {"test_loss": loss, "preds": preds, "target": batch["target"]}

    def test_epoch_end(self, outputs: List[Any]) -> None:
        avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
        preds = torch.cat([x["preds"] for x in outputs])
        labels = torch.cat([x["target"] for x in outputs])
        acc = accuracy_score(labels.cpu(), preds.cpu().round())
        auc = roc_auc_score(labels.cpu(), preds.cpu())
        self.log("avg_test_loss", avg_loss)
        self.log("test_acc", acc)
        self.log("test_auc", auc)

    def configure_optimizers(
        self,
    ) -> Dict[str, Union[optim.Optimizer, optim.lr_scheduler._LRScheduler]]:
        optimizer = optim.AdamW(
            self.model.optimization_param_groups(),
            lr=self.lr,
        )
        scheduler = LinearWarmupCosineAnnealingLR(
            optimizer,
            warmup_epochs=self.warmup_epochs,
            max_epochs=self.epochs,
            warmup_start_lr=self.warmup_start_lr,
            eta_min=self.min_lr,
        )
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

In [None]:
# Finetune hyperparameters
fine_epochs = 200
fine_batch_size = 32
fine_eval_batch_size = 32
num_workers = 4
fine_lr = 5e-4
fine_warmup_epochs = 10
fine_warmup_start = 1e-7
min_lr = 1e-7
fine_patience = 20

In [None]:
retokenized_model = join(CHECKPOINT_DIR, "retokenize", "last.ckpt")

In [None]:
fine_model = TabRetFinetune(
    pre_continuous_columns=pre_continuous_columns,
    pre_cat_cardinality_dict=pre_cat_cardinality_dict,
    continuous_columns=continuous_columns,
    cat_cardinality_dict=cat_cardinality_dict,
    enc_transformer_config=enc_transformer_config,
    dec_transformer_config=dec_transformer_config,
    model_path=retokenized_model,
    epochs=fine_epochs,
    lr=fine_lr,
    warmup_epochs=fine_warmup_epochs,
    warmup_start_lr=fine_warmup_start,
    min_lr=min_lr,
    output_dim=1,
)

In [None]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=fine_batch_size,
    num_workers=num_workers,
    pin_memory=True,
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=fine_eval_batch_size,
    num_workers=num_workers,
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=len(test_dataset),
    num_workers=num_workers,
)

In [None]:
fine_trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    max_epochs=fine_epochs,
    precision=16,
    check_val_every_n_epoch=1,
    callbacks=[
        ModelCheckpoint(
            monitor="val_loss",
            mode="min",
            filename="best",
            save_top_k=1,
            save_last=True,
            verbose=True,
            dirpath=join(CHECKPOINT_DIR, "finetune"),
        ),
        LearningRateMonitor(logging_interval="step"),
        EarlyStopping(
            monitor="val_loss",
            min_delta=0.00,
            patience=fine_patience,
            verbose=False,
            mode="min",
        ),
        RichProgressBar(),
        # logger=WandbLogger(
        #     project="tabret-finetune",
        #     entity="vector-ssl-bootcamp",
        #     save_dir=join(LOG_DIR, "finetune"),
        # ),
    ],
    log_every_n_steps=1,
)

In [None]:
fine_trainer.fit(model=fine_model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
fine_trainer.test(dataloaders=test_dataloader, ckpt_path=join(CHECKPOINT_DIR, "finetune", "last.ckpt"))

## Baseline

We employ XGBoost Classifier as our baseline model, training and testing it on the same data subsets used for finetuning. To optimize hyperparameters, we utilize the Optuna library. Categorical features are converted to numerical format using Ordinal Encoding, after which they are treated as numerical features.


In [None]:
if common_cate_columns:
    base_common_enc = OneHotEncoder(handle_unknown="ignore").fit(brfss_df_copy[common_cate_columns])

if diff_cate_columns:
    base_diff_enc = OneHotEncoder(handle_unknown="ignore").fit(stroke_df_copy[diff_cate_columns])

In [None]:
def get_columns(columns: List[str], categories: List[np.ndarray]) -> List[str]:
    columns_after = []
    for col, cates in zip(columns, categories):
        for cate in cates:
            columns_after.append(f"{col}-{cate}")
    return columns_after


if common_cate_columns:
    new_df = pd.DataFrame(
        base_common_enc.transform(stroke_df_copy[common_cate_columns]).toarray().astype("int64"),
        columns=get_columns(common_cate_columns, base_common_enc.categories_),
    )
    stroke_df_copy = pd.concat([stroke_df_copy, new_df], axis=1)
    stroke_df_copy = stroke_df_copy.drop(common_cate_columns, axis=1)

if diff_cate_columns:
    new_df = pd.DataFrame(
        base_diff_enc.transform(stroke_df_copy[diff_cate_columns]).toarray().astype("int64"),
        columns=get_columns(diff_cate_columns, base_diff_enc.categories_),
    )
    stroke_df_copy = pd.concat([stroke_df_copy, new_df], axis=1)
    stroke_df_copy = stroke_df_copy.drop(diff_cate_columns, axis=1)

In [None]:
continuous_columns = list(stroke_df_copy.drop(target_columns, axis=1).columns)
categorical_columns = []
feature_columns = continuous_columns + categorical_columns

In [None]:
train_base, test_base = train_test_split(
    stroke_df_copy, test_size=0.20, random_state=SEED, stratify=stroke_df_copy["stroke"]
)

fine_base, fval_base = train_test_split(
    train_base,
    train_size=100,
    random_state=SEED,
    stratify=train_base["stroke"],
)

In [None]:
def objective(trial: optuna.Trial) -> float:
    params = {
        "max_depth": trial.suggest_int("max_depth", 1, 11),
        "n_estimators": trial.suggest_int("n_estimators", 100, 5900, 200),
        "min_child_weight": trial.suggest_int("min_child_weight", 1, 1e2),
        "subsample": trial.suggest_float("subsample", 0.5, 1.0),
        "learning_rate": trial.suggest_float("learning_rate", 1e-5, 0.7),
        "colsample_bylevel": trial.suggest_float("colsample_bylevel", 0.5, 1.0),
        "colsample_bytree": trial.suggest_float("colsample_bytree", 0.5, 1.0),
        "gamma": trial.suggest_float("gamma", 1e-8, 7),
        "lambda": trial.suggest_float("lambda", 1, 4),
        "alpha": trial.suggest_float("alpha", 1e-8, 1e2),
    }

    params.update({"objective": "binary:logistic", "eval_metric": "auc"})
    model = xgb.XGBClassifier(
        **params,
        random_state=SEED,
        early_stopping_rounds=20,
    )
    model.fit(
        fine_base[feature_columns],
        fine_base[target_columns],
        eval_set=[(fval_base[feature_columns], fval_base[target_columns])],
        verbose=False,
    )
    pred = model.predict_proba(fval_base[feature_columns])[:, 1]
    auc = roc_auc_score(fval_base[target_columns], pred)
    return auc

In [None]:
study = optuna.create_study(
    direction="maximize",
    sampler=optuna.samplers.TPESampler(seed=SEED),
    load_if_exists=True,
)
study.optimize(objective, callbacks=[optuna.study.MaxTrialsCallback(500)])
best_params = study.best_params
best_params.update({"objective": "binary:logistic", "eval_metric": "auc"})
best_params

In [None]:
base_model = xgb.XGBClassifier(
    **best_params,
    random_state=SEED,
    early_stopping_rounds=20,
)

In [None]:
base_model.fit(
    fine_base[feature_columns],
    fine_base[target_columns],
    eval_set=[(fval_base[feature_columns], fval_base[target_columns])],
    verbose=False,
)

In [None]:
prediction = base_model.predict_proba(test_base[feature_columns])[:, 1]
target = test_base[target_columns]
label = (prediction > 0.5).astype(np.int64)
base_score = {
    "ACC": accuracy_score(target, label),
    "AUC": roc_auc_score(target, prediction),
}
base_score