Loading the model: 

In [17]:
from transformers import AutoModel, AutoTokenizer
import torch
import numpy as np

tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en", trust_remote_code = True, use_fast = False)

model = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-en", trust_remote_code = True)

prompt = "Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets."

input = tokenizer(prompt, return_tensors = "pt")


Then we output some model specific parameter values. 


In [18]:
num_layers = len(model.encoder.layer)
emb_dim = model.config.hidden_size
num_heads = model.config.num_attention_heads
head_dim = emb_dim // num_heads

Then we need some code to "hook" specific activations from the model (such as the output from MLPs, attention heads). We collect the hooks in dictionaries. 

In [19]:
#Hook attention heads

attn_act = {}
def hook_atthead(layer_index,head_index): 
    def hook(module, inp, out):
        batch, seq_len, embed_dim = out.shape
        out = out.view(batch, seq_len, num_heads, head_dim)
        attn_act[f"{layer_index}.{head_index}"] = out[:, :, j, :].detach()
    return hook

for i in range(num_layers):
    for j in range(num_heads): 
        model.encoder.layer[i].attention.output.dense.register_forward_hook(hook_atthead(i,j))

#Hook MLPs

mlp_act = {}

def hook_mlp(layer_index):
    def hook(module, inp, out):
        mlp_act[f"{layer_index}"] = out.detach()
    return hook

for i in range(num_layers): 
    model.encoder.layer[i].mlp.wo.register_forward_hook(hook_mlp(i))

embeddings = model(**input, output_hidden_states=False)

print(attn_act)

print(mlp_act)
                                                                           





{'0.0': tensor([[[ 0.0688, -0.1184, -0.0274,  ..., -0.3140, -0.1124,  0.0262],
         [-0.7773, -0.2132, -0.0038,  ...,  0.0340, -0.0912,  0.0093],
         [ 0.0143,  0.2995,  0.0009,  ..., -0.2675, -0.0612, -0.1207],
         ...,
         [ 0.2530,  0.0853,  0.0781,  ...,  0.0029,  0.0707, -0.1558],
         [ 0.0585, -0.0440,  0.0403,  ..., -0.0829, -0.0058, -0.0536],
         [ 0.0915,  0.0103,  0.0395,  ..., -0.0457, -0.0447, -0.0594]]]), '0.1': tensor([[[ 0.0688, -0.1184, -0.0274,  ..., -0.3140, -0.1124,  0.0262],
         [-0.7773, -0.2132, -0.0038,  ...,  0.0340, -0.0912,  0.0093],
         [ 0.0143,  0.2995,  0.0009,  ..., -0.2675, -0.0612, -0.1207],
         ...,
         [ 0.2530,  0.0853,  0.0781,  ...,  0.0029,  0.0707, -0.1558],
         [ 0.0585, -0.0440,  0.0403,  ..., -0.0829, -0.0058, -0.0536],
         [ 0.0915,  0.0103,  0.0395,  ..., -0.0457, -0.0447, -0.0594]]]), '0.2': tensor([[[ 0.0688, -0.1184, -0.0274,  ..., -0.3140, -0.1124,  0.0262],
         [-0.7773, -0