## Keras Implementation of TabNet

TabNet is proposed in [this article] (https://arxiv.org/abs/1908.07442) as a neuralnetwork architecture capable of learning a canonical representation of tabular data. This architecture has shown to perform well against the current gold-standard gradient boosting models for learning on tabular data.

**Taken**

This implementation closely follows [the TabNet implementation in PyTorch linked here](https://github.com/dreamquark-ai/tabnet/tree/b6e1ebaf694f37ad40a6ba525aa016fd3cec15da). The description of that implementation is [explained in this helpful video by Sebastian Fischman](https://www.youtube.com/watch?v=ysBaZO8YmX8).

<img src="images/tabnet_schematic.jpg" width="700" height="500" align="center"/>

In [3]:
import multiprocessing as mp
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow_addons.activations import sparsemax

import global_variables as gv
import utilities

### Step 1. Unsupervised pre-training (imputation)

In [2]:
# import preprocessed dataset 
  
df = pd.read_csv(gv.tabnet_data)
df.drop('Unnamed: 0', axis=1, inplace=True)

In [4]:
df

Unnamed: 0,30850-0.0,30780-0.0,30690-0.0,30790-0.0,23101-0.0,23099-0.0,48-0.0,23100-0.0,30710-0.0,30760-0.0,...,1448-0.0.1,outcome_myocardial_infarction,outcome_cardiomyopathies,outcome_ischemic_heart_disease,outcome_heart_failure,hypertension.1,outcome_peripheral_vascular_disease,outcome_cardiac_arrest,outcome_cerebral_infarction,outcome_arrhythmia
0,0.508,3.888,6.477,,45.2,35.6,74.0,25.0,0.34,1.706,...,3.0,0,0,0,0,0,0,0,0,1
1,13.088,3.520,5.512,15.40,74.6,36.5,120.0,42.9,3.94,1.173,...,-1.0,1,0,1,0,1,0,0,0,0
2,4.675,3.041,5.028,,79.6,28.5,110.0,31.7,0.45,1.169,...,3.0,0,0,0,0,0,0,0,0,0
3,,,,,71.7,29.7,112.0,30.3,,,...,3.0,0,0,1,0,1,0,1,1,1
4,1.788,2.887,5.565,,40.2,29.8,67.0,17.0,0.87,2.115,...,3.0,0,0,0,0,0,0,0,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
101724,,3.875,6.190,169.20,46.9,35.8,94.0,26.2,3.83,1.008,...,3.0,0,0,0,0,1,0,0,1,1
101725,9.036,2.467,4.035,,66.3,36.9,114.0,38.7,2.24,1.087,...,2.0,0,0,0,1,1,0,0,0,0
101726,0.485,3.802,6.507,,41.6,37.1,82.0,24.5,0.52,1.857,...,3.0,0,0,1,0,0,0,0,0,0
101727,0.725,2.863,4.664,5.09,44.9,46.3,102.0,38.7,2.75,1.159,...,3.0,0,0,1,0,1,0,0,0,0



### Step 2. Supervised Fine Tuning

In [None]:
df = pd.read_csv(gv.data_link)
pd.set_option('display.max_columns', None)
df.drop('Unnamed: 0', axis=1, inplace=True)
df.head()

#### fully connected block

In [None]:
def GLU(x):
    return x * tf.sigmoid(x)

class FCBlock(layers.Layer):
    def __init__(self, units):
        super().__init__()
        self.layer = layers.Dense(units)
        self.bn = layers.BatchNormalization()

    def call(self, x):
        return GLU(self.bn(self.layer(x)))

#### Feature Transformer 

In [None]:
class SharedBlock(layers.Layer):
    def __init__(self, units, mult=tf.sqrt(0.5)):
        super().__init__()
        self.layer1 = FCBlock(units)
        self.layer2 = FCBlock(units)
        self.mult = mult

    def call(self, x):
        out1 = self.layer1(x)
        out2 = self.layer2(out1)
        return out2 + self.mult * out1

class DecisionBlock(SharedBlock):
    def __init__(self, units, mult=tf.sqrt(0.5)):
        super().__init__(units, mult)

    def call(self, x):
        out1 = x * self.mult + self.layer1(x)
        out2 = out1 * self.mult + self.layer2(out1)
        return out2

#### Attentive Transformer

In [None]:
class Prior(layers.Layer):
    def __init__(self, gamma=1.1):
        super().__init__()
        self.gamma = gamma

    def reset(self):
        self.P = 1.0

    def call(self, mask):
        self.P = self.P * (self.gamma - mask)
        return self.P
    
class AttentiveTransformer(layers.Layer):
    def __init__(self, units):
        super().__init__()
        self.layer = layers.Dense(units)
        self.bn = layers.BatchNormalization()

    def call(self, x, prior):
        return sparsemax(prior * self.bn(self.layer(x)))