Do custom install of `sage-importance`

```bash
git clone https://github.com/karelze/sage.git
cd sage
pip install .
```

In [None]:
import os
import sys
import pickle
from pathlib import Path

from catboost import CatBoostClassifier, Pool

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import rc
import matplotlib.dates as mdates
import matplotlib.ticker as ticker
from matplotlib.ticker import FormatStrFormatter, StrMethodFormatter, PercentFormatter


import pandas as pd
import seaborn as sns
import sklearn
from sklearn.metrics import accuracy_score
from sklearn.ensemble import HistGradientBoostingClassifier
import torch
from torch import nn

sys.path.append("..")
from otc.models.classical_classifier import ClassicalClassifier

from sage import GroupedMarginalImputer, PermutationEstimator, MarginalImputer

from otc.features.build_features import (
    features_categorical,
    features_classical,
    features_classical_size,
    features_ml,
)

from otc.models.fttransformer import FeatureTokenizer, FTTransformer, Transformer
from otc.models.activation import ReGLU
from otc.data.dataset import TabDataset
from otc.data.dataloader import TabDataLoader
from otc.features.build_features import features_classical_size
from otc.optim.early_stopping import EarlyStopping
from otc.optim.scheduler import CosineWarmupScheduler

import wandb
from tqdm.auto import tqdm

In [None]:
SEED = 42

np.random.seed(42) 

# set globally here
EXCHANGE = "ise"  
STRATEGY = "supervised"  
SUBSET = "test"  


# Change depending on model!
FEATURES = features_ml

In [None]:
# set project name. Required to access files and artefacts
os.environ["GCLOUD_PROJECT"] = "flowing-mantis-239216"

## Sage Values🌵

In [None]:
def get_feature_groups(feature_names, feature_str):

    fg_classical = {
        'chg_all_lead (grouped)': ['price_all_lead', 'chg_all_lead'],
        'chg_all_lag (grouped)': ['price_all_lag', 'chg_ex_lag'],
        'chg_ex_lead (grouped)': ['price_ex_lead', 'chg_ex_lead', 'chg_all_lag'],
        'chg_ex_lag (grouped)': ['price_ex_lag'],
        'quote_best (grouped)': ['BEST_ASK', 'BEST_BID', 'prox_best'],
        'quote_ex (grouped)': ['bid_ex', 'ask_ex','prox_ex' ],
        'TRADE_PRICE': ['TRADE_PRICE'],
        }
    
    fg_size = {'size_ex (grouped)': [ 'bid_ask_size_ratio_ex', 'rel_bid_size_ex',  'rel_ask_size_ex', 'bid_size_ex', 'ask_size_ex','depth_ex'], 'TRADE_SIZE': ['TRADE_SIZE']}
    
    fg_ml = {
        "STRK_PRC": ["STRK_PRC"],
        "ttm": ["ttm"],
        "option_type": ["option_type"],
        "root":["root"],
        "myn":["myn"],
        "day_vol":["day_vol"], 
        "issue_type":["issue_type"],
    }
    
    if feature_str == "classical":
        feature_groups = group_names = fg_classical    
    if feature_str == "classical-size":
        feature_groups = group_names = fg_classical | fg_size
    if feature_str == "ml":
        feature_groups = group_names = fg_classical | fg_size | fg_ml      
    

    # Group indices
    groups = []
    for _, group in feature_groups.items():
        ind_list = []
        for feature in group:
            ind_list.append(feature_names.index(feature))
        groups.append(ind_list)

    return groups, group_names


In [None]:
# load unscaled data for classical classifier
run = wandb.init(project="thesis", entity="fbv")

dataset = f"fbv/thesis/{EXCHANGE}_{STRATEGY}_none:latest"

artifact = run.use_artifact(dataset)
data_dir = artifact.download()

data = pd.read_parquet(Path(data_dir, "test_set.parquet"), engine="fastparquet", columns=[*features_classical_size, "buy_sell"])

y_test = data["buy_sell"]
X_test = data.drop(columns="buy_sell")

feature_names = X_test.columns

### Classical Classifier🏦

In [None]:
y_test.head()

In [None]:
sample_size = 1024 # default

idx = np.random.choice(y_test.index, size=sample_size, replace=False)

X_importance = X_test.loc[idx]
y_importance = y_test.loc[idx]

In [None]:
# config = [("trade_size", "ex"), ("quote", "best"),  ("quote", "ex"), ("depth", "best"), ("depth", "ex"), ("rev_tick", "all")]  
config = [("nan", "ex")]
clf = ClassicalClassifier(layers=config, random_state=SEED, strategy="random")
clf.fit(X=X_test.head(5), y=y_test.head(5))
clf.score(X_test, y_test)

pred_grauer = clf.predict_proba(X_test)
pred_grauer_label = clf.predict(X_test)

In [None]:
from sklearn.metrics import log_loss, zero_one_loss

In [None]:
# samplewise_log_loss = []
# for true_label, pred_probs in zip(y_importance, y_pred):
#     sample_loss = log_loss([true_label], [pred_probs], labels=[0, 1])
#     samplewise_log_loss.append(sample_loss)

# print("Samplewise Log Loss:", samplewise_log_loss)

In [None]:
# log_loss(y_test.clip(0), pred_grauer)

zero_one_loss(y_test.clip(0), pred_grauer_label.clip(0))


In [None]:
# compare benchmarks
configs = [
    [("quote", "best"), ("quote", "ex"), ("rev_tick", "all")],
    [("trade_size", "ex"), ("quote", "best"),  ("quote", "ex"), ("depth", "best"), ("depth", "ex"), ("rev_tick", "all")]  
]

results = []
for config in configs:
    
    groups, group_names = get_feature_groups(X_importance.columns.tolist(), "classical-size")
    
    clf = ClassicalClassifier(layers=config, random_state=SEED, strategy="random")
    # only set headers etc, no leakage
    clf.fit(X=X_test.head(5), y=y_test.head(5))
    
    def call_classical(X):
        
        pred = clf.predict_proba(X)
        # max_class = np.argmax(pred, axis=-1)
        # return max_class
        return pred

    # apply group based imputation + estimate importances in terms of zero-one loss
    imputer = GroupedMarginalImputer(call_classical, X_importance.values, groups)
    estimator = PermutationEstimator(imputer, "zero one")
    
    # calculate values over entire test set
    #y_test_np = y_test.clip(lower=0).values
    # print(y_test_np.shape)
    sage_values = estimator(X_test.values, y_test.values)
    
    # save sage values + std deviation to data frame
    result = pd.DataFrame(index=group_names, data={"values": sage_values.values, "std": sage_values.std})
    results.append(result)

In [None]:
# generate names for df
names = []

# generate human readable names like quote(best)->quote(ex)
for r in tqdm(configs):
    name = "->".join("%s(%s)" % tup for tup in r)
    names.append(name)

results_df = pd.concat(results, axis=1, keys=names)

# flatten column names (required to save to parquet)
results_df.columns = [' '.join(col).strip() for col in results_df.columns.values]

In [None]:
results_df

In [None]:
results_df

In [None]:
KEY = f"{EXCHANGE}_{STRATEGY}_{SUBSET}_classical_feature_importance"

URI_FI_CLASSICAL = f"gs://thesis-bucket-option-trade-classification/data/results/{KEY}.parquet"

results_df.to_parquet(URI_FI_CLASSICAL)

result_set = wandb.Artifact(name=KEY, type="results")
result_set.add_reference(URI_FI_CLASSICAL, name="results")

In [None]:
run.finish()

### Gradient Boosting 🐈

In [None]:
FEATURE_MAP = {
    "classical": features_classical,
    "classical-size": features_classical_size,
    "ml": features_ml,
}

run = wandb.init(project="thesis", entity="fbv")

# load processed data for gradient-boosting
dataset = f"fbv/thesis/{EXCHANGE}_{STRATEGY}_log_standardized_clipped:latest"

artifact = run.use_artifact(dataset)
data_dir = artifact.download()

data = pd.read_parquet(Path(data_dir, "test_set.parquet"), engine="fastparquet", columns=[*features_ml, "buy_sell"])

y_test = data["buy_sell"]
X_test = data.drop(columns="buy_sell")

feature_names = X_test.columns

In [None]:
dataset

In [None]:
sample_size = 512

In [None]:
idx = np.random.choice(X_test.index, size=sample_size, replace=False)

X_importance = X_test.loc[idx]
y_importance = y_test.loc[idx]

In [None]:
X_importance.shape

In [None]:
groups, group_names = get_feature_groups(X_importance.columns.tolist(), "classical")
print(groups)
print(group_names)

In [None]:
configs = [("classical", "1gzk7msy_CatBoostClassifier_default.cbm:latest"),
    ("classical-size", "3vntumoi_CatBoostClassifier_default.cbm:latest"),
    ("ml", "2t5zo50f_CatBoostClassifier_default.cbm:latest")]

results = []

for feature_str, model in configs:
    
    # get feature names and slice to subset
    fs = FEATURE_MAP.get(feature_str)
    X_importance_fs = X_importance.loc[:, fs]
    X_importance_cols = X_importance_fs.columns.tolist()
    
    # calculate cat indices
    if feature_str == "ml":
        cat_features = [t[0] for t in features_categorical]
        cat_idx = [X_importance_cols.index(f) for f in cat_features]
    
    # get groups
    groups, group_names = get_feature_groups(X_importance_cols, feature_str)
    
    #  load model by identifier from wandb
    model_name = model.split("/")[-1].split(":")[0]

    
    artifact = run.use_artifact(model)
    model_dir = artifact.download()
    clf = CatBoostClassifier()
    clf.load_model(fname=Path(model_dir, model_name))
    
    # print(clf.score(X_test, y_test))
    
    # use callable instead of default catboost as it doesn't work with categoricals otherwise
    
    pred=None
    
    def call_catboost(X):
        if feature_str == "ml":       
            # convert categorical to int
            X = pd.DataFrame(X, columns=X_importance.columns)
            # Update the selected columns in the original DataFrame
            X[cat_features] = X.iloc[:, cat_idx].astype(int)
            # pass cat indices
            # return clf.predict(Pool(X, cat_features=cat_idx))
            return clf.predict_proba(Pool(X, cat_features=cat_idx))
        else:
            # pred = clf.predict_proba(X)
            # print(pred)
            
            # max_class = np.argmax(pred, axis=-1)
            # print(max_class)
            
            # zeros = np.count_nonzero(max_class == 0)
            #  print(zeros)
            # ones = np.count_nonzero(max_class == 1)
            # print(ones)
            # return max_class
            return clf.predict_proba(X)
            
    
    # apply group based imputation + estimate importances in terms of zero-one loss
    imputer = GroupedMarginalImputer(call_catboost, X_importance_fs, groups)
    # imputer = MarginalImputer(call_catboost, X_importance_fs)
    estimator = PermutationEstimator(imputer, "zero one")
    
    # print(X_test.loc[:,fs].shape)
    # calculate values over entire test set
    # y_test_np = y_test.clip(lower=0).values
    sage_values = estimator(X_test.loc[:,fs].values, y_test.values)
    
    # save sage values + std deviation to data frame
    result = pd.DataFrame(index=group_names, data={"values": sage_values.values, "std": sage_values.std})
    # result = pd.DataFrame(index=X_importance_cols, data={"values": sage_values.values, "std": sage_values.std})
    results.append(result)

In [None]:
result

In [None]:
pred = clf.predict_proba(X_importance)

In [None]:
pred

In [None]:
result

In [None]:
names = [f"gbm({feature_str[0]})" for feature_str in configs]
results_df = pd.concat(results, axis=1, keys=names)
results_df.columns = [' '.join(col).strip() for col in results_df.columns.values]

In [None]:
results_df

In [None]:
# list to data frame + set human readable names
names = [f"gbm({feature_str[0]})" for feature_str in configs]
results_df = pd.concat(results, axis=1, keys=names)
results_df.columns = [' '.join(col).strip() for col in results_df.columns.values]

# save to google clound and save identiifer
KEY = f"{EXCHANGE}_{STRATEGY}_{SUBSET}_gbm_feature_importance_{sample_size}"

URI_FI_GBM = f"gs://thesis-bucket-option-trade-classification/data/results/{KEY}.parquet"

results_df.to_parquet(URI_FI_GBM)

result_set = wandb.Artifact(name=KEY, type="results")
result_set.add_reference(URI_FI_GBM, name="results")

### Transformer Classifier 🤖

In [None]:
configs = [
    ("classical", "3jpe46s1_TransformerClassifier_default.pkl:latest"),
    ("classical-size", "1qx3ul4j_TransformerClassifier_default.pkl:latest"), 
    ("ml", "2h81aiow_TransformerClassifier_default.pkl:latest"),
]

results = []

for feature_str, model in configs:
    # load model by identifier from wandb
    model_name = model.split("/")[-1].split(":")[0]

    # get feature names and slice to subset
    fs = FEATURE_MAP.get(feature_str)
    X_importance_fs = X_importance.loc[:, fs]
    X_importance_cols = X_importance_fs.columns.tolist()
    
    # calculate cat indices
    if feature_str == "ml":
        cat_features = [t[0] for t in features_categorical]
        cat_idx = [X_importance_cols.index(f) for f in cat_features]
    
    # get groups
    groups, group_names = get_feature_groups(X_importance_cols, feature_str)
    
    model_name = model.split("/")[-1].split(":")[0]

    artifact = run.use_artifact(model)
    model_dir = artifact.download()

    with open(Path(model_dir, model_name), 'rb') as f:
        clf = pickle.load(f)
    
    # apply group based imputation + estimate importances in terms of zero-one loss
    imputer = GroupedMarginalImputer(clf, X_importance_fs, groups)
    estimator = PermutationEstimator(imputer, "zero one")
    
    # calculate values over entire test set
    sage_values = estimator(X_test.loc[:,fs].values, y_test.values)
    
    # save sage values + std deviation to data frame
    result = pd.DataFrame(index=group_names, data={"values": sage_values.values, "std": sage_values.std})
    results.append(result)

In [None]:
results

In [None]:
list to data frame + set human readable names
names = [f"fttransformer({feature_str[0]})" for feature_str in configs]
results_df = pd.concat(results, axis=1, keys=names)
results_df.columns = [' '.join(col).strip() for col in results_df.columns.values]

# save to google clound and save identiifer
KEY = f"{EXCHANGE}_{STRATEGY}_{SUBSET}_fttransformer_feature_importance_{sample_size}"

URI_FI_FTTRANSFORMER = f"gs://thesis-bucket-option-trade-classification/data/results/{KEY}.parquet"

results_df.to_parquet(URI_FI_FTTRANSFORMER)

result_set = wandb.Artifact(name=KEY, type="results")
result_set.add_reference(URI_FI_FTTRANSFORMER, name="results")
run.log_artifact(result_set)

wandb.finish()

## Attention Maps for Transformers

We calculate the average attention map from all transformer blocks, as done in the [here](https://github.com/hila-chefer/Transformer-MM-Explainability/blob/main/lxmert/lxmert/src/ExplanationGenerator.py#L26) and [here](https://colab.research.google.com/github/hila-chefer/Transformer-MM-Explainability/blob/main/CLIP_explainability.ipynb#scrollTo=fWKGyu2YAeSV)

In [None]:
params = {
    "pgf.texsystem": "xelatex",
    "pgf.rcfonts": False,
    "font.serif": [],
    "font.family": "serif",
    "font.sans-serif": [],
    "axes.labelsize": 11,
}

plt.rcParams.update(params)
rc("text", usetex=True)

plt.rc('text.latex', preamble=r'\usepackage{amsmath}\usepackage[utf8]{inputenc}')

CM = 1 / 2.54

cmap = mpl.colormaps.get_cmap("plasma")

In [None]:
MODEL = "2h81aiow_TransformerClassifier_default.pkl:latest"

run = wandb.init(project="thesis", entity="fbv")

model_name = MODEL.split("/")[-1].split(":")[0]

artifact = run.use_artifact(MODEL)
model_dir = artifact.download()
    
with open(Path(model_dir, model_name), 'rb') as f:
    model = pickle.load(f)
    
clf = model.clf

In [None]:
dataset = f"fbv/thesis/{EXCHANGE}_{STRATEGY}_log_standardized:latest"

artifact = run.use_artifact(dataset)
data_dir = artifact.download()

data = pd.read_parquet(Path(data_dir, "test_set.parquet"), engine="fastparquet", columns=[*features_ml, "buy_sell"])

y_test = data["buy_sell"]
X_test = data.drop(columns="buy_sell")

In [None]:
X_test.head()

In [None]:
key = "ise_quotes_mid"

# first 16 at quotes, last 16 at mid
idx =  [39342191, 39342189, 39342188, 39342175, 39342174, 39342171,
            39342233, 39342241, 39342388, 39342239, 39342237, 39342193,
            39342194, 39342199, 39342202, 39342204,
        39342276, 39342363, 39342387, 39342437, 39342436, 39342428,
            39342464, 39342540, 39342608, 39342598, 39342620, 39342632,
            39342674, 39342781, 39342804, 39342824]

In [None]:
X_test.loc[idx]

In [None]:
# idx = 0
device = "cuda"
batch_size = len(idx)

cat_features = model.module_params["cat_features"]
cat_unique_counts = model.module_params["cat_cardinalities"]

dl_params = {
    "batch_size": batch_size,  
    "shuffle": False,
    "device": device,
}

test_data = TabDataset(X_test.loc[idx], y_test.loc[idx], cat_features=cat_features, cat_unique_counts=cat_unique_counts)


test_loader = TabDataLoader(
    test_data.x_cat,
    test_data.x_cont,
    test_data.weight,
    test_data.y,
    **dl_params
)



In [None]:
x_cat, x_cont, weight, y = next(iter(test_loader))

In [None]:
criterion = nn.BCEWithLogitsLoss()

# calculate outputs
logits = clf(x_cat, x_cont).flatten()

# zero gradients
clf.zero_grad()

# loss + backward pass
loss = criterion(logits, y)
loss.backward()

In [None]:
# https://github.com/hila-chefer/Transformer-MM-Explainability/blob/main/lxmert/lxmert/src/ExplanationGenerator.py#L26
# https://colab.research.google.com/github/hila-chefer/Transformer-MM-Explainability/blob/main/CLIP_explainability.ipynb#scrollTo=fWKGyu2YAeSV

attn_block = clf.transformer.blocks[0].attention.get_attn()
# cat + cont + [CLS]
n_tokens = attn_block.shape[-1]
# residual connection. Repeat along batch dimension
res = torch.eye(n_tokens, n_tokens).to(device)
res = res.unsqueeze(0).expand(batch_size, n_tokens, n_tokens)

# one_hot = expected_outputs.sum()
# one_hot.backward(retain_graph=True)

cams = []
grads = []

for i, block in enumerate(clf.transformer.blocks):

    grad = block.attention.get_attn_gradients().detach()
    cam = block.attention.get_attn().detach()
    
    cams.append(cam)
    grads.append(grad)
    
    # reshape to [batch_size x num_head, num_tokens, num_tokens]
    cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
    grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
    
    # dot product
    cam = grad * cam
    
    # reshape to [batch_size, num_head, num_tokens, num_tokens]
    cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])
    # clamp negative values, calculate mean over heads
    cam = cam.clamp(min=0).mean(dim=1)
    res = res + torch.bmm(cam, res)

relevancy = res

In [None]:
# get first attention map from batch and visualize
batch_probs = relevancy.detach().cpu().numpy()

In [None]:
# visualize
stack = []

for i in range(batch_size):
    row = batch_probs[-i][0,1:]
    # row = test[np.newaxis,...]
    stack.append(row)
    
stack_np = np.vstack(stack)

In [None]:
len(stack_np)

In [None]:
cont_features = [f for f in X_test.columns.tolist() if f not in cat_features]
# see feature tokenizer but without cls token
labels = [*cont_features, *cat_features]

In [None]:
labels_sanitized = ['trade price',
 'bid (ex)',
 'ask (ex)',
 'ask (best)',
 'bid (best)',
 'price lag (ex)',
 'price lead (ex)',
 'price lag (all)',
 'price lead (all)',
 'chg lead (ex)',
 'chg lag (ex)',
 'chg lead (all)',
 'chg lag (all)',
 'prox (ex)',
 'prox (best)',
 'bid ask size ratio (ex)',
 'rel. bid size (ex)',
 'rel. ask size (ex)',
 'trade size',
 'bid size (ex)',
 'ask size (ex)',
 'depth (ex)',
 'strike price',
 'time to maturity',
 'moneyness',
 'day volume',
 'option type',
 'issue type',
 'root']

In [None]:
# split into trades at quotes + at mid
stack_at_quotes, stack_at_mid = np.split(stack_np, 2)

In [None]:
detail_idx_correct = 0 # 39342191
detail_idx_false = 8 # 39342388
alpha = 0.1

fig, ax = plt.subplots(1, 2, figsize=(14*CM,10*CM), sharey=True, sharex=True)
ax[0].imshow(stack_at_quotes.T, cmap='Blues', interpolation='nearest')
ax[0].set_yticks(list(range(len(labels_sanitized))))
# ax[0].set_xticks(range(1, 17, 2))
ax[0].tick_params(axis='both', which='major', labelsize="x-small")

ax[0].axvspan(detail_idx_correct - 0.5, detail_idx_correct + 0.5, color='green', alpha=alpha)
ax[0].axvspan(detail_idx_false - 0.5, detail_idx_false + 0.5, color='red', alpha=alpha)

# ax[0].set_xticks(range(1, 17, 2), fontsize='x-small')
ax[0].set_yticklabels(labels_sanitized)
ax[0].set_xlabel("At Quotes",size="small")

ax[1].imshow(stack_at_mid.T, cmap='Blues', interpolation='nearest')
ax[1].set_yticks(list(range(len(labels_sanitized))))
ax[1].set_xticks(range(0, 16, 2))
ax[1].tick_params(axis='both', which='major', labelsize="x-small")
ax[1].set_xlabel("At Mid", size="small")

plt.tight_layout()
plt.savefig(f"../reports/Graphs/attention_maps_{key}.pdf", bbox_inches="tight")

In [None]:
labels_right = ["$\mathtt{[CLS]}$", *labels_sanitized]

In [None]:
labels_left = ['$\\mathtt{[CLS]}$', *["..."]*(len(labels_right) - 1)]

In [None]:
plt.figure(figsize=(3*CM,10*CM))


yoffset = 0
# xoffset = ei * width * example_sep
xoffset = 0


# width = 1
# # example_sep = 3
# word_height = 1
# pad = 0.02


width = 0.5
example_sep = 3
word_height = 0.01
pad = 0.025

# by index
l = 3 # 3# 3
h = 7 # 4 #6 #2 # 4
batch_idx = 8 # or 8
c = "green" if batch_idx == 0 else "red"

# heads=8
# layers=4

cam = cams[l].reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])
attention = cam[batch_idx,h,:,:]
attention /= attention.sum(axis=-1, keepdims=True)


# print(attention)
# color = iter(plt.cm.rainbow(np.linspace(0, 1, heads * layers)))

for position, word in enumerate(labels_left):
    plt.text(0, yoffset - position * word_height, word,
                ha="right", va="center", size="xx-small")
for position, word in enumerate(labels_detail):
    plt.text(width, yoffset - position * word_height, word,
                ha="left", va="center", size="xx-small")
# focus on cls token
# c = next(color)
# CLS is prepended, get first row, similar to chefer
for i, vec in enumerate(attention[0:1]):
    for j, el in enumerate(vec):
        plt.plot([xoffset + pad, xoffset + width - pad],
                    [yoffset - word_height * i, yoffset - word_height * j],
                    color=c, linewidth=2, alpha=el.item())
plt.axis('off')
# plt.tight_layout()
plt.savefig(f"../reports/Graphs/attention_head_{h+1}_layer_{l+1}_color_{c}_{key}.pdf", bbox_inches="tight")

In [None]:
# by index
batch_idx = 0 # or 8
c = "green" if batch_idx == 0 else "red"

heads=8
layers=4

# iterate over all layers
for l in tqdm(range(layers)):
    for h in range(heads):

        plt.figure(figsize=(3*CM,10*CM))
        
        cam = cams[l].reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])
        attention = cam[batch_idx,h,:,:]
        attention /= attention.sum(axis=-1, keepdims=True)

        
        for position, word in enumerate(labels_left):
            plt.text(0, yoffset - position * word_height, word,
                        ha="right", va="center", size="xx-small")
        for position, word in enumerate(labels_right):
            plt.text(width, yoffset - position * word_height, word,
                        ha="left", va="center", size="xx-small")
            
        # CLS is prepended, get first row, similar to chefer
        for i, vec in enumerate(attention[0:1]):
            for j, el in enumerate(vec):
                plt.plot([xoffset + pad, xoffset + width - pad],
                            [yoffset - word_height * i, yoffset - word_height * j],
                            color=c, linewidth=2, alpha=el.item())
        plt.axis('off')
        
        plt.tight_layout()
        plt.savefig(f"../reports/Graphs/attention_head_{h+1}_layer_{l+1}_color_{c}_{key}.pdf", bbox_inches="tight")

In [None]:
from matplotlib.pyplot import cm

plt.figure(figsize=(28,12))


yoffset = 0
# xoffset = ei * width * example_sep
xoffset = 0


# width = 1
# # example_sep = 3
# word_height = 1
# pad = 0.02

batch_idx = 0
color = "green" if batch_idx else "red"

width = 0.5
example_sep = 3
word_height = 1
pad = 0.025

layers = 4
heads = 4

fig, axes = plt.subplots(heads, layers)



# cam = cams[3]

# [batch x head x attn x dim attn]
# attention = cam[0,7,:,:]

# print(attention.shape)

# # attention = cams[3][5,:,:]
# attention /= attention.sum(axis=-1, keepdims=True)

# strengthen
# attention = np.exp(attention.cpu())/np.exp(attention.cpu()).sum()

# print(attention)
# color = iter(cm.rainbow(np.linspace(0, 1, heads * layer)))

for l in range(layer):

    for h in range (heads):
        # [batch x head x attn x dim attn]

        cam = cams[l].reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])

        # [first in batch, head h, :,:]
        attention = cam[0,h,:,:]

        attention /= attention.sum(axis=-1, keepdims=True)

        # yoffset = 1
        # xoffset = h * width * example_sep

        # for position, word in enumerate(labels_detail):
        #     plt.text(xoffset + 0, yoffset - position * word_height, word,
        #             ha="right", va="center")
        #     plt.text(xoffset + width, yoffset - position * word_height, word,
        #             ha="left", va="center")

        # focus on cls token
        # c = next(color)
        for i, vec in enumerate(attention[0:1]):
            for j, el in enumerate(vec):
                axes[h, l].plot([pad, width - pad], # x axis
                         [word_height * i, word_height * j],
                         color=color, linewidth=2, alpha=el.item())

        # axes[l,h].set_yticks(range(len(labels_left)), labels_left, size='xx-small', ha="right", va="baseline")
        # axes_right.set_yticks(range(len(labels_detail)), labels_detail, size='xx-small')
        # axes_right = axes[l,h].twinx()
        # axes_right.set_yticks(range(len(labels_detail)), labels_detail, size='xx-small', ha="left", va="baseline")
        # axes[l,h].set_xticks([])
        axes[h,l].set_xlabel(f"head ({l+1},{h+1})", size='xx-small')
# fig.tight_layout()
        axes[h,l].set_xticks([])
        axes[h,l].set_yticks([])
        # axes[l,h].axis('off')

plt.savefig(f"../reports/Graphs/attention_heads_layer_all_{key}.pdf", bbox_inches="tight")

In [None]:

data = {"grads":grads, "cams":cams, "final-scores":stack_np_copy}

In [None]:
# Specify the file path where you want to save the pickle file
file_path = 'data.pickle'

# Open the file in binary mode and write the dictionary to it
with open(file_path, 'wb') as file:
    pickle.dump(data, file)