In [1]:
from time import sleep
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from otc.data.dataloader import TabDataLoader
from otc.models.activation import ReGLU
from otc.models.fttransformer import (
    CategoricalFeatureTokenizer,
    CLSToken,
    FeatureTokenizer,
    FTTransformer,
    MultiheadAttention,
    NumericalFeatureTokenizer,
    Transformer,
)
from otc.models.tabtransformer import TabTransformer

from tqdm import tqdm

import os 
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"


In [2]:
# code adapted from here:
# https://towardsdatascience.com/a-batch-too-large-finding-the-batch-size-that-fits-on-gpus-aef70902a9f1


# dataset information
CAT_CARDINALITY = 9_000
NUM_FEATURES_CAT = 3
NUM_FEATURES_CONT = 41
DATASET_SIZE = 50_000_000


def get_batch_size(
    model: nn.Module,
    device: torch.device,
    min_batch_size: int = 2,
    max_batch_size: Optional[int] = None,
    num_iterations: int = 5,
) -> int:
    #print(model)
    model.to(device)
    model.train(True)
    optimizer = torch.optim.AdamW(model.parameters())

    print("Test batch size")
    batch_size = min_batch_size
    while True:
        if max_batch_size is not None and batch_size >= max_batch_size:
            batch_size = max_batch_size
            break
        if batch_size >= DATASET_SIZE:
            batch_size = batch_size // 2
            break
        try:
            for _ in range(num_iterations):
                # dummy inputs and targets

                x_cat = torch.randint(
                    1, CAT_CARDINALITY, (batch_size, NUM_FEATURES_CAT)
                ).to(device)
                x_cont = torch.rand((batch_size, NUM_FEATURES_CONT)).to(device)
                targets = torch.randint(0, 1, (batch_size, 1)).float().to(device)
                outputs = model(x_cat, x_cont)
                loss = F.binary_cross_entropy_with_logits(outputs, targets)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
            batch_size *= 2
            print(f"\tTesting batch size {batch_size}")
            sleep(3)
        except RuntimeError as e:
            print(e)
            print(f"\tOOM at batch size {batch_size}")
            batch_size //= 2
            break
    del model, optimizer
    torch.cuda.empty_cache()
    print(f"Final batch size {batch_size}")
    return batch_size


def get_datasets(batch_size: int, num_workers: int = 2):

    x_cat = torch.randint(0, CAT_CARDINALITY, (DATASET_SIZE, NUM_FEATURES_CAT))
    x_cont = torch.rand((DATASET_SIZE, NUM_FEATURES_CONT))
    weight = torch.ones((DATASET_SIZE, 1))
    y = torch.randint(0, 1, (DATASET_SIZE, 1))

    train_ds = TabDataLoader(
        x_cat,
        x_cont,
        weight,
        y,
        batch_size=batch_size,
        shuffle=False,
    )

    test_ds = TabDataLoader(
        x_cat,
        x_cont,
        weight,
        y,
        batch_size=batch_size,
        shuffle=False,
    )
    return train_ds, test_ds


def main(epochs: int = 2):
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available.")

    device = torch.device("cuda")

    # https://github.com/Yura52/rtdl/blob/main/rtdl/modules.py
    # set max cardinality for all categorical features and max dimension from search space
    # params_feature_tokenizer = {
    #     "num_continous": NUM_FEATURES_CONT,
    #     "cat_cardinalities": [CAT_CARDINALITY] * NUM_FEATURES_CAT,
    #     "d_token": 512,
    # }
    # feature_tokenizer = FeatureTokenizer(**params_feature_tokenizer)
    # params_transformer = {
    #     "d_token": 512,
    #     "n_blocks": 6,
    #     "attention_n_heads": 8,
    #     "attention_initialization": "kaiming",
    #     "ffn_activation": ReGLU,
    #     "attention_normalization": nn.LayerNorm,
    #     "ffn_normalization": nn.LayerNorm,
    #     "ffn_dropout": 0.1,
    #     "ffn_d_hidden": int(512 * (4 / 3)),
    #     "attention_dropout": 0.1,
    #     "residual_dropout": 0.1,
    #     "prenormalization": True,
    #     "first_prenormalization": False,
    #     "last_layer_query_idx": None,
    #     "n_tokens": None,
    #     "kv_compression_ratio": None,
    #     "kv_compression_sharing": None,
    #     "head_activation": nn.ReLU,
    #     "head_normalization": nn.LayerNorm,
    #     "d_out": 1,
    # }

    # transformer = Transformer(**params_transformer)

    # model = FTTransformer(feature_tokenizer, transformer)

    module_params = {
            "depth": 12,
            "heads":8,
            "dim": 256,
            "dim_out": 1,
            "mlp_act": nn.ReLU,
            "transformer_act": F.gelu,
            "transformer_norm_first": False,
            "mlp_hidden_mults": (4, 2),
            "transformer_dropout": 0.5,
            "cat_cardinalities": [CAT_CARDINALITY] * NUM_FEATURES_CAT,
            "cat_features": NUM_FEATURES_CAT,
            "num_continuous": NUM_FEATURES_CONT,
        }

    model = TabTransformer(
        **module_params
    )


    batch_size = get_batch_size(
        model=model,
        device=device,
        min_batch_size=32,
        max_batch_size= 1024 * 1024,
    )


if __name__ == "__main__":
    main()


Test batch size
	Testing batch size 64
	Testing batch size 128
	Testing batch size 256
CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`
	OOM at batch size 256


RuntimeError: CUDA error: device-side assert triggered