Loading the model: 

In [63]:
from transformers import AutoModel, AutoTokenizer
import torch
import numpy as np

tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en", trust_remote_code = True, use_fast = False)

model = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-en", trust_remote_code = True)

prompt = "Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets."

input = tokenizer(prompt, return_tensors = "pt")


Then we output some model specific parameter values. 


In [64]:
num_layers = len(model.encoder.layer)
emb_dim = model.config.hidden_size
num_heads = model.config.num_attention_heads
head_dim = emb_dim // num_heads

Then we need some code to "hook" specific activations from the model (such as the output from MLPs, attention heads). We collect the hooks in dictionaries. 

In [None]:
#Hook attention heads

attn_act = {}
def hook_atthead(layer_index,head_index): 
    def hook(module, inp, out):
        batch, seq_len, embed_dim = out.shape
        out = out.view(batch, seq_len, num_heads, head_dim)
        attn_act[f"{layer_index}.{head_index}"] = out[:, :, j, :].detach()
    return hook

for i in range(num_layers):
    for j in range(num_heads): 
        model.encoder.layer[i].attention.output.dense.register_forward_hook(hook_atthead(i,j))

embeddings = model(**input, output_hidden_states=False)

print(attn_act["1.2"])

mlp_act = {}

def hook_mlp(layer_index):
    def hook(module, input, output):
        mlp_act[f"{layer_index}"] = output.detach()  # [batch, seq_len, hidden_dim]
    return hook

#Hook MLPs





tensor([[[-0.0560,  0.0559,  0.0365,  ..., -0.0143,  0.0070, -0.0210],
         [-0.1482,  0.1395,  0.1118,  ..., -0.0241,  0.1141,  0.0900],
         [ 0.4656,  0.1625, -0.1174,  ..., -0.0746, -0.1491,  0.0180],
         ...,
         [ 0.5019, -0.1147, -0.1381,  ..., -0.0634, -0.0950, -0.0970],
         [ 0.1408,  0.0327,  0.0639,  ...,  0.1274, -0.0340, -0.1094],
         [ 0.1683,  0.0364,  0.0613,  ...,  0.1993, -0.0167, -0.0698]]])
