## Integrated gradients

Reference:

https://captum.ai/tutorials/Bert_SQUAD_Interpret

In [2]:
import numpy as np
import pandas as pd
# import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

from transformers import BertTokenizer, BertConfig
from model import RegBertForQA 

from captum.attr import visualization as viz
from captum.attr import LayerConductance, LayerIntegratedGradients

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [4]:
model_path = 'fine_tuned_model_registers_Nov17'


# load model
model = RegBertForQA.from_pretrained(model_path)
model.to(device)
model.eval()
model.zero_grad()

# load tokenizer
tokenizer = BertTokenizer.from_pretrained(model_path)

from regbertfor QA, num_reg= 50


Some weights of RegBert were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.embeddings.position_ids', 'bert.reg_pos', 'bert.reg_tokens']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
def predict(inputs, token_type_ids=None, position_ids=None, attention_mask=None):
    output = model(inputs, token_type_ids=token_type_ids,
                 position_ids=position_ids, attention_mask=attention_mask, )
    return output.start_logits, output.end_logits

In [6]:
def squad_pos_forward_func(inputs, token_type_ids=None, position_ids=None, attention_mask=None, position=0):
    pred = predict(inputs,
                   token_type_ids=token_type_ids,
                   position_ids=position_ids,
                   attention_mask=attention_mask)
    pred = pred[position]
    return pred.max(1).values

In [7]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence


In [8]:
def construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id):
    question_ids = tokenizer.encode(question, add_special_tokens=False)
    text_ids = tokenizer.encode(text, add_special_tokens=False)

    # construct input token ids
    input_ids = [cls_token_id] + question_ids + [sep_token_id] + text_ids + [sep_token_id]

    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(question_ids) + [sep_token_id] + \
        [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(question_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

def construct_whole_bert_embeddings(input_ids, ref_input_ids, \
                                    token_type_ids=None, ref_token_type_ids=None, \
                                    position_ids=None, ref_position_ids=None):
    input_embeddings = model.bert.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
    ref_input_embeddings = model.bert.embeddings(ref_input_ids, token_type_ids=ref_token_type_ids, position_ids=ref_position_ids)
    
    return input_embeddings, ref_input_embeddings

In [9]:
question, text = "How long is the Great Wall of China?", \
    "The Great Wall of China is an ancient wall in China. The wall was built to protect the northern borders of the Chinese Empire from invading nomadic tribes. The Great Wall stretches from Dandong in the east to Lop Lake in the west, covering about 13,000 miles."


In [10]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

In [11]:
ground_truth = 'about 13,000 miles.'

ground_truth_tokens = tokenizer.encode(ground_truth, add_special_tokens=False)
ground_truth_end_ind = indices.index(ground_truth_tokens[-1])
ground_truth_start_ind = ground_truth_end_ind - len(ground_truth_tokens) + 1

In [12]:
start_scores, end_scores = predict(input_ids, \
                                   token_type_ids=token_type_ids, \
                                   position_ids=position_ids, \
                                   attention_mask=attention_mask)


print('Question: ', question)
print('Predicted Answer: ', ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]))

Calculating here ...
inputs_embeds.shape:  torch.Size([1, 67, 768])
input_ids.shape:  torch.Size([1, 67])
position_embeddings.shape:  torch.Size([1, 67, 768])
token_type_embeddings.shape:  torch.Size([1, 67, 768])
Question:  How long is the Great Wall of China?
Predicted Answer:  13 , 000 miles


In [13]:
lig = LayerIntegratedGradients(squad_pos_forward_func, model.bert.embeddings)

attributions_start, delta_start = lig.attribute(inputs=input_ids,
                                  baselines=ref_input_ids,
                                  additional_forward_args=(token_type_ids, position_ids, attention_mask, 0),
                                  return_convergence_delta=True)
attributions_end, delta_end = lig.attribute(inputs=input_ids, baselines=ref_input_ids,
                                additional_forward_args=(token_type_ids, position_ids, attention_mask, 1),
                                return_convergence_delta=True)

Calculating here ...
inputs_embeds.shape:  torch.Size([1, 67, 768])
input_ids.shape:  torch.Size([1, 67])
position_embeddings.shape:  torch.Size([1, 67, 768])
token_type_embeddings.shape:  torch.Size([1, 67, 768])
Calculating here ...
inputs_embeds.shape:  torch.Size([1, 67, 768])
input_ids.shape:  torch.Size([1, 67])
position_embeddings.shape:  torch.Size([1, 67, 768])
token_type_embeddings.shape:  torch.Size([1, 67, 768])
Calculating here ...
inputs_embeds.shape:  torch.Size([50, 67, 768])
input_ids.shape:  torch.Size([50, 67])
position_embeddings.shape:  torch.Size([50, 67, 768])
token_type_embeddings.shape:  torch.Size([50, 67, 768])
Calculating here ...
inputs_embeds.shape:  torch.Size([1, 67, 768])
input_ids.shape:  torch.Size([1, 67])
position_embeddings.shape:  torch.Size([1, 67, 768])
token_type_embeddings.shape:  torch.Size([1, 67, 768])
Calculating here ...
inputs_embeds.shape:  torch.Size([1, 67, 768])
input_ids.shape:  torch.Size([1, 67])
position_embeddings.shape:  torch.

In [14]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [15]:
attributions_start_sum = summarize_attributions(attributions_start)
attributions_end_sum = summarize_attributions(attributions_end)

In [16]:
# storing couple samples in an array for visualization purposes
start_position_vis = viz.VisualizationDataRecord(
                        attributions_start_sum,
                        torch.max(torch.softmax(start_scores[0], dim=0)),
                        torch.argmax(start_scores),
                        torch.argmax(start_scores),
                        str(ground_truth_start_ind),
                        attributions_start_sum.sum(),       
                        all_tokens,
                        delta_start)

end_position_vis = viz.VisualizationDataRecord(
                        attributions_end_sum,
                        torch.max(torch.softmax(end_scores[0], dim=0)),
                        torch.argmax(end_scores),
                        torch.argmax(end_scores),
                        str(ground_truth_end_ind),
                        attributions_end_sum.sum(),       
                        all_tokens,
                        delta_end)

print('\033[1m', 'Visualizations For Start Position', '\033[0m')
viz.visualize_text([start_position_vis])

print('\033[1m', 'Visualizations For End Position', '\033[0m')
viz.visualize_text([end_position_vis])

[1m Visualizations For Start Position [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
61.0,61 (0.85),17.0,1.98,"[CLS] how long is the great wall of china ? [SEP] the great wall of china is an ancient wall in china . the wall was built to protect the northern borders of the chinese empire from invading nomadic tribes . the great wall stretches from dan ##dong in the east to lo ##p lake in the west , covering about 13 , 000 miles . [SEP]"
,,,,


[1m Visualizations For End Position [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
64.0,64 (0.55),22.0,2.04,"[CLS] how long is the great wall of china ? [SEP] the great wall of china is an ancient wall in china . the wall was built to protect the northern borders of the chinese empire from invading nomadic tribes . the great wall stretches from dan ##dong in the east to lo ##p lake in the west , covering about 13 , 000 miles . [SEP]"
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
64.0,64 (0.55),22.0,2.04,"[CLS] how long is the great wall of china ? [SEP] the great wall of china is an ancient wall in china . the wall was built to protect the northern borders of the chinese empire from invading nomadic tribes . the great wall stretches from dan ##dong in the east to lo ##p lake in the west , covering about 13 , 000 miles . [SEP]"
,,,,
