# Interpreting BERT Models (Part 1)

In this notebook we demonstrate how to interpret Bert models using  `Captum` library. In this particular case study we focus on a fine-tuned Question Answering model on SQUAD dataset using transformers library from Hugging Face: https://huggingface.co/transformers/

We show how to use interpretation hooks to examine and better understand embeddings, sub-embeddings, bert, and attention layers. 

Note: Before running this tutorial, please install `seaborn`, `pandas` and `matplotlib`, `transformers`(from hugging face, tested on transformer version `4.3.0.dev0`) python packages.

In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

from transformers import BertTokenizer, BertForQuestionAnswering, BertConfig

from captum.attr import visualization as viz
from captum.attr import LayerConductance, LayerIntegratedGradients

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


The first step is to fine-tune BERT model on SQUAD dataset. This can be easiy accomplished by following the steps described in hugging face's official web site: https://github.com/huggingface/transformers#run_squadpy-fine-tuning-on-squad-for-question-answering 

Note that the fine-tuning is done on a `bert-base-uncased` pre-trained model.

After we pretrain the model, we can load the tokenizer and pre-trained BERT model using the commands described below. 

In [3]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ['KMP_DUPLICATE_LIB_OK']='True'

In [4]:

# replace <PATd:/spofrte/modeH-TO-SAVED-MODEL> with the real path of the saved model
model_path = 'bert-large-uncased-whole-word-masking-finetuned-squad'

# load model
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
model.to(device)
model.eval()
model.zero_grad()

# load tokenizer
tokenizer = BertTokenizer.from_pretrained(model_path)

A helper function to perform forward pass of the model and make predictions.

In [5]:
def predict(inputs, token_type_ids=None, position_ids=None, attention_mask=None):
    output = model(inputs, token_type_ids=token_type_ids,
                 position_ids=position_ids, attention_mask=attention_mask, )
    return output.start_logits, output.end_logits

Defining a custom forward function that will allow us to access the start and end postitions of our prediction using `position` input argument.

In [6]:
def squad_pos_forward_func(inputs, token_type_ids=None, position_ids=None, attention_mask=None, position=0):
    pred = predict(inputs,
                   token_type_ids=token_type_ids,
                   position_ids=position_ids,
                   attention_mask=attention_mask)
    pred = pred[position]
    return pred.max(1).values


Let's compute attributions with respect to the `BertEmbeddings` layer.

To do so, we need to define baselines / references, numericalize both the baselines and the inputs. We will define helper functions to achieve that.

The cell below defines numericalized special tokens that will be later used for constructing inputs and corresponding baselines/references.

In [7]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

Below we define a set of helper function for constructing references / baselines for word tokens, token types and position ids. We also provide separate helper functions that allow to construct attention masks and bert embeddings both for input and reference.

In [8]:
def construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id):
    question_ids = tokenizer.encode(question, add_special_tokens=False)
    text_ids = tokenizer.encode(text, add_special_tokens=False)

    # construct input token ids
    input_ids = [cls_token_id] + question_ids + [sep_token_id] + text_ids + [sep_token_id]

    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(question_ids) + [sep_token_id] + \
        [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(question_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

def construct_whole_bert_embeddings(input_ids, ref_input_ids, \
                                    token_type_ids=None, ref_token_type_ids=None, \
                                    position_ids=None, ref_position_ids=None):
    input_embeddings = model.bert.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
    ref_input_embeddings = model.bert.embeddings(ref_input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
    
    return input_embeddings, ref_input_embeddings


Let's define the `question - text` pair that we'd like to use as an input for our Bert model and interpret what the model was forcusing on when predicting an answer to the question from given input text 

## original text: 
At the age of just 13, Japan's Momiji Nishiya made history on Monday by winning the first-ever Olympic gold medal in women's street skateboarding at the Games in Tokyo.

Nishiya topped a youthful podium with Rayssa Leal of Brazil, also 13, taking the silver medal and Japan's Funa Nakayama, 16, winning bronze. With an average age of 14 years and 191 days it is the youngest individual podium in the history of the Olympic Games.

In [9]:
question, text = "who made the history", "At the age of just 13, Japan's Momiji Nishiya made history on Monday by winning the first-ever Olympic gold medal in women's street skateboarding at the Games in Tokyo.  Nishiya topped a youthful podium with Rayssa Leal of Brazil, also 13, taking the silver medal and Japan's Funa Nakayama, 16, winning bronze. With an average age of 14 years and 191 days it is the youngest individual podium in the history of the Olympic Games."

Let's numericalize the question, the input text and generate corresponding baselines / references for all three sub-embeddings (word, token type and position embeddings) types using our helper functions defined above.

In [10]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

Also, let's define the ground truth for prediction's start and end positions.

In [11]:
ground_truth = 'Momiji Nishiya'

ground_truth_tokens = tokenizer.encode(ground_truth, add_special_tokens=False)
ground_truth_end_ind = indices.index(ground_truth_tokens[-1])
ground_truth_start_ind = ground_truth_end_ind - len(ground_truth_tokens) + 1

Now let's make predictions using input, token type, position id and a default attention mask.

In [12]:
start_scores, end_scores = predict(input_ids, \
                                   token_type_ids=token_type_ids, \
                                   position_ids=position_ids, \
                                   attention_mask=attention_mask)


print('Question: ', question)
print('Predicted Answer: ', ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]))

Question:  who made the history
Predicted Answer:  mom ##iji ni ##shi ##ya


There are two different ways of computing the attributions for emebdding layers. One option is to use `LayerIntegratedGradients` and compute the attributions with respect to `BertEmbedding`. The second option is to use `LayerIntegratedGradients` for each `word_embeddings`, `token_type_embeddings` and `position_embeddings` and compute the attributions w.r.t each embedding vector.


In [13]:
#torch.cuda.empty_cache()
lig = LayerIntegratedGradients(squad_pos_forward_func, model.bert.embeddings)

In [14]:

attributions_start, delta_start = lig.attribute(inputs=input_ids,
                                  baselines=ref_input_ids,
                                  additional_forward_args=(token_type_ids, position_ids, attention_mask, 0),
                                  internal_batch_size=4,
                                  return_convergence_delta=True)


In [15]:

attributions_end, delta_end = lig.attribute(inputs=input_ids, baselines=ref_input_ids,
                                additional_forward_args=(token_type_ids, position_ids, attention_mask, 1),
                                            internal_batch_size=4,
                                return_convergence_delta=True)

A helper function to summarize attributions for each word token in the sequence.

In [16]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [17]:
attributions_start_sum = summarize_attributions(attributions_start)
attributions_end_sum = summarize_attributions(attributions_end)

In [18]:
# storing couple samples in an array for visualization purposes
start_position_vis = viz.VisualizationDataRecord(
                        attributions_start_sum,
                        torch.max(torch.softmax(start_scores[0], dim=0)),
                        torch.argmax(start_scores),
                        torch.argmax(start_scores),
                        str(ground_truth_start_ind),
                        attributions_start_sum.sum(),       
                        all_tokens,
                        delta_start)

end_position_vis = viz.VisualizationDataRecord(
                        attributions_end_sum,
                        torch.max(torch.softmax(end_scores[0], dim=0)),
                        torch.argmax(end_scores),
                        torch.argmax(end_scores),
                        str(ground_truth_end_ind),
                        attributions_end_sum.sum(),       
                        all_tokens,
                        delta_end)
print(all_tokens)
print('\033[1m', 'Visualizations For Start Position', '\033[0m')
viz.visualize_text([start_position_vis])

print('\033[1m', 'Visualizations For End Position', '\033[0m')
viz.visualize_text([end_position_vis])

['[CLS]', 'who', 'made', 'the', 'history', '[SEP]', 'at', 'the', 'age', 'of', 'just', '13', ',', 'japan', "'", 's', 'mom', '##iji', 'ni', '##shi', '##ya', 'made', 'history', 'on', 'monday', 'by', 'winning', 'the', 'first', '-', 'ever', 'olympic', 'gold', 'medal', 'in', 'women', "'", 's', 'street', 'skate', '##boarding', 'at', 'the', 'games', 'in', 'tokyo', '.', 'ni', '##shi', '##ya', 'topped', 'a', 'youthful', 'podium', 'with', 'rays', '##sa', 'lea', '##l', 'of', 'brazil', ',', 'also', '13', ',', 'taking', 'the', 'silver', 'medal', 'and', 'japan', "'", 's', 'fun', '##a', 'nak', '##aya', '##ma', ',', '16', ',', 'winning', 'bronze', '.', 'with', 'an', 'average', 'age', 'of', '14', 'years', 'and', '191', 'days', 'it', 'is', 'the', 'youngest', 'individual', 'podium', 'in', 'the', 'history', 'of', 'the', 'olympic', 'games', '.', '[SEP]']
[1m Visualizations For Start Position [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
16.0,16 (0.82),16.0,2.4,"[CLS] who made the history [SEP] at the age of just 13 , japan ' s mom ##iji ni ##shi ##ya made history on monday by winning the first - ever olympic gold medal in women ' s street skate ##boarding at the games in tokyo . ni ##shi ##ya topped a youthful podium with rays ##sa lea ##l of brazil , also 13 , taking the silver medal and japan ' s fun ##a nak ##aya ##ma , 16 , winning bronze . with an average age of 14 years and 191 days it is the youngest individual podium in the history of the olympic games . [SEP]"
,,,,


[1m Visualizations For End Position [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
20.0,20 (0.93),20.0,2.74,"[CLS] who made the history [SEP] at the age of just 13 , japan ' s mom ##iji ni ##shi ##ya made history on monday by winning the first - ever olympic gold medal in women ' s street skate ##boarding at the games in tokyo . ni ##shi ##ya topped a youthful podium with rays ##sa lea ##l of brazil , also 13 , taking the silver medal and japan ' s fun ##a nak ##aya ##ma , 16 , winning bronze . with an average age of 14 years and 191 days it is the youngest individual podium in the history of the olympic games . [SEP]"
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
20.0,20 (0.93),20.0,2.74,"[CLS] who made the history [SEP] at the age of just 13 , japan ' s mom ##iji ni ##shi ##ya made history on monday by winning the first - ever olympic gold medal in women ' s street skate ##boarding at the games in tokyo . ni ##shi ##ya topped a youthful podium with rays ##sa lea ##l of brazil , also 13 , taking the silver medal and japan ' s fun ##a nak ##aya ##ma , 16 , winning bronze . with an average age of 14 years and 191 days it is the youngest individual podium in the history of the olympic games . [SEP]"
,,,,


# 获取所有的积极的文字
还有消极的文字

In [None]:
positive=[]
negative=[]
for i,j in enumerate(attributions_start_sum):
    if j>0:
        positive.append(i)
        #print('positive:',j)
##print(all_tokens[i])
    elif j < 0:
        negative.append(i)
        #print('negative:',j)
        #print(all_tokens[i])
        
s_pos=''
s_neg=''

#print(len(attributions_start_sum))
#print(len(positive))
#print(len(negative))

for i in positive:
    s_pos+=all_tokens[i]+' '
print("positive :" , s_pos)
for i in negative:
    s_neg+=all_tokens[i]+' '
print("negative :" , s_neg)

# 重新进行预测

In [None]:
question = "who made the history"
pos_text = s_pos
neg_text = s_neg
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(question, pos_text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

ground_truth = 'iji ni ##shi ##ya'

ground_truth_tokens = tokenizer.encode(ground_truth, add_special_tokens=False)
ground_truth_end_ind = indices.index(ground_truth_tokens[-1])
ground_truth_start_ind = ground_truth_end_ind - len(ground_truth_tokens) + 1

start_scores, end_scores = predict(input_ids, \
                                   token_type_ids=token_type_ids, \
                                   position_ids=position_ids, \
                                   attention_mask=attention_mask)


print('Question: ', question)
print('Predicted Answer: ', ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]))

In [None]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(question, pos_text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

start_scores, end_scores = predict(input_ids, \
                                   token_type_ids=token_type_ids, \
                                   position_ids=position_ids, \
                                   attention_mask=attention_mask)


print('Question: ', question)
print('Predicted Answer: ', ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]))

# 对积极的问题进行分析

In [None]:
lig = LayerIntegratedGradients(squad_pos_forward_func, model.bert.embeddings)

attributions_start, delta_start = lig.attribute(inputs=input_ids,
                                  baselines=ref_input_ids,
                                  additional_forward_args=(token_type_ids, position_ids, attention_mask, 0),
                                  return_convergence_delta=True)
attributions_end, delta_end = lig.attribute(inputs=input_ids, baselines=ref_input_ids,
                                additional_forward_args=(token_type_ids, position_ids, attention_mask, 1),
                                return_convergence_delta=True)

print(all_tokens)
print('\033[1m', 'Visualizations For Start Position', '\033[0m')
viz.visualize_text([start_position_vis])

print('\033[1m', 'Visualizations For End Position', '\033[0m')
viz.visualize_text([end_position_vis])

# 消极预测

In [None]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(question, neg_text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

ground_truth = ''

ground_truth_tokens = tokenizer.encode(ground_truth, add_special_tokens=False)
ground_truth_end_ind = indices.index(ground_truth_tokens[-1])
ground_truth_start_ind = ground_truth_end_ind - len(ground_truth_tokens) + 1

start_scores, end_scores = predict(input_ids, \
                                   token_type_ids=token_type_ids, \
                                   position_ids=position_ids, \
                                   attention_mask=attention_mask)


print('Question: ', question)
print('Predicted Answer: ', ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]))

In [None]:
print("Question:  who made the history")
print("Predicted Answer:  ")