# Saliency Maps with HuggingFace and TextualHeatmap

This notebook implements the saliency map as described in [Andreas Madsen's distill paper](https://distill.pub/2019/memorization-in-rnns/). However, it apply the method on BERT models rather than RNN models.

The visualization therefore describes which words/sub-words were important for infering a masked word/sub-word.

* TextualHeatmap: https://github.com/AndreasMadsen/python-textualheatmap
* HuggingFace Transformers: https://github.com/huggingface/transformers

First install TensorFlow, TextualHeatmap, and Transformers. In this notebook the TensorFlow implementations in `transfomers` are used, however this could also be applyed to the PyTorch implementation.

As described in [Andreas Madsen's distill paper](https://distill.pub/2019/memorization-in-rnns/), the saliency map is computed by measuring the gradient magnitude of the output w.r.t. the input.

$$
\mathrm{connectivity}(t, \tilde{t}) = \left|\left| \frac{\partial y^{\tilde{t}}_{k}}{\partial x^t} \right|\right|_2
$$

Implementation wise this can be done quite easily with `tf.GradientTape`. However, because the gradient can not be take w.r.t. to an `int32` type, which is how the `token_ids` are encoded, an one-hot-encoding should be used instead. HuggingFace Transformers supports this via `inputs_embeds` which is the actual input-word-embedding, thus by computing $\mathbf{x} \mathbf{W}$ in the `tf.GradientTape` scope the gradient w.r.t. $\mathbf{x}$ can be computed.

In [138]:
import torch
import torch.nn.functional as F

def compute_textual_saliency(model, embedding_matrix, tokenizer, text):
    token_ids = tokenizer.encode(text, add_special_tokens=True)
    vocab_size = embedding_matrix.num_embeddings

    print(token_ids)

    heatmap_data = []

    for masked_token_index in range(len(token_ids)):
        # print(f'processing token {masked_token_index + 1} / {len(token_ids)}')

        if masked_token_index == 0:
            heatmap_data.append({
                'token': '[CLR]',
                'meta': ['', '', ''],
                'heat': [1] + [0] * (len(token_ids) - 1)
            })
        elif masked_token_index == len(token_ids) - 1:
            heatmap_data.append({
                'token': ' ',
                'format': True
            })
            heatmap_data.append({
                'token': '[SEP]',
                'meta': ['', '', ''],
                'heat': [0] * (len(token_ids) - 1) + [1]
            })
        else:
            # Get the actual token
            target_token = tokenizer.convert_ids_to_tokens(
                token_ids[masked_token_index])

            if target_token[0:2] == '##':
                target_token = target_token[2:]
            else:
                heatmap_data.append({
                    'token': ' ',
                    'format': True
                })

            # integers are not differentable, so use a one-hot encoding
            # of the intput
            token_ids_tensor = torch.tensor(token_ids, dtype=torch.int64)
            token_ids_tensor[masked_token_index] = tokenizer.mask_token_id
            token_ids_tensor_one_hot = F.one_hot(token_ids_tensor, vocab_size).float()
            token_ids_tensor_one_hot.requires_grad = True

            # To select, the correct output witch is what the importance
            # measure targets, create a masking tensor. tf.gather_nd could also
            # be used, but this is easier.
            output_mask = torch.zeros((1, len(token_ids), model.num_classes))
            # todo match here with correct label / only mark correct one with 1?
            output_mask[0, masked_token_index, :] = 1

            # Compute gradient of the logits of the correct target, w.r.t. the
            # input

            inputs_embeds = torch.matmul(token_ids_tensor_one_hot, embeddings.weight)
            dummy = torch.full_like(token_ids_tensor, -100)
            output = model(**{"inputs_embeds": inputs_embeds.unsqueeze(dim=0), "labels": dummy, "sense-labels": dummy})
            logits = output.logits

            # todo fixme ? to be matched with original label
            predict_mask_correct_token = torch.sum(logits * output_mask)

            print(predict_mask_correct_token)

            # Get the top-3 predictions
            (_, top_3_indices) = torch.topk(logits[masked_token_index, :], 3)
            # top_3_predicted_tokens = tokenizer.convert_ids_to_tokens(top_3_indices)

            predict_mask_correct_token.backward()

            # compute the connectivity
            connectivity_non_normalized = torch.norm(token_ids_tensor_one_hot.grad, dim=1)
            connectivity_tensor = (
                connectivity_non_normalized /
                torch.max(connectivity_non_normalized)
            )

            print(connectivity_tensor)
            connectivity = connectivity_tensor[0].numpy().tolist()

            # todo - zero grads!
            # token_ids_tensor_one_hot.zero_()
            # predict_mask_correct_token.zero_()
            model.zero_grad()


            heatmap_data.append({
                'token': target_token,
                'meta': top_3_indices.tolist(), # todo replace with sense-key names
                'heat': connectivity
            })

    return heatmap_data

With this implementation, it is now possible to compare different BERT-like models. In theory any model can be compared, as long as the tokenization is the same. In this case the BERT and DistillBERT models are very similar, which is what we would expect and want.

In [139]:
def construct_model_name(hf_model: str):
    if "bert-wwm" in hf_model:
        model_name = "bert-large-uncased-whole-word-masking"
    elif "roberta" in hf_model:
        model_name = "roberta-base"
    else:
        assert "deberta" in hf_model
        model_name = "microsoft/deberta-base"
    return model_name

In [140]:
from transformers import AutoConfig, AutoTokenizer
from textualheatmap import TextualHeatmap
import datagen
from datagen.dataset import SemCorDataSet
import datasets
from modelling.model import SynsetClassificationModel

text = ("context the formal study of grammar is an important part of education"
        " from a young age through advanced learning though the rules taught"
        " in schools are not a grammar in the sense most linguists use")

model_name = "out/checkpoints/roberta-probing+semcor/checkpoint-185900"
base_model_name = construct_model_name('roberta')
config = AutoConfig.from_pretrained(model_name, local_files_only=True)
cl_model = SynsetClassificationModel.from_pretrained(
    model_name,
    config=config,
    local_files_only=True,
    model_name=base_model_name,
    num_classes=2584,
    freeze_lm=False,
)
tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True)
embeddings = cl_model.mlmodel.embeddings.word_embeddings

heatmap = TextualHeatmap(facet_titles = ['Roberta'], show_meta=True)
heatmap.set_data([
    compute_textual_saliency(cl_model, embeddings, tokenizer, text),
])

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

[0, 46796, 5, 4828, 892, 9, 33055, 16, 41, 505, 233, 9, 1265, 31, 10, 664, 1046, 149, 3319, 2239, 600, 5, 1492, 5850, 11, 1304, 32, 45, 10, 33055, 11, 5, 1472, 144, 38954, 1952, 304, 2]
tensor(-90103.6953, grad_fn=<SumBackward0>)
tensor([0.2899, 0.6737, 0.4252, 0.6075, 0.4417, 0.2493, 0.4475, 0.3164, 0.1965,
        0.4759, 0.2452, 0.1722, 0.3359, 0.2107, 0.1845, 0.2607, 0.2184, 0.2045,
        0.2501, 0.2939, 1.0000, 0.3461, 0.2916, 0.2617, 0.1471, 0.2924, 0.1913,
        0.2512, 0.1861, 0.4380, 0.1778, 0.2066, 0.3312, 0.2248, 0.4726, 0.2581,
        0.2823, 0.3032])
tensor(-72087.4453, grad_fn=<SumBackward0>)
tensor([0.1510, 1.0000, 0.2902, 0.4173, 0.2937, 0.0970, 0.1972, 0.1018, 0.0912,
        0.1221, 0.0967, 0.0738, 0.1285, 0.0718, 0.0566, 0.0777, 0.0896, 0.0787,
        0.0761, 0.0816, 0.1282, 0.0699, 0.0890, 0.0698, 0.0478, 0.0930, 0.0586,
        0.0596, 0.0682, 0.0964, 0.0619, 0.0659, 0.0819, 0.0666, 0.1166, 0.0755,
        0.2576, 0.4283])
tensor(-96877.5781, grad_fn=<SumBack

I hope you found this useful, if so please consider sharing/retweeting it.