## Attention visualisation with BertViz

Interactive visualisation of Jina AI attention, first in model view:

In [None]:
from transformers import AutoTokenizer, AutoModel, utils # torch >= 2.6
from bertviz import model_view, head_view
utils.logging.set_verbosity_error()  # Suppress standard warnings

model_name = "jinaai/jina-embeddings-v2-base-en"  # Find popular HuggingFace models here: https://huggingface.co/models
input_text = "The cat sat on the mat"  
model = AutoModel.from_pretrained(model_name, output_attentions=True)  # Configure model to return attention values
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer.encode(input_text, return_tensors='pt')  # Tokenize input text
outputs = model(inputs)  # Run model
attention = outputs[-1]  # Retrieve attention from model outputs
tokens = tokenizer.convert_ids_to_tokens(inputs[0])  # Convert input ids to token strings
model_view(attention, tokens)  # Display model view

<IPython.core.display.Javascript object>

We will use the small model to get nicer plots:

In [3]:
model_name = "jinaai/jina-embeddings-v2-small-en"  # Find popular HuggingFace models here: https://huggingface.co/models

model = AutoModel.from_pretrained(model_name, output_attentions=True)  # Configure model to return attention values
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer.encode(input_text, return_tensors='pt')  # Tokenize input text
outputs = model(inputs)  # Run model
attention = outputs[-1]  # Retrieve attention from model outputs
tokens = tokenizer.convert_ids_to_tokens(inputs[0])  # Convert input ids to token strings
model_view(attention, tokens)  # Display model view

<IPython.core.display.Javascript object>

We can also visualise the heads separately (choose layer, color specifies head, press color to hide, hover over word to see attn):

In [4]:
head_view(attention, tokens)

<IPython.core.display.Javascript object>

Or a neuron view, which traces the computations; this is not returned by the Huggingface API, so we would have to customize it for Jina AI if we want to use it.

## Attribution with Captum

Had a lot of issues with the environment here, so we first run the ToyModel-test to check that the kernel survives.

In [7]:
import numpy as np

import torch
import torch.nn as nn

from captum.attr import IntegratedGradients # must pip install and numpy<2.0

torch.manual_seed(123)
np.random.seed(123)

In [9]:
class ToyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin1 = nn.Linear(3, 3)
        self.relu = nn.ReLU()
        self.lin2 = nn.Linear(3, 2)

        # initialize weights and biases
        self.lin1.weight = nn.Parameter(torch.arange(-4.0, 5.0).view(3, 3))
        self.lin1.bias = nn.Parameter(torch.zeros(1,3))
        self.lin2.weight = nn.Parameter(torch.arange(-3.0, 3.0).view(2, 3))
        self.lin2.bias = nn.Parameter(torch.ones(1,2))

    def forward(self, input):
        return self.lin2(self.relu(self.lin1(input)))


model = ToyModel()
model.eval()

ToyModel(
  (lin1): Linear(in_features=3, out_features=3, bias=True)
  (relu): ReLU()
  (lin2): Linear(in_features=3, out_features=2, bias=True)
)

In [None]:
inp = torch.rand(2, 3)
baseline = torch.zeros(2, 3)

ig = IntegratedGradients(model)
attributions, delta = ig.attribute(inp, baseline, target=0, return_convergence_delta=True)

print('IG Attributions:', attributions)
print('Convergence Delta:', delta)

## Attribution visualisations for Encoders with Captum

Good. We have a demo for a Bert-QA model which we will try to adjust for Jina. First we modify the BERT-demo slighlty to use Jina AI (we do not want the QA-style in the final version - thus we don't train it here, and so the attribution is kinda random, it is just to check that everything works). We need something other than logits to make the attribution for non-QA Jina - this might not need a major change, though.

In [40]:
import torch
import torch.nn as nn

from transformers import AutoTokenizer, AutoModel, utils 
from captum.attr import visualization as viz
from captum.attr import LayerConductance, LayerIntegratedGradients

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_name = "jinaai/jina-embeddings-v2-small-en"
# model = AutoModel.from_pretrained(model_name, output_attentions=True)  # Configure model to return attention values
# print(model.eval()) # fix parameters, print to see architechture
# model.zero_grad()   # clearing (potential) accumulated gradients

# tokenizer = AutoTokenizer.from_pretrained(model_name)



from transformers import BertTokenizer, BertForQuestionAnswering, BertConfig

model = BertForQuestionAnswering.from_pretrained(model_name)
model.to(device)
print(model.eval())
model.zero_grad()

# load tokenizer
tokenizer = BertTokenizer.from_pretrained(model_name)

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at jinaai/jina-embeddings-v2-small-en and are newly initialized: ['bert.embeddings.position_embeddings.weight', 'bert.encoder.layer.0.intermediate.dense.bias', 'bert.encoder.layer.0.intermediate.dense.weight', 'bert.encoder.layer.0.output.LayerNorm.bias', 'bert.encoder.layer.0.output.LayerNorm.weight', 'bert.encoder.layer.0.output.dense.bias', 'bert.encoder.layer.0.output.dense.weight', 'bert.encoder.layer.1.intermediate.dense.bias', 'bert.encoder.layer.1.intermediate.dense.weight', 'bert.encoder.layer.1.output.LayerNorm.bias', 'bert.encoder.layer.1.output.LayerNorm.weight', 'bert.encoder.layer.1.output.dense.bias', 'bert.encoder.layer.1.output.dense.weight', 'bert.encoder.layer.2.intermediate.dense.bias', 'bert.encoder.layer.2.intermediate.dense.weight', 'bert.encoder.layer.2.output.LayerNorm.bias', 'bert.encoder.layer.2.output.LayerNorm.weight', 'bert.encoder.layer.2.output.dense.bias', 'bert.enco

BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30528, 512, padding_idx=0)
      (position_embeddings): Embedding(8192, 512)
      (token_type_embeddings): Embedding(2, 512)
      (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-3): 4 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=512, out_features=512, bias=True)
              (key): Linear(in_features=512, out_features=512, bias=True)
              (value): Linear(in_features=512, out_features=512, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=512, out_features=512, bias=True)
              (LayerNorm): LayerNorm((512,), eps=1e-12, e

In [None]:
# Of course, with Jina AI we don't have logits:
# 'BaseModelOutputWithPoolingAndCrossAttentions' object has no attribute 'start_logits' 

def predict(inputs, token_type_ids=None, position_ids=None, attention_mask=None):
    output = model(inputs, token_type_ids=token_type_ids,
                 position_ids=position_ids, attention_mask=attention_mask, )
    return output.start_logits, output.end_logits

def squad_pos_forward_func(inputs, token_type_ids=None, position_ids=None, attention_mask=None, position=0):
    pred = predict(inputs,
                   token_type_ids=token_type_ids,
                   position_ids=position_ids,
                   attention_mask=attention_mask)
    pred = pred[position]
    return pred.max(1).values

In [None]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

In [36]:
def construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id):
    question_ids = tokenizer.encode(question, add_special_tokens=False)
    text_ids = tokenizer.encode(text, add_special_tokens=False)

    # construct input token ids
    input_ids = [cls_token_id] + question_ids + [sep_token_id] + text_ids + [sep_token_id]

    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(question_ids) + [sep_token_id] + \
        [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(question_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids): # we don't need masking, so just 1's
    return torch.ones_like(input_ids)

def construct_whole_bert_embeddings(input_ids, ref_input_ids, \
                                    token_type_ids=None, ref_token_type_ids=None, \
                                    position_ids=None, ref_position_ids=None):
    input_embeddings = model.bert.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
    ref_input_embeddings = model.bert.embeddings(ref_input_ids, token_type_ids=ref_token_type_ids, position_ids=ref_position_ids)
    
    return input_embeddings, ref_input_embeddings

In [37]:
# Of course, jina does not predict an answer

question, text = "What is important to us?", "It is important to us to include, empower and support humans of all kinds."


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

input_ids, ref_input_ids, sep_id = construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

In [38]:
ground_truth = 'to include, empower and support humans of all kinds'

ground_truth_tokens = tokenizer.encode(ground_truth, add_special_tokens=False)
ground_truth_end_ind = indices.index(ground_truth_tokens[-1])
ground_truth_start_ind = ground_truth_end_ind - len(ground_truth_tokens) + 1

In [39]:
start_scores, end_scores = predict(input_ids, \
                                   token_type_ids=token_type_ids, \
                                   position_ids=position_ids, \
                                   attention_mask=attention_mask)


print('Question: ', question)
print('Predicted Answer: ', ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]))

tensor([[-0.5290, -0.5660, -0.5650, -0.5285, -0.5705, -0.5309, -0.5193, -0.5392,
         -0.4925, -0.5685, -0.5213, -0.5776, -0.5341, -0.5776, -0.5948, -0.5635,
         -0.4584, -0.6614, -0.5666, -0.5082, -0.4793, -0.6049, -0.4562, -0.5345,
         -0.5142, -0.5392]], grad_fn=<CloneBackward0>) tensor([[-0.0385, -0.0958, -0.0753, -0.0801, -0.0424, -0.0667, -0.0286, -0.0656,
         -0.0445, -0.1022, -0.0946, -0.0565, -0.0786, -0.0565, -0.0282, -0.0681,
         -0.0108, -0.1197, -0.0813, -0.0712, -0.0399, -0.1079, -0.1078, -0.0787,
         -0.0562, -0.0656]], grad_fn=<CloneBackward0>)
Question:  What is important to us?
Predicted Answer:  


In [30]:
lig = LayerIntegratedGradients(squad_pos_forward_func, model.bert.embeddings)

attributions_start, delta_start = lig.attribute(inputs=input_ids,
                                  baselines=ref_input_ids,
                                  additional_forward_args=(token_type_ids, position_ids, attention_mask, 0),
                                  return_convergence_delta=True)
attributions_end, delta_end = lig.attribute(inputs=input_ids, baselines=ref_input_ids,
                                additional_forward_args=(token_type_ids, position_ids, attention_mask, 1),
                                return_convergence_delta=True)

In [31]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

attributions_start_sum = summarize_attributions(attributions_start)
attributions_end_sum = summarize_attributions(attributions_end)

In [32]:
# storing couple samples in an array for visualization purposes
start_position_vis = viz.VisualizationDataRecord(
                        attributions_start_sum,
                        torch.max(torch.softmax(start_scores[0], dim=0)),
                        torch.argmax(start_scores),
                        torch.argmax(start_scores),
                        str(ground_truth_start_ind),
                        attributions_start_sum.sum(),       
                        all_tokens,
                        delta_start)

end_position_vis = viz.VisualizationDataRecord(
                        attributions_end_sum,
                        torch.max(torch.softmax(end_scores[0], dim=0)),
                        torch.argmax(end_scores),
                        torch.argmax(end_scores),
                        str(ground_truth_end_ind),
                        attributions_end_sum.sum(),       
                        all_tokens,
                        delta_end)

print('\033[1m', 'Visualizations For Start Position', '\033[0m')
viz.visualize_text([start_position_vis])

print('\033[1m', 'Visualizations For End Position', '\033[0m')
viz.visualize_text([end_position_vis])

[1m Visualizations For Start Position [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
21.0,21 (0.04),13.0,-0.02,"[CLS] what is important to us ? [SEP] it is important to us to include , em ##power and support humans of all kinds . [SEP]"
,,,,


[1m Visualizations For End Position [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.04),23.0,0.15,"[CLS] what is important to us ? [SEP] it is important to us to include , em ##power and support humans of all kinds . [SEP]"
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.04),23.0,0.15,"[CLS] what is important to us ? [SEP] it is important to us to include , em ##power and support humans of all kinds . [SEP]"
,,,,
