In [1]:
import torch
import torch.nn as nn
from transformers import LlamaForCausalLM, AutoTokenizer
import random
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import math

device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
llama_dir = '/mntcephfs/data/ruoyusun/liziniu/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-hf/snapshots/8a0442e81540efaeb1a0fe3e95477b5e0edfd423'
llama = LlamaForCausalLM.from_pretrained(llama_dir,attn_implementation="eager")
llama = llama.to(device) 
max_positions = 4096
attn_bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
                1, 1, max_positions, max_positions
            )
attn_bias = attn_bias.to(device)
def attention_score_wo_rotary(layer_idx, hidden_states, num_heads=32, head_dim = 128):
    hidden_states = hidden_states.to(device)
    bsz, q_len, _ = hidden_states.size()
    
    attn_model = llama.model.layers[layer_idx]
    
    layer_norm = attn_model.input_layernorm
    
    hidden_states = layer_norm(hidden_states)
    
    query_states = attn_model.self_attn.q_proj(hidden_states)
    key_states = attn_model.self_attn.k_proj(hidden_states)
    value_states = attn_model.self_attn.v_proj(hidden_states)
    
    query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
    value_states = value_states.view(bsz, q_len, num_heads,head_dim).transpose(1, 2)
    
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)
    attn_shape = attn_weights.shape
    
    query_length, key_length = attn_shape[-2],attn_shape[-1]
    causal_mask = attn_bias[:, :, key_length - query_length : key_length, :key_length]
    mask_value = torch.finfo(attn_weights.dtype).min
    mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
    attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
    return attn_weights[0]

tokenizer = AutoTokenizer.from_pretrained(llama_dir)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [36]:
def filter(query3_number, M = 10, layer_idx = 0, head_idx = 0, candidate_token_list = torch.arange(llama.vocab_size)): #找出query_token最相近的M个token
    token_list = {}
    m = 0
    while len(token_list)<M:
        if m != query3_number:
            input_ids = torch.tensor([[query3_number, m]]).long().to(device)
            h_s = llama(input_ids,output_hidden_states = True, 
                        position_ids = torch.ones_like(input_ids).long()).hidden_states[layer_idx] # type: ignore
            attention_score = attention_score_wo_rotary(layer_idx=layer_idx, hidden_states=h_s)[head_idx][1,0] #torch.tensor, deivce = 'cuda'
            token_list[attention_score] = m
        m+=1
    score_list = list(token_list.keys())
    minimum = min(score_list)
    for m in tqdm(candidate_token_list):
        if m != query3_number:
            input_ids = torch.tensor([[query3_number, m]]).long().to(device)
            h_s = llama(input_ids,output_hidden_states = True, 
                        position_ids = torch.ones_like(input_ids).long()).hidden_states[layer_idx] # type: ignore
            attention_score = attention_score_wo_rotary(layer_idx=layer_idx, hidden_states=h_s)[head_idx][1,0]
            if attention_score > minimum:
                del token_list[minimum]
                token_list[attention_score] = m
                score_list = list(token_list.keys())
                minimum = min(score_list)
    return token_list

In [74]:
def testing(N = 10, M = 10, inputs = None,layer_idx = 0, head_idx = 0, threshold = 1.0, candidate_strategy = 'all', candidate_num = 1000): #做N次实验, 每次实验选topM
    compare_list = {}
    word_list = {}
    n = 0
    while n < N:
        #先抽取llama的hidden_states
        if inputs == None:
            inputs = torch.randint(0,32000,(1,3)).long().to(device)
        position_ids = torch.ones_like(inputs).long().to(device)
        outputs = llama(inputs,output_hidden_states = True, position_ids = position_ids) # type: ignore
        hidden_inputs = outputs.hidden_states[layer_idx]
        attention_score = attention_score_wo_rotary(layer_idx, hidden_states=hidden_inputs)[head_idx]
        
        if abs(attention_score[-1,0]-attention_score[-1,1]) > threshold: #存在某个key比另外一个有明显的强势
            word_list[n] = inputs
            query3_number = int(inputs[:,-1])
            if candidate_strategy == 'all':
                candidate_token_list = list(range(llama.vocab_size))
            elif candidate_strategy == 'random':
                candidate_token_list = random.sample(range(llama.vocab_size),candidate_num)
            top_token_list = filter(query3_number,
                                    M, layer_idx, head_idx,candidate_token_list=candidate_token_list)#和query3最接近的几个token
            compare = []
            token1,token2 = inputs[:,0], inputs[:,1]
            for token in top_token_list.values():
                input_new = torch.tensor([[token1, token2, token]]).to(device)
                out_new = llama(input_new, output_hidden_states = True, position_ids = torch.ones_like(input_new).long()) # type: ignore
                h_s = out_new.hidden_states[layer_idx]
                attention_new = attention_score_wo_rotary(layer_idx, 
                                                hidden_states=h_s)[head_idx]
                compare.append(1 if (attention_new[-1,0]-attention_new[-1,1])*(attention_score[-1,0]-attention_score[-1,1]) >=0 else 0)
            compare_list[n] = compare
            n+=1
    return compare_list, word_list
            

In [84]:
candidate_token_list = random.sample(range(32000),5000)
inputs = torch.tensor([[   78, 27557, 15044]], device='cuda:0')
out = testing(N = 1,candidate_strategy = 'all', candidate_num = 5000, inputs = inputs)

100%|████████████████████████████████████████████████████████████████| 32000/32000 [18:58<00:00, 28.12it/s]


In [87]:
position_ids = torch.zeros_like(inputs).reshape(1,-1).long()
outputs = llama(inputs,output_hidden_states = True, position_ids = position_ids).hidden_states[0]
attention_score_wo_rotary(layer_idx=0, hidden_states=outputs)[0]

tensor([[-2.9971e-07, -3.4028e+38, -3.4028e+38],
        [ 1.7915e-04, -1.0058e+00, -3.4028e+38],
        [ 3.4570e-04, -1.5929e+00,  1.5746e-01]], device='cuda:0',
       grad_fn=<SelectBackward0>)

In [94]:
test_inputs = torch.tensor([[   78, 27557, 78]], device='cuda:0')
test_outputs = llama(test_inputs,output_hidden_states = True, position_ids = position_ids).hidden_states[0]
attention_score_wo_rotary(layer_idx=0, hidden_states=test_outputs)[0]

tensor([[-2.9971e-07, -3.4028e+38, -3.4028e+38],
        [ 1.7915e-04, -1.0058e+00, -3.4028e+38],
        [-2.9971e-07,  6.6722e-04, -2.9971e-07]], device='cuda:0',
       grad_fn=<SelectBackward0>)

In [93]:
np.load('out.npy',allow_pickle=True)[0]

{0: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 1: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 2: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 3: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 4: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 5: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 6: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 7: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 8: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 9: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 10: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 11: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 12: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 13: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 14: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 15: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 16: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 17: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 18: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 19: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 20: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 21: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 22: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 23: [0, 1, 1, 1, 1, 0, 0, 0, 0, 1],
 24: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 25: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 26: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 27: [1, 1,