Do custom install of `sage-importance`

```bash
git clone https://github.com/karelze/sage.git
cd sage
pip install .
```

In [7]:
import os
import sys
import pickle
from pathlib import Path

from typing import List

from catboost import CatBoostClassifier

import numpy as np
import pandas as pd
import scipy.stats
import seaborn as sns
import sklearn
from sklearn.metrics import accuracy_score
from sklearn.ensemble import HistGradientBoostingClassifier
from torch import nn

sys.path.append("..")
from otc.models.classical_classifier import ClassicalClassifier

from sage import GroupedMarginalImputer, PermutationEstimator

from otc.features.build_features import (
    features_categorical,
    features_classical,
    features_classical_size,
    features_ml,
)

import wandb
from tqdm.auto import tqdm

In [None]:
SEED = 42

# set globally here
EXCHANGE = "ise"  
STRATEGY = "supervised"  
SUBSET = "test"  

# Change depending on model!
FEATURES = features_classical_size

In [None]:
# set project name. Required to access files and artefacts
os.environ["GCLOUD_PROJECT"] = "flowing-mantis-239216"

In [None]:
# see https://wandb.ai/fbv/thesis/runs/kwlaw02g/overview?workspace=user-karelze
dataset = f"fbv/thesis/{EXCHANGE}_{STRATEGY}_none:latest"
run = wandb.init(project="thesis", entity="fbv")

artifact = run.use_artifact(dataset)
data_dir = artifact.download()

## Data Preparation 🌊

In [None]:
run = wandb.init(project="thesis", entity="fbv")

dataset = f"fbv/thesis/{EXCHANGE}_{STRATEGY}_none:latest"

artifact = run.use_artifact(dataset)
data_dir = artifact.download()

columns = [
    *FEATURES,
    "buy_sell",
]

data = pd.read_parquet(Path(data_dir, "train_set.parquet"), engine="fastparquet", columns=columns)

y_train_none = data["buy_sell"]
X_train_none = data.drop(columns="buy_sell")

data = pd.read_parquet(Path(data_dir, "test_set.parquet"), engine="fastparquet", columns=columns)

y_test_none = data["buy_sell"]
X_test_none = data.drop(columns="buy_sell")


dataset = f"fbv/thesis/{EXCHANGE}_{STRATEGY}_log_standardized_clipped:latest"

artifact = run.use_artifact(dataset)
data_dir = artifact.download()

data = pd.read_parquet(Path(data_dir, "train_set.parquet"), engine="fastparquet", columns=columns)

y_train_processed = data["buy_sell"]
X_train_processed = data.drop(columns="buy_sell")

data = pd.read_parquet(Path(data_dir, "test_set.parquet"), engine="fastparquet", columns=columns)

y_test_processed = data["buy_sell"]
X_test_processed = data.drop(columns="buy_sell")

feature_names = X_train_none.columns

## Sage Values🌵

In [None]:
# Define feature groups (disjoint)
feature_groups = group_names = {
    'chg_all_lead (grouped)': ['price_all_lead', 'chg_all_lead'],
    'chg_all_lag (grouped)': ['price_all_lag', 'chg_ex_lag'],
    'chg_ex_lead (grouped)': ['price_ex_lead', 'chg_ex_lead', 'chg_all_lag'],
    'chg_ex_lag (grouped)': ['price_ex_lag'],
    'size_ex (grouped)': [ 'bid_ask_size_ratio_ex', 'rel_bid_size_ex',  'rel_ask_size_ex', 'bid_size_ex', 'ask_size_ex','depth_ex'],
    'quote_best (grouped)': ['BEST_ASK', 'BEST_BID', 'prox_best'],
    'quote_ex (grouped)': ['bid_ex', 'ask_ex','prox_ex' ],
    'TRADE_PRICE': ['TRADE_PRICE'],
    'TRADE_SIZE': ['TRADE_SIZE']    
}
group_names = [group for group in feature_groups]
for col in feature_names:
    if np.all([col not in group[1] for group in feature_groups.items()]):
        group_names.append(col)

# Group indices
groups = []
for _, group in feature_groups.items():
    ind_list = []
    for feature in group:
        ind_list.append(feature_names.tolist().index(feature))
    groups.append(ind_list)


### Classical Classifier🏦

In [None]:
clf = ClassicalClassifier(layers=[("trade_size", "ex"), ("rev_lr", "best")], 
                                  random_state=SEED, strategy="random")

clf.fit(X=X_train_none.head(5), y=y_train_none.head(5))

imputer = GroupedMarginalImputer(clf, X_test_none.head(1024).values, groups)
estimator = PermutationEstimator(imputer, "cross entropy")
sage_values = estimator(X_test_none.head(1024).values, y_test_none.head(1024).values)

In [None]:
sage_values.plot(group_names, title=None)

### Gradient Boosting 🐈

In [None]:
# load model by identifier from wandb
model = "17malsep_CatBoostClassifier_default.cbm:v7"
model_name = model.split("/")[-1].split(":")[0]

artifact = run.use_artifact(model)
model_dir = artifact.download()

clf = CatBoostClassifier()
clf.load_model(fname=Path(model_dir, model_name))

In [None]:
imputer = GroupedMarginalImputer(clf, X_test_processed.head(128).values, groups)
estimator = PermutationEstimator(imputer, "cross entropy")
sage_values = estimator(X_test_processed.head(128).values, y_test_processed.head(128).values)

In [None]:
sage_values.plot(group_names, title=None)

### Transformer Classifier 🤖

In [None]:
model = "2rq3hrkw_TransformerClassifier_default.pkl:latest"
model_name = model.split("/")[-1].split(":")[0]

artifact = run.use_artifact(model)
model_dir = artifact.download()
    
with open(Path(model_dir, model_name), 'rb') as f:
    clf = pickle.load(f)

In [None]:
imputer = GroupedMarginalImputer(clf, X_test_processed.head(1024).values, groups)
estimator = PermutationEstimator(imputer, "cross entropy")
sage_values = estimator(X_test_processed.head(1024).values, y_test_processed.head(1024).values)

In [None]:
sage_values.plot(group_names, title=None)

In [None]:
model = "i3pvza1q_TransformerClassifier_default.pkl:latest"
model_name = model.split("/")[-1].split(":")[0]

artifact = run.use_artifact(model)
model_dir = artifact.download()
    
with open(Path(model_dir, model_name), 'rb') as f:
    clf = pickle.load(f)

## Attention Maps for Transformers

We calculate the average attention map from all transformer blocks, as done in the Gorishniy paper (see [here](https://github.com/Yura52/tabular-dl-revisiting-models/issues/2)). This is different from the Borisov paper (see [here](https://github.com/kathrinse/TabSurvey/blob/main/models/basemodel_torch.py)).

In [1]:
from otc.models.activation import ReGLU
from otc.models.fttransformer import (
    CategoricalFeatureTokenizer,
    CLSToken,
    FeatureTokenizer,
    FTTransformer,
    MultiheadAttention,
    NumericalFeatureTokenizer,
    Transformer,
)

In [2]:
import torch

In [46]:
num_features_cont = 5
num_features_cat = 1
cat_cardinalities = [2]
batch_size = 64


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

x_cat = torch.randint(0, 1, (batch_size, num_features_cat)).to(
            device
        )
x_cont = (
            torch.randn(batch_size, num_features_cont).float().to(device)
)

params_feature_tokenizer  = {
            "num_continous": num_features_cont,
            "cat_cardinalities": cat_cardinalities,
            "d_token": 96,
        }
feature_tokenizer = FeatureTokenizer(**params_feature_tokenizer)
params_transformer = {
            "d_token": 96,
            "n_blocks": 3,
            "attention_n_heads": 8,
            "attention_initialization": "kaiming",
            "ffn_activation": ReGLU,
            "attention_normalization": nn.LayerNorm,
            "ffn_normalization": nn.LayerNorm,
            "ffn_dropout": 0.1,
            "ffn_d_hidden": 96 * 2,
            "attention_dropout": 0.1,
            "residual_dropout": 0.1,
            "prenormalization": True,
            "first_prenormalization": False,
            "last_layer_query_idx": None,
            "n_tokens": None,
            "kv_compression_ratio": None,
            "kv_compression_sharing": None,
            "head_activation": nn.ReLU,
            "head_normalization": nn.LayerNorm,
            "d_out": 1,
        }

transformer = Transformer(**params_transformer)
model = FTTransformer(feature_tokenizer, transformer).to(device)

In [47]:
class SaveAttentionMaps:
    """
    Hook for attention maps.

    Inspired by:
    https://github.com/Yura52/tabular-dl-revisiting-models/issues/2#issuecomment-1068123629
    """

    def __init__(self):
        self.attention_maps: List[torch.Tensor] = []

    def __call__(self, _, __, output):
        # print(output[1]["attention_probs"].shape)
        self.attention_maps.append(output[1]["attention_probs"])

In [56]:
# The following hook will save all attention maps from all attention modules.
# Prepare data and model.
n_objects = len(x_cat)  # 12
n_features = num_features_cont + num_features_cat

# The following hook will save all attention maps from all attention modules.
hook = SaveAttentionMaps()
for block in model.transformer.blocks:
    block.attention.register_forward_hook(hook)

# Apply the model to all objects.
model.eval()
with torch.inference_mode():
    y = model(x_cat, x_cont)
    # y.requires_grad_(True)
    # print(y.shape)

# Collect attention maps
n_blocks = len(model.transformer.blocks)
n_heads = model.transformer.blocks[0].attention.n_heads
# continuous features + categorical features + CLS token
n_tokens = n_features + 1
# residual connection
res = torch.eye(n_tokens, n_tokens)
res = res.unsqueeze(0).expand(batch_size, n_tokens, n_tokens)

model.zero_grad()

# index = [i for i in range(batch_size)]
# one_hot = np.zeros((logits_per_image.shape[0], logits_per_image.shape[1]), dtype=np.float32)
# one_hot[torch.arange(logits_per_image.shape[0]), index] = 1
# one_hot = torch.from_numpy(one_hot).requires_grad_(True)
# one_hot = torch.sum(one_hot.cuda() * logits_per_image)
# model.zero_grad()

for attention_map in hook.attention_maps:
    # batch size, n_heads, n_tokens, n_tokens
    attention_probs = attention_map.reshape(batch_size, n_heads, n_tokens, n_tokens)

    attention_probs = attention_probs.clone().detach().requires_grad_(True)

    # calculate gradient with respect to output
    grad = torch.autograd.grad(y.sum(), [attention_probs], retain_graph=True)[0].detach()
    cam = attention_probs.detach()
    print(grad.shape)
    print(cam.shape)

    cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
    grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
    cam = grad * cam
    cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])
    cam = cam.clamp(min=0).mean(dim=1)
    res = res + torch.bmm(cam, res)

attention_maps = res

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [None]:
# https://colab.research.google.com/github/hila-chefer/Transformer-MM-Explainability/blob/main/CLIP_explainability.ipynb#scrollTo=fWKGyu2YAeSV

def interpret(image, texts, model, device, start_layer=start_layer, start_layer_text=start_layer_text):
    batch_size = texts.shape[0]
    images = image.repeat(batch_size, 1, 1, 1)
    logits_per_image, logits_per_text = model(images, texts)
    probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy()
    index = [i for i in range(batch_size)]
    one_hot = np.zeros((logits_per_image.shape[0], logits_per_image.shape[1]), dtype=np.float32)
    one_hot[torch.arange(logits_per_image.shape[0]), index] = 1
    one_hot = torch.from_numpy(one_hot).requires_grad_(True)
    one_hot = torch.sum(one_hot.cuda() * logits_per_image)
    model.zero_grad()

    image_attn_blocks = list(dict(model.visual.transformer.resblocks.named_children()).values())

    if start_layer == -1: 
      # calculate index of last layer 
      start_layer = len(image_attn_blocks) - 1
    
    num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
    R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device)
    R = R.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
    for i, blk in enumerate(image_attn_blocks):
        if i < start_layer:
          continue
        grad = torch.autograd.grad(one_hot, [blk.attn_probs], retain_graph=True)[0].detach()
        cam = blk.attn_probs.detach()
        cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
        grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
        cam = grad * cam
        cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])
        cam = cam.clamp(min=0).mean(dim=1)
        R = R + torch.bmm(cam, R)
    image_relevance = R[:, 0, 1:]

    
    text_attn_blocks = list(dict(model.transformer.resblocks.named_children()).values())

    if start_layer_text == -1: 
      # calculate index of last layer 
      start_layer_text = len(text_attn_blocks) - 1

    num_tokens = text_attn_blocks[0].attn_probs.shape[-1]
    R_text = torch.eye(num_tokens, num_tokens, dtype=text_attn_blocks[0].attn_probs.dtype).to(device)
    R_text = R_text.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
    for i, blk in enumerate(text_attn_blocks):
        if i < start_layer_text:
          continue
        grad = torch.autograd.grad(one_hot, [blk.attn_probs], retain_graph=True)[0].detach()
        cam = blk.attn_probs.detach()
        cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
        grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
        cam = grad * cam
        cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])
        cam = cam.clamp(min=0).mean(dim=1)
        R_text = R_text + torch.bmm(cam, R_text)
    text_relevance = R_text
   
    return text_relevance, image_relevance
