Recently, I encountered this great library, which can visualize the decision making process of BERT. It is very like the Grad-CAM approach in computer vision, which is a powerful debugging tool.

This notebook is modified from the official COLAB notebook: https://colab.research.google.com/github/hila-chefer/Transformer-Explainability/blob/main/BERT_explainability.ipynb#scrollTo=VakYjrkC6C3S

It seems there are some problems with the original notebook, while this should work fine.

In [1]:
!git clone https://github.com/hila-chefer/Transformer-Explainability.git

import os
os.chdir(f'./Transformer-Explainability')

# !pip install -r requirements.txt

Cloning into 'Transformer-Explainability'...
remote: Enumerating objects: 386, done.[K
remote: Counting objects: 100% (5/5), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 386 (delta 3), reused 2 (delta 2), pack-reused 381 (from 2)[K
Receiving objects: 100% (386/386), 3.85 MiB | 10.83 MiB/s, done.
Resolving deltas: 100% (194/194), done.


In [2]:
!pip install captum

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting captum
  Downloading captum-0.7.0-py3-none-any.whl (1.3 MB)
[K     |████████████████████████████████| 1.3 MB 10.4 MB/s eta 0:00:01
Installing collected packages: captum
Successfully installed captum-0.7.0
You should consider upgrading via the '/usr/bin/python -m pip install --upgrade pip' command.[0m


In [3]:
# !pip install transformers==3.5.1

In [4]:
from transformers import BertTokenizer
from BERT_explainability.modules.BERT.ExplanationGenerator import Generator
from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification
from transformers import BertTokenizer
from BERT_explainability.modules.BERT.ExplanationGenerator import Generator
from transformers import AutoTokenizer

from captum.attr import (
    visualization
)
import torch

In [5]:
from transformers.modeling_outputs import SequenceClassifierOutput

In [6]:
from transformers import BertPreTrainedModel
from BERT_explainability.modules.layers_ours import *
from BERT_explainability.modules.BERT.BERT import BertModel
from torch.nn import CrossEntropyLoss, MSELoss

class BertForSequenceClassification(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.bert = BertModel(config)
        self.dropout = Dropout(config.hidden_dropout_prob)
        self.classifier = Linear(config.hidden_size, config.num_labels)

        self.init_weights()

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            labels=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for computing the sequence classification/regression loss.
            Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
            If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def relprop(self, cam=None, **kwargs):
        cam = self.classifier.relprop(cam, **kwargs)
        cam = self.dropout.relprop(cam, **kwargs)
        cam = self.bert.relprop(cam, **kwargs)
        # print("conservation: ", cam.sum())
        return cam

In [23]:
model = BertForSequenceClassification.from_pretrained('bert-base-uncased').cuda()
model.eval()
tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
# initialize the explanations generator
explanations = Generator(model)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.embeddings.position_ids', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [24]:
classifications = ["1", "2", "3", "4", "5"]

In [25]:
import pandas as pd
data = pd.read_pickle('/data2/cehou/LLM_safety/img_text_data/finished/dataset_30_male_HongKong_traffic accident_GSV_2000_all_4712.pkl')
data

Unnamed: 0,GSV_idx,panoid,age,gender,location,event,text_description_all,text_description_age,text_description_gender,text_description_location
0,0,pLuKvPlHrMPC7HbiwdWJsQ,30,male,HongKong,traffic accident,"[[INST] \n[""Please design a car accident-focu...","For a 30-year-old individual, several factors ...",When evaluating the safety perception of speci...,The built environment in Hong Kong is characte...
1,1,9N3G-AjC5a_k_cjbYWTuXQ,30,male,HongKong,traffic accident,"[[INST] \n[""Please design a car accident-focu...","For a 30-year-old, several factors in the imag...",When evaluating the safety perception of speci...,Hong Kong is a densely populated urban area wi...
2,2,vBpRyVLbY1WpqvrM11TCaw,30,male,HongKong,traffic accident,"[[INST] \n[""Please design a car accident-focu...","For a 30-year-old, several factors in the imag...",When evaluating the safety perception of speci...,The built environment in Hong Kong is characte...
3,3,ksvoYZaPS1tOj9flVw2rPg,30,male,HongKong,traffic accident,"[[INST] \n[""Please design a car accident-focu...","For a 30-year-old individual, several factors ...",When evaluating the safety perception of speci...,Hong Kong is a densely populated urban area wi...
4,4,WkHeZe_Exxcwry_UbS7bZw,30,male,HongKong,traffic accident,"[[INST] \n[""Please design a car accident-focu...","For a 30-year-old, several factors in the imag...",When evaluating the safety perception of speci...,Hong Kong is a densely populated urban area wi...
...,...,...,...,...,...,...,...,...,...,...
4707,4984,B6jXBp8HViBTNANUascVbg,30,male,HongKong,traffic accident,"[[INST] \n[""Please design a car accident-focu...","For 30, the factors in the image that could im...",When evaluating the safety perception of speci...,When discussing the safety perception in Hong ...
4708,4985,E1hFilX1FmwMioq1o-hgKw,30,male,HongKong,traffic accident,"[[INST] \n[""Please design a car accident-focu...","For a 30-year-old individual, several factors ...",When evaluating the safety perception of speci...,Hong Kong is a densely populated urban area wi...
4709,4986,dqEW3eJiJmJJgIIuFD2PjQ,30,male,HongKong,traffic accident,"[[INST] \n[""Please design a car accident-focu...","For a 30-year-old, several factors in the imag...",When evaluating the safety perception of speci...,The built environment in Hong Kong is characte...
4710,4987,UzH6unzGLtuSsTtOYYvfXA,30,male,HongKong,traffic accident,"[[INST] \n[""Please design a car accident-focu...","For a 30-year-old, several factors in the imag...",When evaluating the safety perception of speci...,The built environment in Hong Kong is characte...


In [26]:
data.iloc[0]['text_description_gender'].split('\n')[2]

'1. **Traffic Signs and Signals**: Males may be more attentive to traffic signs and signals, as they are more likely to be drivers. In the image, there are traffic lights and signs that provide guidance to drivers and pedestrians.'

In [27]:
# encode a sentence
text_batch = [data.iloc[0]['text_description_gender'].split('\n')[2]]
encoding = tokenizer(text_batch, return_tensors='pt')
input_ids = encoding['input_ids'].cuda()
attention_mask = encoding['attention_mask'].cuda()

# 4 is most positive
true_class = '4'

# generate an explanation for the input
expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=0)[0]
# normalize scores
expl = (expl - expl.min()) / (expl.max() - expl.min())

# get the model classification
output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)
classification = output.argmax(dim=-1).item()
# get class name
class_name = classifications[classification]
# if the classification is negative, higher explanation scores are more negative
# flip for visualization
# if class_name == "NEGATIVE":
#   expl *= (-1)

tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())
print([(tokens[i], expl[i].item()) for i in range(len(tokens))])
vis_data_records = [visualization.VisualizationDataRecord(
                                expl,
                                output[0][classification],
                                classification,
                                true_class,
                                true_class,
                                1,       
                                tokens,
                                1)]
visualization.visualize_text(vis_data_records)



/opt/pytorch/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1151: indexSelectLargeIndex: block: [313,0,0], thread: [96,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/pytorch/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1151: indexSelectLargeIndex: block: [313,0,0], thread: [97,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/pytorch/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1151: indexSelectLargeIndex: block: [313,0,0], thread: [98,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/pytorch/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1151: indexSelectLargeIndex: block: [313,0,0], thread: [99,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/pytorch/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1151: indexSelectLargeIndex: block: [313,0,0], thread: [100,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/pytorch/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1151: indexSelectLargeIndex: block: [313,0,0], thread: [101,0,0] Assertion `srcIndex 

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [10]:
# encode a sentence
text_batch = ["This movie was the worst movie I have ever seen! some scenes were ridiculous."]
encoding = tokenizer(text_batch, return_tensors='pt')
input_ids = encoding['input_ids'].cuda()
attention_mask = encoding['attention_mask'].cuda()

# 0 is least positive
true_class = 0

# generate an explanation for the input
expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=0)[0]
# normalize scores
expl = (expl - expl.min()) / (expl.max() - expl.min())

# get the model classification
output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)
classification = output.argmax(dim=-1).item()
# get class name
class_name = classifications[classification]
# if the classification is negative, higher explanation scores are more negative
# flip for visualization
# if class_name == "NEGATIVE":
#   expl *= (-1)

tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())
print([(tokens[i], expl[i].item()) for i in range(len(tokens))])
vis_data_records = [visualization.VisualizationDataRecord(
                                expl,
                                output[0][classification],
                                classification,
                                true_class,
                                true_class,
                                1,       
                                tokens,
                                1)]
visualization.visualize_text(vis_data_records)

[('[CLS]', 0.0), ('this', 0.16134262084960938), ('movie', 0.05559374764561653), ('was', 0.10608766973018646), ('the', 0.20024555921554565), ('worst', 1.0), ('movie', 0.050500787794589996), ('i', 0.03956344351172447), ('have', 0.02918350324034691), ('ever', 0.1223004162311554), ('seen', 0.013092913664877415), ('!', 0.19819246232509613), ('some', 0.053029805421829224), ('scenes', 0.0202292799949646), ('were', 0.016616055741906166), ('rid', 0.06373898684978485), ('##icu', 0.011113693937659264), ('##lou', 0.007081188261508942), ('##s', 0.008961187675595284), ('.', 0.041823938488960266), ('[SEP]', 0.014206083491444588)]


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.94),0.0,1.0,[CLS] this movie was the worst movie i have ever seen ! some scenes were rid ##icu ##lou ##s . [SEP]
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.94),0.0,1.0,[CLS] this movie was the worst movie i have ever seen ! some scenes were rid ##icu ##lou ##s . [SEP]
,,,,




Hope this helps :)

As a Z by HP Global Data Science Ambassador, Yuanhao's content is sponsored and he was provided with HP products.