In [104]:
import torch
from transformers import AutoTokenizer, DPRQuestionEncoder, DPRContextEncoder, DPRQuestionEncoderTokenizer
from captum.attr import LayerDeepLift   # you can try other methods too.

In [32]:
# Explain a document predicted by a DPR model, use deeplift method. 
query = 'Szechwan dish food cuisine'

In [46]:
# Init tokenizer.
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-multiset-base')
context_tokenizer = AutoTokenizer.from_pretrained("facebook/dpr-ctx_encoder-multiset-base")

def tokenize(sentence_pairs):
    queries = question_tokenizer([pair[0] for pair in sentence_pairs])
    docs = context_tokenizer([pair[1] for pair in sentence_pairs])
    return torch.tensor(queries['input_ids']), torch.tensor(docs['input_ids'])

In [66]:
# define two-tower model, use cosine similarity for relevance prediction.
class DPRRanker(torch.nn.Module):
    def __init__(self):
        super(DPRRanker, self).__init__()
        self.question_encoder = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-multiset-base')
        self.context_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-multiset-base')
        
    def forward(self, docids, qids):
        """ doc-token-ids, query-token-ids, query argument after doc to suit feature attribution method later."""
        embedding_q = self.question_encoder(qids).pooler_output
        embedding_doc = self.context_encoder(docids).pooler_output
        score = torch.cosine_similarity(embedding_q, embedding_doc)
        return score


In [105]:
# use a query-doc pair example and tokenize it.
sentence_pair = [('Szechwan dish food cuisine', 'Cuisine A cuisine (/kwɪˈziːn/ kwi-ZEEN; from French [kɥizin], in turn from Latin coquere = "to cook") is a style of cooking characterized by distinctive ingredients, techniques and dishes, and usually associated with a specific culture or geographic region. A cuisine is primarily influenced by the ingredients that are available locally or through trade. Religious food laws, such as Hindu, Islamic and Jewish dietary laws, can also exercise a strong influence on cuisine.')]
(query_token_ids, doc_token_ids)= tokenize(sentence_pair)

In [None]:
# init model.
model = DPRRanker()
model.eval()

In [69]:
# Init DeepLift method.
LDF = LayerDeepLift(model, model.context_encoder.ctx_encoder.bert_model.embeddings)

In [60]:
# build reference input, query remains the same, replace doc tokens to [PAD] tokens.
def construct_ref_input(orig_input):
    ref_input= orig_input.clone()
    select = torch.zeros_like(ref_input)
    for i in [100, 101 ,102, 103]:   # keep special_token
        special = ref_input == i
        select = torch.logical_or(select, special)
    return ref_input * select

In [62]:
reference_doc_ids = construct_ref_input(doc_token_ids)

In [72]:
importance = LDF.attribute(inputs=doc_token_ids, baselines=reference_doc_ids, additional_forward_args=(query_token_ids)).detach()

Input Tensor 0 has a dtype of torch.int64.
                Gradients cannot be activated
                for these data types.


In [76]:
# Normalize importance scores, to token-wise.
importance = importance.sum(dim=-1).squeeze(0)
importance /= torch.norm(importance)

In [106]:
# visualize importance scores for tokens. red: negative, green: positive.

from IPython.core.display import display, HTML
def to_html(word, importance):
  def _get_color(attr):
    # clip values to prevent CSS errors (Values should be from [-1,1])
    attr = max(-1, min(1, attr))
    if attr > 0:
        hue = 120
        sat = 75
        lig = 100 - int(50 * attr)
    else:
        hue = 0
        sat = 75
        lig = 100 - int(-40 * attr)
    return "hsl({}, {}%, {}%)".format(hue, sat, lig)
  color = _get_color(importance)
  tag = '<mark style="background-color: {color}; opacity:1.0; \
                    line-height:1.75"><font color="black"> {word}\
                    </font></mark>'.format(
            color=color, word=word)
  return tag


In [116]:
# recover tokens.
doc_tokens = context_tokenizer.convert_ids_to_tokens(doc_token_ids.squeeze(0))
res = ''.join(to_html(word, att) for word, att in zip(doc_tokens, importance))
display(HTML(res))