# Feature Explanation for Bert-based Sentiment Classification Model

Captum has a nice range of many feature methods including simple gradient, integrated gradient, deeplift and so on for neural network models. We simply use those implemented functions to compute the importance scores (or arribute scores) for each dimension of the input resulting to the specific output. For NLP, the input is simply words or tokens, for which the scores are aggregated from all embedding dimensions.

Install Captum by ```pip install captum```.


Load a pretrained sentiment classification model use Bert-base.

In [1]:
import torch
from networks import BertBase
from captum.attr import Saliency, IntegratedGradients, ShapleyValueSampling, DeepLift, visualization 

In [2]:
model_dir = '/home/lyu/robustness/SST/model/Bert_classifier.pt'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Bert_classifier = BertBase(1).to(device)
if device.type == 'cpu': 
    Bert_classifier.load_state_dict(torch.load(model_dir, map_location=lambda storage, loc: storage))
else:
    Bert_classifier.load_state_dict(torch.load(model_dir))
    
Bert_classifier.eval()

BertBase(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
          

Let's try to predict a sentence with the pretrained Bert model.

In [3]:
sent = 'this movie is really good.'
tokens = Bert_classifier.tokenizer.tokenize(sent)
sent_idx = Bert_classifier.tokenizer.convert_tokens_to_ids(tokens)
sent_idx = Bert_classifier.tokenizer.add_special_tokens_single_sentence(sent_idx)   # add [CLS] and [SEP].

out = Bert_classifier(torch.tensor([sent_idx]).to(device))
confidence = torch.sigmoid(out).item()
label = round(confidence)
print(confidence, label)

0.9929807186126709 1


Now we are interested in which inputs contribute to the output classification. We first try the simple gradient method from captum. If the text input are token indexes, Captum uses ```configure_interpretable_embedding_layer()``` to denote the embedding layer and finally return the gradients of embeddings. This can be annoying if you interrupt your process and can't do ```remove_interpretable_embedding_layer()```. Thus I prefer to take the embeddings as input directly, this however requires direct access to the model definition, then you can simply modify the ```forward``` function. 

In [4]:
# convert the input sentence to embeddings
embeddings = Bert_classifier.bert.embeddings(torch.tensor([sent_idx]).to(device)).detach()
print(embeddings.shape)

torch.Size([1, 8, 768])


In [5]:
SA = Saliency(Bert_classifier)
attribute = SA.attribute(embeddings, target=None, abs=False).detach()  # target denotes which output class (only one class for this binary cls), abs denotes reture absolute value.
print(attribute.shape)

torch.Size([1, 8, 768])


  "required_grads has been set automatically." % index


Now we have the gradient/importance score for each embedding dimension of the input, we aggregate to token base by summing them up, of course there are other ways like L2 norm or multiplying with embedding values

In [6]:
importance = torch.sum(attribute, -1).squeeze(0)   # summing embedding values, remove batch dimension if only one instance.
importance = importance/torch.norm(importance)
importance = importance.numpy()
print(importance)

[-0.8838346  -0.18882875  0.21356963  0.07026832 -0.00701645 -0.25199187
 -0.25687048  0.05566759]


We can visualize it simply using Captum again.  

In [7]:
vis_data_record = []
vis_data_record.append(visualization.VisualizationDataRecord(importance, confidence, label, '1',
                                          'pos', importance.sum(), ['[CLS]'] + tokens + ['[SEP]'], 0))
visualization.visualize_text(vis_data_record)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.99),pos,-1.25,[CLS] this movie is really good . [SEP]
,,,,


This results look weird to me, we would expect ```really good``` have positive contribution to the prediction. Let's try some other methods, which might give different results. For simplicity, let's wrappy all methods together.

In [8]:

def get_saliency_for_bert(model, input_embedding, target, method):        
    if method == 'simple':
        sa = Saliency(model)
        attr_sa = sa.attribute(input_embedding, target=target, abs=False).detach()  # for binary
        return attr_sa
    elif method == 'integrated_gradient':
        IG = IntegratedGradients(model)
        attr_ig = IG.attribute(input_embedding, target=target, n_steps=10, return_convergence_delta=False).detach()
        return attr_ig
    elif method == 'deeplift':
        DL = DeepLift(model)
        attr_df = DL.attribute(input_embedding, target=target, return_convergence_delta=False).detach()
        return attr_df
    elif method == 'shapley':
        Shap = ShapleyValueSampling(model)
        feature_mask = torch.tensor(range(input_embedding.shape[1])).repeat_interleave(input_embedding.shape[2]).reshape(input_embedding.shape[1], -1)
        attr_shap = Shap.attribute(input_embedding, target=target, feature_mask=feature_mask).detach()
        return attr_shap
    else:
        raise NotImplementedError('Only supports simple, integrated_gradient deeplift and shapley for now.')


def visulize_features(attribute):
    importance = torch.sum(attribute, -1).squeeze(0)   
    importance = importance/torch.norm(importance)
    importance = importance.numpy()
    vis_data_record = []
    vis_data_record.append(visualization.VisualizationDataRecord(importance, confidence, label, '1',
                                          'pos', importance.sum(), ['[CLS]'] + tokens + ['[SEP]'], 0))
    visualization.visualize_text(vis_data_record)

In [9]:
# try integrated gradient method
method = 'integrated_gradient'
attribute_ig = get_saliency_for_bert(Bert_classifier, embeddings, None, method)
visulize_features(attribute_ig)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.99),pos,0.25,[CLS] this movie is really good . [SEP]
,,,,


In [11]:
# try deeplift method
method = 'deeplift'
attribute_dl = get_saliency_for_bert(Bert_classifier, embeddings, None, method)
visulize_features(attribute_dl)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.99),pos,-0.07,[CLS] this movie is really good . [SEP]
,,,,


In [12]:
# try shapley method
method = 'shapley'
attribute_sh = get_saliency_for_bert(Bert_classifier, embeddings, None, method)
visulize_features(attribute_sh)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.99),pos,1.53,[CLS] this movie is really good . [SEP]
,,,,


We can see different feature methods give pretty different importance distribution, of course there are many reasons like aggregation and perturbation, especially for Bert model where the ```[CLS]``` is directly used for classification. The methods applied well on vision tasks are probably not directly applicable to text data, and the interpretability of Bert model still remains challenging.  