In [1]:
import numpy as np
import torch
import torch.nn as nn
from IPython.display import display, HTML
from transformers import DistilBertModel, DistilBertTokenizer, logging
import matplotlib
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm as tqdm_notebook
import spacy
from spacy import displacy
import seaborn as sns
import pandas as pd
import numpy as np
import collections
import glob
import pickle
import re
from bs4 import BeautifulSoup
import requests
from sklearn.metrics.pairwise import cosine_similarity
nlp = spacy.load('en_core_web_trf')
logging.set_verbosity_error()

import sys
sys.path.insert(0, '../../src/models/')
sys.path.insert(0, '../../src/features/')

from build_features import similarity_matrix as vector_values
from predict_model import load_PLANT_Bert, load_simBERT

%matplotlib inline

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
model = load_PLANT_Bert("../../models/", 'saved_weights_CUB_PLANTS_7584.pt')
#SIMmodel = load_simBERT()
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

Local Success


In [35]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

# Modify the prediction output and define a custom forward
def predict(inputs, attentions):
    return model(input_ids=inputs, attention_mask=attentions)[0]

def custom_forward(inputs, attentions):
    preds = predict(inputs, attentions)
    return torch.exp(preds)

# Tokenize functions
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

    text_ids = tokenizer.encode(text, add_special_tokens=False)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

# Summarize and vis functions
def summarize_attributions_ig(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

def summarize_attributions_occ(attributions):
    return attributions.sum(axis=0)

def token_to_words(attribution, tokens):
    
    words = []
    attributes = []

    for attribute, word in zip(attribution, tokens):

        attribute = attribute.cpu().detach().numpy()
        if word == '[CLS]' or word == '[SEP]':
            words.append(word)
            attributes.append([attribute])
        elif not word.startswith('##'):
            words.append(word)
            attributes.append([attribute])
        elif word.startswith('##'):
            words[-1] += word.strip('##')
            attributes[-1] = np.append(attributes[-1], attribute)

    attribution = [np.sum(mean) for mean in attributes]
    return attribution, words

def colorize(attribution, tokens):
    
    template = """  
    <mark class="entity" style=" background: {}; padding: 0.4em 0.0em; margin: 0.0em; line-height: 2; 
    border-radius: 0.0em; ">{}<span style=" font-size: 0.8em;  font-weight: bold;  line-height: 1; 
    border-radius: 0.0em; text-align-last:center; vertical-align: middle; margin-left: 0rem; "></span></mark>
    """

    colored_string = ''
    normalized_and_mapped = matplotlib.cm.ScalarMappable(cmap=matplotlib.cm.Greens).to_rgba(attribution)
    for idx, (word, color) in enumerate(zip(tokens, normalized_and_mapped)):
        
        word = word + ' '
        color = matplotlib.colors.rgb2hex(color[:3])
        if word.strip() == '[CLS]' or word.strip() == '[SEP]': 
            color = '#ffffff'
        #print(color)
        colored_string += template.format(color, word)

    return colored_string

def explain(word):
    
    data = collections.defaultdict(list)
       
    # tokenize
    input_ids, ref_input_ids, sep_id = construct_input_ref_pair(word, ref_token_id, sep_token_id, cls_token_id)
    token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
    position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
    attention_mask = construct_attention_mask(input_ids)
    indices = input_ids[0].type(torch.LongTensor)
    all_tokens = tokenizer.convert_ids_to_tokens(indices)
    
    #### Occlusion maps (3, 2)
    attribution_occ2 = occ.attribute(inputs=(input_ids, attention_mask),
                                sliding_window_shapes=(tuple([3,]), tuple([3,])), 
                                strides=(2, 2), 
                                baselines=(ref_input_ids, attention_mask),)
    attribution_occ2_sum = summarize_attributions_occ(attribution_occ2[0])
    attributions_occ2_words, words = token_to_words(attribution_occ2_sum, all_tokens)
    data['Occlusion'] = attributions_occ2_words
    data['Words'] = words
    ####
    
    return data

In [36]:
from captum.attr import Occlusion
occ = Occlusion(custom_forward)

In [37]:
# Single Example
string = 'Green leaves with spikelets.'
data = explain(string)

In [39]:
data

defaultdict(list,
            {'Occlusion': [-3.5762787e-07,
              -3.5762787e-07,
              1.937151e-07,
              7.4505806e-08,
              -5.9604645e-07,
              -2.8312206e-07,
              -3.7252903e-07],
             'Words': ['[CLS]',
              'green',
              'leaves',
              'with',
              'spikelets',
              '.',
              '[SEP]']})

In [41]:
string = colorize(data['Occlusion'], words)
display(HTML(string))

In [19]:
# URL
URL = 'https://en.wikipedia.org/wiki/Glossary_of_botanical_terms'
# Get the page
page = requests.get(URL, timeout=5)
soup = BeautifulSoup(page.content, "lxml", from_encoding="iso-8859-1")   
# Find embedded glossary
glossaries = soup.find_all('dt', {'class': 'glossary'})
parts = [part.text.lower().strip() for part in glossaries]
# Get additional anchors ("also know as...")
glossaries_other = soup.find_all('span', {'class': 'anchor'})
parts_other = [part['id'].lower().strip() for part in glossaries_other]
# Append and drop duplicates
parts = list(set((parts + parts_other)))
# Replace underscore with space
parts = [part.replace('_', ' ') for part in parts]

In [23]:
parts.index('leaf')

1064

In [25]:
PLANTS_dict = pickle.load(open('../../data/description/04_TRAIN_0000000-0007584_PLANTS.pkl', 'rb'))

In [29]:
PLANTS_list = list(PLANTS_dict.keys())

In [32]:
PLANTS_dict['hiteochloa semitonsa'][0:10]

['Rhachilla internodes elongated below proximal fertile floret.',
 'Culms 30-60 cm long, 3-6 -noded.',
 'Ligule a fringe of hairs, 0.5 mm long.',
 'Fertile spikelets pedicelled, 2 in the cluster.',
 'Leaf-blades flat, or conduplicate, 6-11 cm long, 3-4 mm wide.',
 'Caryopsis with adherent pericarp, ellipsoid, 1.5-1.75 mm long, dark brown.',
 'Upper glume lateral veins ribbed.',
 'Upper glume oblong, 1 length of spikelet, membranous, without keels, 5-7 -veined.',
 'Lemma of lower sterile floret similar to upper glume, ovate, 1 length of spikelet, scarious, 5 -veined, sulcate, glabrous, or hispid, acuminate.',
 'Lower glume apex acute, muticous, or mucronate.']

In [43]:
attribution = collections.defaultdict(list)
idx = 0

# Loop over common birds
for plant in tqdm_notebook(PLANTS_list[0:3], desc='Plant'):
    # Get descriptions
    text_list = [data for data in PLANTS_dict[plant]]
    # Get attribution values
    for text in tqdm_notebook(text_list[0:5], desc='Sentences', leave=False):
        d = explain(text)
        d['Sentence'] = len(d['Words']) * [idx]
        d['Plant'] = len(d['Words']) * [plant]
        for key in d.keys():
            attribution[key] += d[key]
        
        idx += 1

Plant:   0%|          | 0/3 [00:00<?, ?it/s]

Sentences:   0%|          | 0/5 [00:00<?, ?it/s]

Sentences:   0%|          | 0/5 [00:00<?, ?it/s]

Sentences:   0%|          | 0/5 [00:00<?, ?it/s]

In [60]:
# Drop into a df
df_attribution = pd.DataFrame.from_dict(attribution)

data_random = []
# Extract highest attributions
for idx in tqdm_notebook(df_attribution['Sentence'].unique()):
#for idx in range(50, 51):
    #doc = nlp(text_list[idx])
    doc = nlp(' '.join(df_attribution[df_attribution['Sentence'] == idx]['Words']))
    if len(doc) <= 3:
        continue
    # Check single
    words = [chunk.root.lemma_.lower() for chunk in doc.noun_chunks] 
    # Check multiple
    words += [chunk.root.text.lower() for chunk in doc.noun_chunks]
    # Drop duplicate
    words = list(set(words))
    #print(words)
    #print(words)
    traits =  set(words) & set(parts)
    #print(traits)
    if traits:
        # Yield the traits
        trait_list = list(traits)
        #print(trait_list)
        
        for trait in trait_list[0:1]:
            #for column in df_attribution.columns[1:-2]:
                #print(column)
            index = df_attribution[df_attribution['Sentence'] == idx]['Occlusion'].sort_values(ascending=False)
            data_random.append((idx, df_attribution.iloc[index.index[0]].Words, trait, df_attribution.iloc[index.index[0]].Plant))

  0%|          | 0/15 [00:00<?, ?it/s]

In [62]:
data_random

[(0, 'internodes', 'floret', 'hiteochloa semitonsa'),
 (2, 'ligule', 'hair', 'hiteochloa semitonsa'),
 (3, 'pedicelled', 'spikelet', 'hiteochloa semitonsa'),
 (4, ',', 'blade', 'hiteochloa semitonsa'),
 (6, ',', 'glumes', 'Iseilema arguta'),
 (7, 'awn', 'column', 'Iseilema arguta'),
 (8, ',', 'spikelet', 'Iseilema arguta'),
 (9, 'elliptic', 'spikelet', 'Iseilema arguta'),
 (10, 'aristida', 'kingdom', 'Aristida pubescens'),
 (12, 'florets', 'floret', 'Aristida pubescens'),
 (13, 'scabrous', 'vein', 'Aristida pubescens'),
 (14, 'principal', 'awn', 'Aristida pubescens')]

In [63]:
df_random = pd.DataFrame(data_random, columns =['Sentence', 'Adjective', 'Part', 'Plant'])

In [67]:
df_attribution[df_attribution['Sentence'] == 13]

Unnamed: 0,Occlusion,Words,Sentence,Plant
215,8.381903e-09,[CLS],13,Aristida pubescens
216,8.381903e-09,upper,13,Aristida pubescens
217,5.587935e-09,glume,13,Aristida pubescens
218,1.229346e-07,primary,13,Aristida pubescens
219,2.328306e-07,vein,13,Aristida pubescens
220,1.016539e-06,scabrous,13,Aristida pubescens
221,1.86963e-07,.,13,Aristida pubescens
222,3.367313e-08,[SEP],13,Aristida pubescens


In [66]:
df_random

Unnamed: 0,Sentence,Adjective,Part,Plant
0,0,internodes,floret,hiteochloa semitonsa
1,2,ligule,hair,hiteochloa semitonsa
2,3,pedicelled,spikelet,hiteochloa semitonsa
3,4,",",blade,hiteochloa semitonsa
4,6,",",glumes,Iseilema arguta
5,7,awn,column,Iseilema arguta
6,8,",",spikelet,Iseilema arguta
7,9,elliptic,spikelet,Iseilema arguta
8,10,aristida,kingdom,Aristida pubescens
9,12,florets,floret,Aristida pubescens
