<a href="https://colab.research.google.com/github/KarelZe/thesis/blob/speedy-transformer/05_tab_transformer_draft.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [40]:
"""
Implementation of a TabTransformer.
Based on paper:
https://arxiv.org/abs/2012.06678
Implementation adapted from: https://github.com/lucidrains/tab-transformer-pytorch
"""
from __future__ import annotations

from typing import Any, Callable, cast

import torch
import torch.nn.functional as F
from torch import einsum, nn

import torch
import torch.nn.functional as F
from torch import nn


class GeGLU(nn.Module):
    r"""
    Implementation of the GeGLU activation function.
    Given by:
    $\operatorname{GeGLU}(x, W, V, b, c)=\operatorname{GELU}(x W+b) \otimes(x V+c)$
    Proposed in https://arxiv.org/pdf/2002.05202v1.pdf.
    Args:
        nn (torch.Tensor): module
    """

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of GeGlU activation.
        Args:
            x (torch.Tensor): input tensor.
        Returns:
            torch.Tensor: output tensor.
        """
        assert x.shape[-1] % 2 == 0
        x, gates = x.chunk(2, dim=-1)
        return x * F.gelu(gates)


class ReGLU(nn.Module):
    r"""
    Implementation of the GeGLU activation function.
    Given by:
    Proposed in https://arxiv.org/pdf/2002.05202v1.pdf.
    Args:
        nn (torch.Tensor): module
    """

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of GeGlU activation.
        Args:
            x (torch.Tensor): input tensor.
        Returns:
            torch.Tensor: output tensor.
        """
        assert x.shape[-1] % 2 == 0
        x, gates = x.chunk(2, dim=-1)
        return x * F.relu(gates)

class Residual(nn.Module):
    """
    PyTorch implementation of residual connections.
    Args:
        nn (nn.Module): module
    """

    def __init__(self, fn: nn.Module):
        """
        Residual connection.
        Args:
            fn (nn.Module): network.
        """
        super().__init__()
        self.fn = fn

    def forward(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
        """
        Forward pass of residual connections.
        Args:
            x (torch.Tensor): input tensor.
        Returns:
            torch.Tensor: output tensor.
        """
        out = self.fn(x, **kwargs)
        if isinstance(out, tuple):
            out, _ = out
        return out + x


class PreNorm(nn.Module):
    """
    PyTorch implementation of pre-normalization.
    Args:
        nn (nn.module): module.
    """

    def __init__(self, dim: int, fn: nn.Module):
        """
        Pre-normalization.
        Consists of layer for layer normalization followed by another network.
        Args:
            dim (int): Number of dimensions of normalized shape.
            fn (nn.Module): network.
        """
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
        """
        Forward pass of pre-normalization layers.
        Args:
            x (torch.Tensor): input tensor.
        Returns:
            torch.Tensor: output tensor.
        """
        return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Module):
    """
    PyTorch implementation of feed forward network.
    Args:
        nn (nn.module): module.
    """

    def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0):
        """
        Feed forward network.
        Network consists of input layer, GEGLU activation, dropout layer,
        and output layer.
        Args:
            dim (int): dimension of input and output layer.
            mult (int, optional): Scaling factor for output dimension of input layer or
            input dimension of output layer. Defaults to 4.
            dropout (float, optional): Degree of dropout. Defaults to 0.0.
        """
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2),
            GeGLU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim),
        )

    def forward(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
        """
        Forward pass of feed forward network.
        Args:
            x (torch.Tensor): input tensor.
        Returns:
            torch.Tensor: output tensor.
        """
        return self.net(x)


class Attention(nn.Module):
    """
    Pytorch implementation of attention.
    Args:
        nn (nn.Module): module.
    """

    def __init__(
        self, dim: int, n_heads: int = 8, dim_head: int = 16, dropout: float = 0.0
    ):
        """
        Attention.
        Args:
            dim (int): Number of dimensions.
            n_heads (int, optional): Number of attention heads. Defaults to 8.
            dim_head (int, optional): Dimension of attention heads. Defaults to 16.
            dropout (float, optional): Degree of dropout. Defaults to 0.0.
        """
        super().__init__()
        inner_dim = dim_head * n_heads
        self.n_heads = n_heads
        self.scale = dim_head**-0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.to_out = nn.Linear(inner_dim, dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
        """
        Forward pass of attention module.
        Args:
            x (torch.Tensor): input tensor.
        Returns:
            Tuple[torch.Tensor, Dict[str, torch.Tensor]]: Tuple with tokens and
            attention_stats
        """
        q, k, v = self.to_qkv(x).chunk(3, dim=-1)
        b, n, _ = q.shape
        # reshape and permute: b n (h d) -> b h n d
        q, k, v = map(
            lambda t: t.reshape(b, n, self.n_heads, -1).permute(0, 2, 1, 3), (q, k, v)
        )
        attention_logits = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
        attention_probs = attention_logits.softmax(dim=-1)
        attention_probs = self.dropout(attention_probs)
        out = einsum("b h i j, b h j d -> b h i d", attention_probs, v)
        # reshape and permute: b h i j, b h j d -> b h i d
        out = out.permute(0, 2, 1, 3).reshape(b, n, -1)

        return self.to_out(out), {
            "attention_logits": attention_logits,
            "attention_probs": attention_probs,
        }


class Transformer(nn.Module):
    """
    Transformer.
    Based on paper:
    https://arxiv.org/abs/1706.03762
    Args:
        nn (nn.Module): Module with transformer.
    """

    def __init__(
        self,
        num_tokens: int,
        dim: int,
        depth: int,
        heads: int,
        dim_head: int,
        attn_dropout: float,
        ff_dropout: float,
    ):
        """
        Classical transformer.
        Args:
            num_tokens (int): Number of tokens i. e., unique classes + special tokens.
            dim (int): Number of dimensions.
            depth (int): Depth of encoder / decoder.
            heads (int): Number of attention heads.
            dim_head (int): Dimensions of attention heads.
            attn_dropout (float): Degree of dropout in attention.
            ff_dropout (float): Degree of dropout in feed-forward network.
        """
        super().__init__()
        self.embeds = nn.Embedding(num_tokens, dim)  # (Embed the categorical features.)
        self.blocks = nn.ModuleList([])

        for _ in range(depth):
            self.blocks.append(
                nn.ModuleDict(
                    {
                        "attention": Residual(
                            PreNorm(
                                dim,
                                Attention(
                                    dim,
                                    n_heads=heads,
                                    dim_head=dim_head,
                                    dropout=attn_dropout,
                                ),
                            )
                        ),
                        "ffn": Residual(
                            PreNorm(dim, FeedForward(dim, dropout=ff_dropout))
                        ),
                    }
                )
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of transformer.
        Args:
            x (torch.Tensor): input tensor.
        Returns:
            torch.Tensor: output tensor.
        """
        x = self.embeds(x)

        for layer in self.blocks:
            layer = cast(nn.ModuleDict, layer)
            x = layer["attention"](x)
            x = layer["ffn"](x)

        return x


class MLP(nn.Module):
    """
    Pytorch model of a vanilla multi-layer perceptron.
    Args:
        nn (nn.Module): module with implementation of MLP.
    """

    def __init__(self, dims: list[int], act: Callable[..., nn.Module]):
        """
        Multilayer perceptron.
        Depth of network is given by `len(dims)`. Capacity is given by entries
        of `dim`. Activation function is used after each linear layer. There is
        no activation function for the final linear layer, as it is sometimes part
        of the loss function already e. g., `nn.BCEWithLogitsLoss()`.
        Args:
            dims (List[int]): List with dimensions of layers.
            act (Callable[..., nn.Module]): Activation function of each linear
            layer.
        """
        super().__init__()
        dims_pairs = list(zip(dims[:-1], dims[1:]))
        layers = []
        for dim_in, dim_out in dims_pairs:
            linear = nn.Linear(dim_in, dim_out)
            layers.append(linear)
            layers.append(act())

        # drop last layer, as a sigmoid layer is included from BCELogitLoss
        del layers[-1]

        self.mlp = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward propagate tensor through MLP.
        Args:
            x (torch.Tensor): input tensor.
        Returns:
            torch.Tensor: output tensor.
        """
        return self.mlp(x)


class TabTransformer(nn.Module):
    """
    PyTorch model of TabTransformer.
    Based on paper:
    https://arxiv.org/abs/2012.06678
    Args:
        nn (nn.Module): Module with implementation of TabTransformer.
    """

    def __init__(
        self,
        *,
        cat_cardinalities: tuple[int, ...] | tuple[()],
        num_continuous: int,
        dim: int = 32,
        depth: int = 4,
        heads: int = 8,
        dim_head: int = 16,
        dim_out: int = 1,
        mlp_hidden_mults: tuple[(int, int)] = (4, 2),
        mlp_act: Callable[..., nn.Module] = nn.ReLU,
        num_special_tokens: int = 9,
        continuous_mean_std: torch.Tensor | None = None,
        attn_dropout: float = 0.0,
        ff_dropout: float = 0.0,
        **kwargs: Any,
    ):
        """
        TabTransformer.
        Originally introduced in https://arxiv.org/abs/2012.06678.
        Args:
            cat_cardinalities ([List[int] | Tuple[()]): List with number of categories
            for each categorical feature. If no categorical variables are present,
            use empty tuple. For categorical variables e. g., option type ('C' or 'P'),
            the list would be `[1]`.
            num_continuous (int): Number of continous features.
            dim (int, optional): Dimensionality of transformer. Defaults to 32.
            depth (int, optional): Depth of encoder / decoder of transformer.
            Defaults to 4.
            heads (int, optional): Number of attention heads. Defaults to 8.
            dim_head (int, optional): Dimensionality of attention head. Defaults to 16.
            dim_out (int, optional): Dimension of output layer of MLP. Set to one for
            binary classification. Defaults to 1.
            mlp_hidden_mults (Tuple[(int, int)], optional): multipliers for dimensions
            of hidden layer in MLP. Defaults to (4, 2).
            mlp_act (Callable[..., nn.Module], optional): Activation function used
            in MLP. Defaults to nn.ReLU().
            num_special_tokens (int, optional): Number of special tokens in transformer.
            Defaults to 2.
            continuous_mean_std (torch.Tensor | None): List with mean and
            std deviation of each continous feature. Shape eq. `[num_continous x 2]`.
            Defaults to None.
            attn_dropout (float, optional): Degree of attention dropout used in
            transformer. Defaults to 0.0.
            ff_dropout (float, optional): Dropout in feed forward net. Defaults to 0.0.
        """
        super().__init__()
        assert all(
            map(lambda n: n > 0, cat_cardinalities)
        ), "number of each category must be positive"

        # categories related calculations

        self.num_categories = len(cat_cardinalities)
        self.cardinality_categories = sum(cat_cardinalities)

        # create category embeddings table

        self.num_special_tokens = num_special_tokens
        total_tokens = self.cardinality_categories + num_special_tokens

        # for automatically offsetting unique category ids to the correct position
        #  in the categories embedding table

        categories_offset = F.pad(
            torch.tensor(list(cat_cardinalities)), (1, 0), value=num_special_tokens
        )  # Prepend num_special_tokens.
        print("categories_offset")
        print(categories_offset)
        categories_offset = categories_offset.cumsum(dim=-1)[:-1]
        self.register_buffer("categories_offset", categories_offset)
        print("categories_offset (cum sum)")
        print(categories_offset)

        # continuous

        if continuous_mean_std is not None:
            assert continuous_mean_std.shape == (num_continuous, 2,), (
                f"continuous_mean_std must have a shape of ({num_continuous}, 2)"
                f"where the last dimension contains the mean and variance respectively"
            )
        self.register_buffer("continuous_mean_std", continuous_mean_std)

        self.norm = nn.LayerNorm(num_continuous)
        self.num_continuous = num_continuous

        # transformer

        self.transformer = Transformer(
            num_tokens=total_tokens,
            dim=dim,
            depth=depth,
            heads=heads,
            dim_head=dim_head,
            attn_dropout=attn_dropout,
            ff_dropout=ff_dropout,
        )

        # mlp to logits
        input_size = (dim * self.num_categories) + num_continuous
        j = input_size // 8

        hidden_dimensions = list(map(lambda t: j * t, mlp_hidden_mults))
        all_dimensions = [input_size, *hidden_dimensions, dim_out]

        self.mlp = MLP(all_dimensions, act=mlp_act)

    def forward(self, x_cat: torch.Tensor | None, x_cont: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of TabTransformer.
        Args:
            x_cat (torch.Tensor | None): tensor with categorical data.
            x_cont (torch.Tensor): tensor with continous data.
        Returns:
            torch.Tensor: probabilities
        """
        flat_categ: torch.Tensor | None = None

        # print("x_cat")
        # print(x_cat)

        # print(x_cont)

        if x_cat is not None:
            assert x_cat.shape[-1] == self.num_categories, (
                f"you must pass in {self.num_categories} "
                f"values for your categories input"
            )
            print("categories_offset")
            print(self.categories_offset)

            print("x_cat")
            print(x_cat)

            x_cat += self.categories_offset
            print("x_cat + categories_offset")
            print(x_cat)
            x = self.transformer(x_cat)
            flat_categ = x.flatten(1)

        assert x_cont.shape[1] == self.num_continuous, (
            f"you must pass in {self.num_continuous} "
            f"values for your continuous input"
        )

        if self.continuous_mean_std is not None:
            mean, std = self.continuous_mean_std.unbind(dim=-1)  # type: ignore
            x_cont = (x_cont - mean) / std

        normed_cont = self.norm(x_cont)

        # Adaptation to work without categorical data
        x = (
            torch.cat((flat_categ, normed_cont), dim=-1)
            if flat_categ is not None
            else normed_cont
        )

        return self.mlp(x)

In [43]:
num_features_cont = 5
num_features_cat = 3
num_unique_cat = tuple([6, 4, 5])
batch_size = 8



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

x_cat = torch.randint(0, 2, (batch_size, num_features_cat)).to(
            device
        )
x_cont = (
            torch.randn(batch_size, num_features_cont).float().to(device)
        )
expected_outputs = (
            torch.randint(0, 1, (batch_size, num_features_cat)).float().to(device)
        )

net = TabTransformer(
            cat_cardinalities=num_unique_cat,
            num_continuous=num_features_cont,
            dim_out=1,
            mlp_act=nn.ReLU,
            dim=32,
            depth=2,
            heads=6,
            attn_dropout=0.1,
            ff_dropout=0.1,
            mlp_hidden_mults=(4, 2),
).to(device)

net(x_cat, x_cont)

categories_offset
tensor([9, 6, 4, 5])
categories_offset (cum sum)
tensor([ 9, 15, 19])
categories_offset
tensor([ 9, 15, 19])
x_cat
tensor([[0, 1, 1],
        [0, 0, 0],
        [0, 1, 1],
        [0, 1, 0],
        [1, 0, 1],
        [0, 0, 1],
        [1, 1, 0],
        [0, 1, 1]])
x_cat + categories_offset
tensor([[ 9, 16, 20],
        [ 9, 15, 19],
        [ 9, 16, 20],
        [ 9, 16, 19],
        [10, 15, 20],
        [ 9, 15, 20],
        [10, 16, 19],
        [ 9, 16, 20]])


tensor([[-0.0902],
        [-0.0463],
        [-0.0528],
        [-0.0579],
        [-0.0432],
        [-0.0020],
        [-0.0780],
        [-0.0400]], grad_fn=<AddmmBackward0>)

In [81]:
# [special tokens,cat cardinality 1, cat cardinality 2 etc.]
x_cat = torch.randint(0, 3, (batch_size, num_features_cat)).to(
            device
        )

In [108]:
class ColumnEmbedding(nn.Module):

  def __init__(self, cardinalities: list[int], d_token: int, dropout: float = 0.0, bias=False):

    super().__init__()
  
    assert cardinalities, "cardinalities must be non-empty"
    assert d_token > 0, "d_token must be positive"

    self.dropout = nn.Dropout(p=dropout)

    # embeddings for every class in every column
    category_offsets = torch.tensor([0] + cardinalities[:-1]).cumsum(0)
    self.register_buffer("category_offsets", category_offsets, persistent=False)
    self.ie = nn.Embedding(sum(cardinalities), d_token)
    
    # embeddings for entire column
    self.se = nn.Parameter(torch.empty(len(cardinalities), d_token).uniform_(-1, 1)
    )

    self.bias = (
          nn.Parameter(torch.Tensor(len(cardinalities), d_token)._zero()) if bias else None
    )

  def forward(self, x: Tensor) -> Tensor:
    # [num_cat_columns, d_model] (+) [batch_size, num_cat_]
    x = self.ie(x + self.category_offsets[None])

    # dim [batch_size, num_cat_cols, d_token] + [num_cat_cols, d_token]
    # add elemnt-wisely
    x = x + self.se
    

    # add bias term, not part of paper but works in Gorishny
    if self.bias is not None:
        x = x + self.bias[None]

    # add dropout, not part of paper, but could work
    return self.dropout(x)

In [109]:
col_embed = ColumnEmbedding([6,4,5], 3)(x_cat)

In [82]:
x_cat

tensor([[2, 2, 1],
        [0, 0, 0],
        [2, 2, 1],
        [0, 0, 2],
        [0, 0, 0],
        [2, 0, 0],
        [2, 2, 2],
        [1, 0, 1]])

In [84]:
col_embed

tensor([[[-0.9568,  0.2226,  0.5660],
         [-0.3934, -0.9328,  1.0054],
         [ 0.7568, -0.7914, -1.3837]],

        [[-0.7228,  0.0340,  0.6935],
         [ 0.2731, -1.3477, -0.9526],
         [ 0.0608, -0.1256, -0.5442]],

        [[-0.9568,  0.2226,  0.5660],
         [-0.3934, -0.9328,  1.0054],
         [ 0.7568, -0.7914, -1.3837]],

        [[-0.7228,  0.0340,  0.6935],
         [ 0.2731, -1.3477, -0.9526],
         [ 2.1985, -0.4491, -1.1389]],

        [[-0.7228,  0.0340,  0.6935],
         [ 0.2731, -1.3477, -0.9526],
         [ 0.0608, -0.1256, -0.5442]],

        [[-0.9568,  0.2226,  0.5660],
         [ 0.2731, -1.3477, -0.9526],
         [ 0.0608, -0.1256, -0.5442]],

        [[-0.9568,  0.2226,  0.5660],
         [-0.3934, -0.9328,  1.0054],
         [ 2.1985, -0.4491, -1.1389]],

        [[ 1.8193, -0.0998,  0.6289],
         [ 0.2731, -1.3477, -0.9526],
         [ 0.7568, -0.7914, -1.3837]]], grad_fn=<AddBackward0>)

In [112]:
import math
from typing import Tuple

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

class Transformer(nn.Module):

    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
                 nlayers: int,cat_cardinalities, dropout: float = 0.5,):
        super().__init__()
        self.model_type = 'Transformer'
        self.col_embedding = ColumnEmbedding(cat_cardinalities, d_model)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout,
                                                 batch_first=True, norm_first=False)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)


    def forward(self, x_cat: Tensor) -> Tensor:
        """
        Args:
            src: Tensor, shape [seq_len, batch_size]
            src_mask: Tensor, shape [seq_len, seq_len]

        Returns:
            output Tensor of shape [seq_len, batch_size, ntoken]
        """
        x_embed = self.col_embedding(x_cat)
        output = self.transformer_encoder(x_embed)
        # output = self.decoder(output)
        return output

In [90]:
x_cat.shape

torch.Size([8, 3])

In [113]:
out = Transformer(ntoken=3,d_model=32,nhead=2,d_hid=4, nlayers=1, cat_cardinalities=[5,5,5])(x_cat)

In [114]:
out

tensor([[[ 8.8864e-01,  9.4383e-02,  5.2546e-01,  7.2947e-01,  1.2627e-01,
          -2.4456e-01,  7.6873e-01,  9.9869e-01, -5.6915e-01,  1.2840e+00,
           1.5431e+00, -1.8993e-01,  5.8414e-01, -1.9980e-01,  3.8460e-01,
          -6.0572e-02, -7.0591e-01,  5.0290e-01, -2.3723e-01, -1.6604e+00,
          -9.8042e-01,  1.2650e+00,  1.6362e+00, -2.4667e+00, -1.3068e+00,
           5.9329e-01, -1.6280e+00,  1.0104e-01,  4.0894e-01,  5.2002e-02,
          -1.9921e+00, -2.4524e-01],
         [ 1.1112e+00, -4.7838e-01,  8.2323e-01,  6.4713e-01, -1.1281e+00,
           2.1425e+00, -1.7950e-01,  1.2102e+00, -1.9020e-01, -1.4245e+00,
          -3.8994e-01, -2.5254e-01,  6.0541e-01,  7.1715e-01,  3.0771e-01,
          -4.6271e-01, -8.0492e-02, -8.3950e-01,  5.2788e-01, -2.3134e+00,
           7.0031e-01,  5.2626e-01,  1.5085e+00,  7.8332e-01, -4.5750e-01,
          -9.0921e-01, -1.5029e+00,  9.3786e-01,  7.1906e-01, -1.6726e+00,
          -1.0569e+00,  7.0625e-02],
         [ 3.8769e-02, -2.

In [115]:
out.flatten(1)

tensor([[ 8.8864e-01,  9.4383e-02,  5.2546e-01,  7.2947e-01,  1.2627e-01,
         -2.4456e-01,  7.6873e-01,  9.9869e-01, -5.6915e-01,  1.2840e+00,
          1.5431e+00, -1.8993e-01,  5.8414e-01, -1.9980e-01,  3.8460e-01,
         -6.0572e-02, -7.0591e-01,  5.0290e-01, -2.3723e-01, -1.6604e+00,
         -9.8042e-01,  1.2650e+00,  1.6362e+00, -2.4667e+00, -1.3068e+00,
          5.9329e-01, -1.6280e+00,  1.0104e-01,  4.0894e-01,  5.2002e-02,
         -1.9921e+00, -2.4524e-01,  1.1112e+00, -4.7838e-01,  8.2323e-01,
          6.4713e-01, -1.1281e+00,  2.1425e+00, -1.7950e-01,  1.2102e+00,
         -1.9020e-01, -1.4245e+00, -3.8994e-01, -2.5254e-01,  6.0541e-01,
          7.1715e-01,  3.0771e-01, -4.6271e-01, -8.0492e-02, -8.3950e-01,
          5.2788e-01, -2.3134e+00,  7.0031e-01,  5.2626e-01,  1.5085e+00,
          7.8332e-01, -4.5750e-01, -9.0921e-01, -1.5029e+00,  9.3786e-01,
          7.1906e-01, -1.6726e+00, -1.0569e+00,  7.0625e-02,  3.8769e-02,
         -2.0612e-01, -2.4222e-01, -5.

In [None]:
class TabTransformer(nn.Module):
    """
    PyTorch model of TabTransformer.
    Based on paper:
    https://arxiv.org/abs/2012.06678
    Args:
        nn (nn.Module): Module with implementation of TabTransformer.
    """

    def __init__(
        self,
        *,
        cat_cardinalities: tuple[int, ...] | tuple[()],
        num_continuous: int,
        dim: int = 32,
        depth: int = 4,
        heads: int = 8,
        dim_head: int = 16,
        dim_out: int = 1,
        mlp_hidden_mults: tuple[(int, int)] = (4, 2),
        mlp_act: Callable[..., nn.Module] = nn.ReLU,
        num_special_tokens: int = 9,
        continuous_mean_std: torch.Tensor | None = None,
        attn_dropout: float = 0.0,
        ff_dropout: float = 0.0,
        **kwargs: Any,
    ):
        """
        TabTransformer.
        Originally introduced in https://arxiv.org/abs/2012.06678.
        Args:
            cat_cardinalities ([List[int] | Tuple[()]): List with number of categories
            for each categorical feature. If no categorical variables are present,
            use empty tuple. For categorical variables e. g., option type ('C' or 'P'),
            the list would be `[1]`.
            num_continuous (int): Number of continous features.
            dim (int, optional): Dimensionality of transformer. Defaults to 32.
            depth (int, optional): Depth of encoder / decoder of transformer.
            Defaults to 4.
            heads (int, optional): Number of attention heads. Defaults to 8.
            dim_head (int, optional): Dimensionality of attention head. Defaults to 16.
            dim_out (int, optional): Dimension of output layer of MLP. Set to one for
            binary classification. Defaults to 1.
            mlp_hidden_mults (Tuple[(int, int)], optional): multipliers for dimensions
            of hidden layer in MLP. Defaults to (4, 2).
            mlp_act (Callable[..., nn.Module], optional): Activation function used
            in MLP. Defaults to nn.ReLU().
            num_special_tokens (int, optional): Number of special tokens in transformer.
            Defaults to 2.
            continuous_mean_std (torch.Tensor | None): List with mean and
            std deviation of each continous feature. Shape eq. `[num_continous x 2]`.
            Defaults to None.
            attn_dropout (float, optional): Degree of attention dropout used in
            transformer. Defaults to 0.0.
            ff_dropout (float, optional): Dropout in feed forward net. Defaults to 0.0.
        """
        super().__init__()
        assert all(
            map(lambda n: n > 0, cat_cardinalities)
        ), "number of each category must be positive"

        # categories related calculations

        self.num_categories = len(cat_cardinalities)
        self.cardinality_categories = sum(cat_cardinalities)

        # create category embeddings table

        self.num_special_tokens = num_special_tokens
        total_tokens = self.cardinality_categories + num_special_tokens

        # for automatically offsetting unique category ids to the correct position
        #  in the categories embedding table

        categories_offset = F.pad(
            torch.tensor(list(cat_cardinalities)), (1, 0), value=num_special_tokens
        )  # Prepend num_special_tokens.
        print("categories_offset")
        print(categories_offset)
        categories_offset = categories_offset.cumsum(dim=-1)[:-1]
        self.register_buffer("categories_offset", categories_offset)
        print("categories_offset (cum sum)")
        print(categories_offset)

        # continuous

        if continuous_mean_std is not None:
            assert continuous_mean_std.shape == (num_continuous, 2,), (
                f"continuous_mean_std must have a shape of ({num_continuous}, 2)"
                f"where the last dimension contains the mean and variance respectively"
            )
        self.register_buffer("continuous_mean_std", continuous_mean_std)

        self.norm = nn.LayerNorm(num_continuous)
        self.num_continuous = num_continuous

        # transformer
        self.transformer = Transformer(ntoken=total_tokens,d_model=dim, nhead=heads,d_hid=dim * 4, nlayers=depth,)

        self.transformer = Transformer(
            num_tokens=total_tokens,
            dim=dim,
            depth=depth,
            heads=heads,
            dim_head=dim_head,
            attn_dropout=attn_dropout,
            ff_dropout=ff_dropout,
        )

        # mlp to logits
        input_size = (dim * self.num_categories) + num_continuous
        j = input_size // 8

        hidden_dimensions = list(map(lambda t: j * t, mlp_hidden_mults))
        all_dimensions = [input_size, *hidden_dimensions, dim_out]

        self.mlp = MLP(all_dimensions, act=mlp_act)

    def forward(self, x_cat: torch.Tensor | None, x_cont: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of TabTransformer.
        Args:
            x_cat (torch.Tensor | None): tensor with categorical data.
            x_cont (torch.Tensor): tensor with continous data.
        Returns:
            torch.Tensor: probabilities
        """
        flat_categ: torch.Tensor | None = None

        # print("x_cat")
        # print(x_cat)

        # print(x_cont)

        if x_cat is not None:
            assert x_cat.shape[-1] == self.num_categories, (
                f"you must pass in {self.num_categories} "
                f"values for your categories input"
            )
            x = self.transformer(x_cat)
            flat_categ = x.flatten(1)

        assert x_cont.shape[1] == self.num_continuous, (
            f"you must pass in {self.num_continuous} "
            f"values for your continuous input"
        )

        if self.continuous_mean_std is not None:
            mean, std = self.continuous_mean_std.unbind(dim=-1)  # type: ignore
            x_cont = (x_cont - mean) / std

        normed_cont = self.norm(x_cont)

        # Adaptation to work without categorical data
        x = (
            torch.cat((flat_categ, normed_cont), dim=-1)
            if flat_categ is not None
            else normed_cont
        )

        return self.mlp(x)