In [2]:
import torch

# Wq, Wk, Wv是nn.Linear的权重，默认是行优先的，需要转置
cos, sin, Wq, Wk, Wv = torch.load('key2value_demo.pt')
Wq = Wq.t()
Wk = Wk.t()
Wv = Wv.t()

# Restore value from key

$$k_{cache} = x W_k R_m, v_{cache} = x W_v$$  
$$v_{cache} = k_{cache} R_m^{-1} W_k^{-1} W_v$$

## Reverse RoPE
$$
\text{Rot}(m \theta_0) = 
\begin{bmatrix}
    \text{cos}(m \theta_0) & -\text{sin}(m \theta_0) \\
    \text{sin}(m \theta_0) & \text{cos}(m \theta_0) 
\end{bmatrix}
$$

$$
\text{Rot}^{-1}(m \theta_0) = \text{Rot}(-m \theta_0) =
\begin{bmatrix}
    \text{cos}(m \theta_0) & \text{sin}(m \theta_0) \\
    -\text{sin}(m \theta_0) & \text{cos}(m \theta_0) 
\end{bmatrix}
$$

In [3]:
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

In [4]:
bsz = 1
q_len = 4096
num_heads = 32
head_dim = 128

torch.manual_seed(0)
x = torch.randn(bsz, q_len, num_heads*head_dim)

query_states = (x @ Wq).view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
key_states   = (x @ Wk).view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
value_states = (x @ Wv).view(bsz, q_len, num_heads, head_dim).transpose(1, 2)

In [5]:
query_rot, key_rot = apply_rotary_pos_emb(query_states, key_states, cos, sin)

In [6]:
query_rerot, key_rerot = apply_rotary_pos_emb(query_rot, key_rot, cos, -sin)

In [7]:
torch.allclose(key_states, key_rerot, atol=1e-5)

True

## Reverse nn.Linear (weights)

In [8]:
dtype = torch.double
w = Wk.to(dtype)
torch.allclose(w @ torch.inverse(w), torch.eye(q_len, dtype=dtype), atol=1e-9)

True

In [9]:
dtype = torch.float
w = Wk.to(dtype)
torch.allclose(w @ torch.inverse(w), torch.eye(q_len, dtype=dtype), atol=1e-5)

False

In [10]:
import numpy as np
dtype = np.float128
size = 128
M = np.random.randn(size, size).astype(dtype)

res = M @ np.linalg.inv(M)
print(np.abs(res - np.eye(size)).max())

TypeError: array type float128 is unsupported in linalg

好像为不同平台适配128位很困难：https://github.com/pytorch/pytorch/issues/48338

## Restore v from k

In [11]:
dtype = torch.float

# to dtype
key_rot = key_rot.to(dtype)
query_rot = query_rot.to(dtype)
cos = cos.to(dtype)
sin = sin.to(dtype)

W_k_double = Wk.to(torch.double)
W_k_double_inv = torch.inverse(W_k_double).to(dtype)
W_v = Wv.to(dtype)

_, key_rerot = apply_rotary_pos_emb(query_rot, key_rot, cos, -sin)
key_rerot_full_head = key_rerot.transpose(1, 2).contiguous().view(bsz, q_len, num_heads*head_dim)


transition_matrix = (W_k_double_inv @ W_v).to(dtype)

value_restored = (key_rerot_full_head @ transition_matrix).view(bsz, q_len, num_heads, head_dim).transpose(1, 2)



In [12]:
torch.allclose(value_restored, value_states.to(dtype), atol=1e-1)

True

In [13]:
torch.allclose(key_rerot, key_states.to(dtype), atol=1e-5)

True

In [14]:
torch.allclose(value_restored, value_states.to(dtype), atol=1e-1)

True

In [15]:
torch.allclose(value_restored, value_states.to(dtype), atol=1e-2)


False

# Reconstruct vectors from IndexPQ

## 使用IndexFlatIP确保正确

In [16]:
import faiss
index = faiss.IndexFlatIP(128)
index.add(key_states[0, 0].cpu().detach().numpy())
q = query_states[0, 0, :2, :].cpu().detach().numpy() 
D, I = index.search(q, 4096)
# index.make_direct_map()
k = index.reconstruct(int(I[0, 0]))

In [17]:
torch.allclose(torch.Tensor(k), key_states[0, 0, I[0, 0]].cpu(), atol=1e-5)

True

In [18]:
import faiss
from faiss import read_index
def restore_index(layer_idx: int, head_idx: int, toGPU=False):
    """
    Restore a FAISS IndexPQ from disk.
    
    Args:
    layer_idx (int): The index of the layer in the model.
    head_idx (int): The index of the attention head within the layer.

    Returns:
    faiss.IndexPQ: The restored FAISS index.
    """
    index_filename = f"../llama_pqindex/PTB/key_{layer_idx}_{head_idx}.ivf" # TODO: use ivfpq with nlist=1 to move to GPU
    # index_filename = f"../pq_index/pq_{layer_idx}_{head_idx}.index"
    idx = read_index(index_filename)
    if toGPU is True:
        # move the index to gpu
        res = faiss.StandardGpuResources()
        res.noTempMemory() #TODO: install nightly build https://github.com/facebookresearch/faiss/issues/3259
        idx = faiss.index_cpu_to_gpu(res, 0, idx)
        # print(f"pq_{layer_idx}_{head_idx}.index moved to GPU.")

        # copy to ivf index
        
        
    return idx

In [19]:
index = restore_index(0, 0)

In [20]:
index.reconstruct_batch

<bound method handle_Index.<locals>.replacement_reconstruct_batch of <faiss.swigfaiss_avx2.IndexIVFPQ; proxy of <Swig Object of type 'faiss::IndexIVFPQ *' at 0x7f91ac14fab0> >>

In [21]:
key_states.shape

torch.Size([1, 32, 4096, 128])

In [22]:
index.add(key_states[0, 0].cpu().detach().numpy())

In [23]:
q = query_states[0, 0, :2, :].cpu().detach().numpy() 

In [24]:
q.shape

(2, 128)

In [25]:
D, I = index.search(q, 4096)

In [26]:
D

array([[ 152.4641  ,  147.69005 ,  145.22713 , ..., -170.49046 ,
        -170.6556  , -185.49026 ],
       [ 124.03221 ,  119.00385 ,  118.868866, ...,  -99.34771 ,
         -99.79555 ,  -99.87963 ]], dtype=float32)

In [27]:
I

array([[ 848, 2971, 3658, ..., 3024,  661, 2814],
       [1287,  145, 2400, ..., 3311, 1976, 3658]])

In [37]:

I = torch.tensor(I)

  I = torch.tensor(I)


In [42]:
I = I[:1]

In [43]:
query_len = 1
seq_len = 4096
attn_bias = torch.zeros(query_len, seq_len, device=query_states.device, dtype=query_states.dtype)


In [47]:
attn_bias[I]

IndexError: index 1076 is out of bounds for dimension 0 with size 1

In [29]:
torch.softmax(torch.tensor(D), dim=1)

tensor([[9.8966e-01, 8.3588e-03, 7.1205e-04,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [9.8152e-01, 6.4285e-03, 5.6167e-03,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])

In [None]:
index.make_direct_map()

In [None]:
k = index.reconstruct(int(I[0, 0]))

In [None]:
key_states[0, 0, int(I[0, 0])]

tensor([ 1.1897, -2.4531,  2.2896,  1.5974,  1.0974, -1.7216,  2.5351, -0.2188,
        -2.1048,  2.4843, -3.3355,  1.8666, -1.3950, -0.3860,  2.2833,  0.6918,
         1.1429,  0.8756, -0.3545,  0.6246, -0.3108, -0.1326, -0.4309,  0.7417,
         1.7682,  1.5424,  0.9124, -1.8929,  2.0146,  2.9387,  1.1566, -1.3198,
         1.3583, -1.7361,  1.3166, -0.1700,  1.6137,  1.4473,  1.6422, -0.1327,
         1.4049, -1.7469, -0.4068, -2.0375, -2.3081,  1.4111,  0.7240, -0.4021,
         0.5501, -0.8176,  0.6696,  0.3275,  0.9077, -2.2053,  2.0409,  0.6902,
         0.6504, -1.5217, -2.3723,  0.7258, -2.2717, -1.5013, -1.7850, -0.5082,
        -2.3895,  1.4159,  2.4559, -0.0972,  2.2484, -2.9504, -0.7898,  0.2306,
        -2.1079,  2.4772, -0.4761,  1.9000, -1.1578, -2.1209,  1.7029,  0.0284,
        -0.2022, -1.7919,  1.5018, -0.9045, -1.5731,  1.2282,  2.0642, -0.1675,
        -0.5248, -0.8281,  1.7748,  1.3238,  0.0307,  2.6912, -0.9722, -1.6677,
         2.0722,  2.1624, -1.9113,  0.97

In [None]:
k

array([ 1.81850350e+00, -1.57141471e+00,  1.34008694e+00,  1.25246370e+00,
        1.24328887e+00, -9.89186093e-02, -1.01044849e-01,  6.85416758e-02,
        3.40758497e-03,  9.34282601e-01, -2.71085715e+00,  5.44420294e-02,
       -5.53655505e-01, -6.08758986e-01,  8.17654654e-02, -2.80177921e-01,
        4.26508397e-01,  1.44758210e-01, -2.11644962e-01,  5.20126879e-01,
       -6.05017960e-01, -4.58945110e-02, -3.70356232e-01,  5.84389210e-01,
        1.61656463e+00,  5.91018379e-01,  1.13304138e+00,  1.18021257e-01,
        1.13816106e+00,  1.33691120e+00,  1.61989617e+00, -1.60063279e+00,
        1.70022321e+00, -1.38676214e+00, -3.43358099e-01,  7.92919993e-02,
       -1.57923505e-01,  9.59423482e-01,  4.93834257e-01, -3.71007442e-01,
        1.37167108e+00, -6.86888814e-01, -5.55884600e-01, -1.34201860e+00,
       -1.16221464e+00,  5.61712543e-03,  5.36624551e-01,  1.16127884e+00,
        7.61983693e-01, -5.10609627e-01, -6.17755726e-02,  1.04267396e-01,
        9.02703926e-02, -