In [19]:
import torch
from torch.nn import Linear, Module
from torch import softmax
import time

## Basic Self Attention


- input X, shape is (bs, seqlen, hidden_size)
- weight Wq, Wk, Wv, shape is (hidden_size， hidden_size) 
- Q,K,V shape is (bs, seqlen, hidden_size)
- attention_score shape is (bs, seqlen, seqlen), second dim is each token, third dim is the score of all tokens against this token 
- attention_output shape is (bs, seqlen, hidden_size)

note: 
`torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)`
Applies an affine linear transformation to the incoming data: 
```y=xA^T+b```

to calculate Q K^T, need:
- transpose K on last 2 dims (seqlen and hidden_size)
- use matmul (@)

note:
`torch.matmul(input, other, *, out=None)`
The behavior depends on the dimensionality of the tensors as follows:
- If both tensors are 1-dimensional, the dot product (scalar) is returned.
- If both arguments are 2-dimensional, the matrix-matrix product is returned.
- If the first argument is 1-dimensional and the second argument is 2-dimensional, a 1 is prepended to its dimension for the purpose of the matrix multiply. After the matrix multiply, the prepended dimension is removed.
- If the first argument is 2-dimensional and the second argument is 1-dimensional, the matrix-vector product is returned.
- If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where N > 2), then a **batched matrix multiply** is returned. 




In [6]:
class SelfAttention(Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.Wq = Linear(hidden_size, hidden_size, bias=False)
        self.Wk = Linear(hidden_size, hidden_size, bias=False)
        self.Wv = Linear(hidden_size, hidden_size, bias=False)
        self.Wo = Linear(hidden_size, hidden_size, bias=False)
        self.scale = hidden_size ** (1/2)
    
    def forward(self, x):
        Q = self.Wq(x)
        K = self.Wk(x)
        V = self.Wv(x)
        
        attention_score = softmax(Q @ K.transpose(-1, -2) / self.scale, dim=-1)
        return self.Wo(attention_score @ V)

In [7]:
layer = SelfAttention(512)

In [8]:
layer(torch.rand(2, 32, 512))

tensor([[[-0.0395, -0.0129, -0.1174,  ...,  0.3310, -0.0522,  0.2033],
         [-0.0397, -0.0138, -0.1185,  ...,  0.3312, -0.0536,  0.2036],
         [-0.0392, -0.0130, -0.1186,  ...,  0.3306, -0.0531,  0.2035],
         ...,
         [-0.0401, -0.0143, -0.1181,  ...,  0.3309, -0.0520,  0.2033],
         [-0.0388, -0.0135, -0.1184,  ...,  0.3307, -0.0543,  0.2032],
         [-0.0397, -0.0141, -0.1185,  ...,  0.3302, -0.0529,  0.2038]],

        [[-0.0096, -0.0171, -0.1130,  ...,  0.3027, -0.0272,  0.2473],
         [-0.0110, -0.0163, -0.1134,  ...,  0.3029, -0.0267,  0.2478],
         [-0.0087, -0.0172, -0.1129,  ...,  0.3021, -0.0270,  0.2475],
         ...,
         [-0.0105, -0.0171, -0.1128,  ...,  0.3028, -0.0264,  0.2474],
         [-0.0108, -0.0169, -0.1136,  ...,  0.3027, -0.0270,  0.2481],
         [-0.0103, -0.0157, -0.1135,  ...,  0.3021, -0.0265,  0.2476]]],
       grad_fn=<UnsafeViewBackward0>)

## MHA


- the Wq Qk Wv are still the shape (hidden_size, hidden_size) on init, but will reshape to (hidden_size, num_heads, head_dim) when calling forward 
- Q K V shape will be (batch_size, seqlen, num_heads, head_dim)
- Q K^T needs output of (batch_size, num_heads, seqlen, seqlen), so transpose seqlen and num_heads
- scale use head_dim, not hidden_size
- concat all heads after attention_score @ V (batch_size, num_heads, seqlen, head_dim) -> (batch_size, seqlen, num_heads * head_dim)


In [9]:
class MultiHeadSelfAttention(SelfAttention):
    def __init__(self, hidden_size, num_heads):
        super().__init__(hidden_size)
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.scale = self.head_dim ** (1/2)
        
    def forward(self, x):
        bs, seqlen, _ = x.shape
        Q = self.Wq(x).view(bs, seqlen, self.num_heads, self.head_dim)
        K = self.Wk(x).view(bs, seqlen, self.num_heads, self.head_dim)
        V = self.Wv(x).view(bs, seqlen, self.num_heads, self.head_dim)
        
        attention_score = softmax(Q.transpose(1,2) @ K.transpose(1,2).transpose(-1, -2) / self.scale, dim=-1)
        attention_output = (attention_score @ V.transpose(1,2)).transpose(1,2).contiguous().view(bs, seqlen, -1)
        return self.Wo(attention_output)

In [10]:
layer = MultiHeadSelfAttention(512, 8)

In [11]:
layer(torch.rand(2, 32, 512))

tensor([[[-0.2776,  0.0994,  0.1848,  ..., -0.0187,  0.1510,  0.0623],
         [-0.2770,  0.0991,  0.1851,  ..., -0.0192,  0.1498,  0.0625],
         [-0.2776,  0.0996,  0.1854,  ..., -0.0183,  0.1508,  0.0613],
         ...,
         [-0.2774,  0.0986,  0.1857,  ..., -0.0178,  0.1503,  0.0627],
         [-0.2767,  0.0993,  0.1852,  ..., -0.0181,  0.1508,  0.0625],
         [-0.2773,  0.0991,  0.1843,  ..., -0.0180,  0.1513,  0.0621]],

        [[-0.2816,  0.1062,  0.1920,  ...,  0.0167,  0.1328,  0.0427],
         [-0.2811,  0.1061,  0.1924,  ...,  0.0166,  0.1315,  0.0422],
         [-0.2813,  0.1051,  0.1923,  ...,  0.0174,  0.1335,  0.0436],
         ...,
         [-0.2812,  0.1064,  0.1926,  ...,  0.0161,  0.1320,  0.0439],
         [-0.2806,  0.1056,  0.1930,  ...,  0.0171,  0.1321,  0.0426],
         [-0.2810,  0.1063,  0.1923,  ...,  0.0168,  0.1323,  0.0431]]],
       grad_fn=<UnsafeViewBackward0>)

## GQA

- the Wq is the original size, Wk and Wv should be smaller
- Wq is reshaped as (hidden_size, num_heads, head_dim), Wk and Wv is reshaped as (hidden_size, num_kv_heads, head_dim)
- Q is (bs, seqlen, num_heads, head_dim), KV is  (bs, seqlen, num_kv_heads, head_dim), so need to interleave KV for group_size times, after repeating, KV is (bs, seqlen, num_heads, head_dim)
- attention_score is still (bs, num_heads, seqlen, seqlen), after matmul V^T, is (bs, num_heads, seqlen, head_dim)

note:
`torch.repeat_interleave(input, repeats, dim=None, *, output_size=None)`

- input (Tensor) – the input tensor.
- repeats (Tensor or int) – The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis.
- dim (int, optional) – The dimension along which to repeat values. By default, use the flattened input array, and return a flat output array.

```
>>> torch.repeat_interleave(y, 3, dim=1)
tensor([[1, 1, 1, 2, 2, 2],
        [3, 3, 3, 4, 4, 4]])
```


In [12]:
class GroupedQueryAttention(Module):
    def __init__(self, hidden_size, num_heads, num_kv_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.scale = self.head_dim ** (1/2)
        self.num_kv_heads = num_kv_heads
        self.group_size = num_heads // num_kv_heads
        
        self.Wq = Linear(hidden_size, hidden_size, bias=False)
        self.Wk = Linear(hidden_size, num_kv_heads * self.head_dim, bias=False)
        self.Wv = Linear(hidden_size, num_kv_heads * self.head_dim, bias=False)
        self.Wo = Linear(hidden_size, hidden_size, bias=False)
    
    def forward(self, x):
        bs, seqlen, _ = x.shape
        Q = self.Wq(x).view(bs, seqlen, self.num_heads, self.head_dim)
        K = self.Wk(x).view(bs, seqlen, self.num_kv_heads, self.head_dim)
        V = self.Wv(x).view(bs, seqlen, self.num_kv_heads, self.head_dim)
        
        K = K.repeat_interleave(self.group_size, dim=2)
        V = V.repeat_interleave(self.group_size, dim=2)
        
        attention_score = (Q.transpose(1,2) @ K.transpose(1,2).transpose(-1,-2) / self.scale)
        attention_output = (attention_score @ V.transpose(1,2)).transpose(1,2).contiguous().view(bs, seqlen, -1)
        return self.Wo(attention_output)
        

In [13]:
layer = GroupedQueryAttention(512, 32, 4)

In [14]:
layer(torch.rand(2, 32, 512))

tensor([[[ 0.2398, -0.2169,  0.0244,  ...,  0.5646, -0.1241,  0.1939],
         [-0.0762, -0.0269,  0.1034,  ...,  0.3262, -0.0654,  0.0542],
         [-0.0948,  0.1091,  0.4631,  ...,  0.3258, -0.3812,  0.3509],
         ...,
         [-0.0408, -0.3334,  0.0882,  ...,  0.8266, -0.1313,  0.3886],
         [-0.0590,  0.1535,  0.2616,  ...,  0.0895, -0.4499,  0.1498],
         [ 0.1171,  0.0685,  0.0594,  ...,  0.7256, -0.0245, -0.1301]],

        [[-0.1654,  0.1116,  0.8133,  ...,  0.2510, -0.5537,  0.3507],
         [ 0.3660, -0.1932,  0.4336,  ...,  0.7051,  0.1167,  0.2321],
         [ 0.1229,  0.0439,  0.5826,  ...,  0.4699, -0.0516,  0.3244],
         ...,
         [ 0.3033,  0.1550,  0.3179,  ...,  0.5955, -0.0863,  0.2315],
         [ 0.3834,  0.2093, -0.5854,  ..., -0.0066, -0.1727,  0.3564],
         [ 0.3823, -0.1722,  0.2567,  ...,  0.3105, -0.3939,  0.4147]]],
       grad_fn=<UnsafeViewBackward0>)

## KV Cache

- the forward just need to handle 1 token at a time, means seqlen==1, so the input X shape is (bs, 1, hidden_size)
- past_kv = (past_key, past_value), both in shape of (bs, seqlen, hidden_size)
- for new kv, concat past and this token's on dim=1, so now the Q is (bs, 1, hidden_size), new KV is (bs, seqlen+1, hidden_size) 

In [15]:
class SelfAttentionWithKVCache(SelfAttention):
    
    def forward(self, x, cached_kv=None):
        bs, _, hidden_size = x.shape
        if cached_kv:
            cached_key, cached_value = cached_kv
        else:
            cached_key = torch.zeros(bs, 0, hidden_size)
            cached_value = torch.zeros(bs, 0, hidden_size)
        
        Q = self.Wq(x)
        K = self.Wk(x)
        V = self.Wv(x)
        
        cached_key = torch.cat([cached_key, K], dim=1)
        cached_value = torch.cat([cached_value, V], dim=1)
        
        attention_score = softmax(Q @ cached_key.transpose(-1, -2) / self.scale, dim=-1)
        output = self.Wo(attention_score @ cached_value)
        
        return output, (cached_key, cached_value)

In [16]:
layer = SelfAttentionWithKVCache(512)

### Time with KV Cache

In [24]:
start_time = time.time()
output, kv_cache = layer(torch.rand(2, 4096, 512))
prefilling_time_cost = time.time() - start_time
print(prefilling_time_cost)

0.2263050079345703


In [25]:
for _ in range(5):
    start_time = time.time()
    output, kv_cache = layer(torch.rand(2, 1, 512), kv_cache)
    decoding_time_cost = time.time() - start_time
    print(decoding_time_cost)
    print(output.shape, kv_cache[0].shape, kv_cache[1].shape)

0.021359920501708984
torch.Size([2, 1, 512]) torch.Size([2, 4097, 512]) torch.Size([2, 4097, 512])
0.013353347778320312
torch.Size([2, 1, 512]) torch.Size([2, 4098, 512]) torch.Size([2, 4098, 512])
0.004659891128540039
torch.Size([2, 1, 512]) torch.Size([2, 4099, 512]) torch.Size([2, 4099, 512])
0.004498958587646484
torch.Size([2, 1, 512]) torch.Size([2, 4100, 512]) torch.Size([2, 4100, 512])
0.006295204162597656
torch.Size([2, 1, 512]) torch.Size([2, 4101, 512]) torch.Size([2, 4101, 512])


### Time without KV Cache


In [28]:
for _ in range(5):
    start_time = time.time()
    output, _ = layer(torch.rand(2, 4096 + _, 512), None)
    decoding_time_cost = time.time() - start_time
    print(decoding_time_cost)
    print(output.shape)

0.1712038516998291
torch.Size([2, 4096, 512])
0.15407514572143555
torch.Size([2, 4097, 512])
0.154710054397583
torch.Size([2, 4098, 512])
0.14727115631103516
torch.Size([2, 4099, 512])
0.14984130859375
torch.Size([2, 4100, 512])


## Heterogeneous Attention


- the code demo is based on basic self attention, instead of the original implementation on GQA


In [None]:
class HeterogeneousKVCache:
    def __init__(self, layer_idx):
        self.layer_idx = layer_idx
        self.sink = None
        self.context = None
        self.recency = None
        
    
    def build(self):
        pass
    
    def offload(self):
        pass
    
    def search(self, q_states, top_k):
        # ... mock some output here
        selected_indices = list(range(top_k))
        return selected_indices
    
    
class HeterogeneousAttention(SelfAttentionWithKVCache): # or: SelfAttentionWithHeterogeneousKVCache
    def __init__(self, hidden_size, layer_idx):
        super().__init__(hidden_size)
        self.heterogeneous_kv_cache = HeterogeneousKVCache(layer_idx)
        
    
    def forward(self, x):
        if
    
    