<a href="https://colab.research.google.com/github/NataliaDiaz/XAI-tutorials/blob/main/Bert_for_Sequence_classification_Interpretation_in_Captum.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Interpretation of BertForSequenceClassification in captum

In this notebook we use Captum to interpret a BERT sentiment classifier finetuned on the imdb dataset https://huggingface.co/lvwerra/bert-imdb 

In [None]:
# Install dependencies
!pip install transformers
!pip install captum



In [None]:
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig
from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer
import torch
import matplotlib.pyplot as plt

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# Get model and config files from https://huggingface.co/lvwerra/bert-imdb
!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/config.json
!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/pytorch_model.bin
!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/special_tokens_map.json
!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/tokenizer_config.json
!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/training_args.bin
!wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/vocab.txt

--2020-04-13 08:27:20--  https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/config.json
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.216.131.101
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.216.131.101|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1220 (1.2K) [application/json]
Saving to: ‘./model/config.json.2’


2020-04-13 08:27:20 (80.5 MB/s) - ‘./model/config.json.2’ saved [1220/1220]

--2020-04-13 08:27:21--  https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/pytorch_model.bin
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.216.131.101
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.216.131.101|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1334420863 (1.2G) [application/octet-stream]
Saving to: ‘./model/pytorch_model.bin.2’


2020-04-13 08:27:47 (49.3 MB/s) - ‘./model/pytorch_model.bin.2’ saved [1334420863/1334420863]

--2020-04-13 08:27:48--  https://s3.amazonaws.

In [None]:
# load model
model = BertForSequenceClassification.from_pretrained('./model')
model.to(device)
model.eval()
model.zero_grad()

# load tokenizer
tokenizer = BertTokenizer.from_pretrained('./model')

In [None]:
def predict(inputs):
    return model(inputs)[0]

In [None]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

In [None]:
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

    text_ids = tokenizer.encode(text, add_special_tokens=False)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

In [None]:
def custom_forward(inputs):
    preds = predict(inputs)
    return torch.softmax(preds, dim = 1)[0][1].unsqueeze(-1)

In [None]:
lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)

In [None]:
# One can test a couple of examples and check that the sentiment classifier is behaving
text = "The movie was one of those amazing movies"
#text = "The movie was one of those crappy movies you can't forget."

In [None]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

In [None]:
# Check predict output
predict(input_ids)

tensor([[-3.1333,  3.6520]], device='cuda:0', grad_fn=<AddmmBackward>)

In [None]:
# Check output of custom_forward
custom_forward(input_ids)

tensor([0.9989], device='cuda:0', grad_fn=<UnsqueezeBackward0>)

In [None]:
attributions, delta = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    n_steps=700,
                                    internal_batch_size=3,
                                    return_convergence_delta=True)

In [None]:
score = predict(input_ids)

print('Sentence: ', text)
print('Sentiment: ' + str(torch.argmax(score[0]).cpu().numpy()) + \
      ', Probability positive: ' + str(torch.softmax(score, dim = 1)[0][1].cpu().detach().numpy()))

Sentence:  The movie was one of those amazing movies
Sentiment: 1, Probability positive: 0.9988709


In [None]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [None]:
attributions_sum = summarize_attributions(attributions)

In [None]:
# storing couple samples in an array for visualization purposes
score_vis = viz.VisualizationDataRecord(attributions_sum,
                                        torch.softmax(score, dim = 1)[0][1],
                                        torch.argmax(torch.softmax(score, dim = 0)[0]),
                                        1,
                                        text,
                                        attributions_sum.sum(),       
                                        all_tokens,
                                        delta)


In [None]:
print('\033[1m', 'Visualization For Score', '\033[0m')
viz.visualize_text([score_vis])

[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (1.00),The movie was one of those amazing movies,0.26,[CLS] The movie was one of those amazing movies [SEP]
,,,,


The visualization is clearly meaningless! :(
