In [1]:
import torch
from torch import nn
from transformers import DistilBertModel
import numpy as np
from tqdm import tqdm
import pickle
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from transformers import DistilBertConfig, DistilBertTokenizer
from IPython.display import display, HTML

import sys
sys.path.insert(0, '../src/models/')
import predict_model

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [2]:
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
#model = predict_model.loadBERT("../models/", 'saved_weights_inf_FIXED_boot_beta80.pt')

In [3]:
# Bert mode
bert = DistilBertModel.from_pretrained('distilbert-base-uncased')

class BERT(nn.Module):
    def __init__(self, bert):
        
        super(BERT, self).__init__()
        
        # Distil Bert model
        self.bert = bert
        ## Additional layers
        # Dropout layer
        self.dropout = nn.Dropout(0.3)
        # Relu 
        self.relu =  nn.ReLU()
        # Linear I 
        self.fc1 = nn.Linear(768, 512)
        # Linear II (Out)
        self.fc2 = nn.Linear(512, 170)
        # Softmax
        self.softmax = nn.LogSoftmax(dim=1)


    # Forward pass
    def forward(self, **kwargs):

        # Pass data trough bert and extract 
        cls_hs = self.bert(**kwargs)
        # Extract hidden state
        hidden_state = cls_hs.last_hidden_state
        # Only first is needed for classification
        pooler = hidden_state[:, 0]
        
        # Dense layer 1        
        x = self.fc1(pooler)
        # ReLU activation
        x = self.relu(x)
        # Drop out
        x = self.dropout(x)
        # Dense layer 2
        x = self.fc2(x)
        # Activation
        x = self.softmax(x)

        return x
    
# Load the entire model
model = BERT(bert)

# Load trained model (colab)
try:
    try:
        model_save_name = 'saved_weights_NLP_test.pt'
        path = F"/content/gdrive/My Drive/{model_save_name}"
        model.load_state_dict(torch.load(path))
        print('Google Success')

    except:
        model_save_name = 'saved_weights_NLP_subset.pt'
        path = "../models/" + model_save_name
        model.load_state_dict(torch.load(path, 
                                         map_location=torch.device('cpu')))
        print('Local Success')
except:
    print('No pretrained model found.')

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Local Success


In [4]:
criterion = nn.CrossEntropyLoss()
batch_size = 1

In [5]:
class BertVizDataset(Dataset):
    
    """
    Testing
    """
    def __init__(self, sentence_list, tokenizer):
        
        self.sentence_list = sentence_list
        self.tokenizer = tokenizer

    def __getitem__(self, idx):
            
        
        tokenized_text = self.tokenizer(self.sentence_list[idx], 
                                        max_length=512, 
                                        padding='max_length')

        return {'input_ids': torch.tensor(tokenized_text['input_ids']), 
                'attention_mask': torch.tensor(tokenized_text['attention_mask'])}

    def __len__(self):
        return len(self.sentence_list)

In [6]:
import torch
from torch.nn.functional import softmax

import matplotlib
import matplotlib.pyplot as plt


class SaliencyInterpreter:
    def __init__(self,
                 model,
                 criterion,
                 tokenizer,
                 show_progress=True,
                 **kwargs):

        """
        Model         : PyTorch model.
        Criterion     : Loss function.
        Tokenizer     : Used Tokenizer for the model.
        Show_progress : TQDM progress bar
        Kwargs        : Huggingsface encoder if you are using something different.
        """
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = model.to(self.device)
        self.model.eval()
        self.criterion = criterion
        self.tokenizer = tokenizer
        self.show_progress = show_progress
        self.kwargs = kwargs
        # To save outputs in saliency_interpret
        self.batch_output = None

    def _get_gradients(self, batch):
        # Set requires_grad to true for all parameters, but save original values to restore them later
        original_param_name_to_requires_grad_dict = {}
        for param_name, param in self.model.named_parameters():
            original_param_name_to_requires_grad_dict[param_name] = param.requires_grad
            param.requires_grad = True
        embedding_gradients = []
        hooks = self._register_embedding_gradient_hooks(embedding_gradients)

        loss = self.forward_step(batch)

        self.model.zero_grad()
        loss.backward()

        for hook in hooks:
            hook.remove()

        # Restore the original requires_grad values of the parameters
        for param_name, param in self.model.named_parameters():
            param.requires_grad = original_param_name_to_requires_grad_dict[param_name]

        return embedding_gradients[0]

    def _register_embedding_gradient_hooks(self, embedding_gradients):
        
        """
        Registers a backward hook to save the gradients of the embeddings for 
        use in get_gradients(). When there are multiple inputs (e.g., a passage 
        and question), the hook will be called multiple times. We append all the 
        embeddings gradients to a list.
        """

        def hook_layers(module, grad_in, grad_out):
            embedding_gradients.append(grad_out[0])

        backward_hooks = []
        embedding_layer = self.get_embeddings_layer()
        backward_hooks.append(embedding_layer.register_backward_hook(hook_layers))
        return backward_hooks

    def get_embeddings_layer(self):
        if hasattr(self.model, "get_input_embeddings"):
            embedding_layer = self.model.get_input_embeddings()
        else:
            encoder_attribute = self.kwargs.get("encoder")
            assert encoder_attribute, "Your model doesn't have 'get_input_embeddings' method, thus you " \
                "have provide 'encoder' key argument while initializing SaliencyInterpreter object"
            embedding_layer = getattr(self.model, encoder_attribute).embeddings
        return embedding_layer

    def colorize(self, instance, skip_special_tokens=False):

        special_tokens = self.special_tokens

        word_cmap = matplotlib.cm.Blues
        prob_cmap = matplotlib.cm.Greens
        #template = '<span class="barcode"; style="color: black; background-color: {}">{}</span>'
        # OLD ONEtemplate = """ <body style="font-size: 16px; font-family: -apple-system, BlinkMacSystemFont, \'Segoe UI\', Helvetica, Arial, sans-serif, \'Apple Color Emoji\', \'Segoe UI Emoji\', \'Segoe UI Symbol\'; padding: 4rem 2rem; direction: ltr"><figure style="margin-bottom: 6rem"><div class="entities" style="line-height: 2.0; direction: ltr"><mark class="entity" style="background: {}; padding: 0.2em 0.0em; margin: 0 0.1em; line-height: 1; border-radius: 0.3em;">{}<span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.0rem"></span></mark></div></figure> """
        template = """<mark class="entity" style="background: {}; padding: 0.0em 0.0em; margin: 0 0.2em; line-height: 1; border-radius: 0.0em;">\n    {}    <span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.Oem; vertical-align: middle; margin-left: 0.0rem"></span>\n</mark>"""

        colored_string = ''
        # Use a matplotlib normalizer in order to make clearer the difference between values
        #normalized_and_mapped = matplotlib.cm.ScalarMappable(cmap=word_cmap).to_rgba(instance['grad'])
        normalized_and_mapped = matplotlib.cm.ScalarMappable(cmap=word_cmap).to_rgba(instance['grad'][1:-1])
        #for word, color in zip(instance['tokens'], normalized_and_mapped):
        for word, color in zip(instance['tokens'][1:-1], normalized_and_mapped):
            if word in special_tokens and skip_special_tokens:
                continue
            # Handle wordpieces
            word = word.replace("##", "") if "##" in word else ' ' + word
            color = matplotlib.colors.rgb2hex(color[:3])
            colored_string += template.format(color, word)
        colored_string += template.format(0, "    Label: {} |".format(instance['label']))
        prob = instance['prob']
        color = matplotlib.colors.rgb2hex(prob_cmap(prob)[:3])
        colored_string += template.format(color, "{:.2f}%".format(instance['prob']*100)) + '|'
        return colored_string

    @property
    def special_tokens(self):
        
        """
        Some tokenizers don't have 'eos_token' and 'bos_token' attributes.
        So needed we some trick to get them.
        """
        
        if self.tokenizer.bos_token is None or self.tokenizer.eos_token is None:
            special_tokens = self.tokenizer.build_inputs_with_special_tokens([])
            special_tokens_ids = self.tokenizer.convert_ids_to_tokens(special_tokens)
            self.tokenizer.bos_token, self.tokenizer.eos_token = special_tokens_ids

        special_tokens = self.tokenizer.eos_token, self.tokenizer.bos_token
        return special_tokens

    def forward_step(self, batch):
        
        """
        This method can be overrided by the models forward function. This
        function returns the loss (based on the provided loss).
        Batch  : batch returned by dataloader.
        """
        
        input_ids = batch.get('input_ids').to(self.device)
        attention_mask = batch.get("attention_mask").to(self.device)
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        #print(outputs.shape)
        #print(outputs)
        
        label = torch.argmax(outputs, dim=1)
        batch_losses = self.criterion(outputs, label)
        #batch_losses = self.criterion(outputs, torch.max(label, 1)[1]) 
        loss = torch.mean(batch_losses)

        self.batch_output = [input_ids, outputs]

        return loss

    def update_output(self):
        
        """
        This format can be formatted to another output.
        The batch output is returned.
        """

        input_ids, outputs, grads = self.batch_output

        probs = softmax(outputs, dim=-1)
        probs, labels = torch.max(probs, dim=-1)

        tokens = [
            self.tokenizer.convert_ids_to_tokens(input_ids_)
            for input_ids_ in input_ids
        ]

        embedding_grads = grads.sum(dim=2)
        # norm for each sequence
        norms = torch.norm(embedding_grads, dim=1, p=1)
        # normalizing
        for i, norm in enumerate(norms):
            embedding_grads[i] = torch.abs(embedding_grads[i]) / norm

        batch_output = []

        iterator = zip(tokens, probs, embedding_grads, labels)

        for example_tokens, example_prob, example_grad, example_label in iterator:
            example_dict = dict()
            # Remove the batch padding
            example_tokens = [t for t in example_tokens if t != self.tokenizer.pad_token]
            example_dict['tokens'] = example_tokens
            example_dict['grad'] = example_grad.cpu().tolist()[:len(example_tokens)]
            example_dict['label'] = example_label.item()
            example_dict['prob'] = example_prob.item()
            batch_output.append(example_dict)
        return batch_output
    
class IntegratedGradient(SaliencyInterpreter):
    
    """
    Interprets the prediction using Integrated Gradients (https://arxiv.org/abs/1703.01365)
    Registered as a `SaliencyInterpreter` with name "integrated-gradient".
    """
    
    def __init__(self,
                 model,
                 criterion,
                 tokenizer,
                 num_steps=20,
                 show_progress=True,
                 **kwargs):
        super().__init__(model, criterion, tokenizer, show_progress, **kwargs)
        # Hyperparameters
        self.num_steps = num_steps

    def saliency_interpret(self, test_dataloader):

        instances_with_grads = []
        iterator = tqdm(test_dataloader) if self.show_progress else test_dataloader

        for batch in iterator:

            # we will store there batch outputs such as gradients, probability, tokens
            # so as each of them are used in different places, for convenience we will create
            # it as attribute:
            self.batch_output = []
            self._integrate_gradients(batch)
            batch_output = self.update_output()
            instances_with_grads.extend(batch_output)

        return instances_with_grads

    def _register_forward_hook(self, alpha, embeddings_list):
        
        """
        Register a forward hook on the embedding layer which scales the embeddings by alpha. Used
        for one term in the Integrated Gradients sum.
        We store the embedding output into the embeddings_list when alpha is zero.  This is used
        later to element-wise multiply the input by the averaged gradients.
        """

        def forward_hook(module, inputs, output):
            # Save the input for later use. Only do so on first call.
            if alpha == 0:
                embeddings_list.append(output.squeeze(0).clone().detach())

            # Scale the embedding by alpha
            output.mul_(alpha)

        embedding_layer = self.get_embeddings_layer()
        handle = embedding_layer.register_forward_hook(forward_hook)
        return handle

    def _integrate_gradients(self, batch):

        ig_grads = None

        # List of Embedding inputs
        embeddings_list = []

        # Exclude the endpoint because we do a left point integral approximation
        for alpha in np.linspace(0, 1.0, num=self.num_steps, endpoint=False):
            # Hook for modifying embedding value
            handle = self._register_forward_hook(alpha, embeddings_list)

            grads = self._get_gradients(batch)
            handle.remove()

            # Running sum of gradients
            if ig_grads is None:
                ig_grads = grads
            else:
                ig_grads = ig_grads + grads

        # Average of each gradient term
        ig_grads /= self.num_steps

        # Gradients come back in the reverse order that they were sent into the network
        embeddings_list.reverse()

        # Element-wise multiply average gradient by the input. 
        ig_grads *= embeddings_list[0]

        self.batch_output.append(ig_grads)

In [59]:
data_list = [
'The plant has 7 antheriferous stamens.'
]
dataset = BertVizDataset(data_list, tokenizer)

test_dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,
)

In [60]:
integrated_grad = IntegratedGradient(
    model, 
    criterion, 
    tokenizer, 
    show_progress=True,
    encoder="bert"
)
instances = integrated_grad.saliency_interpret(test_dataloader)

100%|█████████████████████████████████████████████| 1/1 [00:40<00:00, 40.12s/it]


In [61]:
#colored_string = integrated_grad.colorize(instances[0])
#display(HTML(colored_string))

In [62]:
def colorize(instance):

    word_cmap = matplotlib.cm.Greens
    prob_cmap = matplotlib.cm.Greens
    #template = '<span class="barcode"; style="color: black; background-color: {}">{}</span>'
    template = """  
    <mark class="entity" style="
    background: {}; 
    padding: 0.4em 0.0em; 
    margin: 0.0em; 
    line-height: 2; 
    border-radius: 0.0em;
    ">{}<span style="
    font-size: 0.8em; 
    font-weight: bold; 
    line-height: 1; 
    border-radius: 0.0em;   
    text-align-last:center;
    vertical-align: middle;
    margin-left: 0rem;
    "></span></mark>
    """

    # word-break: break-all;
    # word-wrap: break-word;
    #template = '<span class="barcode"; style="color: black; background-color: {}">{}</span>'

    colored_string = ''
    # Use a matplotlib normalizer in order to make clearer the difference between values
    normalized_and_mapped = matplotlib.cm.ScalarMappable(cmap=word_cmap).to_rgba(instance['grad'])
    for idx, (word, color) in enumerate(zip(instance['tokens'], normalized_and_mapped)):
        
        word = word.replace("##", "-") + ' ' if "##" in word else word + ' '
        color = matplotlib.colors.rgb2hex(color[:3])
        colored_string += template.format(color, word)
    #colored_string += template.format(0, "Label: {}: ".format(instance['label']))
    prob = instance['prob']
    color = matplotlib.colors.rgb2hex(prob_cmap(prob)[:3])
    #colored_string += template.format(0, "{:.2f}%".format(instance['prob']*100))
    
    print(f"Label:{instance['label']} -- {instance['prob']*100:.2f}%")
    return colored_string

In [63]:
string = colorize(instances[0])
display(HTML(string))

Label:35 -- 71.81%


In [51]:
string = colorize(instances[1])
display(HTML(string))

Label:92 -- 66.47%


In [52]:
string = colorize(instances[2])
display(HTML(string))

Label:16 -- 94.26%


In [53]:
string = colorize(instances[3])
display(HTML(string))

Label:36 -- 69.90%


In [57]:
test = tokenizer('The plant has 7 antheriferous stamens.')

In [58]:
test

{'input_ids': [101, 1996, 3269, 2038, 1021, 14405, 5886, 23930, 2358, 27245, 2015, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [66]:
' '.join(instances[0]['tokens'])

'[CLS] the plant has 7 ant ##her ##iferous st ##amen ##s . [SEP]'