In [37]:
from torch.nn import Linear, Module
from torch import softmax

## Basic Self Attention


- input X, shape is (bs, seqlen, hidden_size)
- weight Wq, Wk, Wv, shape is (hidden_size， hidden_size) 
- Q,K,V shape is (bs, seqlen, hidden_size)
- attention_score shape is (bs, seqlen, seqlen), second dim is each token, third dim is the score of all tokens against this token 
- attention_output shape is (bs, seqlen, hidden_size)

note: 
`torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)`
Applies an affine linear transformation to the incoming data: 
```y=xA^T+b```

to calculate Q K^T, need:
- transpose K on last 2 dims (seqlen and hidden_size)
- use matmul (@)

note:
`torch.matmul(input, other, *, out=None)`
The behavior depends on the dimensionality of the tensors as follows:
- If both tensors are 1-dimensional, the dot product (scalar) is returned.
- If both arguments are 2-dimensional, the matrix-matrix product is returned.
- If the first argument is 1-dimensional and the second argument is 2-dimensional, a 1 is prepended to its dimension for the purpose of the matrix multiply. After the matrix multiply, the prepended dimension is removed.
- If the first argument is 2-dimensional and the second argument is 1-dimensional, the matrix-vector product is returned.
- If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where N > 2), then a **batched matrix multiply** is returned. 




In [112]:
class SelfAttention(Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.Wq = Linear(hidden_size, hidden_size, bias=False)
        self.Wk = Linear(hidden_size, hidden_size, bias=False)
        self.Wv = Linear(hidden_size, hidden_size, bias=False)
        self.Wo = Linear(hidden_size, hidden_size, bias=False)
        self.scale = hidden_size ** (1/2)
    
    def forward(self, x):
        Q = self.Wq(x)
        K = self.Wk(x)
        V = self.Wv(x)
        
        attention_score = softmax(Q @ K.transpose(-1, -2) / self.scale, dim=-1)
        return self.Wo(attention_score @ V)

In [52]:
layer = SelfAttention(512)

In [53]:
layer(torch.rand(2, 32, 512))

tensor([[[ 1.3673e-01,  1.7024e-01, -1.0274e-02,  ...,  4.1232e-02,
          -9.1783e-02, -4.5934e-06],
         [ 1.3637e-01,  1.7086e-01, -1.0617e-02,  ...,  4.1984e-02,
          -9.0352e-02,  5.8779e-04],
         [ 1.3568e-01,  1.7027e-01, -1.0546e-02,  ...,  4.1324e-02,
          -9.0808e-02,  2.9186e-05],
         ...,
         [ 1.3546e-01,  1.6903e-01, -9.2836e-03,  ...,  4.1253e-02,
          -9.0734e-02, -1.0160e-03],
         [ 1.3652e-01,  1.6946e-01, -1.0524e-02,  ...,  4.1064e-02,
          -9.0825e-02, -6.5265e-04],
         [ 1.3561e-01,  1.6883e-01, -9.5603e-03,  ...,  4.0815e-02,
          -9.0803e-02, -9.8485e-04]],

        [[ 1.1687e-01,  2.2417e-01,  3.0953e-02,  ...,  2.4701e-02,
          -7.5829e-02, -1.6887e-02],
         [ 1.1527e-01,  2.2471e-01,  3.1047e-02,  ...,  2.4696e-02,
          -7.6358e-02, -1.7282e-02],
         [ 1.1688e-01,  2.2457e-01,  3.1124e-02,  ...,  2.5412e-02,
          -7.5537e-02, -1.6291e-02],
         ...,
         [ 1.1584e-01,  2

## MHA


- the Wq Qk Wv are still the shape (hidden_size, hidden_size) on init, but will reshape to (hidden_size, num_heads, head_dim) when calling forward 
- Q K V shape will be (batch_size, seqlen, num_heads, head_dim)
- Q K^T needs output of (batch_size, num_heads, seqlen, seqlen), so transpose seqlen and num_heads
- scale use head_dim, not hidden_size
- concat all heads after attention_score @ V (batch_size, num_heads, seqlen, head_dim) -> (batch_size, seqlen, num_heads * head_dim)


In [99]:
class MultiHeadSelfAttention(SelfAttention):
    def __init__(self, hidden_size, num_heads):
        super().__init__(hidden_size)
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.scale = self.head_dim ** (1/2)
        
    def forward(self, x):
        bs, seqlen, _ = x.shape
        Q = self.Wq(x).view(bs, seqlen, self.num_heads, self.head_dim)
        K = self.Wk(x).view(bs, seqlen, self.num_heads, self.head_dim)
        V = self.Wv(x).view(bs, seqlen, self.num_heads, self.head_dim)
        
        attention_score = softmax(Q.transpose(1,2) @ K.transpose(1,2).transpose(-1, -2) / self.scale, dim=-1)
        attention_output = (attention_score @ V.transpose(1,2)).transpose(1,2).contiguous().view(bs, seqlen, -1)
        return self.Wo(attention_output)

In [73]:
layer = MultiHeadSelfAttention(512, 8)

In [76]:
layer(torch.rand(2, 32, 512))

tensor([[[-0.0442, -0.0179, -0.0077,  ..., -0.1054, -0.1257, -0.0280],
         [-0.0439, -0.0180, -0.0084,  ..., -0.1055, -0.1254, -0.0277],
         [-0.0448, -0.0184, -0.0087,  ..., -0.1060, -0.1256, -0.0284],
         ...,
         [-0.0456, -0.0191, -0.0094,  ..., -0.1060, -0.1261, -0.0274],
         [-0.0442, -0.0183, -0.0090,  ..., -0.1062, -0.1251, -0.0286],
         [-0.0442, -0.0196, -0.0081,  ..., -0.1061, -0.1250, -0.0280]],

        [[-0.0682, -0.0174, -0.0067,  ..., -0.1155, -0.1311, -0.0311],
         [-0.0677, -0.0164, -0.0067,  ..., -0.1149, -0.1304, -0.0315],
         [-0.0682, -0.0174, -0.0074,  ..., -0.1163, -0.1311, -0.0319],
         ...,
         [-0.0688, -0.0165, -0.0064,  ..., -0.1146, -0.1304, -0.0309],
         [-0.0678, -0.0162, -0.0065,  ..., -0.1152, -0.1309, -0.0316],
         [-0.0681, -0.0164, -0.0068,  ..., -0.1150, -0.1300, -0.0312]]],
       grad_fn=<UnsafeViewBackward0>)

## GQA

- the Wq is the original size, Wk and Wv should be smaller
- Wq is reshaped as (hidden_size, num_heads, head_dim), Wk and Wv is reshaped as (hidden_size, num_kv_heads, head_dim)
- Q is (bs, seqlen, num_heads, head_dim), KV is  (bs, seqlen, num_kv_heads, head_dim), so need to interleave KV for group_size times, after repeating, KV is (bs, seqlen, num_heads, head_dim)
- attention_score is still (bs, num_heads, seqlen, seqlen), after matmul V^T, is (bs, num_heads, seqlen, head_dim)

note:
`torch.repeat_interleave(input, repeats, dim=None, *, output_size=None)`

- input (Tensor) – the input tensor.
- repeats (Tensor or int) – The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis.
- dim (int, optional) – The dimension along which to repeat values. By default, use the flattened input array, and return a flat output array.

```
>>> torch.repeat_interleave(y, 3, dim=1)
tensor([[1, 1, 1, 2, 2, 2],
        [3, 3, 3, 4, 4, 4]])
```


In [92]:
class GroupedQueryAttention(Module):
    def __init__(self, hidden_size, num_heads, num_kv_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.scale = self.head_dim ** (1/2)
        self.num_kv_heads = num_kv_heads
        self.group_size = num_heads // num_kv_heads
        
        self.Wq = Linear(hidden_size, hidden_size, bias=False)
        self.Wk = Linear(hidden_size, num_kv_heads * self.head_dim, bias=False)
        self.Wv = Linear(hidden_size, num_kv_heads * self.head_dim, bias=False)
        self.Wo = Linear(hidden_size, hidden_size, bias=False)
    
    def forward(self, x):
        bs, seqlen, _ = x.shape
        Q = self.Wq(x).view(bs, seqlen, self.num_heads, self.head_dim)
        K = self.Wk(x).view(bs, seqlen, self.num_kv_heads, self.head_dim)
        V = self.Wv(x).view(bs, seqlen, self.num_kv_heads, self.head_dim)
        
        K = K.repeat_interleave(self.group_size, dim=2)
        V = V.repeat_interleave(self.group_size, dim=2)
        
        attention_score = (Q.transpose(1,2) @ K.transpose(1,2).transpose(-1,-2) / self.scale)
        attention_output = (attention_score @ V.transpose(1,2)).transpose(1,2).contiguous().view(bs, seqlen, -1)
        return self.Wo(attention_output)
        

In [93]:
layer = GroupedQueryAttention(512, 32, 4)

In [98]:
layer(torch.rand(2, 32, 512))

tensor([[[ 0.2463, -0.3342, -0.3047,  ...,  0.2547, -0.0646,  1.3266],
         [-0.0288, -0.7129,  0.0258,  ...,  0.1388, -0.1351,  1.5268],
         [-0.0497, -0.6961,  0.1443,  ...,  0.3307, -0.3921,  1.3867],
         ...,
         [ 0.0371, -1.0292, -0.0592,  ...,  0.0646,  0.2182,  1.7566],
         [ 0.0039, -0.7354, -0.1269,  ..., -0.2023, -0.1142,  1.2810],
         [-0.0516, -0.8042, -0.5209,  ...,  0.1557, -0.6026,  1.0158]],

        [[ 0.0509, -1.1555,  0.1331,  ...,  0.5010, -0.3044,  1.6570],
         [-0.4592, -1.1015, -0.4584,  ...,  0.9297, -0.1316,  1.2649],
         [ 0.4306, -0.8634,  0.0858,  ...,  0.3441, -0.1953,  1.2412],
         ...,
         [-0.2646, -0.9186,  0.0988,  ...,  0.4202, -0.0237,  1.4110],
         [-0.3137, -0.5333, -0.4952,  ...,  0.4895, -0.0046,  1.2702],
         [ 0.0539, -0.9006,  0.4146,  ...,  0.3690, -0.0928,  1.2824]]],
       grad_fn=<UnsafeViewBackward0>)

## KV Cache

- the forward just need to handle 1 token at a time, means seqlen==1, so the input X shape is (bs, 1, hidden_size)
- past_kv = (past_key, past_value), both in shape of (bs, seqlen, hidden_size)
- for new kv, concat past and this token's on dim=1, so now the Q is (bs, 1, hidden_size), new KV is (bs, seqlen+1, hidden_size) 

In [113]:
class SelfAttentionWithKVCache(SelfAttention):
    
    def forward(self, x, cached_kv=None):
        bs, _, hidden_size = x.shape
        if cached_kv:
            cached_key, cached_value = cached_kv
        else:
            cached_key = torch.zeros(bs, 0, hidden_size)
            cached_value = torch.zeros(bs, 0, hidden_size)
        
        Q = self.Wq(x)
        K = self.Wk(x)
        V = self.Wv(x)
        
        cached_key = torch.cat([cached_key, K], dim=1)
        cached_value = torch.cat([cached_value, V], dim=1)
        
        attention_score = softmax(Q @ cached_key.transpose(-1, -2) / self.scale, dim=-1)
        output = self.Wo(attention_score @ cached_value)
        
        return output, (cached_key, cached_value)

In [114]:
layer = SelfAttentionWithKVCache(512)

In [116]:
output, kv_cache = layer(torch.rand(2, 32, 512))

In [117]:
output.shape

torch.Size([2, 32, 512])

In [120]:
kv_cache[0].shape, kv_cache[1].shape

(torch.Size([2, 32, 512]), torch.Size([2, 32, 512]))

In [121]:
output, kv_cache = layer(torch.rand(2, 1, 512), kv_cache)

In [122]:
output.shape

torch.Size([2, 1, 512])

In [123]:
kv_cache[0].shape, kv_cache[1].shape

(torch.Size([2, 33, 512]), torch.Size([2, 33, 512]))