In [None]:
import sys

import numpy as np
import pandas as pd
import scipy.stats
import seaborn as sns
import shap
import sklearn
import torch
from catboost import CatBoostClassifier, Pool
from sklearn.inspection import permutation_importance
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from torch import nn

sys.path.append("..")
from otc.models.classical_classifier import ClassicalClassifier
from otc.models.fttransformer import FTTransformer
from otc.models.tabtransformer import TabTransformer
from otc.models.transformer_classifier import TransformerClassifier

shap.initjs()


In [None]:
X_train = pd.DataFrame(
    [[1, 2, 2, 5], [1, 2, 3, 4], [1, 3, 4, 5]],
    columns=["TRADE_PRICE", "ask_ex", "bid_ex", "unrelated"],
)
y_train = pd.Series([-1, 1, -1])
X_test = pd.DataFrame(
    [
        [1, 1, 3, 9],
        [1, 1, 3, 7],
        [1, 1, 3, 4],
        [3, 1, 3],
        [1, 1, 1],
        [3, 2, 4],
        [1, np.nan, 1],
        [3, np.nan, np.nan],
    ],
    columns=["TRADE_PRICE", "ask_ex", "bid_ex", "unrelated"],
)
y_test = pd.Series([-1, -1, -1, 1, 1, -1, -1, 1])


In [None]:
classical = ClassicalClassifier(layers=[("quote", "ex")], random_state=45)
catboost = CatBoostClassifier(logging_level="Silent")

dl_params = {
    "batch_size": 8,
    "shuffle": False,
    "device": "cpu",
}

module_params = {
    "depth": 1,
    "heads": 2,
    "dim": 2,
    "dim_out": 1,
    "mlp_act": nn.ReLU,
    "mlp_hidden_mults": (4, 2),
    "attn_dropout": 0.5,
    "ff_dropout": 0.5,
    "cat_features": [],
    "cat_cardinalities": (),
    "num_continuous": X_train.shape[1],
}

optim_params = {"lr": 0.1, "weight_decay": 1e-3}

# transformer = TransformerClassifier(
#     module=TabTransformer,  # type: ignore
#     module_params=module_params,
#     optim_params=optim_params,
#     dl_params=dl_params,
#     callbacks=CallbackContainer([]),
# )
# transformer.epochs = 5


In [None]:
models = [classical, catboost]

for model in models:
    model.fit(X_train, y_train)


In [None]:
for model in models:
    # shap values with kernel explainer
    explainer = shap.KernelExplainer(model.predict_proba, X_train)
    shap_values = explainer.shap_values(X_test)
    shap.summary_plot(shap_values[0], X_test, plot_type="bar")


## Compare different calculation methods for CatBoost 🐈

In [None]:
X_train, X_test, y_train, y_test = train_test_split(
    *shap.datasets.iris(), test_size=0.2, random_state=0
)
shap.initjs()

model = CatBoostClassifier()
model.fit(X_train, y_train)
print(accuracy_score(y_test, model.predict(X_test)))
print(model.predict_proba(X_test))


# shap values with kernel explainer
explainer = shap.KernelExplainer(model.predict_proba, X_train)
shap_values = explainer.shap_values(X_test)
shap.summary_plot(shap_values[0], X_test, plot_type="bar")


In [None]:
# shap values with tree explainer
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)
shap.summary_plot(shap_values[0], X_test, plot_type="bar")


In [None]:
# see https://catboost.ai/en/docs/concepts/shap-values
shap_values = model.get_feature_importance(data=Pool(X_test, y_test), type="ShapValues")
# shape (observations, features + 1 * expected_value)shap_values = model.get_feature_importance(data=Pool(X_test, y_test), type="ShapValues")
shap.summary_plot(shap_values[:, 0, :-1], X_test, plot_type="bar")


In [None]:
# similar to random feature permutation
# https://catboost.ai/en/docs/concepts/fstr#regular-feature-importance
model.get_feature_importance(
    data=Pool(X_test, y_test), type="FeatureImportance", prettified=True
)


In [None]:
# random feature permutation sklearn
r = permutation_importance(model, X_test, y_test, n_repeats=30, random_state=0)
# results are average; obviously not normalized to one.
for i in r.importances_mean.argsort()[::-1]:
    print(
        f"{X_train.columns[i]}"
        f"{r.importances_mean[i]:.3f}"
        f" +/- {r.importances_std[i]:.3f}"
    )


## Attention Maps for Transformers

We calculate the average attention map from all transformer blocks, as done in the Gorishniy paper (see [here](https://github.com/Yura52/tabular-dl-revisiting-models/issues/2)). This is different from the Borisov paper (see [here](https://github.com/kathrinse/TabSurvey/blob/main/models/basemodel_torch.py)).

In [None]:
import sys
from typing import List

import pandas as pd
import seaborn as sns
import scipy.stats
import torch
from torch import nn

sys.path.append("..")
from otc.models.tabtransformer import TabTransformer


In [None]:
num_features_cont = 5
num_features_cat = 3
num_unique_cat = tuple([2, 2, 2])
batch_size = 64

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

x_cat = torch.randint(0, 1, (batch_size, num_features_cat)).to(device)
x_cont = torch.randn(batch_size, num_features_cont).float().to(device)
expected_outputs = torch.randint(0, 1, (batch_size, 1)).float().to(device)

model = TabTransformer(
    cat_cardinalities=num_unique_cat,
    num_continuous=num_features_cont,
    dim_out=1,
    mlp_act=nn.ReLU,
    dim=32,
    depth=2,
    heads=6,
    attn_dropout=0.1,
    ff_dropout=0.1,
    mlp_hidden_mults=(4, 2),
).to(device)


In [None]:
class SaveAttentionMaps:
    """
    Hook for attention maps.

    Inspired by:
    https://github.com/Yura52/tabular-dl-revisiting-models/issues/2#issuecomment-1068123629
    """

    def __init__(self):
        self.attention_maps: List[torch.Tensor] = []

    def __call__(self, _, __, output):
        print(output[1]["attention_probs"].shape)
        self.attention_maps.append(output[1]["attention_probs"])


In [None]:
# The following hook will save all attention maps from all attention modules.
hook = SaveAttentionMaps()
for block in model.transformer.blocks:
    block.attention.fn.fn.register_forward_hook(hook)

# Apply the model to all objects.
model.eval()
with torch.inference_mode():
    model(x_cat.clone(), x_cont.clone())

# Collect attention maps
n_objects = len(x_cat)
n_blocks = len(model.transformer.blocks)
n_heads = model.transformer.blocks[0].attention.fn.fn.n_heads

attention_maps = torch.cat(hook.attention_maps)

# Calculate feature importance and ranks.
attention_maps = attention_maps.reshape(
    n_objects * n_blocks * n_heads, num_features_cat, num_features_cat
)
assert attention_maps.shape == (
    n_objects * n_blocks * n_heads,
    num_features_cat,
    num_features_cat,
)

# Calculate feature importance and ranks.
average_attention_map = attention_maps.mean(0)
feature_importance = average_attention_map[-1]

feature_importance = feature_importance.cpu().numpy()
feature_ranks = scipy.stats.rankdata(-feature_importance)
feature_indices_sorted_by_importance = feature_importance.argsort()[::-1]

print(feature_importance)
print(feature_ranks)
print(feature_indices_sorted_by_importance)


In [None]:
ax = sns.barplot(x=feature_importance, y=["f1", "f2", "f3"])
ax.set(xlim=(0, 1))


In [None]:
from otc.models.activation import ReGLU
from otc.models.fttransformer import (
    CategoricalFeatureTokenizer,
    CLSToken,
    FeatureTokenizer,
    FTTransformer,
    MultiheadAttention,
    NumericalFeatureTokenizer,
    Transformer,
)

num_features_cont = 5
num_features_cat = 1
cat_cardinalities = [2]
batch_size = 64

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

x_cat = torch.randint(0, 1, (batch_size, num_features_cat)).to(device)
x_cont = torch.randn(batch_size, num_features_cont).float().to(device)
expected_outputs = torch.randint(0, 1, (batch_size, 1)).float().to(device)

params_feature_tokenizer = {
    "num_continous": num_features_cont,
    "cat_cardinalities": cat_cardinalities,
    "d_token": 96,
}
feature_tokenizer = FeatureTokenizer(**params_feature_tokenizer)
params_transformer = {
    "d_token": 96,
    "n_blocks": 3,
    "attention_n_heads": 8,
    "attention_initialization": "kaiming",
    "ffn_activation": ReGLU,
    "attention_normalization": nn.LayerNorm,
    "ffn_normalization": nn.LayerNorm,
    "ffn_dropout": 0.1,
    "ffn_d_hidden": 96 * 2,
    "attention_dropout": 0.1,
    "residual_dropout": 0.1,
    "prenormalization": True,
    "first_prenormalization": False,
    "last_layer_query_idx": None,
    "n_tokens": None,
    "kv_compression_ratio": None,
    "kv_compression_sharing": None,
    "head_activation": nn.ReLU,
    "head_normalization": nn.LayerNorm,
    "d_out": 1,
}

transformer = Transformer(**params_transformer)

model = FTTransformer(feature_tokenizer, transformer).to(device)


In [None]:
# Prepare data and model.
n_objects = len(x_cat)  # 12
n_features = num_features_cont + num_features_cat

# The following hook will save all attention maps from all attention modules.
hook = SaveAttentionMaps()
for block in model.transformer.blocks:
    block.attention.register_forward_hook(hook)

# Apply the model to all objects.
model.eval()
with torch.inference_mode():
    model(x_cat, x_cont)

# Collect attention maps
n_blocks = len(model.transformer.blocks)
n_heads = model.transformer.blocks[0].attention.n_heads
n_tokens = n_features + 1
attention_maps = torch.cat(hook.attention_maps)
assert attention_maps.shape == (n_objects * n_blocks * n_heads, n_tokens, n_tokens)

# Calculate feature importance and ranks.
average_attention_map = attention_maps.mean(0)
average_cls_attention_map = average_attention_map[-1]  # consider only the [CLS] token
feature_importance = average_cls_attention_map[:-1]  # drop the [CLS] token importance
assert feature_importance.shape == (n_features,)

feature_importance = feature_importance.cpu().numpy()
feature_ranks = scipy.stats.rankdata(-feature_importance)
feature_indices_sorted_by_importance = feature_importance.argsort()[::-1]

print(feature_importance)
print(feature_ranks)
print(feature_indices_sorted_by_importance)


In [None]:
ax = sns.barplot(x=feature_importance, y=["f1", "f2", "f3", "f4", "f5", "f6"])
ax.set(xlim=(0, 1))
