In [1]:
import torch
from torch.nn import Linear, Module
from torch import softmax
import time

## Basic Self Attention


- input X, shape is (bs, seqlen, hidden_size)
- weight Wq, Wk, Wv, shape is (hidden_size， hidden_size) 
- Q,K,V shape is (bs, seqlen, hidden_size)
- attention_score shape is (bs, seqlen, seqlen), second dim is each token, third dim is the score of all tokens against this token 
- attention_output shape is (bs, seqlen, hidden_size)

note: 
`torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)`
Applies an affine linear transformation to the incoming data: 
```y=xA^T+b```

to calculate Q K^T, need:
- transpose K on last 2 dims (seqlen and hidden_size)
- use matmul (@)

note:
`torch.matmul(input, other, *, out=None)`
The behavior depends on the dimensionality of the tensors as follows:
- If both tensors are 1-dimensional, the dot product (scalar) is returned.
- If both arguments are 2-dimensional, the matrix-matrix product is returned.
- If the first argument is 1-dimensional and the second argument is 2-dimensional, a 1 is prepended to its dimension for the purpose of the matrix multiply. After the matrix multiply, the prepended dimension is removed.
- If the first argument is 2-dimensional and the second argument is 1-dimensional, the matrix-vector product is returned.
- If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where N > 2), then a **batched matrix multiply** is returned. 




In [2]:
class SelfAttention(Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.Wq = Linear(hidden_size, hidden_size, bias=False)
        self.Wk = Linear(hidden_size, hidden_size, bias=False)
        self.Wv = Linear(hidden_size, hidden_size, bias=False)
        self.Wo = Linear(hidden_size, hidden_size, bias=False)
        self.scale = hidden_size ** (1/2)
    
    def forward(self, x):
        Q = self.Wq(x)
        K = self.Wk(x)
        V = self.Wv(x)
        
        # some model apply position embedding here
        # cos, sin = self.rotary_emb(V, seq_len=seq_len)
        # Q, K = apply_rotary_pos_emb(Q, K, cos, sin, position_ids)
        
        attention_score = softmax(Q @ K.transpose(-1, -2) / self.scale, dim=-1)
        return self.Wo(attention_score @ V)

In [3]:
layer = SelfAttention(512)

In [4]:
layer(torch.rand(2, 32, 512))

tensor([[[-0.1180,  0.2132,  0.1698,  ..., -0.2668, -0.2701, -0.0347],
         [-0.1179,  0.2134,  0.1706,  ..., -0.2672, -0.2701, -0.0357],
         [-0.1184,  0.2139,  0.1715,  ..., -0.2665, -0.2702, -0.0363],
         ...,
         [-0.1179,  0.2135,  0.1718,  ..., -0.2665, -0.2703, -0.0350],
         [-0.1187,  0.2144,  0.1706,  ..., -0.2672, -0.2704, -0.0350],
         [-0.1178,  0.2132,  0.1705,  ..., -0.2668, -0.2698, -0.0352]],

        [[-0.1145,  0.1760,  0.1555,  ..., -0.2653, -0.2918, -0.0416],
         [-0.1145,  0.1761,  0.1544,  ..., -0.2650, -0.2915, -0.0424],
         [-0.1147,  0.1752,  0.1550,  ..., -0.2648, -0.2924, -0.0427],
         ...,
         [-0.1136,  0.1760,  0.1560,  ..., -0.2648, -0.2920, -0.0419],
         [-0.1148,  0.1762,  0.1554,  ..., -0.2646, -0.2922, -0.0423],
         [-0.1141,  0.1754,  0.1560,  ..., -0.2648, -0.2918, -0.0421]]],
       grad_fn=<UnsafeViewBackward0>)

## MHA


- the Wq Qk Wv are still the shape (hidden_size, hidden_size) on init, but will reshape to (hidden_size, num_heads, head_dim) when calling forward 
- Q K V shape will be (batch_size, seqlen, num_heads, head_dim)
- Q K^T needs output of (batch_size, num_heads, seqlen, seqlen), so transpose seqlen and num_heads
- scale use head_dim, not hidden_size
- concat all heads after attention_score @ V (batch_size, num_heads, seqlen, head_dim) -> (batch_size, seqlen, num_heads * head_dim)


In [5]:
class MultiHeadSelfAttention(SelfAttention):
    def __init__(self, hidden_size, num_heads):
        super().__init__(hidden_size)
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.scale = self.head_dim ** (1/2)
        
    def forward(self, x):
        bs, seqlen, _ = x.shape
        Q = self.Wq(x).view(bs, seqlen, self.num_heads, self.head_dim)
        K = self.Wk(x).view(bs, seqlen, self.num_heads, self.head_dim)
        V = self.Wv(x).view(bs, seqlen, self.num_heads, self.head_dim)
        
        attention_score = softmax(Q.transpose(1,2) @ K.transpose(1,2).transpose(-1, -2) / self.scale, dim=-1)
        attention_output = (attention_score @ V.transpose(1,2)).transpose(1,2).contiguous().view(bs, seqlen, -1)
        return self.Wo(attention_output)

In [6]:
layer = MultiHeadSelfAttention(512, 8)

In [7]:
layer(torch.rand(2, 32, 512))

tensor([[[ 0.1221,  0.0139, -0.1489,  ..., -0.1508,  0.2944, -0.0723],
         [ 0.1221,  0.0140, -0.1486,  ..., -0.1510,  0.2942, -0.0720],
         [ 0.1227,  0.0129, -0.1486,  ..., -0.1503,  0.2948, -0.0726],
         ...,
         [ 0.1217,  0.0135, -0.1492,  ..., -0.1502,  0.2945, -0.0721],
         [ 0.1222,  0.0131, -0.1487,  ..., -0.1500,  0.2934, -0.0715],
         [ 0.1216,  0.0141, -0.1493,  ..., -0.1506,  0.2947, -0.0721]],

        [[ 0.1366,  0.0156, -0.1249,  ..., -0.1414,  0.2669, -0.0616],
         [ 0.1363,  0.0158, -0.1251,  ..., -0.1407,  0.2666, -0.0614],
         [ 0.1361,  0.0157, -0.1247,  ..., -0.1406,  0.2664, -0.0623],
         ...,
         [ 0.1355,  0.0171, -0.1241,  ..., -0.1416,  0.2660, -0.0616],
         [ 0.1362,  0.0158, -0.1249,  ..., -0.1411,  0.2662, -0.0621],
         [ 0.1356,  0.0159, -0.1242,  ..., -0.1404,  0.2664, -0.0619]]],
       grad_fn=<UnsafeViewBackward0>)

## GQA

- the Wq is the original size, Wk and Wv should be smaller
- Wq is reshaped as (hidden_size, num_heads, head_dim), Wk and Wv is reshaped as (hidden_size, num_kv_heads, head_dim)
- Q is (bs, seqlen, num_heads, head_dim), KV is  (bs, seqlen, num_kv_heads, head_dim), so need to interleave KV for group_size times, after repeating, KV is (bs, seqlen, num_heads, head_dim)
- attention_score is still (bs, num_heads, seqlen, seqlen), after matmul V^T, is (bs, num_heads, seqlen, head_dim)

note:
`torch.repeat_interleave(input, repeats, dim=None, *, output_size=None)`

- input (Tensor) – the input tensor.
- repeats (Tensor or int) – The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis.
- dim (int, optional) – The dimension along which to repeat values. By default, use the flattened input array, and return a flat output array.

```
>>> torch.repeat_interleave(y, 3, dim=1)
tensor([[1, 1, 1, 2, 2, 2],
        [3, 3, 3, 4, 4, 4]])
```


In [8]:
class GroupedQueryAttention(Module):
    def __init__(self, hidden_size, num_heads, num_kv_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.scale = self.head_dim ** (1/2)
        self.num_kv_heads = num_kv_heads
        self.group_size = num_heads // num_kv_heads
        
        self.Wq = Linear(hidden_size, hidden_size, bias=False)
        self.Wk = Linear(hidden_size, num_kv_heads * self.head_dim, bias=False)
        self.Wv = Linear(hidden_size, num_kv_heads * self.head_dim, bias=False)
        self.Wo = Linear(hidden_size, hidden_size, bias=False)
    
    def forward(self, x):
        bs, seqlen, _ = x.shape
        Q = self.Wq(x).view(bs, seqlen, self.num_heads, self.head_dim)
        K = self.Wk(x).view(bs, seqlen, self.num_kv_heads, self.head_dim)
        V = self.Wv(x).view(bs, seqlen, self.num_kv_heads, self.head_dim)
        
        K = K.repeat_interleave(self.group_size, dim=2)
        V = V.repeat_interleave(self.group_size, dim=2)
        
        attention_score = (Q.transpose(1,2) @ K.transpose(1,2).transpose(-1,-2) / self.scale)
        attention_output = (attention_score @ V.transpose(1,2)).transpose(1,2).contiguous().view(bs, seqlen, -1)
        return self.Wo(attention_output)
        

In [9]:
layer = GroupedQueryAttention(512, 32, 4)

In [10]:
layer(torch.rand(2, 32, 512))

tensor([[[ 1.0130e-01, -2.2926e-03, -5.7200e-01,  ...,  2.5513e-01,
           7.5714e-01, -3.3590e-01],
         [ 7.3070e-01,  4.1067e-01,  3.1527e-01,  ...,  6.7676e-01,
           5.1814e-01, -5.2487e-02],
         [ 1.6133e-01,  2.3641e-01, -2.5082e-03,  ...,  4.0818e-01,
           5.3284e-01, -5.2782e-02],
         ...,
         [ 3.2272e-01,  8.3971e-01,  2.9630e-02,  ...,  6.1965e-01,
          -1.9962e-02,  1.3014e-01],
         [ 2.2647e-01,  6.8219e-01, -4.7240e-01,  ...,  5.8010e-01,
           3.3260e-01, -1.9878e-03],
         [-4.8943e-02,  2.1508e-01, -3.7732e-01,  ...,  2.0345e-01,
           3.0802e-01, -1.1944e-01]],

        [[ 5.9472e-01,  4.0098e-01, -1.4606e-01,  ..., -1.6119e-01,
           6.3656e-01, -4.6469e-04],
         [-3.2492e-03,  5.7523e-01, -8.7522e-02,  ...,  3.2010e-01,
           5.9767e-01,  1.5955e-01],
         [ 3.6261e-01,  5.3366e-01, -8.5597e-02,  ...,  5.9380e-01,
           7.7183e-01,  1.5350e-01],
         ...,
         [ 5.9433e-01,  3

## KV Cache

- the forward just need to handle 1 token at a time, means seqlen==1, so the input X shape is (bs, 1, hidden_size)
- past_kv = (past_key, past_value), both in shape of (bs, seqlen, hidden_size)
- for new kv, concat past and this token's on dim=1, so now the Q is (bs, 1, hidden_size), new KV is (bs, seqlen+1, hidden_size) 

In [11]:
class SelfAttentionWithKVCache(SelfAttention):
    
    def forward(self, x, cached_kv=None):
        bs, _, hidden_size = x.shape
        if cached_kv:
            cached_key, cached_value = cached_kv
        else:
            cached_key = torch.zeros(bs, 0, hidden_size)
            cached_value = torch.zeros(bs, 0, hidden_size)
        
        Q = self.Wq(x)
        K = self.Wk(x)
        V = self.Wv(x)
        
        cached_key = torch.cat([cached_key, K], dim=1)
        cached_value = torch.cat([cached_value, V], dim=1)
        
        attention_score = softmax(Q @ cached_key.transpose(-1, -2) / self.scale, dim=-1)
        output = self.Wo(attention_score @ cached_value)
        
        return output, (cached_key, cached_value)

In [12]:
layer = SelfAttentionWithKVCache(512)

### Time with KV Cache

In [13]:
start_time = time.time()
output, kv_cache = layer(torch.rand(2, 4096, 512))
prefilling_time_cost = time.time() - start_time
print(prefilling_time_cost)

0.1573340892791748


In [14]:
for _ in range(5):
    start_time = time.time()
    output, kv_cache = layer(torch.rand(2, 1, 512), kv_cache)
    decoding_time_cost = time.time() - start_time
    print(decoding_time_cost)
    print(output.shape, kv_cache[0].shape, kv_cache[1].shape)

0.011071920394897461
torch.Size([2, 1, 512]) torch.Size([2, 4097, 512]) torch.Size([2, 4097, 512])
0.005232810974121094
torch.Size([2, 1, 512]) torch.Size([2, 4098, 512]) torch.Size([2, 4098, 512])
0.0035789012908935547
torch.Size([2, 1, 512]) torch.Size([2, 4099, 512]) torch.Size([2, 4099, 512])
0.003389120101928711
torch.Size([2, 1, 512]) torch.Size([2, 4100, 512]) torch.Size([2, 4100, 512])
0.006175041198730469
torch.Size([2, 1, 512]) torch.Size([2, 4101, 512]) torch.Size([2, 4101, 512])


### Time without KV Cache


In [15]:
for _ in range(5):
    start_time = time.time()
    output, _ = layer(torch.rand(2, 4096 + _, 512), None)
    decoding_time_cost = time.time() - start_time
    print(decoding_time_cost)
    print(output.shape)

0.13533520698547363
torch.Size([2, 4096, 512])
0.1895592212677002
torch.Size([2, 4097, 512])
0.18368291854858398
torch.Size([2, 4098, 512])
0.14872288703918457
torch.Size([2, 4099, 512])
0.14966773986816406
torch.Size([2, 4100, 512])


## Heterogeneous Attention


- the code demo is based on basic self attention, instead of the original implementation on GQA
- the prefilled K and V both in shape (bs, seqlen, hidden_size)

note:

`torch.index_select(input, dim, index, *, out=None) → Tensor`

- Returns a new tensor which indexes the input tensor along dimension dim using the entries in index which is a LongTensor.
- The returned tensor has the same number of dimensions as the original tensor (input). The dimth dimension has the same size as the length of index; other dimensions have the same size as in the original tensor.




In [16]:
class HeterogeneousKVCache:
    def __init__(self, layer_idx):
        self.layer_idx = layer_idx
        
        self.sink_token_count = 128
        self.recent_token_count = 128
        
        self.sink_kv = None
        self.context_kv = None
        self.recent_kv = None
        
        self.is_prefilled = False
        
    
    def offload(self, prefilled_kv):
        assert self.is_prefilled is False
        prefilled_keys, prefilled_values = prefilled_kv # assume the seqlen is long enough, e.g. 100K
        
        self.sink_kv = (prefilled_keys[:, :self.sink_token_count, :], prefilled_values[:, :self.sink_token_count, :])
        self.context_kv = (prefilled_keys[:, self.sink_token_count: -self.sink_token_count, :], prefilled_values[:, self.sink_token_count: -self.sink_token_count, :])
        self.recent_kv = (prefilled_keys[:, -self.sink_token_count:, :], prefilled_values[:, -self.sink_token_count:, :])
            
        
        if self.layer_idx >= 2 and self.layer_idx < 31: # assume this is a 32-layer model
            context_keys, context_values = self.context_kv
            # context_keys = context_keys.cpu()
            # context_keys = build_vector_store(context_keys)
        
        self.is_prefilled = True
            
    
    def update(self, new_recent_kv):
        assert self.is_prefilled is True
        new_recent_keys, new_recent_values = new_recent_kv
        self.recent_kv = (
            torch.cat([self.recent_kv[0], new_recent_keys], dim=1), 
            torch.cat([self.recent_kv[1], new_recent_values], dim=1)
        )
    
    
    def search(self, query, top_k):
        assert self.is_prefilled is True
        
        # selected_indices = search_vector_store(query, top_k)
        
        selected_keys = torch.index_select(self.context_kv[0], 1, selected_indices).cuda()
        selected_values = torch.index_select(self.context_kv[1], 1, selected_indices)
        
        return (
            torch.cat([self.sink_kv[0], selected_keys, self.recent_kv[0]], dim=1), 
            torch.cat([self.sink_kv[1], selected_values, self.recent_kv[1]], dim=1)
        )
    
    
class HeterogeneousAttention(SelfAttentionWithKVCache): # or: SelfAttentionWithHeterogeneousKVCache
    def __init__(self, hidden_size, layer_idx):
        super().__init__(hidden_size)
        self.heterogeneous_kv_cache = HeterogeneousKVCache(layer_idx)
        
    
    def forward(self, x):
        Q = self.Wq(x)
        if not self.heterogeneous_kv_cache.is_prefilled:
            K = self.Wk(x)
            V = self.Wk(x)
        else:
            (K, V) = self.heterogeneous_kv_cache.search(Q, top_k=16)
            
        
        attention_score = softmax(Q @ K.transpose(-1, -2) / self.scale, dim=-1)
        output = self.Wo(attention_score @ V)
        
        if not self.heterogeneous_kv_cache.is_prefilled:
            self.heterogeneous_kv_cache.offload((K, V))
        else:
            self.heterogeneous_kv_cache.update((K[:, -1, :], V[:, -1, :]))

    
    