In [None]:
#|default_exp models.GatedTabTransformer

# GatedTabTransformer

This implementation is based on:

- Cholakov, R., & Kolev, T. (2022). <span style="color:dodgerblue">**The GatedTabTransformer. An enhanced deep learning architecture for tabular modeling**</span>. arXiv preprint arXiv:2201.00199. arXiv preprint https://arxiv.org/abs/2201.00199

- Huang, X., Khetan, A., Cvitkovic, M., & Karnin, Z. (2020). <span style="color:dodgerblue">**TabTransformer: Tabular Data Modeling Using Contextual Embeddings**</span>. arXiv preprint https://arxiv.org/pdf/2012.06678

Official repo: https://github.com/radi-cho/GatedTabTransformer

In [None]:
#|eval: false
#|hide
from tsai.imports import *

In [None]:
#|export
import torch
import torch.nn as nn

from tsai.models.TabTransformer import TabTransformer
from tsai.models.gMLP import gMLP

In [None]:
#|export
class _TabMLP(nn.Module):
    def __init__(self, classes, cont_names, c_out, d_model, mlp_d_model, mlp_d_ffn, mlp_layers):
        super().__init__()
        seq_len = d_model * len(classes) + len(cont_names)
        self.mlp = gMLP(1, c_out, seq_len, d_model=mlp_d_model, d_ffn=mlp_d_ffn, depth=mlp_layers)
    
    def forward(self, x):
        x = x.unsqueeze(1)
        return self.mlp(x)


class GatedTabTransformer(TabTransformer):
    def __init__(self, classes, cont_names, c_out, column_embed=True, add_shared_embed=False, shared_embed_div=8, embed_dropout=0.1, drop_whole_embed=False, 
                 d_model=32, n_layers=6, n_heads=8, d_k=None, d_v=None, d_ff=None, res_attention=True, attention_act='gelu', res_dropout=0.1, norm_cont=True,
                 mlp_d_model=32, mlp_d_ffn=64, mlp_layers=4):

        super().__init__(classes, cont_names, c_out, column_embed=column_embed, add_shared_embed=add_shared_embed, shared_embed_div=shared_embed_div,
                         embed_dropout=embed_dropout, drop_whole_embed=drop_whole_embed, d_model=d_model, n_layers=n_layers, n_heads=n_heads, d_k=d_k,
                         d_v=d_v, d_ff=d_ff, res_attention=res_attention, attention_act=attention_act, res_dropout=res_dropout, norm_cont=norm_cont)

        self.mlp = _TabMLP(classes, cont_names, c_out, d_model, mlp_d_model, mlp_d_ffn, mlp_layers)

In [None]:
from fastcore.test import test_eq
from fastcore.basics import first
from fastai.data.external import untar_data, URLs
from fastai.tabular.data import TabularDataLoaders
from fastai.tabular.core import Categorify, FillMissing
from fastai.data.transforms import Normalize
import pandas as pd

In [None]:
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, y_names="salary",
    cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race'],
    cont_names = ['age', 'fnlwgt', 'education-num'],
    procs = [Categorify, FillMissing, Normalize])
x_cat, x_cont, yb = first(dls.train)
model = GatedTabTransformer(dls.classes, dls.cont_names, dls.c)
test_eq(model(x_cat, x_cont).shape, (dls.train.bs, dls.c))

In [None]:
#|eval: false
#|hide
from tsai.export import get_nb_name; nb_name = get_nb_name(locals())
from tsai.imports import create_scripts; create_scripts(nb_name)

<IPython.core.display.Javascript object>

/Users/nacho/notebooks/tsai/nbs/121b_models.GatedTabTransformer.ipynb saved at 2022-11-09 13:12:26
Correct notebook to script conversion! 😃
Wednesday 09/11/22 13:12:28 CET
