In [1]:
from modeling_llamadb import LlamaForCausalLMDB, apply_rotary_pos_emb

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = LlamaForCausalLMDB.from_pretrained("../llama-2-7b-hf")

Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.47it/s]


In [3]:
attn = model.model.layers[0].self_attn

In [4]:
attn.rotary_emb

LlamaRotaryEmbedding()

# Restore value from key

$$k_{cache} = x W_k R_m, v_{cache} = x W_v$$  
$$v_{cache} = k_{cache} R_m^{-1} W_k^{-1} W_v$$

## Reverse RoPE
$$
\text{Rot}(m \theta_0) = 
\begin{bmatrix}
    \text{cos}(m \theta_0) & -\text{sin}(m \theta_0) \\
    \text{sin}(m \theta_0) & \text{cos}(m \theta_0) 
\end{bmatrix}
$$

$$
\text{Rot}^{-1}(m \theta_0) = \text{Rot}(-m \theta_0) =
\begin{bmatrix}
    \text{cos}(m \theta_0) & \text{sin}(m \theta_0) \\
    -\text{sin}(m \theta_0) & \text{cos}(m \theta_0) 
\end{bmatrix}
$$

In [5]:
import torch

class RecursiveNamespace:
    def __init__(self):
        self.__dict__['_attributes'] = {}

    def __getattr__(self, name):
        if name not in self._attributes:
            self._attributes[name] = RecursiveNamespace()
        return self._attributes[name]

    def __setattr__(self, name, value):
        if isinstance(value, RecursiveNamespace):
            self._attributes[name] = value
        else:
            current = self
            *parts, last = name.split('.')
            for part in parts:
                current = getattr(current, part)
            current._attributes[last] = value

    def __repr__(self):
        return repr(self._attributes)
    
A = RecursiveNamespace()
A.device.type = 'cpu'
A.dtype = torch.float32

In [6]:
max_position_embeddings = 4096
position_ids = torch.arange(max_position_embeddings).view(1, max_position_embeddings)
cos, sin = attn.rotary_emb(A, position_ids)

In [58]:
bsz = 1
q_len = 4096
num_heads = 32
head_dim = 128

# set random seed
torch.manual_seed(0)
x = torch.randn(bsz, q_len, num_heads*head_dim)

key_states   = attn.k_proj(x).view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
query_states = attn.q_proj(x).view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
value_states = attn.v_proj(x).view(bsz, q_len, num_heads, head_dim).transpose(1, 2)


In [59]:
query_rot, key_rot = apply_rotary_pos_emb(query_states, key_states, cos, sin)

In [60]:
query_rerot, key_rerot = apply_rotary_pos_emb(query_rot, key_rot, cos, -sin)

In [61]:
torch.allclose(key_states, key_rerot, atol=1e-6)

True

In [83]:
cos_cpu = cos.cpu()
sin_cpu = sin.cpu()
Wq_cpu = attn.q_proj.weight.cpu()
Wk_cpu = attn.k_proj.weight.cpu()
Wv_cpu = attn.v_proj.weight.cpu()

torch.save(
    (
        cos_cpu,
        sin_cpu,
        Wq_cpu,
        Wk_cpu,
        Wv_cpu,
        # x
    ),
    "key2value_demo.pt"
)

## Reverse nn.Linear

In [63]:
print(attn.k_proj.bias)

None


In [64]:
print(attn.k_proj.weight)

Parameter containing:
tensor([[-0.0162,  0.0079, -0.0013,  ...,  0.0166, -0.0099, -0.0135],
        [ 0.0192,  0.0015,  0.0036,  ..., -0.0211,  0.0152,  0.0234],
        [-0.0236, -0.0217,  0.0017,  ...,  0.0150, -0.0165, -0.0118],
        ...,
        [ 0.0128, -0.0007, -0.0008,  ...,  0.0002,  0.0031,  0.0081],
        [-0.0056,  0.0173, -0.0032,  ..., -0.0032,  0.0115, -0.0110],
        [ 0.0037, -0.0021,  0.0013,  ...,  0.0070, -0.0115,  0.0095]],
       requires_grad=True)


In [65]:
attn.k_proj.weight @ torch.pinverse(attn.k_proj.weight)

tensor([[ 1.0000e+00, -1.9431e-05, -5.2676e-05,  ...,  5.3465e-05,
         -1.1873e-04,  4.1932e-05],
        [-1.2130e-05,  1.0001e+00,  7.7486e-06,  ..., -2.2125e-04,
          3.1185e-04,  2.8515e-04],
        [-4.2040e-05,  1.5593e-04,  9.9993e-01,  ..., -2.3746e-04,
          5.5456e-04,  6.5470e-04],
        ...,
        [ 6.8635e-05,  6.8426e-05,  4.5431e-04,  ...,  9.9842e-01,
         -3.8624e-05, -4.5717e-04],
        [-6.1132e-06,  1.0967e-04,  1.2416e-04,  ..., -4.0436e-04,
          1.0003e+00,  9.1076e-05],
        [-2.5466e-05, -1.7881e-06,  9.0897e-05,  ..., -6.6257e-04,
          7.4387e-05,  9.9996e-01]], grad_fn=<MmBackward0>)

In [66]:
attn.k_proj.weight @ torch.inverse(attn.k_proj.weight)

tensor([[ 1.0000e+00,  1.1921e-07, -5.7220e-06,  ...,  4.3631e-05,
          2.8610e-06,  1.0014e-05],
        [-8.3447e-07,  1.0000e+00,  6.1989e-06,  ...,  4.3869e-05,
         -1.0014e-05,  1.5259e-05],
        [ 1.7881e-07,  4.7684e-07,  1.0000e+00,  ...,  2.0981e-05,
          2.4080e-05,  3.0696e-06],
        ...,
        [ 4.7982e-06,  7.4267e-05, -4.1872e-05,  ...,  9.9957e-01,
         -2.8133e-04, -6.8893e-04],
        [-7.2937e-05, -1.8716e-05, -9.3579e-04,  ..., -1.7776e-03,
          9.9938e-01, -5.7840e-04],
        [-6.6906e-06, -5.9485e-05,  3.6937e-04,  ...,  2.2829e-03,
         -8.4162e-05,  1.0001e+00]], grad_fn=<MmBackward0>)

In [67]:
w = attn.k_proj.weight.to(torch.double)
tmp = w @ torch.inverse(w)
tmp.shape

torch.Size([4096, 4096])

In [68]:
torch.allclose(w @ torch.inverse(w), torch.eye(4096, dtype=torch.double), atol=1e-9)

True

In [69]:
dtype = torch.float
size = 1024

# M = torch.randn(128, 128)
M = torch.randn(size, size, dtype=dtype)

res = M @ torch.inverse(M)

print(torch.abs(res - torch.eye(size, dtype=dtype)).max())

tensor(9.7997e-05)


In [70]:
import numpy as np
dtype = np.float128

M = np.random.randn(size, size).astype(dtype)

res = M @ np.linalg.inv(M)
print(np.abs(res - np.eye(size)).max())

TypeError: array type float128 is unsupported in linalg

## Combination of `inverse(k_proj)` and `v_proj`

In [71]:
dtype = torch.float

# to dtype
key_rot = key_rot.to(dtype)
query_rot = query_rot.to(dtype)
cos = cos.to(dtype)
sin = sin.to(dtype)
W_k_double = attn.k_proj.weight.T.to(torch.double)
W_v_double = attn.v_proj.weight.T.to(torch.double)

W_k_double_inv = torch.inverse(W_k_double).to(dtype)
W_v = W_v_double.to(dtype)

_, key_rerot = apply_rotary_pos_emb(query_rot, key_rot, cos, -sin)
key_rerot_full_head = key_rerot.transpose(1, 2).contiguous().view(bsz, q_len, num_heads*head_dim)


transition_matrix = (W_k_double_inv @ W_v).to(dtype)

value_restored = (key_rerot_full_head @ transition_matrix).view(bsz, q_len, num_heads, head_dim).transpose(1, 2)



In [72]:
torch.allclose(key_rerot_full_head, x.to(dtype), atol=1e-3)

False

In [73]:
key_rerot_full_head

tensor([[[-0.6250,  0.9579,  0.4295,  ..., -2.0997,  1.0471,  0.1994],
         [-2.7079,  3.2356, -2.7173,  ...,  0.1224, -0.8601, -0.6023],
         [-0.3822, -0.0718,  0.2622,  ...,  0.3680,  0.2994,  1.1320],
         ...,
         [-1.9004,  2.7978, -2.1939,  ..., -0.8557,  0.2071,  0.1726],
         [ 1.0422, -1.9806,  2.1556,  ..., -0.0361,  0.3327, -0.5052],
         [-0.6820,  0.9117, -1.3218,  ...,  1.0008, -1.4883,  1.3174]]],
       grad_fn=<ViewBackward0>)

In [74]:
torch.allclose(key_rerot, key_states.to(dtype), atol=1e-6)

True

In [75]:
torch.allclose(value_restored, value_states.to(dtype), atol=1e-1)

True

In [76]:
torch.abs(value_restored - value_states.to(dtype)).max()

tensor(0.0154, grad_fn=<MaxBackward1>)

In [77]:
torch.abs(value_restored - value_states.to(dtype)).mean()

tensor(0.0007, grad_fn=<MeanBackward0>)

In [None]:
torch.count_nonzero(torch.abs(value_restored - value_states) > 1e-3) / value_states.numel()

tensor(0.2260)

In [None]:
value_states[0, 0, 0, :10]

tensor([ 0.0058,  0.2702, -0.3478, -0.6390, -1.2680, -0.7903, -0.4243, -0.1293,
         0.1100, -1.5165], grad_fn=<SliceBackward0>)

In [None]:
value_restored[0, 0, 0, :10]

tensor([-1.2167,  0.5526,  0.0596, -0.3388, -0.1308,  0.3450,  0.0022, -0.4988,
        -0.6386,  0.6053], dtype=torch.float64, grad_fn=<SliceBackward0>)

In [None]:
key_rerot.transpose(1, 2).contiguous().view(bsz, q_len, num_heads*head_dim)

tensor([[[-2.9145,  2.6446, -1.7584,  ..., -1.0911,  0.8427, -0.7621],
         [-0.5875,  0.6003, -0.1183,  ..., -0.3745, -0.6578,  0.5973],
         [ 0.6806, -0.7473,  0.4930,  ...,  0.0648,  2.8428, -0.5266],
         ...,
         [-0.6588, -0.4816,  0.4644,  ..., -1.5955, -1.0929,  1.5297],
         [ 1.3331, -0.6762,  0.8802,  ...,  1.5387, -1.0052,  0.2580],
         [-0.2668,  0.3794, -0.6736,  ..., -0.6867, -2.0124, -0.3171]]],
       dtype=torch.float64, grad_fn=<ViewBackward0>)

In [None]:
key_rerot_full_head

tensor([[[-2.9145,  2.6446, -1.7584,  ..., -1.0911,  0.8427, -0.7621],
         [-0.5875,  0.6003, -0.1183,  ..., -0.3745, -0.6578,  0.5973],
         [ 0.6806, -0.7473,  0.4930,  ...,  0.0648,  2.8428, -0.5266],
         ...,
         [-0.6588, -0.4816,  0.4644,  ..., -1.5955, -1.0929,  1.5297],
         [ 1.3331, -0.6762,  0.8802,  ...,  1.5387, -1.0052,  0.2580],
         [-0.2668,  0.3794, -0.6736,  ..., -0.6867, -2.0124, -0.3171]]],
       dtype=torch.float64, grad_fn=<ViewBackward0>)

In [None]:
x

tensor([[[ 7.2282e-01,  1.0803e+00, -4.5230e-01,  ...,  1.6503e-01,
          -9.3736e-01, -5.0019e-01],
         [-1.6404e+00,  4.6956e-01, -4.5691e-01,  ...,  1.0660e+00,
           9.1425e-02, -1.9353e-03],
         [ 1.3318e+00, -6.9224e-01, -8.6427e-02,  ...,  1.0561e-01,
          -2.0532e+00, -9.0696e-01],
         ...,
         [-1.0923e+00, -8.0692e-01, -2.3702e-01,  ..., -4.9561e-01,
          -3.8218e-01, -2.7865e-01],
         [-1.0283e+00,  1.6916e+00, -1.3834e+00,  ...,  1.0330e+00,
          -3.3758e-01,  9.6891e-02],
         [ 3.3261e-01,  1.6135e+00, -2.3792e-01,  ..., -4.8955e-02,
          -7.5147e-01, -1.0646e-01]]])

In [None]:
import numpy as np

52 * np.log(2) / np.log(10)

15.65355977452702

In [None]:
23 * np.log(2) / np.log(10)

6.923689900271567

In [None]:
import numpy as np
from concurrent.futures import ThreadPoolExecutor

def gauss_jordan_inverse(matrix):
    n = matrix.shape[0]
    # Create an augmented matrix with the identity matrix on the right
    augmented_matrix = np.hstack((matrix, np.eye(n)))
    
    # Apply Gauss-Jordan elimination
    for i in range(n):
        # Make the diagonal contain all ones
        diag_element = augmented_matrix[i, i]
        augmented_matrix[i] = augmented_matrix[i] / diag_element
        
        # Make the other elements in the current column zero
        for j in range(n):
            if i != j:
                row_factor = augmented_matrix[j, i]
                augmented_matrix[j] = augmented_matrix[j] - row_factor * augmented_matrix[i]
    
    # The right half of the augmented matrix is now the inverse
    inverse_matrix = augmented_matrix[:, n:]
    return inverse_matrix

def gauss_jordan_inverse_parallel(matrix):
    n = matrix.shape[0]
    augmented_matrix = np.hstack((matrix, np.eye(n)))

    def process_row(i):
        diag_element = augmented_matrix[i, i]
        augmented_matrix[i] = augmented_matrix[i] / diag_element

        def zero_out_column(j):
            if i != j:
                row_factor = augmented_matrix[j, i]
                augmented_matrix[j] = augmented_matrix[j] - row_factor * augmented_matrix[i]

        with ThreadPoolExecutor() as executor:
            executor.map(zero_out_column, range(n))

    for i in range(n):
        process_row(i)

    inverse_matrix = augmented_matrix[:, n:]
    return inverse_matrix

# Example usage
size = 1024
dtype = np.float128
A = np.random.randn(size, size).astype(dtype)
A_inv = gauss_jordan_inverse_parallel(A)

# print("Original matrix:")
# print(A)
# print("Inverse matrix:")
# print(A_inv)

res = np.matmul(A, A_inv)
print("max error")
print(np.abs(res - np.eye(size)).max())


max error
2.962434713770639716e-14
