In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# --- Setup ---
model_name = "meta-llama/Llama-2-7b-chat-hf" 
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    dtype=torch.float16, 
    device_map="auto"          
)

# Force the 'eager' implementation
model.set_attn_implementation('eager')

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.36s/it]
Some parameters are on the meta device because they were offloaded to the cpu.


In [12]:
text = "The currency of the United States is called the"
inputs = tokenizer(text, return_tensors='pt').to(model.device)
original_tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
num_original_tokens = len(original_tokens)

In [13]:
with torch.no_grad():
    outputs = model(**inputs, output_attentions=True)

In [14]:
attention = outputs.attentions
last_layer_attention = attention[-1].cpu().float()
tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])

In [15]:
model.eval()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
   

In [16]:
last_token_scores = last_layer_attention[0, :, -1, :]
s_token_scores = last_token_scores[:, 0]
best_head_index = s_token_scores.argmin()
specialist_scores = last_token_scores[best_head_index]
probabilities = torch.softmax(specialist_scores, dim=0)

In [17]:
print(f"Found 'Specialist' Head: #{best_head_index.item()}\n")
print(f"Influence on predicting the token *after* '{tokens[-1]}':\n")
print("Token".ljust(12) + "Influence Score")
print("-" * 30)

for token, score in zip(tokens, probabilities):
    print(f"{token.ljust(12)}: {score.item():.4f} ({(score.item() * 100):.2f}%)")

Found 'Specialist' Head: #24

Influence on predicting the token *after* '▁the':

Token       Influence Score
------------------------------
<s>         : 0.0903 (9.03%)
▁The        : 0.0961 (9.61%)
▁currency   : 0.0990 (9.90%)
▁of         : 0.0961 (9.61%)
▁the        : 0.1004 (10.04%)
▁United     : 0.0962 (9.62%)
▁States     : 0.1022 (10.22%)
▁is         : 0.0957 (9.57%)
▁called     : 0.1051 (10.51%)
▁the        : 0.1188 (11.88%)


In [None]:
# head_view(
#     [a.cpu() for a in attention], # Move tensors to CPU for viz
#     tokens
# )

NameError: name 'head_view' is not defined