Given a multiple choice question, we seek to determine which words are most important for the BERT-model's answer selection.

Model: RoBERTa adapted for Multiple Choice (non-causal attention). If you want to specialize your model, you can fine-tune it (eg. specifically for physics questions) as described [here](https://huggingface.co/docs/transformers/en/tasks/multiple_choice).

In [108]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModel, utils

from captum.attr import visualization as viz
from captum.attr import LayerConductance, LayerIntegratedGradients

torch.manual_seed(42)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [55]:
from transformers import BertForMultipleChoice, BertTokenizer

model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForMultipleChoice.from_pretrained(model_name)
model.to(device)
model.eval()
model.zero_grad()
# model

Some weights of BertForMultipleChoice were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [56]:
def predict(inputs, token_type_ids=None, position_ids=None, attention_mask=None):
    output = model(inputs, token_type_ids=token_type_ids,
                 position_ids=position_ids, attention_mask=attention_mask, )
    return output.logits

def squad_pos_forward_func(inputs, token_type_ids=None, position_ids=None, attention_mask=None, position=0):
    logits = predict(inputs,
                   token_type_ids=token_type_ids,
                   position_ids=position_ids,
                   attention_mask=attention_mask)[position]
    print(logits.max(1).values.shape)
    return logits.max(1).values

In [61]:
question = ["Two metal balls are the same size but one weighs twice as much as the other. The balls \
are dropped from the roof of a single story building at the same instant of time. The time it takes \
the balls to reach the ground below will be:"]

choices = ["About half as long for the heavier ball as for the lighter one",
           "About half as long for the lighter ball as for the heavier one",
           "About the same for both balls",
           "Considerably less for the heavier ball, but not necessarily half as long",
           "Considerably less for the lighter ball, but not necessarily half as long"]

ground_truth = "About the same for both balls"

# Tokenizing: [CLS] question [SEP] [SEP] choice [SEP] for each choice
encoding = tokenizer(
    question * len(choices),  # Repeat context for each choice
    choices,
    padding=True,
    truncation=True,
    return_tensors="pt"
) # returns dictionary with input_ids and corresp. token_type_ids and attention_mask (ensuring that padding tokens are not attended to)

input_ids = encoding["input_ids"].unsqueeze(0)  # shape: (1, num_choices, seq_length)

token_type_ids = encoding["token_type_ids"].unsqueeze(0)  # shape: (1, num_choices, seq_length)
attention_mask = encoding["attention_mask"].unsqueeze(0) # attention mask to ignore padding tokens

logits = predict(input_ids, attention_mask=attention_mask)
choice_idx = torch.argmax(logits, dim=1).item()
choice = choices[choice_idx]

print('Question:', question[0])
print('Predicted Answer:', choice)
print('Ground Truth:', ground_truth)

print('Logits:', logits) # Too gauge conviction strength

Question: Two metal balls are the same size but one weighs twice as much as the other. The balls are dropped from the roof of a single story building at the same instant of time. The time it takes the balls to reach the ground below will be:
Predicted Answer: About half as long for the lighter ball as for the heavier one
Ground Truth: About the same for both balls
Logits: tensor([[0.4318, 0.4362, 0.4330, 0.4017, 0.4140]], grad_fn=<ViewBackward0>)


In [62]:
# Helper
def print_tokens(input_ids):
    tokens = tokenizer.convert_ids_to_tokens(input_ids)
    print("Tokens:", tokens)

cls_token_id = tokenizer.cls_token_id # Start
ref_token_id = tokenizer.pad_token_id # Padding
sep_token_id = tokenizer.sep_token_id # Separator
# print(cls_token_id, ref_token_id, sep_token_id)

# The exact setup depends slightly on tokenizer
# Reference tokens: [CLS] [PAD] ... [PAD] [SEP] ([SEP]) [PAD] ... [PAD] [SEP] [PAD] ..., for each choice
# Type tokens: [0] for question, [1] for choice, for each choice
ref_tokens_all = []

for i in range(len(choices)):
    ref_tokens = [cls_token_id]

    for token in input_ids[0,i,1:]:

        if token == sep_token_id:
            ref_tokens += [token]
        else:
            ref_tokens += [ref_token_id]

    ref_tokens_all.append(ref_tokens)
        
ref_input_ids = torch.tensor(ref_tokens_all, dtype=torch.long).unsqueeze(0)

# print_tokens(input_ids[0,0,:])
# print(input_ids[:,0,:], ref_input_ids[:,0,:], token_type_ids[:,0,:],)
# print(ref_input_ids.shape, token_type_ids.shape, input_ids.shape)

# Position ids: [0, 1, 2, ...]
seq_length = input_ids.size(0)
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).expand_as(input_ids)

In [63]:
lig = LayerIntegratedGradients(squad_pos_forward_func, model.bert.embeddings)
# ref_input_ids = ref_input_ids[:, 2, :]
# input_ids = input_ids[:, 2, :]
# token_type_ids = token_type_ids[:, 2, :]
# attention_mask = attention_mask[:, 2, :]


print(ref_input_ids.shape, token_type_ids.shape, input_ids.shape, attention_mask.shape)


attributions, delta = lig.attribute(inputs=input_ids, # Here only attribution for the chosen input
                                  baselines=ref_input_ids,
                                  additional_forward_args=(token_type_ids, position_ids, attention_mask, 2),
                                  target=2,
                                  return_convergence_delta=True)

torch.Size([1, 5, 66]) torch.Size([1, 5, 66]) torch.Size([1, 5, 66]) torch.Size([1, 5, 66])


IndexError: index 2 is out of bounds for dimension 0 with size 1

In [None]:
question = ["A large truck breaks down out on the road and receives a push back into town by a small compact car. \
While the car, still pushing the truck, is speeding up to get up to cruising speed:"]
choices = [
"The amount of force with which the car pushes on the truck is equal to that with which the truck pushes back on the car.",
"The amount of force with which the car pushes on the truck is smaller than that with which the truck pushes back on the car.",
"The amount of force with which the car pushes on the truck is greater than that with which the truck pushes back on the car.",
"The car's engine is running so the car pushes against the truck, but the truck's engine is not running so the truck cannot \
push back against the car. The truck is pushed forward simply because it is in the way of the car.",
"Neither the car nor the truck exert any force on the other. The truck is pushed forward simply because it is in the way of the car."]

encoding = tokenizer(
    question * len(choices),  # Repeat context for each choice
    choices,
    padding=True,
    truncation=True,
    return_tensors="pt"
)

# Add batch dimension
input_ids = encoding["input_ids"].unsqueeze(0)  # shape: (1, num_choices, seq_length)
attention_mask = encoding["attention_mask"].unsqueeze(0)

outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits  # shape: (batch_size, num_choices)

# Predicted answer
predicted_choice = torch.argmax(logits, dim=1).item()
print("Predicted:", choices[predicted_choice])

Predicted: Neither the car nor the truck exert any force on the other. The truck is pushed forward simply because it is in the way of the car.


In [None]:
question = ["A large truck breaks down out on the road and receives a push back into town by a small compact car. \
After the car reaches the constant cruising speed at which its driver wishes to push the truck:"]

choices = [
"The amount of force with which the car pushes on the truck is equal to that with which the truck pushes back on the car.",
"The amount of force with which the car pushes on the truck is smaller than that with which the truck pushes back on the car.",
"The amount of force with which the car pushes on the truck is greater than that with which the truck pushes back on the car.",
"The car's engine is running so the car pushes against the truck, but the truck's engine is not running so the truck cannot \
push back against the car. The truck is pushed forward simply because it is in the way of the car.",
"Neither the car nor the truck exert any force on the other. The truck is pushed forward simply because it is in the way of the car."
]

encoding = tokenizer(
    question * len(choices),  # Repeat context for each choice
    choices,
    padding=True,
    truncation=True,
    return_tensors="pt"
)

# Add batch dimension
input_ids = encoding["input_ids"].unsqueeze(0)  # shape: (1, num_choices, seq_length)
attention_mask = encoding["attention_mask"].unsqueeze(0)

outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits  # shape: (batch_size, num_choices)

# Predicted answer
predicted_choice = torch.argmax(logits, dim=1).item()
print("Predicted:", choices[predicted_choice])

Predicted: The amount of force with which the car pushes on the truck is smaller than that with which the truck pushes back on the car.


In [None]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

In [71]:
import torch
from transformers import BertTokenizer, BertForMultipleChoice
from captum.attr import LayerIntegratedGradients

# Load model and tokenizer
model_name = "bert-base-uncased"
model = BertForMultipleChoice.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)
model.eval()

# Sample question with two options
question = "What is the capital of France?"
choices = ["Paris is the capital of France.", "Berlin is the capital of France."]

# Tokenize for multiple choice
encoding = tokenizer(
    [question] * len(choices),
    choices,
    return_tensors="pt",
    padding=True,
    truncation=True
)

# Model expects input_ids, attention_mask, token_type_ids in a batch of choices
input_ids = encoding["input_ids"]       # shape: [num_choices, seq_len]
attention_mask = encoding["attention_mask"]
token_type_ids = encoding["token_type_ids"]

# Prepare baseline 
with torch.no_grad():
    baseline_input_ids = torch.ones_like(input_ids) * tokenizer.pad_token_id  # Use pad token as baseline

# Choose the embedding layer
embedding_layer = model.bert.embeddings

# Forward function: outputs logits for each choice
def forward_func(input_ids, attention_mask, token_type_ids):
    logits = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids
    ).logits
    print(logits.shape)
    return logits

# Create a wrapper for LayerIntegratedGradients
lig = LayerIntegratedGradients(forward_func, embedding_layer)

# Preparing shape for LayerIntegratedGradients
input_ids = input_ids.unsqueeze(0)  # Add batch dimension
attention_mask = attention_mask.unsqueeze(0)  # Add batch dimension 
token_type_ids = token_type_ids.unsqueeze(0)  # Add batch dimension
baseline_input_ids = baseline_input_ids.unsqueeze(0)  # Add batch dimension

print(input_ids.shape, attention_mask.shape, token_type_ids.shape)

# Compute attributions for the correct answer (index 0 = Paris)
attributions, delta = lig.attribute(
    inputs=input_ids,
    baselines=baseline_input_ids,
    additional_forward_args=(attention_mask, token_type_ids),
    target=0,  # Target the first choice (Paris)    
    return_convergence_delta=True
)

Some weights of BertForMultipleChoice were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


torch.Size([1, 2, 17]) torch.Size([1, 2, 17]) torch.Size([1, 2, 17])
torch.Size([1, 2])
torch.Size([1, 2])
torch.Size([50, 2])


AssertionError: Attributions tensor and the end_point must match on the first dimension but found attribution: 2 and end_point: 1

In [106]:
# trying instead flat input

import torch
from transformers import BertTokenizer, BertForMultipleChoice
from captum.attr import LayerIntegratedGradients

# Load model and tokenizer
model_name = 'jonastokoliu/multi_choice_bert-base-uncased_swag_finetune'
model = BertForMultipleChoice.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)
model.eval()

# Sample question with two options
question = "What is the capital of France?"
choices = ["Paris is the capital of France.", "Berlin is the capital of France."]

# Tokenize for multiple choice
encoding = tokenizer(
    [question] * len(choices),
    choices,
    return_tensors="pt",
    padding=True,
    truncation=True
)

# Model expects input_ids, attention_mask, token_type_ids in a batch of choices
input_ids = encoding["input_ids"]       # shape: [num_choices, seq_len]
attention_mask = encoding["attention_mask"]
token_type_ids = encoding["token_type_ids"]

# Prepare baseline 
with torch.no_grad():
    baseline_input_ids = torch.ones_like(input_ids) * tokenizer.pad_token_id  # Use pad token as baseline

# Preparing shape for LayerIntegratedGradients
input_ids = input_ids.unsqueeze(1)  # Add batch dimension
attention_mask = attention_mask.unsqueeze(1)  # Add batch dimension 
token_type_ids = token_type_ids.unsqueeze(1)  # Add batch dimension
baseline_input_ids = baseline_input_ids.unsqueeze(1)  # Add batch dimension


# Compute model prediction
scores = predict(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
print(scores)
print('Question: ', question)
print('Predicted Answer: ', choices[torch.argmax(scores).item()])

# Choose the embedding layer
embedding_layer = model.bert.embeddings

# Forward function: outputs logits for each choice
def forward_func(input_ids, attention_mask, token_type_ids):
    logits = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids
    ).logits
    print(logits.shape)
    return logits

# Create a wrapper for LayerIntegratedGradients
lig = LayerIntegratedGradients(forward_func, embedding_layer)

print(input_ids.shape, attention_mask.shape, token_type_ids.shape)

# Compute attributions for the correct answer (index 0 = Paris)
attributions, delta = lig.attribute(
    inputs=input_ids,
    baselines=baseline_input_ids,
    additional_forward_args=(attention_mask, token_type_ids),
    target=0,  # Target the first choice (Paris)    
    return_convergence_delta=True
)

# Sum across embedding dimensions to get token-level importance
token_attributions = attributions.sum(dim=-1).squeeze(0)  # shape: [num_choices, seq_len]
token_attributions = token_attributions / torch.norm(token_attributions)  # Normalize

print("Attributions per token (choice 0):")
print(token_attributions[0][0])  # First sample, first choice



config.json:   0%|          | 0.00/670 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/314 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

tensor([[6.9128],
        [6.4859]], grad_fn=<ViewBackward0>)
Question:  What is the capital of France?
Predicted Answer:  Paris is the capital of France.
torch.Size([2, 1, 17]) torch.Size([2, 1, 17]) torch.Size([2, 1, 17])
torch.Size([2, 1])
torch.Size([2, 1])
torch.Size([100, 1])
torch.Size([2, 1])
torch.Size([2, 1])
Attributions per token (choice 0):
tensor(-0.0021)


In [107]:
from captum.attr import visualization as viz

ground_truth = choices[0]  # Assuming first choice is correct

choice_idx = torch.argmax(scores).item()
indices = input_ids[choice_idx,0].detach().tolist()  # batch=0, choice=choice_idx
all_tokens = tokenizer.convert_ids_to_tokens(indices)

token_attributions_for_choice = token_attributions[choice_idx]

vis = viz.VisualizationDataRecord(
    word_attributions=token_attributions_for_choice,
    pred_prob=torch.max(torch.softmax(scores[0], dim=0)).item(),
    pred_class=choice_idx,
    true_class=ground_truth,
    attr_class=str(ground_truth),
    attr_score=token_attributions_for_choice.sum().item(),
    raw_input_ids=all_tokens,
    convergence_score=delta[choice_idx].item()
)

print('\033[1m', 'Visualizations For Start Position', '\033[0m')
viz.visualize_text([vis])


[1m Visualizations For Start Position [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Paris is the capital of France.,0 (1.00),Paris is the capital of France.,-0.06,[CLS] what is the capital of france ? [SEP] paris is the capital of france . [SEP]
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Paris is the capital of France.,0 (1.00),Paris is the capital of France.,-0.06,[CLS] what is the capital of france ? [SEP] paris is the capital of france . [SEP]
,,,,


In [None]:
from captum.attr import visualization as viz

print(input_ids.shape)
choice_idx = 0
indices = input_ids[choice_idx,0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

ground_truth = choices[0]  # Assuming the first choice is the ground truth

# storing couple samples in an array for visualization purposes
vis = viz.VisualizationDataRecord(
                        token_attributions, 
                        torch.max(torch.softmax(scores[0], dim=0)),
                        torch.argmax(scores),
                        torch.argmax(scores),
                        str(ground_truth),
                        token_attributions.sum(),       
                        all_tokens,
                        delta)


print('\033[1m', 'Visualizations For Start Position', '\033[0m')
viz.visualize_text([vis])

torch.Size([2, 1, 17])
[1m Visualizations For Start Position [0m


AssertionError: 