# Explain Attacking BERT models using CAptum

Captum is a PyTorch library to explain neural networks.

Here we show a minimal example using Captum to explain BERT models from TextAttack

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/QData/TextAttack/blob/master/docs/2notebook/Example_5_Explain_BERT.ipynb)

[![View Source on GitHub](https://img.shields.io/badge/github-view%20source-black.svg)](https://github.com/QData/TextAttack/blob/master/docs/2notebook/Example_5_Explain_BERT.ipynb)

## Let's import some packages

In [1]:
from captum.attr import visualization as viz
from textattack.datasets import HuggingFaceDataset
from textattack.models.tokenizers import AutoTokenizer
from textattack.models.wrappers import ModelWrapper, HuggingFaceModelWrapper
from transformers import AutoModelForSequenceClassification
from IPython.display import display, HTML

import torch

## Make GPU available

In [12]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)
print(device)

cuda:0


## Load our model and dataset

In [None]:
dataset = HuggingFaceDataset("ag_news", None, "train")
original_model = AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-ag-news")
original_tokenizer = AutoTokenizer("textattack/bert-base-uncased-ag-news")
model_wrapper = HuggingFaceModelWrapper(original_model,original_tokenizer)

model_wrapper.model.to(device)


## Define some useful functions

In [4]:
def captum_form(encoded):
    input_dict = {k: [_dict[k] for _dict in encoded] for k in encoded[0]}
    batch_encoded = {k: torch.tensor(v).to(device) for k, v in input_dict.items()}
    return batch_encoded


def calculate(input_ids, token_type_ids=None, position_ids=None, attention_mask=None):
    return model_wrapper.model(
        input_ids,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        attention_mask=attention_mask,
    )[0]


def display_html(html_str):
    display(HTML(html_str))

## Pick an Attribution Algorithm

In [5]:
from captum.attr import LayerIntegratedGradients

# more algorithms are avaliable at:
# https://github.com/pytorch/captum/blob/master/docs/algorithms_comparison_matrix.md

lig = LayerIntegratedGradients(calculate, model_wrapper.model.bert.embeddings)

## Pick an Attack Algorithm

In [7]:
from textattack.attack_recipes import PWWSRen2019
attack = PWWSRen2019.build(model_wrapper)

textattack: Unknown if model of class <class 'transformers.models.bert.modeling_bert.BertForSequenceClassification'> compatible with goal function <class 'textattack.goal_functions.classification.untargeted_classification.UntargetedClassification'>.


In [9]:
example_num = 5
results_iterable = attack.attack_dataset(dataset, indices=range(example_num))

viz_list = []

for n, result in enumerate(results_iterable):
    # get text
    orig = result.original_text()
    pert = result.perturbed_text()

    # get prediction
    encoded = model_wrapper.tokenizer.batch_encode([orig])
    batch_encoded = captum_form(encoded)
    logit = calculate(**batch_encoded)

    pert_encoded = model_wrapper.tokenizer.batch_encode([pert])
    pert_batch_encoded = captum_form(pert_encoded)
    logit_pert = calculate(**pert_batch_encoded)

    # attribute
    attributions, delta = lig.attribute(
        inputs=batch_encoded["input_ids"],
        additional_forward_args=(
            batch_encoded["token_type_ids"],
            batch_encoded["attention_mask"],
        ),
        n_steps=10,
        target=torch.argmax(logit, dim=1).item(),
        return_convergence_delta=True,
    )
    attributions_pert, delta_pert = lig.attribute(
        inputs=pert_batch_encoded["input_ids"],
        additional_forward_args=(
            pert_batch_encoded["token_type_ids"],
            pert_batch_encoded["attention_mask"],
        ),
        n_steps=10,
        target=torch.argmax(logit_pert, dim=1).item(),
        return_convergence_delta=True,
    )

    orig = original_tokenizer.tokenizer.tokenize(orig)
    pert = original_tokenizer.tokenizer.tokenize(pert)

    atts = attributions.sum(dim=-1).squeeze(0)
    atts = atts / torch.norm(atts)

    atts_pert = attributions_pert.sum(dim=-1).squeeze(0)
    atts_pert = atts_pert / torch.norm(atts_pert)
    
    # Visualization
    all_tokens = original_tokenizer.tokenizer.convert_ids_to_tokens(
        batch_encoded["input_ids"][0]
    )
    all_tokens_pert = original_tokenizer.tokenizer.convert_ids_to_tokens(
        pert_batch_encoded["input_ids"][0]
    )

    text_length = torch.sum(batch_encoded["attention_mask"]).detach().cpu().numpy()
    text_length_pert = (
        torch.sum(pert_batch_encoded["attention_mask"]).detach().cpu().numpy()
    )

    v = viz.VisualizationDataRecord(
        word_attributions=atts[:text_length].detach().cpu(),
        pred_prob=torch.max(logit).item(),
        pred_class=torch.argmax(logit, dim=1).item(),
        true_class=dataset[n][1],
        attr_class=dataset[n][1],
        attr_score=atts.sum().detach(),
        raw_input=all_tokens[:text_length],
        convergence_score=delta,
    )

    v_pert = viz.VisualizationDataRecord(
        word_attributions=atts_pert[:text_length_pert].detach().cpu(),
        pred_prob=torch.max(logit_pert).item(),
        pred_class=torch.argmax(logit_pert, dim=1).item(),
        true_class=dataset[n][1],
        attr_class=dataset[n][1],
        attr_score=atts_pert.sum().detach(),
        raw_input=all_tokens_pert[:text_length_pert],
        convergence_score=delta_pert,
    )

    viz_list.append(v)
    viz_list.append(v_pert)

    result_html_str = result.__str__(color_method="html").replace("\n\n", "<br>")

    display_html(result_html_str)

In [10]:
print('Visualizations For AG NEWS')
vis_table = viz.visualize_text(viz_list)

Visualizations For AG NEWS


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
2.0,2 (4.86),2.0,1.44,"[CLS] wall st . bears claw back into the black ( reuters ) reuters - short - sellers , wall street ' s d ##wind ##ling \ band of ultra - cy ##nic ##s , are seeing green again . [SEP]"
,,,,
2.0,3 (2.52),2.0,1.67,"[CLS] wall st . suffer claw back into the light ##lessness ( reuters ) reuters - short - sellers , wall street ' s d ##wind ##le \ ist ##hm ##us of ultra - cy ##nic ##s , are examine greenish again . [SEP]"
,,,,
2.0,2 (6.59),2.0,1.29,"[CLS] carly ##le looks toward commercial aerospace ( reuters ) reuters - private investment firm carly ##le group , \ which has a reputation for making well - timed and occasionally \ controversial plays in the defense industry , has quietly placed \ its bets on another part of the market . [SEP]"
,,,,
2.0,3 (3.29),2.0,3.05,"[CLS] carly ##le looks toward commercial aerospace ( reuters ) reuters - private invest ##it ##ure firm carly ##le group , \ which has a reputation for ca - ca well - timed and occasionally \ controversial plays in the denial industry , has quietly site \ its bets on another part of the market . [SEP]"
,,,,
2.0,2 (6.27),2.0,2.82,[CLS] oil and economy cloud stocks ' outlook ( reuters ) reuters - soaring crude prices plus worries \ about the economy and the outlook for earnings are expected to \ hang over the stock market next week during the depth of the \ summer do ##ld ##rum ##s . [SEP]
,,,,
