# Explain BERT models using CAptum

Captum is a PyTorch library to explain neural networks
Here we show a minimal example using Captum to explain BERT models from TextAttack

In [1]:
import torch
from copy import deepcopy

In [2]:
from textattack.datasets import HuggingFaceDataset
from textattack.models.tokenizers import AutoTokenizer
from textattack.models.wrappers import HuggingFaceModelWrapper
from textattack.models.wrappers import ModelWrapper
from transformers import AutoModelForSequenceClassification



In [3]:
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients, LayerDeepLiftShap, InternalInfluence, LayerGradientXActivation
from captum.attr import visualization as viz

In [4]:
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
print(device)
torch.cuda.set_device(device) 

cuda:2


In [5]:
dataset = HuggingFaceDataset("ag_news", None, "train")
original_model = AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-ag-news")
original_tokenizer = AutoTokenizer("textattack/bert-base-uncased-ag-news")
model = HuggingFaceModelWrapper(original_model,original_tokenizer)

Using custom data configuration default
Reusing dataset ag_news (/u/ss7mu/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
[34;1mtextattack[0m: Loading [94mdatasets[0m dataset [94mag_news[0m, split [94mtrain[0m.


In [42]:
def captum_form(encoded):
    input_dict = {k: [_dict[k] for _dict in encoded] for k in encoded[0]}
    batch_encoded = { k: torch.tensor(v).to(device) for k, v in input_dict.items()}
    return batch_encoded

def get_text(tokenizer,input_ids,token_type_ids,attention_mask):
    list_of_text = []
    number = input_ids.size()[0]
    for i in range(number):
        ii = input_ids[i,].cpu().numpy()
        tt = token_type_ids[i,]
        am = attention_mask[i,]
        txt = tokenizer.decode(ii, skip_special_tokens=True)
        list_of_text.append(txt)
    return list_of_text
    
sel =2
encoded = model.tokenizer.batch_encode([dataset[i][0]['text'] for i in range(sel)])
labels = [dataset[i][1] for i in range(sel)]

batch_encoded = captum_form(encoded)

clone = deepcopy(model)
clone.model.to(device)

def calculate(input_ids,token_type_ids,attention_mask):
    #convert back to list of text
    return clone.model(input_ids,token_type_ids,attention_mask)[0]
    
# x = calculate(**batch_encoded)    

lig = LayerIntegratedGradients(calculate, clone.model.bert.embeddings)
# lig = InternalInfluence(calculate, clone.model.bert.embeddings)
# lig = LayerGradientXActivation(calculate, clone.model.bert.embeddings)

bsl = torch.zeros(batch_encoded['input_ids'].size()).type(torch.LongTensor).to(device)
labels = torch.tensor(labels).to(device)

attributions,delta = lig.attribute(inputs=batch_encoded['input_ids'],
                              baselines=bsl,
                              additional_forward_args=(batch_encoded['token_type_ids'], batch_encoded['attention_mask']),
                              n_steps = 10,
                              target = labels,
                              return_convergence_delta=True
                              )
atts = attributions.sum(dim=-1).squeeze(0)
atts = atts / torch.norm(atts)

In [53]:
# print(attributions.size())
atts = attributions.sum(dim=-1).squeeze(0)
atts = atts / torch.norm(atts)

In [54]:
from textattack.attack_recipes import PWWSRen2019
attack = PWWSRen2019.build(model)

[34;1mtextattack[0m: Unknown if model of class <class 'textattack.models.wrappers.huggingface_model_wrapper.HuggingFaceModelWrapper'> compatible with goal function <class 'textattack.goal_functions.classification.untargeted_classification.UntargetedClassification'>.


In [58]:
results_iterable = attack.attack_dataset(dataset, indices=range(10))

viz_list = []

for n,result in enumerate(results_iterable):
    orig = result.original_text()
    pert = result.perturbed_text()
    
    encoded = model.tokenizer.batch_encode([orig])
    batch_encoded = captum_form(encoded)
    x = calculate(**batch_encoded) 
    print(x)
    print(dataset[n][1])
    
    pert_encoded = model.tokenizer.batch_encode([pert])
    pert_batch_encoded = captum_form(pert_encoded)
    x_pert = calculate(**pert_batch_encoded) 


    attributions,delta = lig.attribute(inputs=batch_encoded['input_ids'],
#                               baselines=bsl,
                              additional_forward_args=(batch_encoded['token_type_ids'], batch_encoded['attention_mask']),
                              n_steps = 10,
                              target = torch.argmax(calculate(**batch_encoded)).item(),
                              return_convergence_delta=True
                              )
    attributions_pert,delta_pert = lig.attribute(inputs=pert_batch_encoded['input_ids'],
#                               baselines=bsl,
                              additional_forward_args=(pert_batch_encoded['token_type_ids'], pert_batch_encoded['attention_mask']),
                              n_steps = 10,
                              target = torch.argmax(calculate(**pert_batch_encoded)).item(),
                              return_convergence_delta=True
                              )
    
    orig = original_tokenizer.tokenizer.tokenize(orig)
    pert = original_tokenizer.tokenizer.tokenize(pert)
    
    atts = attributions.sum(dim=-1).squeeze(0)
    atts = atts / torch.norm(atts)

    
    
    atts_pert = attributions_pert.sum(dim=-1).squeeze(0)
    atts_pert = atts_pert / torch.norm(atts)
    
    
    
    all_tokens = original_tokenizer.tokenizer.convert_ids_to_tokens(batch_encoded['input_ids'][0])
    all_tokens_pert = original_tokenizer.tokenizer.convert_ids_to_tokens(pert_batch_encoded['input_ids'][0])
    
        
    v = viz.VisualizationDataRecord(
                    atts[:45].detach().cpu(),
                    torch.max(x).item(),
                    torch.argmax(x,dim=1).item(),
                    dataset[n][1],
                    2,
                    atts.sum().detach(), 
                    all_tokens[:45],
                    delta)
    
    v_pert = viz.VisualizationDataRecord(
                    atts_pert[:45].detach().cpu(),
                    torch.max(x_pert).item(),
                    torch.argmax(x_pert,dim=1).item(),
                    dataset[n][1],
                    2,
                    atts_pert.sum().detach(), 
                    all_tokens_pert[:45],
                    delta_pert)
    viz_list.append(v)
    viz_list.append(v_pert)

#     print(result.perturbed_text())
    print(result.__str__(color_method='ansi'))



tensor([[-3.3016, -2.1467,  2.0953,  2.7272]], device='cuda:2',
       grad_fn=<AddmmBackward>)
2
[94mBusiness (96%)[0m --> [35mSci/tech (68%)[0m

Wall St. [94mBears[0m Claw Back Into the [94mBlack[0m (Reuters) Reuters - Short-sellers, Wall Street's [94mdwindling[0m\[94mband[0m of ultra-cynics, are [94mseeing[0m [94mgreen[0m again.

Wall St. [35msuffer[0m Claw Back Into the [35mlightlessness[0m (Reuters) Reuters - Short-sellers, Wall Street's [35mdwindle[0m\[35misthmus[0m of ultra-cynics, are [35mexamine[0m [35mgreenish[0m again.
tensor([[-4.8619, -0.8199,  3.9263,  1.1104]], device='cuda:2',
       grad_fn=<AddmmBackward>)
2
[94mBusiness (100%)[0m --> [35mSci/tech (50%)[0m

Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private [94minvestment[0m firm Carlyle Group,\which has a reputation for [94mmaking[0m well-timed and occasionally\controversial plays in the [94mdefense[0m industry, has quietly [94mplaced[0m\its bets on another p

In [60]:
print('\033[1m', 'Visualizations For AG NEWS', '\033[0m')
viz.visualize_text(viz_list)

[1m Visualizations For AG NEWS [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
2.0,3 (2.73),2.0,0.2,"[CLS] wall st . bears claw back into the black ( reuters ) reuters - short - sellers , wall street ' s d ##wind ##ling \ band of ultra - cy ##nic ##s , are seeing green again . [SEP] [PAD] [PAD] [PAD] [PAD]"
,,,,
2.0,3 (5.33),2.0,2.42,"[CLS] wall st . suffer claw back into the light ##lessness ( reuters ) reuters - short - sellers , wall street ' s d ##wind ##le \ ist ##hm ##us of ultra - cy ##nic ##s , are examine greenish again . [SEP] [PAD]"
,,,,
2.0,2 (3.93),2.0,1.28,"[CLS] carly ##le looks toward commercial aerospace ( reuters ) reuters - private investment firm carly ##le group , \ which has a reputation for making well - timed and occasionally \ controversial plays in the defense industry , has quietly placed \ its bets"
,,,,
2.0,3 (4.50),2.0,16.14,"[CLS] carly ##le looks toward commercial aerospace ( reuters ) reuters - private invest ##it ##ure firm carly ##le group , \ which has a reputation for ca - ca well - timed and occasionally \ controversial plays in the denial industry , has quietly"
,,,,
2.0,2 (4.61),2.0,1.43,[CLS] oil and economy cloud stocks ' outlook ( reuters ) reuters - soaring crude prices plus worries \ about the economy and the outlook for earnings are expected to \ hang over the stock market next week during the depth of the \ summer
,,,,


In [None]:
# reference for viz datarecord
# def __init__(
#         self,
#         word_attributions,
#         pred_prob,
#         pred_class,
#         true_class,
#         attr_class,
#         attr_score,
#         raw_input,
#         convergence_score,
#     ):