In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModel, utils

from captum.attr import visualization as viz
from captum.attr import LayerConductance, LayerIntegratedGradients

torch.manual_seed(42)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
# Load model
model_name = "sentence-transformers/all-MiniLM-L6-v2" 
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval() # Evaluation mode
model.zero_grad() # Clear summed gradients

# Mean Pooling - To compute embeddings (from huggingface)
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] # First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

In [None]:
def embed(text, model=model, tokenizer=tokenizer):
    encoded_input = tokenizer(text, padding=True, truncation=True, return_tensors='pt') # Tokenize

    with torch.no_grad(): # Compute token embeddings
        model_output = model(**encoded_input)

    # Perform pooling and normalize
    embedding = mean_pooling(model_output, encoded_input['attention_mask'])
    return F.normalize(embedding, p=2, dim=1)

# Compute baseline embedding
baseline = ['Dette er et eksempel'] # Baseline input
baseline_embedding = embed(baseline)


In [5]:
# Distance from baseline embedding
def predict_baseline_distance(inputs, token_type_ids=None, position_ids=None, attention_mask=None):
    with torch.no_grad():
        output = model(inputs, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask)

    embedding = mean_pooling(output, inputs['attention_mask'])
    embedding = F.normalize(embedding, p=2, dim=1)

    return F.cosine_similarity(embedding, baseline_embedding)

In [None]:
def construct_attention_mask(input_ids): # we don't need masking, so just 1's
    return torch.ones_like(input_ids)

ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
ref_input_ids = 

input_ids = tokenizer.encode(baseline[0])
input_ids = torch.tensor([input_ids], device=device)

attention_mask = construct_attention_mask(input_ids)

In [13]:
lig = LayerIntegratedGradients(predict_baseline_distance, model.embeddings)

attributions, delta = lig.attribute(inputs=input_ids,
                                    additional_forward_args=(attention_mask,),
                                    return_convergence_delta=True)

# Convert attributions to list for visualization
attribution_scores = attributions.sum(dim=-1).squeeze(0)  # Sum over embedding dim
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])

# Visualization
viz_data_records = [viz.VisualizationDataRecord(
    word_attributions=attribution_scores.cpu().detach().numpy(),
    pred_prob=0.0,
    pred_class="",
    true_class="",
    attr_class="",
    attr_score=attribution_scores.sum().item(),
    raw_input=tokens,
    convergence_score=delta.item()
)]

viz.visualize_text(viz_data_records)

IndexError: too many indices for tensor of dimension 2