## Integrated gradients

Reference:

https://captum.ai/tutorials/Bert_SQUAD_Interpret

In [1]:
import numpy as np
import pandas as pd
# import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

from transformers import AutoTokenizer, BertConfig
from model import RegBertForQA 

from captum.attr import visualization as viz
from captum.attr import LayerConductance, LayerIntegratedGradients

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [3]:
def predict(inputs, token_type_ids=None, position_ids=None, attention_mask=None):
    output = model(inputs, token_type_ids=token_type_ids,
                 position_ids=position_ids, attention_mask=attention_mask, )
    return output.start_logits, output.end_logits

In [4]:
def squad_pos_forward_func(inputs, token_type_ids=None, position_ids=None, attention_mask=None, position=0):
    pred = predict(inputs,
                   token_type_ids=token_type_ids,
                   position_ids=position_ids,
                   attention_mask=attention_mask)
    pred = pred[position]
    return pred.max(1).values

In [5]:
def construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id):
    question_ids = tokenizer.encode(question, add_special_tokens=False)
    text_ids = tokenizer.encode(text, add_special_tokens=False)

    # construct input token ids
    input_ids = [cls_token_id] + question_ids + [sep_token_id] + text_ids + [sep_token_id]

    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(question_ids) + [sep_token_id] + \
        [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(question_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

def construct_whole_bert_embeddings(input_ids, ref_input_ids, \
                                    token_type_ids=None, ref_token_type_ids=None, \
                                    position_ids=None, ref_position_ids=None):
    input_embeddings = model.bert.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
    ref_input_embeddings = model.bert.embeddings(ref_input_ids, token_type_ids=ref_token_type_ids, position_ids=ref_position_ids)
    
    return input_embeddings, ref_input_embeddings

def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [6]:
question, text = "What is a way to increase your wound healing speed?", \
"Wound care encourages and speeds wound healing via cleaning and protection from reinjury or infection. Depending on each patient's needs, it can range from the simplest first aid to entire nursing specialties such as wound, ostomy, and continence nursing and burn center care."


## BERT without registers

In [7]:
# load model
model = RegBertForQA.from_pretrained("fine_tuned_model_orig", num_registers = 0).to("cuda")
model.to(device)
model.eval()
model.zero_grad()

# load tokenizer
tokenizer = AutoTokenizer.from_pretrained("fine_tuned_model_orig")

from regbertfor QA, num_reg= 0


Some weights of RegBert were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.embeddings.position_ids', 'bert.reg_pos', 'bert.reg_tokens']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RegBertForQA were not initialized from the model checkpoint at fine_tuned_model_orig and are newly initialized: ['bert.embeddings.position_ids', 'bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'bert.reg_pos', 'bert.reg_tokens']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
ref_token_id = tokenizer.pad_token_id 
sep_token_id = tokenizer.sep_token_id 
cls_token_id = tokenizer.cls_token_id


In [9]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

In [10]:
ground_truth = 'Wound care'

ground_truth_tokens = tokenizer.encode(ground_truth, add_special_tokens=False)
ground_truth_end_ind = indices.index(ground_truth_tokens[-1])
ground_truth_start_ind = ground_truth_end_ind - len(ground_truth_tokens) + 1

In [11]:
start_scores, end_scores = predict(input_ids, \
                                   token_type_ids=token_type_ids, \
                                   position_ids=position_ids, \
                                   attention_mask=attention_mask)


print('Question: ', question)
print(torch.argmax(start_scores), torch.argmax(end_scores))
print('Predicted Answer: ', ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]))

Question:  What is a way to increase your wound healing speed?
tensor(15, device='cuda:0') tensor(29, device='cuda:0')
Predicted Answer:  encourages and speeds wound healing via cleaning and protection from rein ##ju ##ry or infection


In [12]:
lig = LayerIntegratedGradients(squad_pos_forward_func, model.bert.embeddings)

attributions_start, delta_start = lig.attribute(inputs=input_ids,
                                  baselines=ref_input_ids,
                                  additional_forward_args=(token_type_ids, position_ids, attention_mask, 0),
                                  return_convergence_delta=True)
attributions_end, delta_end = lig.attribute(inputs=input_ids, baselines=ref_input_ids,
                                additional_forward_args=(token_type_ids, position_ids, attention_mask, 1),
                                return_convergence_delta=True)

In [14]:
attributions_start_sum = summarize_attributions(attributions_start)
attributions_end_sum = summarize_attributions(attributions_end)

In [15]:
# storing couple samples in an array for visualization purposes
start_position_vis = viz.VisualizationDataRecord(
                        attributions_start_sum,
                        torch.max(torch.softmax(start_scores[0], dim=0)),
                        torch.argmax(start_scores),
                        torch.argmax(start_scores),
                        str(ground_truth_start_ind),
                        attributions_start_sum.sum(),       
                        all_tokens,
                        delta_start)

end_position_vis = viz.VisualizationDataRecord(
                        attributions_end_sum,
                        torch.max(torch.softmax(end_scores[0], dim=0)),
                        torch.argmax(end_scores),
                        torch.argmax(end_scores),
                        str(ground_truth_end_ind),
                        attributions_end_sum.sum(),       
                        all_tokens,
                        delta_end)

print('\033[1m', 'Visualizations For Start Position', '\033[0m')
viz.visualize_text([start_position_vis])

print('\033[1m', 'Visualizations For End Position', '\033[0m')
viz.visualize_text([end_position_vis])

[1m Visualizations For Start Position [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
15.0,15 (0.44),13.0,1.44,"[CLS] what is a way to increase your wound healing speed ? [SEP] wound care encourages and speeds wound healing via cleaning and protection from rein ##ju ##ry or infection . depending on each patient ' s needs , it can range from the simplest first aid to entire nursing special ##ties such as wound , os ##tom ##y , and con ##tine ##nce nursing and burn center care . [SEP]"
,,,,


[1m Visualizations For End Position [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
29.0,29 (0.64),14.0,3.36,"[CLS] what is a way to increase your wound healing speed ? [SEP] wound care encourages and speeds wound healing via cleaning and protection from rein ##ju ##ry or infection . depending on each patient ' s needs , it can range from the simplest first aid to entire nursing special ##ties such as wound , os ##tom ##y , and con ##tine ##nce nursing and burn center care . [SEP]"
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
29.0,29 (0.64),14.0,3.36,"[CLS] what is a way to increase your wound healing speed ? [SEP] wound care encourages and speeds wound healing via cleaning and protection from rein ##ju ##ry or infection . depending on each patient ' s needs , it can range from the simplest first aid to entire nursing special ##ties such as wound , os ##tom ##y , and con ##tine ##nce nursing and burn center care . [SEP]"
,,,,


## BERT with registers

In [16]:
# model_path = 'fine_tuned_model_registers_Nov17'
# model_path = 'model_num_reg_50'
# model_path = 'Archive/model_num_reg_50'

# load model
model = RegBertForQA.from_pretrained("Archive/model_num_reg_50").to("cuda")
model.to(device)
model.eval()
model.zero_grad()

# load tokenizer
tokenizer = AutoTokenizer.from_pretrained("Archive/tokenizer_num_reg_50")

from regbertfor QA, num_reg= 50


Some weights of RegBert were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.embeddings.position_ids', 'bert.reg_pos', 'bert.reg_tokens']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [17]:
ref_token_id = tokenizer.pad_token_id 
sep_token_id = tokenizer.sep_token_id 
cls_token_id = tokenizer.cls_token_id

input_ids, ref_input_ids, sep_id = construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

ground_truth = 'Wound care'

ground_truth_tokens = tokenizer.encode(ground_truth, add_special_tokens=False)
ground_truth_end_ind = indices.index(ground_truth_tokens[-1])
ground_truth_start_ind = ground_truth_end_ind - len(ground_truth_tokens) + 1

start_scores, end_scores = predict(input_ids, \
                                   token_type_ids=token_type_ids, \
                                   position_ids=position_ids, \
                                   attention_mask=attention_mask)


print('Question: ', question)
print(torch.argmax(start_scores), torch.argmax(end_scores))
print('Predicted Answer: ', ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]))

lig = LayerIntegratedGradients(squad_pos_forward_func, model.bert.embeddings)

attributions_start, delta_start = lig.attribute(inputs=input_ids,
                                  baselines=ref_input_ids,
                                  additional_forward_args=(token_type_ids, position_ids, attention_mask, 0),
                                  return_convergence_delta=True)
attributions_end, delta_end = lig.attribute(inputs=input_ids, baselines=ref_input_ids,
                                additional_forward_args=(token_type_ids, position_ids, attention_mask, 1),
                                return_convergence_delta=True)

attributions_start_sum = summarize_attributions(attributions_start)
attributions_end_sum = summarize_attributions(attributions_end)


Question:  What is a way to increase your wound healing speed?
tensor(21, device='cuda:0') tensor(29, device='cuda:0')
Predicted Answer:  cleaning and protection from rein ##ju ##ry or infection


In [22]:
# storing couple samples in an array for visualization purposes
start_position_vis = viz.VisualizationDataRecord(
                        attributions_start_sum,
                        torch.max(torch.softmax(start_scores[0], dim=0)),
                        torch.argmax(start_scores),
                        torch.argmax(start_scores),
                        str(ground_truth_start_ind),
                        attributions_start_sum.sum(),       
                        all_tokens,
                        delta_start)

end_position_vis = viz.VisualizationDataRecord(
                        attributions_end_sum,
                        torch.max(torch.softmax(end_scores[0], dim=0)),
                        torch.argmax(end_scores),
                        torch.argmax(end_scores),
                        str(ground_truth_end_ind),
                        attributions_end_sum.sum(),       
                        all_tokens,
                        delta_end)

print('\033[1m', 'Visualizations For Start Position', '\033[0m')
viz.visualize_text([start_position_vis])

print('\033[1m', 'Visualizations For End Position', '\033[0m')
viz.visualize_text([end_position_vis])

[1m Visualizations For Start Position [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
21.0,21 (0.46),13.0,3.11,"[CLS] what is a way to increase your wound healing speed ? [SEP] wound care encourages and speeds wound healing via cleaning and protection from rein ##ju ##ry or infection . depending on each patient ' s needs , it can range from the simplest first aid to entire nursing special ##ties such as wound , os ##tom ##y , and con ##tine ##nce nursing and burn center care . [SEP]"
,,,,


[1m Visualizations For End Position [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
29.0,29 (0.85),14.0,3.02,"[CLS] what is a way to increase your wound healing speed ? [SEP] wound care encourages and speeds wound healing via cleaning and protection from rein ##ju ##ry or infection . depending on each patient ' s needs , it can range from the simplest first aid to entire nursing special ##ties such as wound , os ##tom ##y , and con ##tine ##nce nursing and burn center care . [SEP]"
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
29.0,29 (0.85),14.0,3.02,"[CLS] what is a way to increase your wound healing speed ? [SEP] wound care encourages and speeds wound healing via cleaning and protection from rein ##ju ##ry or infection . depending on each patient ' s needs , it can range from the simplest first aid to entire nursing special ##ties such as wound , os ##tom ##y , and con ##tine ##nce nursing and burn center care . [SEP]"
,,,,
