# Saliency Maps with HuggingFace and TextualHeatmap



As described in [Andreas Madsen's distill paper](https://distill.pub/2019/memorization-in-rnns/), the saliency map is computed by measuring the gradient magnitude of the output w.r.t. the input.

$$
\mathrm{connectivity}(t, \tilde{t}) = \left|\left| \frac{\partial y^{\tilde{t}}_{k}}{\partial x^t} \right|\right|_2
$$

In [1]:
from pathlib import Path
import torch
import torch.nn.functional as F

from transformers import AutoConfig, AutoTokenizer
from textualheatmap import TextualHeatmap
from datasets import Dataset

from datagen.dataset import SemCorDataSet
from modelling.model import SynsetClassificationModel


In [2]:
import json

def get_key(idx, sense_keys):
    keys = sense_keys.loc[sense_keys['sense-key-idx'] == idx]['sense-key1'].values
    if keys.size == 0:
        # temporary till labels of test are fixed
        return "NONE"
    return keys[0]

def compute_textual_saliency(model, tokenizer, text, sense_keys):
    token_ids = tokenizer.encode(text, add_special_tokens=True)
    embeddings = model.mlmodel.embeddings.word_embeddings
    vocab_size = embeddings.num_embeddings

    heatmap_data = []

    for masked_token_index in range(len(token_ids)):
        # print(f'processing token {masked_token_index + 1} / {len(token_ids)}')

        if masked_token_index == 0:
            heatmap_data.append({
                'token': '[CLR]',
                'meta': ['', '', ''],
                'heat': [1] + [0] * (len(token_ids) - 1)
            })
            heatmap_data.append({
                'token': ' ',
                'format': True
            })
        elif masked_token_index == len(token_ids) - 1:
            heatmap_data.append({
                'token': ' ',
                'format': True
            })
            heatmap_data.append({
                'token': '[SEP]',
                'meta': ['', '', ''],
                'heat': [0] * (len(token_ids) - 1) + [1]
            })

        else:
            # Get the actual token
            target_token = tokenizer.convert_ids_to_tokens(
                token_ids[masked_token_index], skip_special_tokens=True)
                            
            if target_token[0] == 'Ġ':
                target_token = target_token[1:]
                heatmap_data.append({
                    'token': ' ',
                    'format': True
                })

            # integers are not differentable, so use a one-hot encoding
            # of the intput
            token_ids_tensor = torch.tensor(token_ids, dtype=torch.int64)
            token_ids_tensor[masked_token_index] = tokenizer.mask_token_id
            token_ids_tensor_one_hot = F.one_hot(token_ids_tensor, vocab_size).float()
            token_ids_tensor_one_hot.requires_grad = True

            # To select, the correct output witch is what the importance
            # measure targets, create a masking tensor. tf.gather_nd could also
            # be used, but this is easier.
            output_mask = torch.zeros((1, len(token_ids), model.num_classes))

            # Compute gradient of the logits of the correct target, w.r.t. the
            # input
            inputs_embeds = torch.matmul(token_ids_tensor_one_hot, embeddings.weight)
            dummy = torch.full_like(token_ids_tensor, -100)
            output = model(**{"inputs_embeds": inputs_embeds.unsqueeze(dim=0), "labels": dummy, "sense-labels": dummy})
            logits = output.logits

            # Get the top-3 predictions
            (_, top_3_indices) = torch.topk(logits[masked_token_index, :], 3)

            # !!! contrary the original impl. we do not look at grads for gold labels, instead for top pred !!!
            output_mask[0, masked_token_index, top_3_indices[0]] = 1
            predict_mask_correct_token = torch.sum(logits * output_mask)
            predict_mask_correct_token.backward()

            # compute the connectivity
            connectivity_non_normalized = torch.norm(token_ids_tensor_one_hot.grad, dim=1)
            connectivity_tensor = (
                connectivity_non_normalized /
                torch.max(connectivity_non_normalized)
            )

            connectivity = connectivity_tensor.numpy().tolist()
            assert(len(connectivity) == len(token_ids))
            model.zero_grad()
            pred_keys = [get_key(idx, sense_keys) for idx in top_3_indices.tolist()]

            heatmap_data.append({
                'token': target_token,
                'meta': pred_keys,
                'heat': connectivity
            })

    return heatmap_data

With this implementation, it is now possible to compare different BERT-like models. In theory any model can be compared, as long as the tokenization is the same. In this case the BERT and DistillBERT models are very similar, which is what we would expect and want.

In [3]:
def construct_model_name(hf_model: str):
    if "bert-wwm" in hf_model:
        model_name = "bert-large-uncased-whole-word-masking"
    elif "roberta" in hf_model:
        model_name = "roberta-base"
    else:
        assert "deberta" in hf_model
        model_name = "microsoft/deberta-base"
    return model_name

def load_model(model_path: str):
    base_model_name = construct_model_name(model_path)
    config = AutoConfig.from_pretrained(model_path, local_files_only=True)
    cl_model = SynsetClassificationModel.from_pretrained(
        model_path,
        config=config,
        local_files_only=True,
        model_name=base_model_name,
        num_classes=2584,
        freeze_lm=False,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)

    return cl_model, tokenizer

def print_saliency(text, model, tokenizer, model_display_name):
    data_path = Path("./dataset/roberta+semeval2013.pickle")
    ds = SemCorDataSet.unpickle(data_path.with_suffix(".pickle"))
    keys = ds.all_sense_keys

    heatmap = TextualHeatmap(facet_titles = [model_display_name], show_meta=True)
    heatmap.set_data([
        compute_textual_saliency(model, tokenizer, text, keys),
    ])

### Roberta Probing

In [4]:
roberta_model, roberta_tokenizer = load_model("out/checkpoints/roberta-probing+semcor/checkpoint-185900")

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
print_saliency("This thing finally works!", roberta_model, roberta_tokenizer, "Roberta-Probing")

<IPython.core.display.Javascript object>