<a target="_blank" href="https://colab.research.google.com/github/yandex-research/rtdl-revisiting-models/blob/main/package/example.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

---

**See also** [RTDL](https://github.com/yandex-research/rtdl)
-- **other projects on tabular deep learning**.

---

- This notebook provides a usage example of the
  [rtdl_revisiting_models](https://github.com/yandex-research/rtdl-revisiting-models)
  package.
- Hyperparameters are not tuned and may be suboptimal.

In [1]:
%pip install delu==0.0.23
%pip install rtdl_revisiting_models

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
# ruff: noqa: E402
import math
import warnings
from typing import Dict, Literal

warnings.simplefilter("ignore")
import delu  # Deep Learning Utilities: https://github.com/Yura52/delu
import numpy as np
import scipy.special
import sklearn.datasets
import sklearn.metrics
import sklearn.model_selection
import sklearn.preprocessing
import torch
import torch.nn.functional as F
import torch.optim
from torch import Tensor
from tqdm.std import tqdm

warnings.resetwarnings()

from rtdl_revisiting_models import MLP, ResNet, FTTransformer

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Set random seeds in all libraries.
delu.random.seed(0)

0

## Dataset

In [4]:
# >>> Dataset.
TaskType = Literal["regression", "binclass", "multiclass"]

task_type: TaskType = "regression"
n_classes = None
dataset = sklearn.datasets.fetch_california_housing()
X_cont: np.ndarray = dataset["data"]
Y: np.ndarray = dataset["target"]

# NOTE: uncomment to solve a classification task.
# n_classes = 2
# assert n_classes >= 2
# task_type: TaskType = 'binclass' if n_classes == 2 else 'multiclass'
# X_cont, Y = sklearn.datasets.make_classification(
#     n_samples=20000,
#     n_features=8,
#     n_classes=n_classes,
#     n_informative=3,
#     n_redundant=2,
# )

# >>> Continuous features.
X_cont: np.ndarray = X_cont.astype(np.float32)
n_cont_features = X_cont.shape[1]

# >>> Categorical features.
# NOTE: the above datasets do not have categorical features, but,
# for the demonstration purposes, it is possible to generate them.
cat_cardinalities = [
    # NOTE: uncomment the two lines below to add two categorical features.
    # 4,  # Allowed values: [0, 1, 2, 3].
    # 7,  # Allowed values: [0, 1, 2, 3, 4, 5, 6].
]
X_cat = (
    np.column_stack(
        [np.random.randint(0, c, (len(X_cont),)) for c in cat_cardinalities]
    )
    if cat_cardinalities
    else None
)

# >>> Labels.
# Regression labels must be represented by float32.
if task_type == "regression":
    Y = Y.astype(np.float32)
else:
    assert n_classes is not None
    Y = Y.astype(np.int64)
    assert set(Y.tolist()) == set(
        range(n_classes)
    ), "Classification labels must form the range [0, 1, ..., n_classes - 1]"

# >>> Split the dataset.
all_idx = np.arange(len(Y))
trainval_idx, test_idx = sklearn.model_selection.train_test_split(
    all_idx, train_size=0.8
)
train_idx, val_idx = sklearn.model_selection.train_test_split(
    trainval_idx, train_size=0.8
)
data_numpy = {
    "train": {"x_cont": X_cont[train_idx], "y": Y[train_idx]},
    "val": {"x_cont": X_cont[val_idx], "y": Y[val_idx]},
    "test": {"x_cont": X_cont[test_idx], "y": Y[test_idx]},
}
if X_cat is not None:
    data_numpy["train"]["x_cat"] = X_cat[train_idx]
    data_numpy["val"]["x_cat"] = X_cat[val_idx]
    data_numpy["test"]["x_cat"] = X_cat[test_idx]

In [6]:
data_numpy

{'train': {'x_cont': array([[ 1.7792566 ,  0.5621234 ,  1.1550618 , ..., -0.6669884 ,
           0.470754  , -1.0061344 ],
         [ 0.55808   ,  0.5621234 , -0.5218781 , ...,  0.9387544 ,
          -0.6907074 ,  0.59090346],
         [-0.44789276,  1.3993868 , -0.69279766, ..., -0.99130523,
           0.7537073 , -1.6643003 ],
         ...,
         [ 1.4408891 , -0.61646837,  1.4222528 , ...,  1.3905935 ,
          -1.1164329 ,  0.7613144 ],
         [ 0.38069886,  0.4390957 , -0.26269946, ...,  1.4723341 ,
          -0.6907074 ,  0.64432883],
         [-0.15040943,  0.2539367 ,  0.87688   , ..., -0.60717887,
           1.9112622 , -0.591341  ]], dtype=float32),
  'y': array([ 2.5195925 , -0.21453126,  1.0480323 , ...,  0.58601344,
         -0.48656097, -1.0850265 ], dtype=float32)},
 'val': {'x_cont': array([[ 1.3481432 , -1.571041  ,  1.1630402 , ..., -0.5179027 ,
          -1.5368205 ,  1.267422  ],
         [-0.74417   ,  0.18374428, -1.0771673 , ..., -1.4639283 ,
          -0.2

## Preprocessing

In [5]:
# >>> Feature preprocessing.
# NOTE
# The choice between preprocessing strategies depends on a task and a model.

# (A) Simple preprocessing strategy.
# preprocessing = sklearn.preprocessing.StandardScaler().fit(
#     data_numpy['train']['x_cont']
# )

# (B) Fancy preprocessing strategy.
# The noise is added to improve the output of QuantileTransformer in some cases.
X_cont_train_numpy = data_numpy["train"]["x_cont"]
noise = (
    np.random.default_rng(0)
    .normal(0.0, 1e-5, X_cont_train_numpy.shape)
    .astype(X_cont_train_numpy.dtype)
)
preprocessing = sklearn.preprocessing.QuantileTransformer(
    n_quantiles=max(min(len(train_idx) // 30, 1000), 10),
    output_distribution="normal",
    subsample=10**9,
).fit(X_cont_train_numpy + noise)
del X_cont_train_numpy

for part in data_numpy:
    data_numpy[part]["x_cont"] = preprocessing.transform(data_numpy[part]["x_cont"])

# >>> Label preprocessing.
if task_type == "regression":
    Y_mean = data_numpy["train"]["y"].mean().item()
    Y_std = data_numpy["train"]["y"].std().item()
    for part in data_numpy:
        data_numpy[part]["y"] = (data_numpy[part]["y"] - Y_mean) / Y_std

# >>> Convert data to tensors.
data = {
    part: {k: torch.as_tensor(v, device=device) for k, v in data_numpy[part].items()}
    for part in data_numpy
}

if task_type != "multiclass":
    # Required by F.binary_cross_entropy_with_logits
    for part in data:
        data[part]["y"] = data[part]["y"].float()

## Model

In [6]:
# The output size.
d_out = n_classes if task_type == "multiclass" else 1

# # NOTE: uncomment to train MLP
# model = MLP(
#     d_in=n_cont_features + sum(cat_cardinalities),
#     d_out=d_out,
#     n_blocks=2,
#     d_block=384,
#     dropout=0.1,
# ).to(device)
# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5)

# # NOTE: uncomment to train ResNet
# model = ResNet(
#     d_in=n_cont_features + sum(cat_cardinalities),
#     d_out=d_out,
#     n_blocks=2,
#     d_block=192,
#     d_hidden=None,
#     d_hidden_multiplier=2.0,
#     dropout1=0.3,
#     dropout2=0.0,
# ).to(device)
# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5)

model = FTTransformer(
    n_cont_features=n_cont_features,
    cat_cardinalities=cat_cardinalities,
    d_out=d_out,
    **FTTransformer.get_default_kwargs(),
).to(device)
optimizer = model.make_default_optimizer()

## Training

In [8]:
def apply_model(batch: Dict[str, Tensor]) -> Tensor:
    if isinstance(model, (MLP, ResNet)):
        x_cat_ohe = (
            [
                F.one_hot(column, cardinality)
                for column, cardinality in zip(batch["x_cat"].T, cat_cardinalities)
            ]
            if "x_cat" in batch
            else []
        )
        return model(torch.column_stack([batch["x_cont"]] + x_cat_ohe)).squeeze(-1)

    elif isinstance(model, FTTransformer):
        return model(batch["x_cont"], batch.get("x_cat")).squeeze(-1)

    else:
        raise RuntimeError(f"Unknown model type: {type(model)}")


loss_fn = (
    F.binary_cross_entropy_with_logits
    if task_type == "binclass"
    else F.cross_entropy
    if task_type == "multiclass"
    else F.mse_loss
)


@torch.no_grad()
def evaluate(part: str) -> float:
    model.eval()

    eval_batch_size = 8096
    y_pred = (
        torch.cat(
            [
                apply_model(batch)
                for batch in delu.iter_batches(data[part], eval_batch_size)
            ]
        )
        .cpu()
        .numpy()
    )
    y_true = data[part]["y"].cpu().numpy()

    if task_type == "binclass":
        y_pred = np.round(scipy.special.expit(y_pred))
        score = sklearn.metrics.accuracy_score(y_true, y_pred)
    elif task_type == "multiclass":
        y_pred = y_pred.argmax(1)
        score = sklearn.metrics.accuracy_score(y_true, y_pred)
    else:
        assert task_type == "regression"
        score = -(sklearn.metrics.mean_squared_error(y_true, y_pred) ** 0.5 * Y_std)
    return score  # The higher -- the better.


print(f'Test score before training: {evaluate("test"):.4f}')

Test score before training: -1.1242


In [9]:
# For demonstration purposes (fast training and bad performance),
# one can set smaller values:
# n_epochs = 20
# patience = 2
n_epochs = 1_000_000_000
patience = 16

batch_size = 256
epoch_size = math.ceil(len(train_idx) / batch_size)
timer = delu.tools.Timer()
early_stopping = delu.tools.EarlyStopping(patience, mode="max")
best = {
    "val": -math.inf,
    "test": -math.inf,
    "epoch": -1,
}

print(f"Device: {device.type.upper()}")
print("-" * 88 + "\n")
timer.run()
for epoch in range(n_epochs):
    for batch in tqdm(
        delu.iter_batches(data["train"], batch_size, shuffle=True),
        desc=f"Epoch {epoch}",
        total=epoch_size,
    ):
        model.train()
        optimizer.zero_grad()
        loss = loss_fn(apply_model(batch), batch["y"])
        loss.backward()
        optimizer.step()

    val_score = evaluate("val")
    test_score = evaluate("test")
    print(f"(val) {val_score:.4f} (test) {test_score:.4f} [time] {timer}")

    early_stopping.update(val_score)
    if early_stopping.should_stop():
        break

    if val_score > best["val"]:
        print("🌸 New best epoch! 🌸")
        best = {"val": val_score, "test": test_score, "epoch": epoch}
    print()

print("\n\nResult:")
print(best)

Device: CUDA
----------------------------------------------------------------------------------------



Epoch 0: 100%|██████████| 52/52 [00:00<00:00, 141.41it/s]


(val) -0.6054 (test) -0.6060 [time] 0:00:00.392096
🌸 New best epoch! 🌸



Epoch 1: 100%|██████████| 52/52 [00:00<00:00, 142.24it/s]


(val) -0.5811 (test) -0.5813 [time] 0:00:00.781521
🌸 New best epoch! 🌸



Epoch 2: 100%|██████████| 52/52 [00:00<00:00, 141.99it/s]


(val) -0.5645 (test) -0.5619 [time] 0:00:01.171774
🌸 New best epoch! 🌸



Epoch 3: 100%|██████████| 52/52 [00:00<00:00, 142.67it/s]


(val) -0.5521 (test) -0.5471 [time] 0:00:01.560298
🌸 New best epoch! 🌸



Epoch 4: 100%|██████████| 52/52 [00:00<00:00, 142.85it/s]


(val) -0.5586 (test) -0.5554 [time] 0:00:01.948285



Epoch 5: 100%|██████████| 52/52 [00:00<00:00, 142.28it/s]


(val) -0.5579 (test) -0.5481 [time] 0:00:02.337521



Epoch 6: 100%|██████████| 52/52 [00:00<00:00, 142.77it/s]


(val) -0.5256 (test) -0.5217 [time] 0:00:02.725734
🌸 New best epoch! 🌸



Epoch 7: 100%|██████████| 52/52 [00:00<00:00, 143.34it/s]


(val) -0.5284 (test) -0.5184 [time] 0:00:03.112472



Epoch 8: 100%|██████████| 52/52 [00:00<00:00, 143.29it/s]


(val) -0.5207 (test) -0.5202 [time] 0:00:03.499205
🌸 New best epoch! 🌸



Epoch 9: 100%|██████████| 52/52 [00:00<00:00, 143.01it/s]


(val) -0.5205 (test) -0.5162 [time] 0:00:03.886972
🌸 New best epoch! 🌸



Epoch 10: 100%|██████████| 52/52 [00:00<00:00, 142.58it/s]


(val) -0.5163 (test) -0.5096 [time] 0:00:04.275435
🌸 New best epoch! 🌸



Epoch 11: 100%|██████████| 52/52 [00:00<00:00, 142.92it/s]


(val) -0.5210 (test) -0.5186 [time] 0:00:04.663017



Epoch 12: 100%|██████████| 52/52 [00:00<00:00, 142.74it/s]


(val) -0.5281 (test) -0.5192 [time] 0:00:05.051185



Epoch 13: 100%|██████████| 52/52 [00:00<00:00, 142.94it/s]


(val) -0.5279 (test) -0.5250 [time] 0:00:05.438946



Epoch 14: 100%|██████████| 52/52 [00:00<00:00, 142.78it/s]


(val) -0.5033 (test) -0.5041 [time] 0:00:05.826898
🌸 New best epoch! 🌸



Epoch 15: 100%|██████████| 52/52 [00:00<00:00, 142.89it/s]


(val) -0.5029 (test) -0.5053 [time] 0:00:06.214594
🌸 New best epoch! 🌸



Epoch 16: 100%|██████████| 52/52 [00:00<00:00, 142.53it/s]


(val) -0.5033 (test) -0.5058 [time] 0:00:06.603162



Epoch 17: 100%|██████████| 52/52 [00:00<00:00, 142.48it/s]


(val) -0.5018 (test) -0.5043 [time] 0:00:06.991846
🌸 New best epoch! 🌸



Epoch 18: 100%|██████████| 52/52 [00:00<00:00, 142.91it/s]


(val) -0.4999 (test) -0.5011 [time] 0:00:07.379796
🌸 New best epoch! 🌸



Epoch 19: 100%|██████████| 52/52 [00:00<00:00, 143.02it/s]


(val) -0.4963 (test) -0.5003 [time] 0:00:07.767410
🌸 New best epoch! 🌸



Epoch 20: 100%|██████████| 52/52 [00:00<00:00, 142.98it/s]


(val) -0.4949 (test) -0.4951 [time] 0:00:08.155145
🌸 New best epoch! 🌸



Epoch 21: 100%|██████████| 52/52 [00:00<00:00, 142.64it/s]


(val) -0.5071 (test) -0.5103 [time] 0:00:08.543719



Epoch 22: 100%|██████████| 52/52 [00:00<00:00, 142.69it/s]


(val) -0.5000 (test) -0.4968 [time] 0:00:08.932403



Epoch 23: 100%|██████████| 52/52 [00:00<00:00, 142.73it/s]


(val) -0.4999 (test) -0.4960 [time] 0:00:09.320922



Epoch 24: 100%|██████████| 52/52 [00:00<00:00, 142.90it/s]


(val) -0.4972 (test) -0.4918 [time] 0:00:09.708883



Epoch 25: 100%|██████████| 52/52 [00:00<00:00, 142.92it/s]


(val) -0.5017 (test) -0.4949 [time] 0:00:10.096916



Epoch 26: 100%|██████████| 52/52 [00:00<00:00, 142.78it/s]


(val) -0.4861 (test) -0.4902 [time] 0:00:10.484974
🌸 New best epoch! 🌸



Epoch 27: 100%|██████████| 52/52 [00:00<00:00, 142.53it/s]


(val) -0.4967 (test) -0.4994 [time] 0:00:10.873696



Epoch 28: 100%|██████████| 52/52 [00:00<00:00, 142.66it/s]


(val) -0.4894 (test) -0.4882 [time] 0:00:11.262060



Epoch 29: 100%|██████████| 52/52 [00:00<00:00, 142.54it/s]


(val) -0.4931 (test) -0.4957 [time] 0:00:11.650885



Epoch 30: 100%|██████████| 52/52 [00:00<00:00, 142.59it/s]


(val) -0.4904 (test) -0.4864 [time] 0:00:12.039530



Epoch 31: 100%|██████████| 52/52 [00:00<00:00, 142.49it/s]


(val) -0.4903 (test) -0.4881 [time] 0:00:12.428463



Epoch 32: 100%|██████████| 52/52 [00:00<00:00, 142.68it/s]


(val) -0.4900 (test) -0.4873 [time] 0:00:12.817033



Epoch 33: 100%|██████████| 52/52 [00:00<00:00, 142.68it/s]


(val) -0.4903 (test) -0.4914 [time] 0:00:13.205344



Epoch 34: 100%|██████████| 52/52 [00:00<00:00, 142.44it/s]


(val) -0.4866 (test) -0.4870 [time] 0:00:13.594308



Epoch 35: 100%|██████████| 52/52 [00:00<00:00, 142.44it/s]


(val) -0.4943 (test) -0.4980 [time] 0:00:13.983209



Epoch 36: 100%|██████████| 52/52 [00:00<00:00, 142.55it/s]


(val) -0.4928 (test) -0.4938 [time] 0:00:14.372026



Epoch 37: 100%|██████████| 52/52 [00:00<00:00, 142.77it/s]


(val) -0.4890 (test) -0.4859 [time] 0:00:14.760250



Epoch 38: 100%|██████████| 52/52 [00:00<00:00, 142.82it/s]


(val) -0.4988 (test) -0.4951 [time] 0:00:15.148408



Epoch 39: 100%|██████████| 52/52 [00:00<00:00, 142.36it/s]


(val) -0.4779 (test) -0.4846 [time] 0:00:15.537572
🌸 New best epoch! 🌸



Epoch 40: 100%|██████████| 52/52 [00:00<00:00, 142.33it/s]


(val) -0.4934 (test) -0.4953 [time] 0:00:15.927020



Epoch 41: 100%|██████████| 52/52 [00:00<00:00, 143.07it/s]


(val) -0.4937 (test) -0.4992 [time] 0:00:16.314155



Epoch 42: 100%|██████████| 52/52 [00:00<00:00, 142.79it/s]


(val) -0.4821 (test) -0.4850 [time] 0:00:16.702016



Epoch 43: 100%|██████████| 52/52 [00:00<00:00, 142.35it/s]


(val) -0.4896 (test) -0.4926 [time] 0:00:17.091193



Epoch 44: 100%|██████████| 52/52 [00:00<00:00, 142.54it/s]


(val) -0.4811 (test) -0.4880 [time] 0:00:17.479864



Epoch 45: 100%|██████████| 52/52 [00:00<00:00, 142.55it/s]


(val) -0.5041 (test) -0.5027 [time] 0:00:17.868686



Epoch 46: 100%|██████████| 52/52 [00:00<00:00, 142.73it/s]


(val) -0.4868 (test) -0.4854 [time] 0:00:18.256951



Epoch 47: 100%|██████████| 52/52 [00:00<00:00, 142.37it/s]


(val) -0.4910 (test) -0.4969 [time] 0:00:18.646070



Epoch 48: 100%|██████████| 52/52 [00:00<00:00, 142.46it/s]


(val) -0.4899 (test) -0.4919 [time] 0:00:19.034916



Epoch 49: 100%|██████████| 52/52 [00:00<00:00, 142.79it/s]


(val) -0.4874 (test) -0.4883 [time] 0:00:19.423114



Epoch 50: 100%|██████████| 52/52 [00:00<00:00, 142.74it/s]


(val) -0.4774 (test) -0.4917 [time] 0:00:19.811330
🌸 New best epoch! 🌸



Epoch 51: 100%|██████████| 52/52 [00:00<00:00, 142.52it/s]


(val) -0.5030 (test) -0.5027 [time] 0:00:20.200258



Epoch 52: 100%|██████████| 52/52 [00:00<00:00, 142.68it/s]


(val) -0.4833 (test) -0.4838 [time] 0:00:20.589445



Epoch 53: 100%|██████████| 52/52 [00:00<00:00, 142.44it/s]


(val) -0.4864 (test) -0.4882 [time] 0:00:20.978184



Epoch 54: 100%|██████████| 52/52 [00:00<00:00, 142.42it/s]


(val) -0.4844 (test) -0.4925 [time] 0:00:21.367320



Epoch 55: 100%|██████████| 52/52 [00:00<00:00, 95.65it/s] 


(val) -0.4904 (test) -0.4878 [time] 0:00:21.936014



Epoch 56: 100%|██████████| 52/52 [00:00<00:00, 87.34it/s]


(val) -0.4822 (test) -0.4937 [time] 0:00:22.556915



Epoch 57: 100%|██████████| 52/52 [00:00<00:00, 81.30it/s]


(val) -0.4838 (test) -0.4879 [time] 0:00:23.221617



Epoch 58: 100%|██████████| 52/52 [00:00<00:00, 83.74it/s]


(val) -0.4765 (test) -0.4934 [time] 0:00:23.868118
🌸 New best epoch! 🌸



Epoch 59: 100%|██████████| 52/52 [00:00<00:00, 117.30it/s]


(val) -0.4880 (test) -0.4944 [time] 0:00:24.335913



Epoch 60: 100%|██████████| 52/52 [00:00<00:00, 142.97it/s]


(val) -0.4881 (test) -0.4918 [time] 0:00:24.723671



Epoch 61: 100%|██████████| 52/52 [00:00<00:00, 103.96it/s]


(val) -0.4836 (test) -0.4913 [time] 0:00:25.248160



Epoch 62: 100%|██████████| 52/52 [00:00<00:00, 136.96it/s]


(val) -0.4863 (test) -0.4978 [time] 0:00:25.651465



Epoch 63: 100%|██████████| 52/52 [00:00<00:00, 137.29it/s]


(val) -0.4904 (test) -0.4999 [time] 0:00:26.053718



Epoch 64: 100%|██████████| 52/52 [00:00<00:00, 137.01it/s]


(val) -0.4844 (test) -0.4967 [time] 0:00:26.456816



Epoch 65: 100%|██████████| 52/52 [00:00<00:00, 137.06it/s]


(val) -0.4825 (test) -0.5001 [time] 0:00:26.859736



Epoch 66: 100%|██████████| 52/52 [00:00<00:00, 136.83it/s]


(val) -0.4796 (test) -0.4906 [time] 0:00:27.263285



Epoch 67: 100%|██████████| 52/52 [00:00<00:00, 137.07it/s]


(val) -0.4923 (test) -0.4965 [time] 0:00:27.666159



Epoch 68: 100%|██████████| 52/52 [00:00<00:00, 136.35it/s]


(val) -0.4821 (test) -0.4951 [time] 0:00:28.071089



Epoch 69: 100%|██████████| 52/52 [00:00<00:00, 137.22it/s]


(val) -0.4862 (test) -0.4959 [time] 0:00:28.473585



Epoch 70: 100%|██████████| 52/52 [00:00<00:00, 136.75it/s]


(val) -0.4851 (test) -0.4964 [time] 0:00:28.877470



Epoch 71: 100%|██████████| 52/52 [00:00<00:00, 137.28it/s]


(val) -0.4886 (test) -0.4987 [time] 0:00:29.279790



Epoch 72: 100%|██████████| 52/52 [00:00<00:00, 137.10it/s]


(val) -0.4823 (test) -0.4917 [time] 0:00:29.682639



Epoch 73: 100%|██████████| 52/52 [00:00<00:00, 137.43it/s]


(val) -0.4946 (test) -0.4992 [time] 0:00:30.084554



Epoch 74: 100%|██████████| 52/52 [00:00<00:00, 137.33it/s]

(val) -0.4856 (test) -0.4986 [time] 0:00:30.486724


Result:
{'val': -0.47645997466732953, 'test': -0.49340017062723307, 'epoch': 58}



