In [1]:
import os
import sys
import threading

import torch
import torch.nn as nn
from torch.nn.parallel.scatter_gather import scatter_kwargs, scatter

from transformers import BertTokenizer, BertForQuestionAnswering, BertConfig
from captum.attr import IntegratedGradients, NoiseTunnel, LayerIntegratedGradients, GradientShap, LayerConductance, GradientAttribution

from captum.attr import visualization as viz

from captum.attr._utils.approximation_methods import approximation_parameters
from captum.attr._utils.batching import _batched_operator
from captum.attr._utils.common import (
    _validate_input,
    _format_additional_forward_args,
    _format_attributions,
    _format_input_baseline,
    _reshape_and_sum,
    _expand_additional_forward_args,
    _expand_target,
)
#from  captum.attr._utils.batching import _sort_key_list

from captum.attr._utils.gradient import compute_gradients, _forward_layer_distributed_eval, _gather_distributed_tensors


To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html


In [2]:
class BasicEmbeddingModel(nn.Module):
    r"""
    Implements basic model with nn.Embedding layer. This simple model
    will help us to test nested InterpretableEmbedding layers
    The model has the following structure:
    BasicEmbeddingModel(
      (embedding1): Embedding(30, 100)
      (embedding2): TextModule(
        (inner_embedding): Embedding(30, 100)
      )
      (linear1): Linear(in_features=100, out_features=256, bias=True)
      (relu): ReLU()
      (linear2): Linear(in_features=256, out_features=1, bias=True)
    )
    """

    def __init__(
        self, num_embeddings=30, embedding_dim=200, hidden_dim=256, output_dim=1
    ):
        super().__init__()
        self.embedding1 = nn.Embedding(num_embeddings, embedding_dim)
        self.linear1 = nn.Linear(embedding_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, input):
        embedding1 = self.embedding1(input)
        return self.linear2(self.relu(self.linear1(embedding1))).squeeze(1)
        

#device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

# load model
model = BasicEmbeddingModel().cuda()
#model = nn.DataParallel(model)
model.eval()
model.zero_grad()

input = torch.tensor([[2],[1], [4]]).cuda()
ref = torch.tensor([[0], [3], [5]]).cuda()
model(input)

tensor([[0.2308],
        [0.3428],
        [0.3436]], device='cuda:0', grad_fn=<SqueezeBackward1>)

In [3]:
lig = LayerIntegratedGradients(model, model.embedding1)
attrs = lig.attribute(input, baselines=ref, return_convergence_delta=True)
attrs

output.shape:  torch.Size([150, 1])
output.shape:  torch.Size([150, 1])
output.shape:  torch.Size([150, 1])
output.shape:  torch.Size([3, 1])
output.shape:  torch.Size([3, 1])


(tensor([[[ 2.8434e-02,  2.4336e-03, -7.7239e-03,  2.9334e-02, -2.2649e-02,
            1.1067e-02,  9.4211e-02, -2.4433e-03,  6.1734e-03,  3.3512e-04,
            1.4698e-02,  4.3736e-03,  1.1171e-03, -1.0069e-02,  3.6302e-02,
           -5.7423e-02,  2.2184e-03,  1.6446e-03,  8.4808e-03,  3.1292e-03,
           -1.2136e-02,  4.6877e-04, -2.5499e-02,  6.4064e-03,  9.3984e-04,
           -6.3043e-03,  4.6368e-04, -1.8031e-03,  2.6459e-02, -8.8340e-03,
           -5.2754e-02,  5.9586e-03, -3.9220e-03,  1.7803e-02,  1.0983e-03,
            4.2014e-03,  1.0652e-02, -9.1668e-04, -1.0238e-02,  2.2246e-03,
           -4.6415e-03, -3.6420e-03, -1.9941e-03, -8.7421e-03,  2.6622e-03,
            3.3955e-03,  3.0525e-03, -2.4556e-03,  6.9845e-03, -1.7373e-02,
           -1.7561e-03,  1.4941e-02,  9.3275e-04,  1.1926e-02, -1.9599e-02,
           -2.4048e-03, -1.6024e-03, -5.4949e-02,  1.3060e-03, -3.2248e-03,
           -5.4806e-03,  5.2448e-03, -1.4473e-02, -2.2163e-02, -7.5677e-04,
            

In [4]:
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

interpret_layer = configure_interpretable_embedding_layer(model, 'embedding1')

        be replaced with an interpretable embedding layer which wraps the
        original embedding layer and takes word embedding vectors as inputs of
        the forward function. This allows to generate baselines for word
        embeddings and compute attributions for each embedding dimension.
        The original embedding layer must be set
        back by calling `remove_interpretable_embedding_layer` function
        after model interpretation is finished.
  after model interpretation is finished."""


In [5]:
embs = interpret_layer.indices_to_embeddings(input)
ref_emb = interpret_layer.indices_to_embeddings(ref)

In [6]:
ig = IntegratedGradients(model)
ig.attribute(embs, ref_emb, return_convergence_delta=True)

output.shape:  torch.Size([150, 1])
output.shape:  torch.Size([3, 1])
output.shape:  torch.Size([3, 1])


(tensor([[[ 2.8434e-02,  2.4336e-03, -7.7239e-03,  2.9334e-02, -2.2649e-02,
            1.1067e-02,  9.4211e-02, -2.4433e-03,  6.1734e-03,  3.3512e-04,
            1.4698e-02,  4.3736e-03,  1.1171e-03, -1.0069e-02,  3.6302e-02,
           -5.7423e-02,  2.2184e-03,  1.6446e-03,  8.4808e-03,  3.1292e-03,
           -1.2136e-02,  4.6877e-04, -2.5499e-02,  6.4064e-03,  9.3984e-04,
           -6.3043e-03,  4.6368e-04, -1.8031e-03,  2.6459e-02, -8.8340e-03,
           -5.2754e-02,  5.9586e-03, -3.9220e-03,  1.7803e-02,  1.0983e-03,
            4.2014e-03,  1.0652e-02, -9.1668e-04, -1.0238e-02,  2.2246e-03,
           -4.6415e-03, -3.6420e-03, -1.9941e-03, -8.7421e-03,  2.6622e-03,
            3.3955e-03,  3.0525e-03, -2.4556e-03,  6.9845e-03, -1.7373e-02,
           -1.7561e-03,  1.4941e-02,  9.3275e-04,  1.1926e-02, -1.9599e-02,
           -2.4048e-03, -1.6024e-03, -5.4949e-02,  1.3060e-03, -3.2248e-03,
           -5.4806e-03,  5.2448e-03, -1.4473e-02, -2.2163e-02, -7.5677e-04,
            

In [7]:
remove_interpretable_embedding_layer(model, interpret_layer)

In [9]:
model_path = '/home/narine/debug_squad' # <PATH-TO-SAVED-MODEL>

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# load model
model = BertForQuestionAnswering.from_pretrained(model_path)
model.to(device)
model.eval()
model.zero_grad()

# load tokenizer
tokenizer = BertTokenizer.from_pretrained(model_path)


In [10]:
def predict(inputs, token_type_ids=None, position_ids=None, attention_mask=None):
    return model(inputs, token_type_ids=token_type_ids,
                 position_ids=position_ids, attention_mask=attention_mask)
    
def squad_pos_forward_func(inputs, token_type_ids=None, position_ids=None, attention_mask=None, position=0):
        pred_ = predict(inputs,
                        token_type_ids=token_type_ids,
                        position_ids=position_ids,
                        attention_mask=attention_mask)[position]
        return pred_.max(1).values


In [11]:
ref = "[PAD]"
sep = "[SEP]"
cls = "[CLS]"

ref_token_id = tokenizer.encode(ref)
sep_token_id = tokenizer.encode(sep)
cls_token_id = tokenizer.encode(cls)


In [12]:
def construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id):
    question_ids = tokenizer.encode(question)
    text_ids = tokenizer.encode(text)

    # construct input token ids
    input_ids = cls_token_id + question_ids + sep_token_id + text_ids + sep_token_id

    # construct reference token ids 
    ref_input_ids = cls_token_id + ref_token_id * len(question_ids) + sep_token_id + \
        ref_token_id * len(text_ids) + sep_token_id

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(question_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    input_ids_ = input_ids[0]
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(len(input_ids_))]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids)
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = torch.zeros_like(position_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)


In [13]:
question, text = "How many artworks did Van Gogh create?", \
                 "In just over a decade Vincent van Gogh created about 2,100 artworks, " \
                 "including around 860 oil paintings, most of which date from the last two years of his life."


In [14]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)
print(input_ids.dtype, input_ids)
print(token_type_ids, ref_token_type_ids)
print(position_ids, ref_position_ids)


torch.int64 tensor([[  101,  1293,  1242, 25466,  1225,  3498,  1301,  5084,  2561,   136,
           102,  1107,  1198,  1166,   170,  4967,   191,  1394,  8298,  3498,
          1301,  5084,  1687,  1164,   123,   117,  1620, 25466,   117,  1259,
          1213,  5942,  1568,  2949,  4694,   117,  1211,  1104,  1134,  2236,
          1121,  1103,  1314,  1160,  1201,  1104,  1117,  1297,   119,   102]],
       device='cuda:0')
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1]], device='cuda:0') tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0]], device='cuda:0')
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
         36, 

In [16]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.cpu().detach().numpy()
    return attributions

lig = LayerIntegratedGradients(squad_pos_forward_func, model.bert.embeddings)

#nt = NoiseTunnel(ig)

print('squad_pos_forward_func: ', squad_pos_forward_func(input_ids, token_type_ids, attention_mask=attention_mask))
model.zero_grad()
attributions_start = lig.attribute(inputs=(input_ids, token_type_ids, position_ids), baselines=(ref_input_ids, ref_token_type_ids, ref_position_ids), additional_forward_args=(attention_mask, 0))
attributions_end = lig.attribute(inputs=(input_ids, token_type_ids, position_ids), baselines=(ref_input_ids, ref_token_type_ids, ref_position_ids), additional_forward_args=(attention_mask, 1))

attributions_start_ = summarize_attributions(attributions_start[0])
attributions_end_ = summarize_attributions(attributions_end[0])

print('attributions_start: ', attributions_start_)
print('attributions_end: ', attributions_end_)
text = tokenizer.convert_ids_to_tokens(input_ids[0].detach().cpu().numpy())
print(text)


squad_pos_forward_func:  tensor([6.8525], device='cuda:0', grad_fn=<MaxBackward0>)
output.shape:  torch.Size([50])
output.shape:  torch.Size([50])
output.shape:  torch.Size([50])
output.shape:  torch.Size([50])
output.shape:  torch.Size([50])
output.shape:  torch.Size([50])
attributions_start:  [ 0.          0.54717004  0.3882186  -0.03181905  0.0175079   0.07919758
  0.15983962 -0.00674258  0.08660728  0.21139494 -0.07965035  0.0599708
 -0.00141058  0.0131349   0.10260752 -0.02915622  0.01871186  0.06346903
  0.00851475 -0.06987926  0.00704876 -0.0222609   0.02858604  0.12898968
  0.30991644  0.0919048   0.2707097  -0.02806081  0.02801514  0.02262776
  0.03283522 -0.0422463  -0.00726324  0.00105244 -0.02436556  0.01064256
 -0.02947477  0.06351329  0.10911785 -0.08753464  0.05641276  0.43600887
  0.03995525  0.02896365 -0.0030782   0.08050218 -0.00569606 -0.05158571
  0.0208116  -0.04891273]
attributions_end:  [ 0.          0.59831417  0.48700386  0.08439662 -0.01518923  0.08282123
  0

In [17]:
text = tokenizer.convert_ids_to_tokens(input_ids[0].detach().cpu().numpy())


In [18]:
vis_data_records = []
# storing couple samples in an array for visualization purposes
vis_data_records.append(viz.VisualizationDataRecord(
                        attributions_start_,
                        0.0, #torch.max(torch.softmax(start_scores[0], dim=0)),
                        0.0, #torch.argmax(start_scores),
                        '2100',
                        '2100',
                        attributions_start_.sum(),       
                        text,
                        -1.0))

vis_data_records.append(viz.VisualizationDataRecord(
                        attributions_end_,
                        0.0, #torch.max(torch.softmax(end_scores[0], dim=0)),
                        0.0, #torch.argmax(end_scores),
                        '2100',
                        '2100',
                        attributions_end_.sum(),       
                        text,
                        -1.0))

viz.visualize_text(vis_data_records)


Target Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
2100.0,0.0 (0.00),2100.0,2.95,"[CLS] how many artworks did van go ##gh create ? [SEP] in just over a decade v ##in ##cent van go ##gh created about 2 , 100 artworks , including around 86 ##0 oil paintings , most of which date from the last two years of his life . [SEP]"
,,,,
2100.0,0.0 (0.00),2100.0,3.54,"[CLS] how many artworks did van go ##gh create ? [SEP] in just over a decade v ##in ##cent van go ##gh created about 2 , 100 artworks , including around 86 ##0 oil paintings , most of which date from the last two years of his life . [SEP]"
,,,,
