In [44]:
import torch
from torch import nn, Tensor
import math

In [45]:
batch = 10
seq_len = 1000
dim = 2048
n_heads = 32    # q head number
n_kv_heads = 8   # kv head number
head_dim = dim // n_heads
n_repeats = n_heads // n_kv_heads
head_dim, n_repeats

(64, 4)

# GroupHeadAttention

## init

In [46]:
q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
o_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
q_proj, k_proj, v_proj, o_proj

(Linear(in_features=2048, out_features=2048, bias=False),
 Linear(in_features=2048, out_features=512, bias=False),
 Linear(in_features=2048, out_features=512, bias=False),
 Linear(in_features=2048, out_features=2048, bias=False))

## x -> qkv

In [47]:
torch.manual_seed(0)
x = torch.randn(batch, seq_len, dim)
x.shape

torch.Size([10, 1000, 2048])

In [48]:
q: Tensor = q_proj(x)
k: Tensor = k_proj(x)
v: Tensor = v_proj(x)
q.shape, k.shape, v.shape

(torch.Size([10, 1000, 2048]),
 torch.Size([10, 1000, 512]),
 torch.Size([10, 1000, 512]))

## reshape & transpose

In [49]:
q = q.reshape(batch, seq_len, n_heads, head_dim)
k = k.reshape(batch, seq_len, n_kv_heads, head_dim)
v = v.reshape(batch, seq_len, n_kv_heads, head_dim)
q.shape, k.shape, v.shape

(torch.Size([10, 1000, 32, 64]),
 torch.Size([10, 1000, 8, 64]),
 torch.Size([10, 1000, 8, 64]))

In [50]:
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
q.shape, k.shape, v.shape

(torch.Size([10, 32, 1000, 64]),
 torch.Size([10, 8, 1000, 64]),
 torch.Size([10, 8, 1000, 64]))

## apply_rotary_emb

In [51]:
def apply_rotary_emb(x: Tensor) -> Tensor:
    return x

In [52]:
q = apply_rotary_emb(q)
k = apply_rotary_emb(k)

## repeat kv

In [53]:
# repeat_interleave 相邻重复, repeat是序列重复, 是不同的
torch.repeat_interleave(k, repeats=n_repeats, dim=1).shape

torch.Size([10, 32, 1000, 64])

In [54]:
def repeat_kv(x: Tensor, repeats: int) -> torch.Tensor:
    """torch.repeat_interleave(x, repeats=repeats, dim=1)"""
    bs, n_kv_heads, seq_len, head_dim = x.shape
    return x[:, :, None, :, :].expand(-1, -1, repeats, -1, -1).reshape(bs, -1, seq_len, head_dim)

In [55]:
torch.all(repeat_kv(k, n_repeats) == torch.repeat_interleave(k, repeats=n_repeats, dim=1))

tensor(True)

In [56]:
k = repeat_kv(k, n_repeats)
v = repeat_kv(v, n_repeats)
k.shape, v.shape

(torch.Size([10, 32, 1000, 64]), torch.Size([10, 32, 1000, 64]))

## attn

In [57]:
attn =  torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(head_dim)
attn.shape

torch.Size([10, 32, 1000, 1000])

In [58]:
attn = attn.softmax(dim=-1)
attn.shape

torch.Size([10, 32, 1000, 1000])

In [59]:
y = torch.matmul(attn, v)
y.shape

torch.Size([10, 32, 1000, 64])

## proj

In [60]:
y = y.transpose(1, 2)
y.shape

torch.Size([10, 1000, 32, 64])

In [61]:
y = y.reshape(batch, seq_len, dim)
y.shape

torch.Size([10, 1000, 2048])

In [62]:
y = o_proj(y)
y.shape

torch.Size([10, 1000, 2048])