In [None]:
# Get the output of Q/K for 4k and 64K before rope and after rope
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from types import MethodType
import json

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype="auto")

with open("/home/azzhang/streaming-llm/output/wikitext2_prompts_llama3.json", "r", encoding="utf-8") as f:
    prompts = json.load(f)
target_length_64k = "64k"
target_length_4k = "4k"

prompt_64k = prompts[target_length_64k]
prompt_4k = prompts[target_length_4k]
inputs_64k = tokenizer(prompt_64k, return_tensors="pt").to(model.device)
inputs_4k = tokenizer(prompt_4k, return_tensors="pt").to(model.device)
seq_len_64 = inputs_64k["input_ids"].shape[1]
seq_len_4 = inputs_4k["input_ids"].shape[1]

cache_4k = {}
cache_64k = {}
target_layer = 0

def patched_forward(self, hidden_states, position_embeddings=None, *args, **kwargs):
    q = self.q_proj(hidden_states)
    k = self.k_proj(hidden_states)
    v = self.v_proj(hidden_states)

    bsz, seqlen, dim = q.shape
    head_dim = self.head_dim
    
    num_heads_q = self.config.num_attention_heads
    num_heads_kv = self.config.num_key_value_heads
    
    if seqlen == seq_len_4:
        print(seq_len_4)
        q = q.view(bsz, seqlen, num_heads_q, head_dim).transpose(1, 2)
        k = k.view(bsz, seqlen, num_heads_kv, head_dim).transpose(1, 2)
        
        # Q、K before rope
        cache_4k["q_raw"] = q.detach().cpu()
        cache_4k["k_raw"] = k.detach().cpu()
        
        cos, sin = position_embeddings

        q_rope, k_rope = apply_rotary_pos_emb(q, k, cos, sin)

        cache_4k["q_rope"] = q_rope.detach().cpu()
        cache_4k["k_rope"] = k_rope.detach().cpu()
    else:
        # print(seq_len_64)
        q = q.view(bsz, seqlen, num_heads_q, head_dim).transpose(1, 2)
        k = k.view(bsz, seqlen, num_heads_kv, head_dim).transpose(1, 2)

        # Q、K before rope
        cache_64k["q_raw"] = q.detach().cpu()
        cache_64k["k_raw"] = k.detach().cpu()
        
        cos, sin = position_embeddings

        q_rope, k_rope = apply_rotary_pos_emb(q, k, cos, sin)

        cache_64k["q_rope"] = q_rope.detach().cpu()
        cache_64k["k_rope"] = k_rope.detach().cpu()

    return self._orig_forward(hidden_states, position_embeddings, *args, **kwargs)

# insert patch
attn_layer = model.model.layers[target_layer].self_attn
attn_layer._orig_forward = attn_layer.forward
attn_layer.forward = MethodType(patched_forward, attn_layer)

with torch.no_grad():
    outputs = model(**inputs_4k)
with torch.no_grad():
    outputs = model(**inputs_64k)

Q_4k = cache_4k["q_raw"].squeeze(0)  # shape: (num_heads, seq_len, head_dim) for one layer (32, 4k, 1024//32)
K_4k = cache_4k["k_raw"].squeeze(0) # (8, 4k, 1024//8)
Q_4k_rope = cache_4k["q_rope"].squeeze(0)  # shape: (num_heads, seq_len, head_dim)
K_4k_rope = cache_4k["k_rope"].squeeze(0)

Q_64k = cache_64k["q_raw"].squeeze(0)  # shape: (num_heads, seq_len, head_dim)
K_64k = cache_64k["k_raw"].squeeze(0)
Q_64k_rope = cache_64k["q_rope"].squeeze(0)  # shape: (num_heads, seq_len, head_dim)
K_64k_rope = cache_64k["k_rope"].squeeze(0)

target_head = 0

K_head_4k_before, K_head_64k_before = K_4k[target_head].float(), K_64k[target_head].float()
K_head_4k_rope, K_head_64k_rope = K_4k_rope[target_head].float(), K_64k_rope[target_head].float()
Q_4k_before_list, Q_64k_before_list, Q_4k_rope_list, Q_64k_rope_list = [], [], [], []

for i in range(4):
    Q_4k_before_list.append(Q_4k[4*target_head+i].float())
    Q_64k_before_list.append(Q_64k[4*target_head+i].float())
    Q_4k_rope_list.append(Q_4k_rope[4*target_head+i].float())
    Q_64k_rope_list.append(Q_64k_rope[4*target_head+i].float())

print(K_head_4k_rope.shape)
print(len(Q_4k_before_list))
print(Q_4k_before_list[0].shape)

Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.21s/it]


3964
torch.Size([3964, 128])
4
torch.Size([3964, 128])
tensor([[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
        [ 0.5391,  0.6875,  0.7891,  ...,  1.0000,  1.0000,  1.0000],
        [-0.4160, -0.0583,  0.2412,  ...,  1.0000,  1.0000,  1.0000],
        ...,
        [-0.8516, -0.9609, -0.5508,  ...,  1.0000,  1.0000,  1.0000],
        [-0.8984, -0.4551, -0.9492,  ...,  1.0000,  1.0000,  1.0000],
        [-0.1187,  0.3359, -0.9414,  ...,  1.0000,  1.0000,  1.0000]],
       device='cuda:0', dtype=torch.bfloat16)


In [None]:
def build_rope_matrix(seq_len: int, head_dim: int, dtype=torch.float32):
    """
    build cos/sin eig matrix of RoPE: shape [2*seq_len, head_dim//2]
    """
    assert head_dim % 2 == 0, "head_dim must be even"
    dim = head_dim // 2
    position = torch.arange(seq_len).unsqueeze(1).to(dtype)  # [seq_len, 1]
    dim_indices = torch.arange(dim).unsqueeze(0).to(dtype)   # [1, dim]
    freq = 100000 ** (-2 * dim_indices / head_dim)            # [1, dim]
    angles = position * freq                                 # [seq_len, dim]
    cos_part = torch.cos(angles)
    sin_part = torch.sin(angles)
    rope_matrix = torch.cat([cos_part, sin_part], dim=1)     # [seq_len, 2*dim]
    return rope_matrix  # shape: [L, 2D]


def get_rope_null_space(cloud: torch.Tensor, tol: float = 1e-4):
    """
    Extract the base of nulll space.
    """
    
    seq_len, full_dim = cloud.shape
    cov = cloud.T @ cloud / seq_len  # [2*dim, 2*dim]
    eigvals, eigvecs = torch.linalg.eigh(cov)  # eigvecs: [2*dim, 2*dim]
    eigvals_sqrt = eigvals.sqrt()
    null_mask = eigvals_sqrt < tol
    null_rank = null_mask.sum().item()
    null_basis = eigvecs[:, null_mask].T  # shape: [null_rank, 2*dim]

    return {
        "eigvals": eigvals_sqrt,
        "null_rank": null_rank,
        "total_rank": full_dim,
        "ratio": null_rank / full_dim,
        "null_basis": null_basis  
    }


rope_4k = build_rope_matrix(seq_len=seq_len_4, head_dim=128)
rope_64k = build_rope_matrix(seq_len=seq_len_64, head_dim=128)
null_basis_4k = get_rope_null_space(rope_4k)["null_basis"]  # shape: [r, d]
null_basis_64k = get_rope_null_space(rope_64k)["null_basis"]  # shape: [r, d]

torch.Size([3964, 256])


In [None]:
# Project to null space
def apply_null_projection(q_list, k_tensor, null_basis):
    B = null_basis.T
    P = B @ B.T
    q_proj_list = [q @ P for q in q_list]
    k_proj = k_tensor @ P
    return q_proj_list, k_proj

Q_proj_4k_before_list, K_proj_4k_before = apply_null_projection(Q_4k_before_list, K_head_4k_before, null_basis_4k)
Q_proj_4k_rope_list, K_proj_4k_rope = apply_null_projection(Q_4k_rope_list, K_head_4k_rope, null_basis_4k)
Q_proj_64k_list, K_proj_64k = apply_null_projection(Q_64k_before_list, K_head_64k_before, null_basis_64k)
Q_proj_64k_rope_list, K_proj_64k_rope = apply_null_projection(Q_64k_rope_list, K_head_64k_rope, null_basis_4k)

print(K_proj_4k_before.shape)

torch.Size([3964, 128])


In [19]:
def energy_ratio(x, x_proj):
    return (x_proj.norm(dim=-1) ** 2).sum() / (x.norm(dim=-1) ** 2).sum()

print("K null space ratio:", energy_ratio(K_head_4k_rope, K_proj_4k_rope))

# for i, q_proj in enumerate(q_proj_null_list):
#     print(f"Q{i} null space ratio:", energy_ratio(q_rope_list[i], q_proj))

def null_energy_per_token(x, null_basis):
    """
    x: [seq_len, 2*dim] - RoPEed feature
    null_basis: [r, 2*dim]
    return: [seq_len], the energy of each projected token
    """
    P = null_basis.T @ null_basis  # [2*dim, 2*dim]
    x_proj = x @ P
    energy = x_proj.norm(dim=-1) ** 2
    return energy

def count_null_tokens(x, null_basis, threshold=0.3):
    """
    返回落入 null space 严重的 token 比例
    """
    P = null_basis.T @ null_basis
    x_proj = x @ P
    total_energy = (x.norm(dim=-1) ** 2)
    null_energy = (x_proj.norm(dim=-1) ** 2)
    ratio = null_energy / (total_energy + 1e-8)

    # 返回比例（多少 token 比例 > threshold）
    return (ratio > threshold).float().mean().item()

def null_basis_activation(x, null_basis):
    x_proj = x @ (null_basis.T @ null_basis)  # [seq_len, 2*dim]
    return x_proj.var(dim=0)  # [2*dim]

def null_direction_strength(x, null_basis):
    """
    投影到 null_basis 后的 norm 方差，表示各方向被用的程度
    """
    coeff = x @ null_basis.T  # [seq_len, r]
    return coeff.var(dim=0) 


print(null_basis_activation(K_head_4k_rope, null_basis_4k))


K null space ratio: tensor(0.0172)
tensor([0.0000e+00, 5.9341e-16, 7.2622e-16, 2.7578e-16, 1.5009e-15, 7.2654e-17,
        3.0101e-15, 2.0377e-15, 6.1580e-17, 1.0541e-15, 7.1247e-15, 2.1657e-14,
        1.3627e-16, 1.8053e-15, 6.2129e-15, 1.8232e-14, 3.4423e-14, 1.5561e-14,
        1.2528e-13, 1.6727e-13, 6.2550e-13, 8.8453e-14, 1.2905e-12, 4.1495e-13,
        1.2280e-11, 1.6710e-10, 1.8557e-09, 4.4164e-08, 1.6287e-06, 2.9889e-05,
        6.5884e-05, 6.8869e-04, 4.6110e-04, 2.8782e-04, 9.4966e-04, 5.1015e-04,
        3.6210e-04, 1.5103e-03, 4.1202e-05, 1.1422e-03, 4.0343e-04, 4.5132e-05,
        6.7421e-04, 1.1828e-03, 1.5680e-03, 5.3019e-04, 3.5367e-04, 7.7921e-04,
        1.4466e-04, 6.1196e-05, 2.4771e-04, 8.6770e-04, 3.0380e-04, 1.0692e-03,
        1.1446e-04, 2.7890e-04, 4.0228e-04, 2.6823e-04, 4.0093e-04, 1.0691e-03,
        4.7966e-04, 3.2873e-04, 2.2035e-04, 3.7001e-04, 2.6244e-16, 9.1444e-17,
        4.5153e-16, 1.1485e-15, 9.9567e-16, 4.5744e-15, 7.0398e-15, 9.8230e-16,
     

In [None]:
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import numpy as np

def analyze_cluster_separation(x_proj, num_clusters=4):
    kmeans = KMeans(n_clusters=num_clusters)
    labels = kmeans.fit_predict(x_proj)
    score = silhouette_score(x_proj, labels)
    return score

from sklearn.metrics import pairwise_distances

def qk_separation(q_list, k_tensor, method="cosine"):
    """
    compute avg distance between each Q head and K head
    return: List: the distance of each Q head and K head
    """
    sep_scores = []
    for q in q_list:
        dist = pairwise_distances(q.numpy(), k_tensor.numpy(), metric=method)
        avg_dist = dist.mean()
        sep_scores.append(avg_dist)
    return sep_scores  

def qq_cluster_separation(q_list, n_clusters=4):
    """
    Concat Q list to KMeans, label: head_id
    compute silhouette score to present the distance of each Q head
    """
    X = torch.cat(q_list, dim=0).numpy()  # [4*seq_len, head_dim]
    y = np.concatenate([[i]*len(q) for i, q in enumerate(q_list)])  # 0,0,...1,1,...
    return silhouette_score(X, y)

# result_4k_rope = qk_separation(Q_4k_rope_list, K_head_4k_rope)
# print(result_4k_rope)
# result_64k_rope = qk_separation(Q_64k_rope_list, K_head_64k_rope)
# # result_4k_proj_before = qk_separation(Q_proj_4k_before_list, K_proj_4k_before)
# # result_64k_proj_rope = qk_separation(Q_proj_64k_rope_list, K_proj_64k_rope)
# print(result_64k_rope)

[1.0871209, 1.1279013, 1.1306926, 1.1371487]
[1.0340893, 1.0736986, 1.0501916, 1.0837443]
