We perform word attribution for the embedding layer of a BertForMultipleChoice-model on the FCI using Captum.

Overview:
1) set up model/tokenizer
2) define custom forward pass, returning a vector with length (choices)
3) 


In [None]:
# Question and choices
question = "What is the capital of France?"
choices = ["Berlin", "Madrid", "Paris"]
ground_truth_idx = 2

# Question and choices (overwrite)
# question = "Two metal balls are the same size but one weighs twice as much as the other. The balls \
# are dropped from the roof of a single story building at the same instant of time. The time it takes \
# the balls to reach the ground below will be:"

# choices = ["About half as long for the heavier ball as for the lighter one",
#            "About half as long for the lighter ball as for the heavier one",
#            "About the same for both balls",
#            "Considerably less for the heavier ball, but not necessarily half as long",
#            "Considerably less for the lighter ball, but not necessarily half as long"]

# ground_truth_idx = 2

In [None]:
# Imports
import torch
from transformers import AutoTokenizer, RobertaForMultipleChoice
from captum.attr import LayerIntegratedGradients
from captum.attr import visualization as viz
from dataclasses import dataclass

# Settings
torch.manual_seed(42)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Config (only for debug atm)
@dataclass
class Config:
    debug: bool = True
cfg = Config()

In [3]:
# Load model and tokenizer
model_name = "LIAMF-USP/roberta-large-finetuned-race"
model = RobertaForMultipleChoice.from_pretrained(model_name)
tokenizer = RobertaTokenizer.from_pretrained(model_name)

# Set to evaluation mode
model.to(device)
model.eval()
model.zero_grad()
# model # Uncomment to print architecture

In [4]:
def forward_func(input_ids, attention_mask, token_type_ids=None):
    """Custom forward pass for Captum Integrated Gradients. Captum wants [batch size, seq_len] while mulitple choice/classification
    models expect [batch size, num_choices, seq_len]. We return the logits for each choice."""

    input_ids = input_ids.unsqueeze(0)
    attention_mask = attention_mask.unsqueeze(0)
    if token_type_ids is not None: token_type_ids = token_type_ids.unsqueeze(0)

    logits = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids
    ).logits

    if cfg.debug: print("Logits from forward pass:", logits.shape)
    return logits

In [5]:
# Get input ids, attn masks, token type ids and baseline for Captum attribution methods.

# Tokenize for multiple choice and get input ids, attention masks, and token type ids
encoding = tokenizer(
    [question] * len(choices),
    choices,
    return_tensors="pt",
    padding=True,
    truncation=True
)

input_ids      = encoding["input_ids"]        # shape: [choices, seq_len]
attention_mask = encoding["attention_mask"]
# token_type_ids = encoding["token_type_ids"] # No token type ids for RoBERTa


# Compute model prediction and get choice index
logits = forward_func(input_ids, attention_mask=attention_mask)
choice_idx = torch.argmax(logits).item()

print('-'*50)
print('Question:', question)
print('Predicted Answer:', f'{choice_idx})', choices[choice_idx])

if cfg.debug: print(logits)

# Note that RoBERTa has two separators between question and choice
# Create baseline input: [CLS] [PAD] ... [PAD] [SEP] [SEP] [PAD] ... [PAD] [SEP] for choice index expanded as input_ids
ref_token_id = tokenizer.pad_token_id       # padding
sep_token_id = tokenizer.sep_token_id       # sepatator
cls_token_id = tokenizer.cls_token_id       # start of sequence

ref_tokens = [cls_token_id]
for i, token in enumerate(input_ids[choice_idx, 1:]):
    if token == sep_token_id:
        ref_tokens += [sep_token_id]
    else:
        ref_tokens += [ref_token_id]

ref_input_ids = torch.tensor(ref_tokens, dtype=torch.long).expand_as(input_ids)

Logits from forward pass: torch.Size([1, 5])
--------------------------------------------------
Question: Two metal balls are the same size but one weighs twice as much as the other. The balls are dropped from the roof of a single story building at the same instant of time. The time it takes the balls to reach the ground below will be:
Predicted Answer: 3) Considerably less for the heavier ball, but not necessarily half as long
tensor([[-0.7417,  1.9860, -3.9501,  5.7018,  5.5120]],
       grad_fn=<ViewBackward0>)


In [None]:
if cfg.debug: print(input_ids.shape, attention_mask.shape)

# LayerIntegratedGradients for attribution
layer = model.roberta.embeddings
lig = LayerIntegratedGradients(forward_func, layer)

# Compute attributions for chosen index (Captum wants [choices, seq_len] and gives attr shape [choices, seq_len, embedding_dim])
attributions, delta = lig.attribute(
    inputs=input_ids,
    baselines=ref_input_ids,
    additional_forward_args=(attention_mask),
    target=choice_idx,  # Target the chosen answer, uses [0,target]
    n_steps=50,  # Number of steps for approximation
    return_convergence_delta=True
)

# Sum across embedding dimensions to get token-level importance
token_attributions = attributions.sum(dim=-1).squeeze(0)  # shape: [num_choices, seq_len]
token_attributions = token_attributions / torch.norm(token_attributions)  # Normalize
if cfg.debug: print('Token attributions:', token_attributions.shape)

print('Attributions per token at choice_idx:', token_attributions[choice_idx])

torch.Size([5, 68]) torch.Size([5, 68])
Logits from forward pass: torch.Size([1, 5])
Logits from forward pass: torch.Size([1, 5])


In [None]:
# The RoBERa tokenizer returns ugly word-tokens - #s is the start-token, the #/s is a separator and the dotted G is a space

choice_attributions = token_attributions[choice_idx]
tokens = tokenizer.convert_ids_to_tokens(input_ids[choice_idx])

vis = viz.VisualizationDataRecord(
                        choice_attributions,                        # word attributions
                        torch.max(torch.softmax(logits, dim=1)),    # prediction probability
                        torch.argmax(logits),                       # predicted class
                        ground_truth_idx,                           # ground truth class
                        str(choice_idx),                            # attributing to this class
                        token_attributions.sum(),                   # summed attribution score
                        tokens,                                     # tokens for the question and choice
                        delta,                                      # convergence delta
)

visualisation = viz.visualize_text([vis]) # get return object to avoid passing the vis object to the ipynb

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
2.0,2 (0.99),2.0,0.68,#s What Ġis Ġthe Ġcapital Ġof ĠFrance ? #/s #/s Paris #/s #pad
,,,,
