Integrated gradients on the cosine-distance of prompt to its embedding. That is, we run the model to obtain the original embedding and differentiate the cosine similarity between the mean pooled model output to this vector.

Alternative approach for paired inputs: use distance to the question.

In [14]:
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoModel, AutoTokenizer

from captum.attr import LayerIntegratedGradients
from captum.attr import visualization as viz

from dataclasses import dataclass

# Settings
torch.manual_seed(42)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Config (only for debug atm)
@dataclass
class Config:
    debug: bool = True
cfg = Config()

In [2]:
# Load model and tokenizer
model_name = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Set to evaluation mode
model.to(device)
model.eval()
model.zero_grad()
# model # Uncomment to print architecture

ConnectionError: (MaxRetryError('HTTPSConnectionPool(host=\'huggingface.co\', port=443): Max retries exceeded with url: /api/models/sentence-transformers/multi-qa-MiniLM-L6-cos-v1/tree/main/additional_chat_templates?recursive=False&expand=False (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x00000164DD8976E0>: Failed to resolve \'huggingface.co\' ([Errno 11001] getaddrinfo failed)"))'), '(Request ID: 7c0765bb-ad06-4593-b166-0fbc83f4cc10)')

In [3]:
# Mean Pooling - To compute embeddings (from huggingface)
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] # First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def embed_tokens(input_ids, attention_mask=None): # Uses global model
    model_output = model(input_ids, attention_mask) # attn mask only needed for padding tokens on acausal models, and we have a single input
    # Perform pooling and normalize
    embedding = mean_pooling(model_output, attention_mask)
    return F.normalize(embedding, p=2, dim=1)


# Compute baseline embedding
text = ['This is a test sentence.']
target_idx = 0 # only one input
encoding = tokenizer(text, padding=True, truncation=True, return_tensors='pt') # Tokenize

# Model input
input_ids = encoding['input_ids'] # Get input ids, shape: [batch_size, seq_len]
attention_masks = encoding['attention_mask'] # Get attention masks (for each batch, same shape as input_ids)

# Baseline embedding for IG
baseline_token_id = tokenizer.pad_token_id 
baseline_ids = torch.full_like(input_ids, baseline_token_id)

# Reference embedding for distance computation
reference_embeddings = embed_tokens(input_ids, attention_masks) # shape: [batch_size, embedding_dim]

if cfg.debug: print(input_ids.shape, attention_masks.shape, baseline_ids.shape, reference_embeddings.shape, ) # Debug shapes

torch.Size([1, 8]) torch.Size([1, 8]) torch.Size([1, 8]) torch.Size([1, 384])


In [11]:
def forward_func_cos_sim(input_ids, attention_mask=None, reference_embeddings=reference_embeddings):
    """Custom forward pass for Captum Integrated Gradients for distance..."""
    
    embeddings = embed_tokens(input_ids, attention_mask)
    sim = F.cosine_similarity(reference_embeddings, embeddings, dim=1) # Compute cosine distance with baseline embedding

    if cfg.debug: print("Similarity:", sim.shape)
    return sim # tensor of shape [batch_size]

In [12]:
# Compute model prediction
start_dist = forward_func_cos_sim(input_ids, attention_masks, reference_embeddings)
print("Original distance:", start_dist.item()) # Print original distance

layer = model.embeddings # Get the embedding layer
lig = LayerIntegratedGradients(forward_func_cos_sim, layer)

# Compute attributions for chosen index (Captum wants [choices, seq_len] and gives attr shape [choices, seq_len, embedding_dim])
attributions, delta = lig.attribute(
    inputs=input_ids,
    baselines=baseline_ids,
    additional_forward_args=(attention_masks, reference_embeddings, ), 
    n_steps=50,  # Number of steps for approximation
    return_convergence_delta=True
)

Similarity: torch.Size([1])
Original distance: 1.0
Similarity: torch.Size([1])
Similarity: torch.Size([1])
Similarity: torch.Size([50])
Similarity: torch.Size([1])
Similarity: torch.Size([1])


In [None]:
# Get word tokens
tokens = tokenizer.convert_ids_to_tokens(input_ids[target_idx])

# Step 2: Aggregate attribution scores (across embedding dimension)
# Shape of attributions: [1, seq_len, emb_dim]
token_attributions = attributions.sum(dim=-1).squeeze(0)  # shape: [seq_len]

# Step 3: Print token + attribution
print("Token-wise attributions:")
for token, score in zip(tokens, token_attributions):
    print(f"{token:>12} : {score.item():.4f}")


vis = viz.VisualizationDataRecord(
                        token_attributions,        # token-wise attributions
                        0, # prediction probability (not relevant for distance attribution)
                        0, # predicted class            -"-
                        0, # ground truth class         -"-
                        0, # attributing to this class  -"-
                        token_attributions.sum(),  # summed attribution score (NA)
                        tokens,                    # tokens for the question and choice
                        delta,                     # convergence delta
)

visualisation = viz.visualize_text([vis]) # get return object to avoid passing the vis object to the ipynb

Token-wise attributions:
       [CLS] : 0.0366
        this : 0.0988
          is : 0.1103
           a : 0.1100
        test : 0.2474
    sentence : 0.1514
           . : 0.0873
       [SEP] : 0.0725


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.00),0.0,0.91,[CLS] this is a test sentence . [SEP]
,,,,


Okay, but we want to attribute not layer-wise, but for the whole model. That might just be as simple as using IG from captum, though we will have to pre-embed:

In [None]:
# to be modified:

# From demo, to get input_embs and ref_input_embs (running ids through emb layer)
def construct_whole_bert_embeddings(input_ids, ref_input_ids, \
                                    token_type_ids=None, ref_token_type_ids=None, \
                                    position_ids=None, ref_position_ids=None):
    input_embeddings = model.bert.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
    ref_input_embeddings = model.bert.embeddings(ref_input_ids, token_type_ids=ref_token_type_ids, position_ids=ref_position_ids)
    
    return input_embeddings, ref_input_embeddings

In [None]:
from captum.attr import IntegratedGradients # Not tested yet

# IntegratedGradients
def get_token_attributions_for_model(layer, forward_func, choice_idx, input_embs, ref_input_embs, attention_mask, token_type_ids=None, n_steps=50):

    if cfg.debug: print(input_embs.shape)
    
    # LayerIntegratedGradients for attribution
    ig = IntegratedGradients(forward_func)

    # Compute attributions for chosen index (Captum wants [choices, seq_len] and gives attr shape [choices, seq_len, layer_output_dim])
    attributions, delta = ig.attribute(
        inputs=input_embs,
        baselines=ref_input_embs,
        additional_forward_args=(attention_mask, token_type_ids),
        target=choice_idx,  # Target the chosen answer, uses [0,target]
        n_steps=n_steps,  # Number of steps for approximation
        return_convergence_delta=True
    )

    # Sum across embedding dimensions to get token-level importance
    token_attributions = attributions.sum(dim=-1).squeeze(0)  # shape: [num_choices, seq_len]
    token_attributions = token_attributions / torch.norm(token_attributions)  # Normalize

    if cfg.debug: 
        print('Token attributions:', token_attributions.shape)
        print('Attributions per token at choice_idx:', token_attributions[choice_idx])
    
    return token_attributions, delta