# Saliency Maps with HuggingFace and TextualHeatmap

This notebook implements the saliency map as described in [Andreas Madsen's distill paper](https://distill.pub/2019/memorization-in-rnns/). However, it apply the method on BERT models rather than RNN models.

The visualization therefore describes which words/sub-words were important for infering a masked word/sub-word.

* TextualHeatmap: https://github.com/AndreasMadsen/python-textualheatmap
* HuggingFace Transformers: https://github.com/huggingface/transformers

First install TensorFlow, TextualHeatmap, and Transformers. In this notebook the TensorFlow implementations in `transfomers` are used, however this could also be applyed to the PyTorch implementation.

In [0]:
%tensorflow_version 2.x
!pip install textualheatmap transformers

As described in [Andreas Madsen's distill paper](https://distill.pub/2019/memorization-in-rnns/), the saliency map is computed by measuring the gradient magnitude of the output w.r.t. the input.

$$
\mathrm{connectivity}(t, \tilde{t}) = \left|\left| \frac{\partial y^{\tilde{t}}_{k}}{\partial x^t} \right|\right|_2
$$

Implementation wise this can be done quite easily with `tf.GradientTape`. However, because the gradient can not be take w.r.t. to an `int32` type, which is how the `token_ids` are encoded, an one-hot-encoding should be used instead. HuggingFace Transformers supports this via `inputs_embeds` which is the actual input-word-embedding, thus by computing $\mathbf{x} \mathbf{W}$ in the `tf.GradientTape` scope the gradient w.r.t. $\mathbf{x}$ can be computed.

In [0]:
import numpy as np
import tensorflow as tf

def compute_textual_saliency(model, embedding_matrix, tokenizer, text):
    token_ids = tokenizer.encode(text, add_special_tokens=True)
    vocab_size = embedding_matrix.get_shape()[0]

    heatmap_data = []

    for masked_token_index in range(len(token_ids)):
        # print(f'processing token {masked_token_index + 1} / {len(token_ids)}')

        if masked_token_index == 0:
            heatmap_data.append({
                'token': '[CLR]',
                'meta': ['', '', ''],
                'heat': [1] + [0] * (len(token_ids) - 1)
            })
        elif masked_token_index == len(token_ids) - 1:
            heatmap_data.append({
                'token': ' ',
                'format': True
            })
            heatmap_data.append({
                'token': '[SEP]',
                'meta': ['', '', ''],
                'heat': [0] * (len(token_ids) - 1) + [1]
            })
        else:
            # Get the actual token
            target_token = tokenizer.convert_ids_to_tokens(
                token_ids[masked_token_index])

            if target_token[0:2] == '##':
                target_token = target_token[2:]
            else:
                heatmap_data.append({
                    'token': ' ',
                    'format': True
                })

            # integers are not differentable, so use a one-hot encoding
            # of the intput
            token_ids_tensor = tf.constant([
                token_ids[0:masked_token_index] +
                [tokenizer.mask_token_id] +
                token_ids[masked_token_index + 1:]
            ], dtype='int32')
            token_ids_tensor_one_hot = tf.one_hot(token_ids_tensor, vocab_size)

            # To select, the correct output witch is what the importance
            # measure targets, create a masking tensor. tf.gather_nd could also
            # be used, but this is easier.
            output_mask = np.zeros((1, len(token_ids), vocab_size))
            output_mask[0, masked_token_index, token_ids[masked_token_index]] = 1
            output_mask_tensor = tf.constant(output_mask, dtype='float32')

            # Compute gradient of the logits of the correct target, w.r.t. the
            # input
            with tf.GradientTape(watch_accessed_variables=False) as tape:
                tape.watch(token_ids_tensor_one_hot)
                inputs_embeds = tf.matmul(token_ids_tensor_one_hot,embedding_matrix)
                predict, = model({"inputs_embeds": inputs_embeds})
                predict_mask_correct_token = tf.reduce_sum(predict * output_mask_tensor)

            # Get the top-3 predictions
            (_, top_3_indices) = tf.math.top_k(predict[0, masked_token_index, :], 3)
            top_3_predicted_tokens = tokenizer.convert_ids_to_tokens(top_3_indices)

            # compute the connectivity
            connectivity_non_normalized = tf.norm(
                tape.gradient(predict_mask_correct_token, token_ids_tensor_one_hot),
                axis=2)
            connectivity_tensor = (
                connectivity_non_normalized /
                tf.reduce_max(connectivity_non_normalized)
            )
            connectivity = connectivity_tensor[0].numpy().tolist()

            heatmap_data.append({
                'token': target_token,
                'meta': top_3_predicted_tokens,
                'heat': connectivity
            })

    return heatmap_data

With this implementation, it is now possible to compare different BERT-like models. In theory any model can be compared, as long as the tokenization is the same. In this case the BERT and DistillBERT models are very similar, which is what we would expect and want.

In [0]:
text = ("context the formal study of grammar is an important part of education"
        " from a young age through advanced learning though the rules taught"
        " in schools are not a grammar in the sense most linguists use")

from transformers import TFDistilBertForMaskedLM, DistilBertTokenizer
dbert_model = TFDistilBertForMaskedLM.from_pretrained('distilbert-base-uncased')
dbert_tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
dbert_embmat = dbert_model.distilbert.embeddings.word_embeddings

from transformers import TFBertForMaskedLM, BertTokenizer
bert_model = TFBertForMaskedLM.from_pretrained('bert-base-uncased')
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_embmat = bert_model.bert.embeddings.word_embeddings

from textualheatmap import TextualHeatmap
heatmap = TextualHeatmap(facet_titles = ['BERT', 'Distil BERT'], show_meta=True)
heatmap.set_data([
    compute_textual_saliency(bert_model, bert_embmat, bert_tokenizer, text),
    compute_textual_saliency(dbert_model, dbert_embmat, dbert_tokenizer, text)
])

<IPython.core.display.Javascript object>

I hope you found this useful, if so please consider sharing/retweeting it.


In [5]:
#@title
%%html
<blockquote class="twitter-tweet"><p lang="en" dir="ltr">Using <a href="https://twitter.com/huggingface?ref_src=twsrc%5Etfw">@huggingface</a> Transformer with TextualHeatmap to make an interactive saliency map in Google Colab. Colab link: <a href="https://t.co/76uZpyYdjE">https://t.co/76uZpyYdjE</a> <a href="https://t.co/15W3MhfxZ3">pic.twitter.com/15W3MhfxZ3</a></p>&mdash; Andreas Madsen (@andreas_madsen) <a href="https://twitter.com/andreas_madsen/status/1243159481372082182?ref_src=twsrc%5Etfw">March 26, 2020</a></blockquote> <script async src="https://platform.twitter.com/widgets.js" charset="utf-8"></script>