### Research course submission

**Name:** Abhijith Sreesylesh Babu

**Paper:** Self-Attention Attribution: Interpreting Information Interactions Inside Transformer

# Interpreting Information Interactions Inside Transformer

In this paper, the authors study how the interaction between the tokens in a text input are affecting the predictions of a transformer model

### Transformer models

Transformer models are the most used models in the modern language models. They have the ability to understand the relationship between the input tokens using attention mechanism. Each layer of a transformer has multiple attention heads, each head containing a self attention matrix

### Attribution in transformer models

In a transformer model, every attention head is a (n x n) matrix where n is the number of tokens in the batch of input. This matrix shows the attention between all pairs of words in the input batch. The change in output of a model with respect to a change in its attention head gives the attribution of each attention.

The attribution can be found by calculating gradients from backpropogation. Since integrated gradients are known to be a good method of attribution in sequential models, here we use similar methods to find the attribution of the attention head.

By finding the attribution of each attention in attention head, we can understand the word interactions that contributed the most to the output

In [1]:
# importing necessary packages

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import os
import contextlib
from tqdm import tqdm
from IPython.core.display import display, HTML
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm
  from IPython.core.display import display, HTML


### Model used in this experiment

This paper is implemented in BERT model which is a transformer model that can be fine tuned to do various tasks. For simplicity, I used a model similar to BERT, namely bart-large-mnli which is already trained to do MNLI (multi-genre natural language inference). The bart model has 12 layers, with each layer having 16 attention heads.

In [19]:
# Loading the model and tokenizer

model_name = "facebook/bart-large-mnli"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, output_attentions=True)

# Setting the model to evaluation mode
with open(os.devnull, 'w') as fnull:
    with contextlib.redirect_stdout(fnull):
        model.eval()

### bart-large-mnli model

The model takes 2 strings as parameters, and says whether the second string entails the first string, contradicts it, or is not related to the first string at all.

Here I am taking an example where the second sentence follows the first sentence.

In [3]:
# Finding the class labels of the model

print(model.config.id2label)

{0: 'contradiction', 1: 'neutral', 2: 'entailment'}


In [4]:
# Defining the input texts

premise = "John is very poor in science"
hypothesis = "John failed in physics exam"

# Since our hypothesis is an entailment, setting the target class accordingly

target_class = 2

In [5]:
# Vectorizing the text to pass them as input to the model

input_vectors = tokenizer(premise, hypothesis, return_tensors="pt")
input_vectors = input_vectors["input_ids"]

In [6]:
# defining the function to compute the integrated gradients

def _compute_integrated_gradients(attention_matrix, embeddings, steps=5):

    # Creating a baseline attention matrix
    baseline = torch.zeros_like(attention_matrix)  # No attention baseline

    # Interpolating between the baseline and the actual attention matrix
    interpolated_attn = [(baseline + (float(i) / steps) * (attention_matrix - baseline)) for i in range(steps + 1)]
    gradients = []
    
    # Computing the gradients for each interpolated attention matrix
    for attn_head in interpolated_attn:
        attn_head_ = attn_head.clone()

        # Setting the gradients to be computed
        embeddings.requires_grad_()
        embeddings.retain_grad()
        embeddings.retain_graph = True

        # Forward pass to find loss
        output = model.model.encoder(inputs_embeds=embeddings, attention_mask = attn_head_, return_dict=True)
        class_logits = model.classification_head(output.last_hidden_state[:, 0, :])
        loss = -1*class_logits[0, target_class]
        
        # Backward pass to compute gradients
        loss.backward(retain_graph=True)
        gradients.append(embeddings.grad.mean(dim=2))

    # Integrating the gradients to get the attribution
    avg_gradients = torch.mean(torch.stack(gradients), dim=0)
    attributions = (attention_matrix - baseline) * avg_gradients
    return attributions 

In [7]:
# defining the function to compute the attention attributions
def compute_attention_attributions():

    # generating embeddings of the inputs for forward and backward pass
    embedding_layer = model.model.shared
    embeddings = embedding_layer(input_vectors)
    embeddings.requires_grad_()
    embeddings.retain_grad()
    
    # Forward pass to get the attentions
    output = model.model.encoder(inputs_embeds=embeddings, return_dict=True)
    attentions = output.attentions
        
    # Compute attribution scores for the attentions
    attributions = []
    for layer_attention in attentions:
        for at_head in tqdm(layer_attention[0]):
            at_head = at_head.mean(dim=0) 
            at_head = at_head.unsqueeze(0) 
            head_attr = _compute_integrated_gradients(at_head, embeddings)
            attributions.append(head_attr)

    return attributions
        

In [8]:
attributions = compute_attention_attributions()

100%|██████████| 16/16 [00:36<00:00,  2.27s/it]
100%|██████████| 16/16 [00:39<00:00,  2.49s/it]
100%|██████████| 16/16 [00:42<00:00,  2.63s/it]
100%|██████████| 16/16 [00:42<00:00,  2.68s/it]
100%|██████████| 16/16 [00:39<00:00,  2.48s/it]
100%|██████████| 16/16 [00:47<00:00,  3.00s/it]
100%|██████████| 16/16 [00:48<00:00,  3.05s/it]
100%|██████████| 16/16 [00:50<00:00,  3.18s/it]
100%|██████████| 16/16 [00:49<00:00,  3.11s/it]
100%|██████████| 16/16 [00:45<00:00,  2.86s/it]
100%|██████████| 16/16 [00:48<00:00,  3.00s/it]
100%|██████████| 16/16 [00:49<00:00,  3.08s/it]


In [9]:
#normalizing the attributions for ease of visualization

def normalize_attributions(attributions):
    attributions = np.array(attributions)
    min_attr = attributions.min()
    max_attr = attributions.max()
    normalized = 255 * (attributions - min_attr) / (max_attr - min_attr + 1e-11)
    return normalized.astype(int)


### Visualizing the attributions

Here I visualize the attribution based on color. The more green the word are, they had more interactions with words that contributed to the output of the model

In [10]:
for idx, layer in enumerate(attributions):
    if idx%16 != 0:
        continue
    # layer = layer.sum(dim=1)
    layer = layer.detach().clone().numpy()
    
    norm_layer = normalize_attributions(layer)
    # print(norm_layer)
    html_string = ""
    for index, val in enumerate(norm_layer[0]):
        str_print = tokenizer.decode(input_vectors.numpy()[0][index])
        if str_print[0] == "<":
            continue
        html_string += f"<span style='color:rgba(0,{val}, 0)'> {str_print} </span>"
    # print(html_string)
    display(HTML(html_string)) 

### Masking the attentionheads based on its importance

If the attribution provided by the attention head is smaller, they are making very less impact on the output of the model. So if we remove the attention heads with less importance, that wont make a big impact on accuracy of the model.

Here the importance of attention head is the maximum value of all the attributions inside the attention head. 

In [11]:
#normalizing the importance for ease of visualization

def normalize_importance(imp):
    imp_arr = np.array(imp)
    min_imp = imp_arr.min()
    max_imp = imp_arr.max()
    normalized = 255 * (imp_arr - min_imp) / (max_imp - min_imp + 1e-11)
    return normalized.astype(int)

In [12]:
# Calculating the importance of attention heads and finding the threshold value
importance_scores = []
for attn_head in attributions:
    max_val = max(attn_head[0])
    importance_scores.append(max_val.item())
importance_scores = normalize_importance(importance_scores)
print(importance_scores)
threshold_val = sorted(importance_scores)[int(len(importance_scores)*0.35)]
print(threshold_val)


[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   1   0   1
   0   0   1   0   1   0   1   1   0   3   1   1   1   1   1   1   1   1
   1   2   2   2   1   3   2   3   3   3   3   4   7   7   5   8   5   9
   6   6   6   7   6   7   7   9   9   8  10  10  13  11  11  11  13  10
  16  16  13  14   9  14  20  12  16  18  17  17  18  19  17  17  15  22
  26  23  21  17  20  20  27  23  28  36  21  43  34  42  33  42  35  44
  40  41  34  34  27  22  25  27  23  16  24  35  28  22  17  34  20  28
  49  39  54  23  45  31  32  36  44  55  44  48  49  52  53  24  36  40
  42  43  21  42  26  27  27  27  34  65  32  34  48  45  49  19  46  93
  53  69  60 172  77 113  82  52  41  55  76  89 156  70  35  35  98  60
 146 216 253 136 137  66 218 199 137 208 118 135]
10


In [13]:
attention_importances = [importance_scores[i * 16:(i + 1) * 16] for i in range(12)]

In [14]:
mask = []
threshold = threshold_val
for layer in attention_importances:
    layer_mask = []
    for head in layer:
        head_mask = head > threshold
        head_mask = int(head_mask)
        layer_mask.append(head_mask)
    mask.append(layer_mask)


for i in mask:
    print(i)

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


Pruning the model based on the above mask will lead us to the important attention heads in the model.

In [15]:
import types

def prune_tree(model, mask):
    encoder = model.model.encoder
    layers = encoder.layers
    for idx, layer in enumerate(layers):
        attn_weights = layer.self_attn
        original_forward = attn_weights.forward 

        def new_forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
            outputs = original_forward(hidden_states, attention_mask)
            attention_scores = outputs[0]

            # Apply pruning mask
            attention_scores = attention_scores * mask[idx].unsqueeze(0).unsqueeze(0)

            return (attention_scores,) + outputs[1:]
        layer.self_attn.forward = types.MethodType(new_forward, layer.self_attn)
    return model

In [16]:
pruned_model = prune_tree(model, mask)

### Adversarial Attack

The attention attribution method gives the pair of words that contributes most to the output of the model. If we use the same words in another situation, it can manipulate the model's output.

One example given in the paper is [floods, iowa] and [ice, florida]. These pairs of words contradict each other. If we use them arbitarily in a entailment input, we might be able to manipulate the model

In [17]:
adv_set_1 = ["ice", "florida"]
adv_set_2 = ["floods", "iowa"]

premise_new = f"Titanic {adv_set_1[0]} is a very sad {adv_set_1[1]} movie"
hypothesis_new = f"It {adv_set_2[0]} rains heavily in {adv_set_2[1]} the summer"

In [20]:
# Creating the input vectors for the safe text

input_vectors_new = tokenizer(premise_new, hypothesis_new, return_tensors="pt")
input_vectors_new = input_vectors_new["input_ids"]
embedding_layer_new = model.model.shared
embeddings_new = embedding_layer_new(input_vectors_new)


output = model.model.encoder(inputs_embeds=embeddings_new, return_dict=True)
class_logits = model.classification_head(output.last_hidden_state[:, 0, :])
class_labels = model.config.id2label
for i in range(3):
    print(class_labels[i],class_logits[0][i].item())
    # print(i.item())

contradiction -0.0021725024562329054
neutral -0.03944947198033333
entailment 0.023743169382214546


Here we can see that even though the inputs are neutral, the model gave higher score for contradiction than neutral