In [3]:
# Get the output of Q/K for 4k and 64K before rope and after rope
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from types import MethodType
import json

model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype="auto")

with open("/home/azzhang/streaming-llm/output/wikitext2_prompts_llama3.json", "r", encoding="utf-8") as f:
    prompts = json.load(f)
target_length_64k = "64k"
target_length_4k = "4k"

prompt_64k = prompts[target_length_64k]
prompt_4k = prompts[target_length_4k]
inputs_64k = tokenizer(prompt_64k, return_tensors="pt").to(model.device)
inputs_4k = tokenizer(prompt_4k, return_tensors="pt").to(model.device)
seq_len_64 = inputs_64k["input_ids"].shape[1]
seq_len_4 = inputs_4k["input_ids"].shape[1]

cache_4k = {}
cache_64k = {}
target_layer = 0

def patched_forward(self, hidden_states, position_embeddings=None, *args, **kwargs):
    q = self.q_proj(hidden_states)
    k = self.k_proj(hidden_states)
    v = self.v_proj(hidden_states)

    bsz, seqlen, dim = q.shape
    head_dim = self.head_dim
    
    num_heads_q = self.config.num_attention_heads
    num_heads_kv = self.config.num_key_value_heads
    
    if seqlen == seq_len_4:
        print(seq_len_4)
        q = q.view(bsz, seqlen, num_heads_q, head_dim).transpose(1, 2)
        k = k.view(bsz, seqlen, num_heads_kv, head_dim).transpose(1, 2)
        
        # Q、K before rope
        cache_4k["q_raw"] = q.detach().cpu()
        cache_4k["k_raw"] = k.detach().cpu()
        
        cos, sin = position_embeddings

        q_rope, k_rope = apply_rotary_pos_emb(q, k, cos, sin)

        cache_4k["q_rope"] = q_rope.detach().cpu()
        cache_4k["k_rope"] = k_rope.detach().cpu()
    else:
        # print(seq_len_64)
        q = q.view(bsz, seqlen, num_heads_q, head_dim).transpose(1, 2)
        k = k.view(bsz, seqlen, num_heads_kv, head_dim).transpose(1, 2)

        # Q、K before rope
        cache_64k["q_raw"] = q.detach().cpu()
        cache_64k["k_raw"] = k.detach().cpu()
        
        cos, sin = position_embeddings

        q_rope, k_rope = apply_rotary_pos_emb(q, k, cos, sin)

        cache_64k["q_rope"] = q_rope.detach().cpu()
        cache_64k["k_rope"] = k_rope.detach().cpu()

    return self._orig_forward(hidden_states, position_embeddings, *args, **kwargs)

# insert patch
attn_layer = model.model.layers[target_layer].self_attn
attn_layer._orig_forward = attn_layer.forward
attn_layer.forward = MethodType(patched_forward, attn_layer)

with torch.no_grad():
    outputs = model(**inputs_4k)
with torch.no_grad():
    outputs = model(**inputs_64k)

Q_4k = cache_4k["q_raw"].squeeze(0)  # shape: (num_heads, seq_len, head_dim) for one layer (32, 4k, 1024//32)
K_4k = cache_4k["k_raw"].squeeze(0) # (8, 4k, 1024//8)
Q_4k_rope = cache_4k["q_rope"].squeeze(0)  # shape: (num_heads, seq_len, head_dim)
K_4k_rope = cache_4k["k_rope"].squeeze(0)

Q_64k = cache_64k["q_raw"].squeeze(0)  # shape: (num_heads, seq_len, head_dim)
K_64k = cache_64k["k_raw"].squeeze(0)
Q_64k_rope = cache_64k["q_rope"].squeeze(0)  # shape: (num_heads, seq_len, head_dim)
K_64k_rope = cache_64k["k_rope"].squeeze(0)

print(Q_64k_rope.shape)
K_4k_rope = K_4k_rope[:, None, :, :].expand(8, 4, seq_len_4, 128)
K_4k_rope = K_4k_rope.reshape(32, seq_len_4, 128)
QK_production = torch.matmul(Q_4k_rope, K_4k_rope.transpose(1, 2))
print(QK_production.shape)
# target_head = 0

# K_head_4k_before, K_head_64k_before = K_4k[target_head].float(), K_64k[target_head].float()
# K_head_4k_rope, K_head_64k_rope = K_4k_rope[target_head].float(), K_64k_rope[target_head].float()
# Q_4k_before_list, Q_64k_before_list, Q_4k_rope_list, Q_64k_rope_list = [], [], [], []

# for i in range(4):
#     Q_4k_before_list.append(Q_4k[4*target_head+i].float())
#     Q_64k_before_list.append(Q_64k[4*target_head+i].float())
#     Q_4k_rope_list.append(Q_4k_rope[4*target_head+i].float())
#     Q_64k_rope_list.append(Q_64k_rope[4*target_head+i].float())

# print(K_head_4k_rope.shape)
# print(len(Q_4k_before_list))
# print(Q_4k_before_list[0].shape)

Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00,  1.36s/it]


3964
torch.Size([32, 65406, 128])
torch.Size([32, 3964, 3964])
