# Imports

In [1]:
import tensorflow as tf
# from tensorflow.python.framework.ops import disable_eager_execution 
# disable_eager_execution()
# from tensorflow.python.framework.ops import enable_eager_execution
# enable_eager_execution()
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

In [2]:
# Import common tensorflow layers and activations
from tensorflow.keras.layers import Dense, Dropout, BatchNormalization, Layer
from tensorflow.keras.layers import Lambda, Multiply, Add, Rescaling
from tensorflow.keras.activations import relu, sigmoid, softmax
from tensorflow.keras import Model
from tensorflow.keras import Sequential
from tensorflow.keras import Input

# from local_tabnet import TabNet as LocalTabNet
from tabnet import TabNet as TFTabNet

# Tabnet experiment Hyperparams

In [3]:
BATCH_SIZE = 16384
LAMBDA = 0.0001
N_A = 64
N_D = 64
VIRTUAL_BATCH_SIZE = 512
BATCH_MOMENTUM = 0.7
N_STEPS = 5
GAMMA = 1.5
LEARNING_RATE = tf.keras.optimizers.schedules.ExponentialDecay(
    0.02,
    decay_steps=500,
    decay_rate=0.95
    )
OPTIMIZER = tf.keras.optimizers.Adam(LEARNING_RATE)

# Data Defn

In [19]:
# Load telematics data
training = pd.read_csv("./poker_hand/poker-hand-training-true.data")
data_test = pd.read_csv("./poker_hand/poker-hand-testing.data")

training.columns = [
    'S1', 'C1', 'S2', 'C2', 'S3', 'C3', 'S4', 'C4', 'S5', 'C5', 'Hand'
]

data_test.columns = training.columns

response = 'Hand'
covariates = [c for c in training.columns if c != response]

training[covariates] = training[covariates].astype('float')
data_test[covariates] = data_test[covariates].astype('float')

# Split data into train and test
data_train, data_val = train_test_split(training, test_size=0.3, random_state=42)

# Split into X and y
X_train = data_train.drop(response, axis=1, errors='ignore')
y_train = data_train[response]
X_val = data_val.drop(response, axis=1, errors='ignore')
y_val = data_val[response]
X_test = data_test.drop(response, axis=1, errors='ignore')
y_test = data_test[response]

# Make tensorflow datasets
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train[..., np.newaxis]))
train_dataset = train_dataset.shuffle(buffer_size=4048).batch(BATCH_SIZE, drop_remainder=True)
val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val[..., np.newaxis]))
val_dataset = val_dataset.batch(BATCH_SIZE, drop_remainder=True)
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test[..., np.newaxis]))
test_dataset = test_dataset.batch(BATCH_SIZE, drop_remainder=True)

  train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train[..., np.newaxis]))
  val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val[..., np.newaxis]))
  test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test[..., np.newaxis]))


# Model creation and training

In [20]:
# import tensorflow as tf
# import tabnet
# from importlib import reload
# reload(tabnet)

# from tabnet import TabNet

# online_implementation = TabNet(
#     feature_columns=None,
#     output_dim=64,
#     feature_dim=128,
#     num_features=X_train.shape[1],
#     num_decision_steps=5,
#     relaxation_factor=1.5,
#     sparsity_coefficient=0.0001,
#     virtual_batch_size=512,
#     norm_type="batch",
#     batch_momentum=0.7,
# )

# online_implementation = Sequential([
#     online_implementation,
#     Dense(1, activation="sigmoid")
# ])


# online_implementation.compile(
#     optimizer=tf.keras.optimizers.Adam(
#         learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
#             initial_learning_rate=0.02,
#             decay_steps=500,
#             decay_rate=0.95
#         ),
#     ),
#     loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),
#     metrics=[
#         tf.keras.metrics.AUC(from_logits=False, name="auc"),
#         tf.keras.metrics.BinaryAccuracy(name="binary_accuracy"),
#         tf.keras.metrics.Precision(name="precision", thresholds=0.5),
#         tf.keras.metrics.Recall(name="recall", thresholds=0.5)
#     ]
# )


In [21]:
# sample = next(iter(train_dataset))
# x = np.arange(len(sample[0].numpy()))
# online_implementation(sample[0])

In [22]:
class CategoryEmebddingShimLayer(Layer):
    def __init__(
            self,
            cat_idxs=None,
            num_cats=None,
            embed_dim=1, # TODO: make arbitrary sized embedding dim
            **kwargs):
        super(CategoryEmebddingShimLayer, self).__init__(**kwargs)
        self.cat_idxs = cat_idxs
        self.num_cats = num_cats
        if isinstance(embed_dim, list):
            assert len(embed_dim) == len(cat_idxs), f"embed_dim {len(embed_dim)} must be same length as cat_idxs {len(cat_idxs)}"
        self.embed_dim = 1
        self.embeddings = []
        assert len(cat_idxs) == len(num_cats), f"cat_idxs {len(cat_idxs)} must be same length as num_cats {len(num_cats)}"
    
    def build(self, input_shape):
        for i, nrow in enumerate(self.num_cats):
            if isinstance(self.embed_dim, list):
                ncol = self.embed_dim[i]
            else:
                ncol = self.embed_dim

            embedding = self.add_weight(
                shape=(nrow, ncol),
                initializer="uniform",
                trainable=True,
                name=f"embedding_{i}"
            )
            self.embeddings.append(embedding)
        super(CategoryEmebddingShimLayer, self).build(input_shape)

    def call(self, inputs):
        x = inputs # (B,D) - float/int mix
        for i, cat_idx in enumerate(self.cat_idxs):
            x_cat = tf.gather(x, cat_idx, axis=1) # (B,1) - int
            x_cat = tf.nn.embedding_lookup(self.embeddings[i], tf.cast(x_cat-1, tf.int32)) # (B,E) - float
            x = tf.concat([x[:, :cat_idx], x_cat, x[:, cat_idx+1:]], axis=1)
        return x

    def compute_output_shape(self, input_shape):
        # Calculate new shape according to cat idxs and num cats in each
        new_shape = input_shape[1]
        for i, _ in enumerate(self.cat_idxs):
            new_shape -= 1
            if isinstance(self.embed_dim, list):
                new_shape += self.embed_dim[i]
            else:
                new_shape += self.embed_dim
        return (input_shape[0], new_shape)
    


In [23]:
# Test the layer
cat_idxs = [0, 1, 2, 3, 4]
num_cats = [4, 13, 4, 13, 4]

layer = CategoryEmebddingShimLayer(cat_idxs, num_cats, embed_dim=1)
layer.build((None, 10))
dummy_shape = (7, 10)
dummy_in = tf.ones(dummy_shape, dtype=tf.float32)
dummy_out = layer(dummy_in)
assert dummy_out.shape == dummy_shape
expected_out = []
for i, idx in enumerate(cat_idxs):
    expected_out.append(layer.embeddings[i][0])

dummy_out,  tf.concat(expected_out, axis=-1)

(<tf.Tensor: shape=(7, 10), dtype=float32, numpy=
 array([[ 0.11982822, -0.03665085,  0.01150975,  0.07373473, -0.06221407,
          1.        ,  1.        ,  1.        ,  1.        ,  1.        ],
        [ 0.11982822, -0.03665085,  0.01150975,  0.07373473, -0.06221407,
          1.        ,  1.        ,  1.        ,  1.        ,  1.        ],
        [ 0.11982822, -0.03665085,  0.01150975,  0.07373473, -0.06221407,
          1.        ,  1.        ,  1.        ,  1.        ,  1.        ],
        [ 0.11982822, -0.03665085,  0.01150975,  0.07373473, -0.06221407,
          1.        ,  1.        ,  1.        ,  1.        ,  1.        ],
        [ 0.11982822, -0.03665085,  0.01150975,  0.07373473, -0.06221407,
          1.        ,  1.        ,  1.        ,  1.        ,  1.        ],
        [ 0.11982822, -0.03665085,  0.01150975,  0.07373473, -0.06221407,
          1.        ,  1.        ,  1.        ,  1.        ,  1.        ],
        [ 0.11982822, -0.03665085,  0.01150975,  0.07373

In [28]:
real_num_cats = [training[col].nunique() for col in training.columns if col != response]

local_implementation = LocalTabNet(
    dim_features=X_train.shape[1],
    dim_attention=16,
    dim_output=10,
    output_activation='softmax',
    sparsity=1e-7,
    num_steps=4,
    gamma=1.5,
    preprocess_layers=CategoryEmebddingShimLayer(np.arange(X_train.shape[1]), real_num_cats, embed_dim=1),
    )

local_implementation.compile( 
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), 
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    metrics=[
        tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
        ]
)
local_implementation.build(X_train[:1].shape)

In [29]:
local_implementation.summary()

Model: "tab_net_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
category_emebdding_shim_laye multiple                  85        
_________________________________________________________________
shared_feature_layer (Shared multiple                  5888      
_________________________________________________________________
feat_0 (FeatureTransformer)  multiple                  12480     
_________________________________________________________________
feat_1 (FeatureTransformer)  multiple                  12480     
_________________________________________________________________
feat_2 (FeatureTransformer)  multiple                  12480     
_________________________________________________________________
feat_3 (FeatureTransformer)  multiple                  12480     
_________________________________________________________________
feat_4 (FeatureTransformer)  multiple                  12

In [26]:
early_stopping = tf.keras.callbacks.EarlyStopping(patience=30, restore_best_weights=True)
lr_on_plateau = tf.keras.callbacks.ReduceLROnPlateau(patience=3, factor=0.5, min_lr=1e-5)
history = local_implementation.fit(train_dataset, epochs=1000, validation_data=val_dataset, callbacks=[early_stopping, lr_on_plateau])

Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000
Epoch 57/1000
Epoch 58/1000
Epoch 59/1000
Epoch 60/1000
Epoch 61/1000
Epoch 62/1000
Epoch 63/1000
Epoch 64/1000
Epoch 65/1000
Epoch 66/1000
Epoch 67/1000
Epoch 68/1000
Epoch 69/1000
Epoch 70/1000
Epoch 71/1000
Epoch 72/1000
E

In [None]:
sample = next(iter(train_dataset))
x = np.arange(len(sample[0].numpy()))
y_pred = online_implementation(sample[0]).numpy().flatten()
y_true = sample[1].numpy().flatten()
sample_plot = 200
plt.scatter(x[:sample_plot], y_true[:sample_plot], label="True", alpha=0.5)
plt.scatter(x[:sample_plot], y_pred[:sample_plot], label="Predictions", marker="x")
plt.legend()
tf.keras.metrics.Recall(name="recall", thresholds=0)

In [None]:
# Plot history loss and RMSE for training and validation set; train solid line, validation dashed line
fig, (top_ax, bottom_ax) = plt.subplots(2, 1, figsize=(10, 10), sharex=True)

hist = local_implementation.history.history

top_ax.plot(hist['loss'], label='train_loss', c='b')
top_ax.plot(hist['val_loss'], label='val_loss', linestyle='--', c='b')
second_ax = top_ax.twinx()
# Plot precision and recall on second axis in orange and red respectively
second_ax.plot(hist['precision'], label='train_precision', c='orange')
second_ax.plot(hist['val_precision'], label='val_precision', linestyle='--', c='orange')
second_ax.plot(hist['recall'], label='train_recall', c='r')
second_ax.plot(hist['val_recall'], label='val_recall', linestyle='--', c='r')
top_ax.set_ylabel('loss')
second_ax.set_ylabel('Precision/Recall')
# Merge top ax legend entries
handles, labels = top_ax.get_legend_handles_labels()
handles2, labels2 = second_ax.get_legend_handles_labels()
top_ax.legend(handles + handles2, labels + labels2)


bottom_ax.plot(hist['lr'], label='lr', c='g')
bottom_ax.set_xlabel('Epoch')
bottom_ax.set_ylabel('Learning rate')

# New plot with auc and accuracy
fig, ax = plt.subplots()
ax.plot(hist['auc'], label='train_auc', c='orange')
ax.plot(hist['val_auc'], label='val_auc', linestyle='--', c='orange')
twinax = ax.twinx()
twinax.plot(hist['binary_accuracy'], label='train_accuracy', c='r')
twinax.plot(hist['val_binary_accuracy'], label='val_accuracy', linestyle='--', c='r')
ax.set_xlabel('Epoch')
ax.set_ylabel('AUC')
twinax.set_ylabel('Accuracy')
# Merge legend entries
handles, labels = ax.get_legend_handles_labels()
twinhandles, twinlabels = twinax.get_legend_handles_labels()
ax.legend(handles + twinhandles, labels + twinlabels)




In [None]:
# Evaluate model on test set
online_implementation.evaluate(test_dataset)