In [107]:
import math
import torch
from torch import nn, Tensor
from torch.nn import functional as F

# GroupQueryAttention: `n_kv_heads < n_heads`

In [84]:
batch = 10
seq_len = 100
cache_len = 134
dim = 2048
n_heads = 32    # q head number
n_kv_heads = 8  # kv head number
head_dim = dim // n_heads
n_repeats = n_heads // n_kv_heads
head_dim, n_repeats

(64, 4)

In [85]:
q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
q_proj, k_proj, v_proj, o_proj

(Linear(in_features=2048, out_features=2048, bias=False),
 Linear(in_features=2048, out_features=512, bias=False),
 Linear(in_features=2048, out_features=512, bias=False),
 Linear(in_features=2048, out_features=2048, bias=False))

In [86]:
torch.manual_seed(0)
x = torch.randn(batch, seq_len, dim)
x.shape

torch.Size([10, 100, 2048])

# x -> qkv

In [87]:
q = q_proj(x)
k = k_proj(x)
v = v_proj(x)
q.shape, k.shape, v.shape

(torch.Size([10, 100, 2048]),
 torch.Size([10, 100, 512]),
 torch.Size([10, 100, 512]))

# reshape & transpose

In [88]:
q = q.reshape(batch, seq_len, n_heads, head_dim).transpose(1, 2)
k = k.reshape(batch, seq_len, n_kv_heads, head_dim).transpose(1, 2)
v = v.reshape(batch, seq_len, n_kv_heads, head_dim).transpose(1, 2)
q.shape, k.shape, v.shape

(torch.Size([10, 32, 100, 64]),
 torch.Size([10, 8, 100, 64]),
 torch.Size([10, 8, 100, 64]))

# apply_rotary_emb

In [89]:
def apply_rotary_emb(x: Tensor) -> Tensor:
    return x

In [90]:
q = apply_rotary_emb(q)
k = apply_rotary_emb(k)

# cat kvcache

In [91]:
torch.manual_seed(0)
k_cache = torch.randn(batch, n_kv_heads, cache_len, head_dim)
v_cache = torch.randn(batch, n_kv_heads, cache_len, head_dim)
k_cache.shape, v_cache.shape

(torch.Size([10, 8, 134, 64]), torch.Size([10, 8, 134, 64]))

In [92]:
k_with_cache = torch.cat((k_cache, k), dim=2)
v_with_cache = torch.cat((v_cache, v), dim=2)
k_with_cache.shape, v_with_cache.shape

(torch.Size([10, 8, 234, 64]), torch.Size([10, 8, 234, 64]))

# repeat n_kv_heads to n_heads

In [93]:
def repeat_kv(x: Tensor, n_repeats: int) -> Tensor:
    """torch.repeat_interleave(x, repeats=n_repeats, dim=1)"""
    if n_repeats == 1:
        return x

    batch, heads, seq_len, head_dim = x.shape
    return x[:, :, None, :, :].expand(batch, heads, n_repeats, seq_len, head_dim) \
        .reshape(batch, heads*n_repeats, seq_len, head_dim)

# 完 全 一 致
torch.all(repeat_kv(k_with_cache, n_repeats) == torch.repeat_interleave(k_with_cache, repeats=n_repeats, dim=1))

tensor(True)

In [94]:
k_with_cache = repeat_kv(k_with_cache, n_repeats)
v_with_cache = repeat_kv(v_with_cache, n_repeats)
k_with_cache.shape, v_with_cache.shape

(torch.Size([10, 32, 234, 64]), torch.Size([10, 32, 234, 64]))

# attn

In [95]:
attn = torch.matmul(q, k_with_cache.transpose(-1, -2)) / math.sqrt(head_dim)
attn.shape

torch.Size([10, 32, 100, 234])

## create mask

In [96]:
mask = torch.full((seq_len, seq_len), -torch.inf)
print(mask.shape)
mask

torch.Size([100, 100])


tensor([[-inf, -inf, -inf,  ..., -inf, -inf, -inf],
        [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
        [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
        ...,
        [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
        [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
        [-inf, -inf, -inf,  ..., -inf, -inf, -inf]])

In [98]:
# 保留右上角数据的矩阵
mask = torch.triu(mask, diagonal=1)
mask

tensor([[0., -inf, -inf,  ..., -inf, -inf, -inf],
        [0., 0., -inf,  ..., -inf, -inf, -inf],
        [0., 0., 0.,  ..., -inf, -inf, -inf],
        ...,
        [0., 0., 0.,  ..., 0., -inf, -inf],
        [0., 0., 0.,  ..., 0., 0., -inf],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [99]:
mask = torch.cat((torch.zeros((seq_len, cache_len)), mask), dim=1)
print(mask.shape)
mask

torch.Size([100, 234])


tensor([[0., 0., 0.,  ..., -inf, -inf, -inf],
        [0., 0., 0.,  ..., -inf, -inf, -inf],
        [0., 0., 0.,  ..., -inf, -inf, -inf],
        ...,
        [0., 0., 0.,  ..., 0., -inf, -inf],
        [0., 0., 0.,  ..., 0., 0., -inf],
        [0., 0., 0.,  ..., 0., 0., 0.]])

## add mask

In [100]:
if mask is not None:
    attn = attn + mask
attn.shape

torch.Size([10, 32, 100, 234])

In [101]:
# check attn add mask
print(attn[0, 0, :, cache_len:].shape)
attn[0, 0, :, cache_len:]

torch.Size([100, 100])


tensor([[-0.4858,    -inf,    -inf,  ...,    -inf,    -inf,    -inf],
        [-0.5556, -0.3688,    -inf,  ...,    -inf,    -inf,    -inf],
        [-0.7119,  0.0268,  0.0745,  ...,    -inf,    -inf,    -inf],
        ...,
        [-0.5759, -0.0379, -0.2373,  ..., -0.3782,    -inf,    -inf],
        [ 0.4594, -0.2608,  0.1086,  ..., -0.0264, -0.1705,    -inf],
        [-0.2287, -0.2571, -0.4704,  ..., -0.4710, -0.3822, -0.1035]],
       grad_fn=<SliceBackward0>)

In [103]:
attn = attn.softmax(dim=-1)
attn.shape

torch.Size([10, 32, 100, 234])

In [104]:
y = torch.matmul(attn, v_with_cache)
y.shape

torch.Size([10, 32, 100, 64])

# proj

In [105]:
y = y.transpose(1, 2).reshape(batch, seq_len, dim)
y.shape

torch.Size([10, 100, 2048])

In [106]:
y = o_proj(y)
y.shape

torch.Size([10, 100, 2048])