In [2]:
import numpy as np
import torch
import torch.nn as nn
from IPython.display import display, HTML
from transformers import DistilBertModel, DistilBertTokenizer, logging
import matplotlib
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm as tqdm_notebook
import spacy
from spacy import displacy
import seaborn as sns
import pandas as pd
import numpy as np
import collections
import glob
import pickle
import re
from bs4 import BeautifulSoup
import requests
from sklearn.metrics.pairwise import cosine_similarity
nlp = spacy.load('en_core_web_trf')
logging.set_verbosity_error()

import sys
sys.path.insert(0, '../../src/models/')
sys.path.insert(0, '../../src/features/')

from build_features import similarity_matrix as vector_values
from predict_model import load_CUB_Bert, load_simBERT

%matplotlib inline

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
model = load_CUB_Bert("../../models/", 'saved_weights_CUB_CHUNKS_1881.pt')
#SIMmodel = load_simBERT()
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

Local Success


In [9]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

# Modify the prediction output and define a custom forward
def predict(inputs, attentions):
    return model(input_ids=inputs, attention_mask=attentions)[0]

def custom_forward(inputs, attentions):
    preds = predict(inputs, attentions)
    return torch.exp(preds)

# Tokenize functions
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

    text_ids = tokenizer.encode(text, add_special_tokens=False)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

# Summarize and vis functions
def summarize_attributions_ig(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

def summarize_attributions_occ(attributions):
    return attributions.sum(axis=0)

def token_to_words(attribution, tokens):
    
    words = []
    attributes = []

    for attribute, word in zip(attribution, tokens):

        attribute = attribute.cpu().detach().numpy()
        if word == '[CLS]' or word == '[SEP]':
            words.append(word)
            attributes.append([attribute])
        elif not word.startswith('##'):
            words.append(word)
            attributes.append([attribute])
        elif word.startswith('##'):
            words[-1] += word.strip('##')
            attributes[-1] = np.append(attributes[-1], attribute)

    attribution = [np.sum(mean) for mean in attributes]
    return attribution, words

def colorize(attribution, tokens):
    
    template = """  
    <mark class="entity" style=" background: {}; padding: 0.4em 0.0em; margin: 0.0em; line-height: 2; 
    border-radius: 0.0em; ">{}<span style=" font-size: 0.8em;  font-weight: bold;  line-height: 1; 
    border-radius: 0.0em; text-align-last:center; vertical-align: middle; margin-left: 0rem; "></span></mark>
    """

    colored_string = ''
    normalized_and_mapped = matplotlib.cm.ScalarMappable(cmap=matplotlib.cm.Greens).to_rgba(attribution)
    for idx, (word, color) in enumerate(zip(tokens, normalized_and_mapped)):
        
        word = word + ' '
        color = matplotlib.colors.rgb2hex(color[:3])
        if word.strip() == '[CLS]' or word.strip() == '[SEP]': 
            color = '#ffffff'
        #print(color)
        colored_string += template.format(color, word)

    return colored_string

def explain(word):
    
    #data = []
       
    # tokenize
    input_ids, ref_input_ids, sep_id = construct_input_ref_pair(word, ref_token_id, sep_token_id, cls_token_id)
    token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
    position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
    attention_mask = construct_attention_mask(input_ids)
    indices = input_ids[0].type(torch.LongTensor)
    all_tokens = tokenizer.convert_ids_to_tokens(indices)
    
    #### Occlusion maps (3, 2)
    attribution_occ2 = occ.attribute(inputs=(input_ids, attention_mask),
                                sliding_window_shapes=(tuple([3,]), tuple([3,])), 
                                strides=(2, 2), 
                                baselines=(ref_input_ids, attention_mask),)
    attribution_occ2_sum = summarize_attributions_occ(attribution_occ2[0])
    attributions_occ2_words, words = token_to_words(attribution_occ2_sum, all_tokens)
    data = [(word, attr) for word, attr in zip(words, attributions_occ2_words)]
    ####
    
    return data

In [10]:
from captum.attr import Occlusion
occ     = Occlusion(custom_forward)

In [11]:
# Single Example
string = 'Orange legs with long nails.'
data = explain(string)

In [12]:
data

[('[CLS]', 1.7881393e-07),
 ('orange', 1.7881393e-07),
 ('legs', 4.4703484e-08),
 ('with', -5.9604645e-08),
 ('long', -1.1920929e-07),
 ('nails', -2.0861626e-07),
 ('.', -2.3841858e-07),
 ('[SEP]', -2.0861626e-07)]