In [2]:
# replace <PATH-TO-SAVED-MODEL> with the real path of the saved model
model_path = 'Hate-speech-CNERG/bert-base-uncased-hatexplain'

import torch
import torch.nn as nn

from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification, AutoConfig

from captum.attr import IntegratedGradients
from captum.attr import InterpretableEmbeddingBase, TokenReferenceBase
from captum.attr import visualization
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)


# We need to split forward pass into two part:
# 1) embeddings computation
# 2) classification

def compute_bert_outputs(model_bert, embedding_output, attention_mask=None, head_mask=None):
    if attention_mask is None:
        attention_mask = torch.ones(embedding_output.shape[0], embedding_output.shape[1]).to(embedding_output)

    extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

    extended_attention_mask = extended_attention_mask.to(
        dtype=next(model_bert.parameters()).dtype)  # fp16 compatibility
    extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

    if head_mask is not None:
        if head_mask.dim() == 1:
            head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            head_mask = head_mask.expand(model_bert.config.num_hidden_layers, -1, -1, -1, -1)
        elif head_mask.dim() == 2:
            head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)  # We can specify head_mask for each layer
        head_mask = head_mask.to(
            dtype=next(model_bert.parameters()).dtype)  # switch to fload if need + fp16 compatibility
    else:
        head_mask = [None] * model_bert.config.num_hidden_layers

    encoder_outputs = model_bert.encoder(embedding_output,
                                         extended_attention_mask,
                                         head_mask=head_mask)
    sequence_output = encoder_outputs[0]
    pooled_output = model_bert.pooler(sequence_output)
    outputs = (sequence_output, pooled_output,) + encoder_outputs[
                                                  1:]  # add hidden_states and attentions if they are here
    return outputs  # sequence_output, pooled_output, (hidden_states), (attentions)


class BertModelWrapper(nn.Module):

    def __init__(self, model):
        super(BertModelWrapper, self).__init__()
        self.model = model

    def forward(self, embeddings):
        outputs = compute_bert_outputs(self.model.bert, embeddings)
        pooled_output = outputs[1]
        logits = self.model.classifier(pooled_output)
        return logits


bert_model_wrapper = BertModelWrapper(model)
ig = IntegratedGradients(bert_model_wrapper)

# accumalate couple samples in this array for visualization purposes
vis_data_records_ig = []


def interpret_sentence(model_wrapper, sentence, label=1):
    model_wrapper.eval()
    model_wrapper.zero_grad()

    input_ids = torch.tensor([tokenizer.encode(sentence, add_special_tokens=True)])
    attn_mask = torch.ones(input_ids.shape[0], input_ids.shape[1]).to(input_ids)
    number_of_class = len(list(model.config.id2label.values()))

    batch_size = input_ids.shape[0]
    input_ids = input_ids.unsqueeze(1).expand(-1, number_of_class, -1).reshape(-1, input_ids.shape[1])
    attn_mask = attn_mask.unsqueeze(1).expand(-1, number_of_class, -1).reshape(-1, input_ids.shape[1])
    all_classes = torch.arange(number_of_class).unsqueeze(0).expand(batch_size, -1).flatten()
    #print(input_ids.shape, attn_mask.shape, all_classes.shape)
    input_embedding = model_wrapper.model.bert.embeddings(input_ids)

    baseline = torch.full(input_ids.shape, tokenizer.pad_token_id).long()
    baseline[:, 0] = tokenizer.cls_token_id
    sep_token_locs = torch.nonzero(input_ids == tokenizer.sep_token_id)
    baseline[sep_token_locs[:, 0], sep_token_locs[:, 1]] = tokenizer.sep_token_id
    baseline_embeds = model_wrapper.model.bert.embeddings(baseline)
  
    # compute attributions and approximation delta using integrated gradients
    attributions_ig = ig.attribute(inputs=input_embedding.requires_grad_(),
                                   target=all_classes,
                                   baselines=baseline_embeds,
                                   n_steps=500,
                                   )
    
#     pred = model_wrapper(input_embedding).item()
#     pred_ind = round(pred)
    attributions_ig = torch.sum(attributions_ig, dim=-1)
    attributions_ig = attributions_ig * attn_mask
    print(attributions_ig)

    #print('pred: ', 0, '(', '%.2f'% pred, ')', ', delta: ', abs(delta))

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0].numpy().tolist())
    add_attributions_to_visualizer(attributions_ig, tokens, 0, 0, label, 0, vis_data_records_ig)


def add_attributions_to_visualizer(attributions, tokens, pred, pred_ind, label, delta, vis_data_records):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.detach().numpy()
    
    print(attributions)
    
    # storing couple samples in an array for visualization purposes
    vis_data_records.append(visualization.VisualizationDataRecord(
        attributions,
        pred,
        list(model.config.id2label.values())[pred_ind],
        list(model.config.id2label.values())[label],
        "label",
        attributions.sum(),
        tokens[:len(attributions)],
        delta))


interpret_sentence(bert_model_wrapper, sentence="all muslims are terrorists and need to be deported from this country", label=0)
visualization.visualize_text(vis_data_records_ig)
print(vis_data_records_ig[0].word_attributions)


tensor([[ 0.0000, -0.0034,  0.5786, -0.0594,  0.3033,  0.0268,  0.0473,  0.0849,
          0.0083,  0.0706,  0.0644, -0.0404,  0.0910,  0.0000],
        [ 0.0000,  0.1954, -0.5991, -0.0430, -0.3898, -0.1109, -0.0233, -0.0698,
         -0.0522, -0.1291, -0.0948,  0.0439, -0.1306,  0.0000],
        [ 0.0000, -0.1245, -0.0532,  0.1059,  0.0812,  0.0143,  0.0236,  0.0288,
          0.0193,  0.0526,  0.0358,  0.0012,  0.0088,  0.0000]],
       dtype=torch.float64, grad_fn=<MulBackward0>)
[ 0.6375017  -0.76321051  0.10536265]


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
hate speech,hate speech (0.00),label,-0.02,[CLS] all muslims
,,,,


[ 0.6375017  -0.76321051  0.10536265]


In [4]:
m = nn.Softmax(dim=0)

In [32]:
a = torch.tensor(vis_data_records_ig[0].word_attributions[1:-1])

In [33]:
m(a)

tensor([0.0977, 0.0798, 0.0757, 0.0931, 0.0895, 0.0865, 0.0875, 0.0484, 0.0845,
        0.0858, 0.0837, 0.0878], dtype=torch.float64)