We perform word attribution for the embedding layer of a BertForMultipleChoice-model on the FCI using Captum.

Overview:
1) set up model/tokenizer
2) define custom forward pass, returning a vector with length (choices)
3) 


In [None]:
# Question and choices
question = "What is the capital of France?"
choices = ["Berlin", "Madrid", "Paris"]
ground_truth_idx = 2

In [None]:
# Imports
import torch
from transformers import RobertaTokenizer, RobertaForMultipleChoice
from captum.attr import LayerIntegratedGradients
from captum.attr import visualization as viz
from dataclasses import dataclass

# Settings
torch.manual_seed(42)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Config (only for debug atm)
@dataclass
class Config:
    debug: bool = True
cfg = Config()

In [None]:
# Load model and tokenizer
model_name = 'jonastokoliu/multi_choice_bert-base-uncased_swag_finetune'
model = RobertaForMultipleChoice.from_pretrained(model_name)
tokenizer = RobertaTokenizer.from_pretrained(model_name)

# Set to evaluation mode
model.to(device)
model.eval()
model.zero_grad()
# model # Uncomment to print architecture

In [4]:
def forward_func(input_ids, attention_mask, token_type_ids):
    """Custom forward pass for Captum Integrated Gradients. Captum wants [batch size, seq_len] while mulitple choice/classification
    models expect [batch size, num_choices, seq_len]. We return the logits for each choice."""

    input_ids = input_ids.unsqueeze(0)
    attention_mask = attention_mask.unsqueeze(0)
    token_type_ids = token_type_ids.unsqueeze(0)

    logits = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids
    ).logits

    if cfg.debug: print("Logits from forward pass:", logits.shape)
    return logits

In [20]:
# Get input ids, attn masks, token type ids and baseline for Captum attribution methods.

# Tokenize for multiple choice and get input ids, attention masks, and token type ids
encoding = tokenizer(
    [question] * len(choices),
    choices,
    return_tensors="pt",
    padding=True,
    truncation=True
)

input_ids      = encoding["input_ids"]      # shape: [choices, seq_len]
attention_mask = encoding["attention_mask"]
token_type_ids = encoding["token_type_ids"]


# Compute model prediction and get choice index
logits = forward_func(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
choice_idx = torch.argmax(logits).item()

print('-'*50)
print('Question:', question)
print('Predicted Answer:', f'{choice_idx})', choices[choice_idx])

if cfg.debug: print(logits)


# Create baseline input: [CLS] [PAD] ... [PAD] [SEP] [PAD] ... [PAD] [SEP] for choice index expanded as input_ids
ref_token_id = tokenizer.pad_token_id       # padding
sep_token_id = tokenizer.sep_token_id       # sepatator
cls_token_id = tokenizer.cls_token_id       # start of sequence

ref_tokens = [cls_token_id]
for i, token in enumerate(input_ids[choice_idx, 1:]):
    if token == sep_token_id:
        ref_tokens += [sep_token_id]
    else:
        ref_tokens += [ref_token_id]

ref_input_ids = torch.tensor(ref_tokens, dtype=torch.long).expand_as(input_ids)

Logits from forward pass: torch.Size([1, 5])
--------------------------------------------------
Question: Two metal balls are the same size but one weighs twice as much as the other. The balls are dropped from the roof of a single story building at the same instant of time. The time it takes the balls to reach the ground below will be:
Predicted Answer: 1) About half as long for the lighter ball as for the heavier one
tensor([[8.5428, 8.5783, 3.6178, 7.4995, 7.3461]], grad_fn=<ViewBackward0>)


In [None]:
if cfg.debug: print(input_ids.shape, attention_mask.shape, token_type_ids.shape)

# LayerIntegratedGradients for attribution
layer = model.bert.embeddings
lig = LayerIntegratedGradients(forward_func, layer)

# Compute attributions for chosen index (Captum wants [choices, seq_len] and gives attr shape [choices, seq_len, embedding_dim])
attributions, delta = lig.attribute(
    inputs=input_ids,
    baselines=ref_input_ids,
    additional_forward_args=(attention_mask, token_type_ids),
    target=choice_idx,  # Target the chosen answer, uses [0,target]
    n_steps=50,  # Number of steps for approximation
    return_convergence_delta=True
)

# Sum across embedding dimensions to get token-level importance
token_attributions = attributions.sum(dim=-1).squeeze(0)  # shape: [num_choices, seq_len]
token_attributions = token_attributions / torch.norm(token_attributions)  # Normalize
if cfg.debug: print('Token attributions:', token_attributions.shape)

print('Attributions per token at choice_idx:', token_attributions[choice_idx])

torch.Size([5, 66]) torch.Size([5, 66]) torch.Size([5, 66])
Logits from forward pass: torch.Size([1, 5])
Logits from forward pass: torch.Size([1, 5])
Logits from forward pass: torch.Size([1, 250])
Logits from forward pass: torch.Size([1, 5])
Logits from forward pass: torch.Size([1, 5])
Token attributions: torch.Size([5, 66])
Attributions per token at choice_idx: tensor([ 0.1777, -0.0664,  0.1701,  0.2068, -0.0117,  0.1075,  0.1174,  0.0669,
         0.0917, -0.0483,  0.0640,  0.1771,  0.0833,  0.1283,  0.0704, -0.0792,
        -0.0006, -0.0668,  0.0262,  0.0692, -0.0239,  0.0951,  0.0349,  0.0808,
        -0.1655, -0.0013, -0.0053,  0.1474, -0.0500, -0.1920, -0.0729, -0.1197,
         0.0561,  0.2201, -0.0320,  0.1257,  0.0912,  0.0641,  0.0221,  0.0254,
         0.1329, -0.0117,  0.1142,  0.0886, -0.0128,  0.0718, -0.0201, -0.1662,
         0.0271,  0.0223,  0.1301, -0.0011,  0.2600,  0.2209,  0.0982,  0.1799,
        -0.0386, -0.1950,  0.3507,  0.0265, -0.0633, -0.2807, -0.1878,  0.1

In [None]:
choice_attributions = token_attributions[choice_idx]
tokens = tokenizer.convert_ids_to_tokens(input_ids[choice_idx])

vis = viz.VisualizationDataRecord(
                        choice_attributions,                        # word attributions
                        torch.max(torch.softmax(logits, dim=1)),    # prediction probability
                        torch.argmax(logits),                       # predicted class
                        ground_truth_idx,                           # ground truth class
                        str(choice_idx),                            # attributing to this class
                        attributions.sum(),                         # summed attribution score
                        tokens,                                     # tokens for the question and choice
                        delta,                                      # convergence delta
)

visualisation = viz.visualize_text([vis]) # get return object to avoid passing the vis object to the ipynb

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
2.0,1 (0.38),1.0,0.01,[CLS] two metal balls are the same size but one weighs twice as much as the other . the balls are dropped from the roof of a single story building at the same instant of time . the time it takes the balls to reach the ground below will be : [SEP] about half as long for the lighter ball as for the heavier one [SEP]
,,,,
