In [1]:
import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from transformers import BertTokenizer, BertForQuestionAnswering, BertConfig

from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients, LayerActivation
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer


In [2]:
device = torch.device("cpu")

In [3]:
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline

model_path = 'deepset/roberta-base-squad2'

# load model
model = AutoModelForQuestionAnswering.from_pretrained(model_path, output_attentions=True)
model.to(device)
model.eval()
model.zero_grad()

# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)

Downloading:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/496M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/79.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/772 [00:00<?, ?B/s]

In [4]:
def predict(inputs, token_type_ids=None, position_ids=None, attention_mask=None):
    output = model(inputs, token_type_ids=token_type_ids,
                 position_ids=position_ids, attention_mask=attention_mask, )
    return output.start_logits, output.end_logits, output.attentions

In [24]:
def squad_pos_forward_func(inputs, position=0):
    pred = model(inputs_embeds=inputs)
    pred = pred[position]
    return pred.max(1).values

In [6]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

ref_token_id, sep_token_id, cls_token_id

(1, 2, 0)

In [7]:
def construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id):
    question_ids = tokenizer.encode(question, add_special_tokens=False)
    text_ids = tokenizer.encode(text, add_special_tokens=False)

    # construct input token ids
    input_ids = [cls_token_id] + question_ids + [sep_token_id] + text_ids + [sep_token_id]

    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(question_ids) + [sep_token_id] + \
        [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(question_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)
    
def construct_whole_bert_embeddings(input_ids, ref_input_ids, \
                                    token_type_ids=None, ref_token_type_ids=None, \
                                    position_ids=None, ref_position_ids=None):
    input_embeddings = interpretable_embedding.indices_to_embeddings(input_ids)
    ref_input_embeddings = interpretable_embedding.indices_to_embeddings(ref_input_ids)
    
    return input_embeddings, ref_input_embeddings

In [8]:
question, text = "What is important to us?", "It is important to us to include, empower and support humans of all kinds."

In [9]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

In [10]:
ground_truth = 'to include, empower and support humans of all kinds'

ground_truth_tokens = tokenizer.encode(ground_truth, add_special_tokens=False)
ground_truth_end_ind = indices.index(ground_truth_tokens[-1])
ground_truth_start_ind = ground_truth_end_ind - len(ground_truth_tokens) + 1

In [12]:
input_ids

tensor([[    0,  2264,    16,   505,     7,   201,   116,     2,   243,    16,
           505,     7,   201,     7,   680,     6, 15519,     8,   323,  5868,
             9,    70,  6134,     4,     2]])

In [13]:
start_scores, end_scores, output_attentions = predict(input_ids)
                                #    token_type_ids=token_type_ids, \
                                #    position_ids=position_ids, \
                                #    attention_mask=attention_mask)


print('Question: ', question)
print('Predicted Answer: ', ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]))

Question:  What is important to us?
Predicted Answer:  Ġto Ġinclude , Ġempower Ġand Ġsupport Ġhumans Ġof Ġall Ġkinds


In [14]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / norm_fn(attributions)
    return attributions

In [28]:
model.roberta.embeddings.word_embeddings

InterpretableEmbeddingBase(
  (embedding): Embedding(50265, 768, padding_idx=1)
)

In [16]:
interpretable_embedding = configure_interpretable_embedding_layer(model, 'roberta.embeddings.word_embeddings')



In [17]:
layer_attrs_start = []
layer_attrs_end = []

layer_attn_mat_start = []
layer_attn_mat_end = []

input_embeddings, ref_input_embeddings = construct_whole_bert_embeddings(input_ids, ref_input_ids, \
                                         token_type_ids=token_type_ids, ref_token_type_ids=ref_token_type_ids, \
                                         position_ids=position_ids, ref_position_ids=ref_position_ids)

In [22]:
lc = LayerConductance(squad_pos_forward_func, model.roberta.encoder.layer[0])

In [26]:
layer_attributions = lc.attribute(inputs=input_embeddings, 
                                  baselines=ref_input_embeddings, 
                                  additional_forward_args=())

In [27]:
layer_attributions

(tensor([[[ 0.0023,  0.0008,  0.0015,  ...,  0.0003, -0.0010,  0.0003],
          [ 0.0089, -0.0074,  0.0034,  ...,  0.0053, -0.0002,  0.0022],
          [ 0.0001,  0.0010, -0.0055,  ...,  0.0068,  0.0019,  0.0001],
          ...,
          [ 0.0024, -0.0015,  0.0024,  ...,  0.0007,  0.0029,  0.0005],
          [ 0.0027, -0.0005, -0.0029,  ..., -0.0008,  0.0009,  0.0006],
          [-0.0004,  0.0002,  0.0022,  ..., -0.0005,  0.0011,  0.0010]]],
        grad_fn=<SumBackward1>),
 tensor([[[[-1.2346e-04, -1.1475e-05, -3.8416e-05,  ...,  4.8486e-07,
             1.7653e-06,  3.6897e-06],
           [ 1.0477e-03,  6.1881e-02, -1.6728e-03,  ...,  5.2565e-06,
             3.5546e-05,  1.7740e-04],
           [-1.2809e-04,  3.0367e-03,  4.0127e-04,  ...,  3.0098e-05,
             2.6549e-05,  1.9713e-05],
           ...,
           [-5.2128e-05,  2.3169e-04,  1.3344e-04,  ..., -1.9066e-04,
             2.7491e-04,  2.9395e-04],
           [-5.7210e-07, -8.0422e-05,  8.3225e-05,  ..., -3.6489e-