In [162]:
import torch
from torch import nn, Tensor
import math

In [163]:
batch, seq_len, cache_len, num_heads, num_kv_heads, head_dim = 10, 34, 200, 32, 4, 64

# init x, cache_k, cache_v

In [164]:
x: Tensor = torch.randn(batch, seq_len, num_heads * head_dim)
k_cache: Tensor = torch.randn(batch, num_kv_heads, cache_len, head_dim)
v_cache: Tensor = torch.randn(batch, num_kv_heads, cache_len, head_dim)
x.shape, k_cache.shape, v_cache.shape

(torch.Size([10, 34, 2048]),
 torch.Size([10, 4, 200, 64]),
 torch.Size([10, 4, 200, 64]))

# init linear

In [165]:
q_proj: nn.Module = nn.Linear(num_heads * head_dim, num_heads * head_dim, bias=False)
k_proj: nn.Module = nn.Linear(num_heads * head_dim, num_kv_heads * head_dim, bias=False)
v_proj: nn.Module = nn.Linear(num_heads * head_dim, num_kv_heads * head_dim, bias=False)
o_proj: nn.Module = nn.Linear(num_heads * head_dim, num_heads * head_dim, bias=False)
q_proj, k_proj, v_proj, o_proj

(Linear(in_features=2048, out_features=2048, bias=False),
 Linear(in_features=2048, out_features=256, bias=False),
 Linear(in_features=2048, out_features=256, bias=False),
 Linear(in_features=2048, out_features=2048, bias=False))

# x -> qkv

In [166]:
q: Tensor = q_proj(x)
k: Tensor = k_proj(x)
v: Tensor = v_proj(x)
q.shape, k.shape, v.shape

(torch.Size([10, 34, 2048]),
 torch.Size([10, 34, 256]),
 torch.Size([10, 34, 256]))

# reshape

In [167]:
q = q.reshape(batch, seq_len, num_heads, head_dim).transpose(1, 2)
k = k.reshape(batch, seq_len, num_kv_heads, head_dim).transpose(1, 2)
v = v.reshape(batch, seq_len, num_kv_heads, head_dim).transpose(1, 2)
q.shape, k.shape, v.shape

(torch.Size([10, 32, 34, 64]),
 torch.Size([10, 4, 34, 64]),
 torch.Size([10, 4, 34, 64]))

# apply_rotary_emb

In [168]:
def apply_rotary_emb(x: Tensor) -> Tensor:
    return x

In [169]:
q = apply_rotary_emb(q)
k = apply_rotary_emb(k)

# cat cache kv

In [170]:
k_cat = torch.cat([k_cache, k], dim=2)
v_cat = torch.cat([v_cache, v], dim=2)
k_cat.shape, v_cat.shape

(torch.Size([10, 4, 234, 64]), torch.Size([10, 4, 234, 64]))

# repeat kv

In [171]:
n_repeats = num_heads // num_kv_heads
n_repeats

8

In [172]:
def repeat_kv(x: Tensor, n_repeats: int) -> Tensor:
    """torch.repeat_interleave(x, repeats=n_repeats, dim=1)"""
    if n_repeats == 1:
        return x

    batch, num_heads, seq_len, head_dim = x.shape
    return x[:, :, None, : ,:].expand(batch, num_heads, n_repeats, seq_len, head_dim) \
        .reshape(batch, num_heads*n_repeats, seq_len, head_dim)

# 完 全 一 致
torch.all(repeat_kv(x=k_cat, n_repeats=n_repeats)==torch.repeat_interleave(input=k_cat, repeats=n_repeats, dim=1))

tensor(True)

In [173]:
k_cat: Tensor = repeat_kv(x=k_cat, n_repeats=n_repeats)
v_cat: Tensor = repeat_kv(x=v_cat, n_repeats=n_repeats)
q.shape, k_cat.shape, v_cat.shape

(torch.Size([10, 32, 34, 64]),
 torch.Size([10, 32, 234, 64]),
 torch.Size([10, 32, 234, 64]))

# attn

In [174]:
attn: Tensor = q @ k_cat.transpose(-1, -2) / math.sqrt(head_dim)
attn.shape

torch.Size([10, 32, 34, 234])

# add mask

In [175]:
mask: Tensor = torch.full((seq_len, seq_len), -torch.inf)
mask

tensor([[-inf, -inf, -inf,  ..., -inf, -inf, -inf],
        [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
        [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
        ...,
        [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
        [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
        [-inf, -inf, -inf,  ..., -inf, -inf, -inf]])

In [176]:
# 保留左下角数据的矩阵
mask = torch.triu(mask, diagonal=1)
mask

tensor([[0., -inf, -inf,  ..., -inf, -inf, -inf],
        [0., 0., -inf,  ..., -inf, -inf, -inf],
        [0., 0., 0.,  ..., -inf, -inf, -inf],
        ...,
        [0., 0., 0.,  ..., 0., -inf, -inf],
        [0., 0., 0.,  ..., 0., 0., -inf],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [177]:
mask = torch.cat([torch.zeros(seq_len, cache_len), mask], dim=1)
print(mask.shape)
mask

torch.Size([34, 234])


tensor([[0., 0., 0.,  ..., -inf, -inf, -inf],
        [0., 0., 0.,  ..., -inf, -inf, -inf],
        [0., 0., 0.,  ..., -inf, -inf, -inf],
        ...,
        [0., 0., 0.,  ..., 0., -inf, -inf],
        [0., 0., 0.,  ..., 0., 0., -inf],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [178]:
attn[0, 0]

tensor([[ 0.5908,  0.3918,  0.1409,  ..., -0.1939,  0.1460,  0.3004],
        [ 0.5672,  0.4411, -0.1403,  ..., -0.3908, -0.2452,  0.3271],
        [-1.0523,  0.7536,  0.0435,  ..., -0.4631,  0.3748, -0.4729],
        ...,
        [ 0.4367, -0.4639,  1.0559,  ...,  0.0061, -0.0058,  0.1531],
        [-0.1250,  0.3146, -1.2364,  ...,  0.7806,  0.1324,  0.5016],
        [-1.3604,  0.2627, -1.5977,  ..., -0.2114, -0.0799, -0.0356]],
       grad_fn=<SelectBackward0>)

In [179]:
attn += mask
attn[0, 0]

tensor([[ 0.5908,  0.3918,  0.1409,  ...,    -inf,    -inf,    -inf],
        [ 0.5672,  0.4411, -0.1403,  ...,    -inf,    -inf,    -inf],
        [-1.0523,  0.7536,  0.0435,  ...,    -inf,    -inf,    -inf],
        ...,
        [ 0.4367, -0.4639,  1.0559,  ...,  0.0061,    -inf,    -inf],
        [-0.1250,  0.3146, -1.2364,  ...,  0.7806,  0.1324,    -inf],
        [-1.3604,  0.2627, -1.5977,  ..., -0.2114, -0.0799, -0.0356]],
       grad_fn=<SelectBackward0>)

# softmax

In [180]:
attn = torch.softmax(attn, dim=-1)
attn[0, 0]

tensor([[0.0075, 0.0062, 0.0048,  ..., 0.0000, 0.0000, 0.0000],
        [0.0072, 0.0064, 0.0036,  ..., 0.0000, 0.0000, 0.0000],
        [0.0014, 0.0086, 0.0042,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0056, 0.0023, 0.0105,  ..., 0.0037, 0.0000, 0.0000],
        [0.0033, 0.0051, 0.0011,  ..., 0.0082, 0.0043, 0.0000],
        [0.0010, 0.0050, 0.0008,  ..., 0.0031, 0.0035, 0.0037]],
       grad_fn=<SelectBackward0>)

# apply attn

In [181]:
x = attn @ v_cat
x.shape

torch.Size([10, 32, 34, 64])

# transpose

In [182]:
x = x.transpose(1, 2).reshape(batch, seq_len, num_heads * head_dim)
x.shape

torch.Size([10, 34, 2048])

# o_proj

In [183]:
x = o_proj(x)
x.shape

torch.Size([10, 34, 2048])

# 封装

In [184]:
class Attention(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        num_kv_heads: int,
        max_batch_size: int,
        max_seq_len: int,
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, "dim 必须被 num_heads 整除"
        assert num_heads % num_kv_heads == 0, "num_heads 必须被 num_kv_heads 整除"
        self.dim = dim
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = dim // num_heads
        self.n_repeats = num_heads // num_kv_heads
        self.max_batch_size = max_batch_size
        self.max_seq_len = max_seq_len

        self.q_proj = nn.Linear(dim, num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(dim, num_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(dim, num_kv_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(num_heads * self.head_dim, dim, bias=False)

        self.k_cache = torch.zeros(max_batch_size, num_kv_heads, max_seq_len, head_dim)
        self.v_cache = torch.zeros(max_batch_size, num_kv_heads, max_seq_len, head_dim)

    def forward(
        self,
        x: Tensor,
        start_pos: int,
        mask: Tensor | None = None,
    ) -> Tensor:
        batch, seq_len, _ = x.shape
        assert start_pos + seq_len <= self.max_seq_len, "当前对话长度超过缓存长度"

        # x -> qkv
        q: Tensor = self.q_proj(x)
        k: Tensor = self.k_proj(x)
        v: Tensor = self.v_proj(x)

        # reshape and transpose
        q = q.reshape(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.reshape(batch, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        v = v.reshape(batch, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)

        # rope
        q = apply_rotary_emb(q)
        k = apply_rotary_emb(k)

        # update cache kv
        self.k_cache[:batch, :, start_pos:start_pos + seq_len] = k
        self.v_cache[:batch, :, start_pos:start_pos + seq_len] = v

        # cat kv
        k = torch.cat([self.k_cache[:batch, :, :start_pos], k], dim=2)
        v = torch.cat([self.v_cache[:batch, :, :start_pos], v], dim=2)

        # repeat kv
        k = repeat_kv(k, self.n_repeats)
        v = repeat_kv(v, self.n_repeats)
        # print(q.shape, k.shape, v.shape)

        # attn
        attn: Tensor = q @ k.transpose(-1, -2) / math.sqrt(self.head_dim)
        if mask is not None:
            attn += mask

        # softmax
        attn = torch.softmax(attn, dim=-1)

        # apply attn
        x = attn @ v
        x = x.transpose(1, 2).reshape(batch, seq_len, self.dim)
        x = self.o_proj(x)

        return x

In [185]:
attention = Attention(
    dim=num_heads * head_dim,
    num_heads=num_heads,
    num_kv_heads=num_kv_heads,
    max_batch_size=128,
    max_seq_len=8192
)
attention

Attention(
  (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
  (k_proj): Linear(in_features=2048, out_features=256, bias=False)
  (v_proj): Linear(in_features=2048, out_features=256, bias=False)
  (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
)

In [186]:
x.shape

torch.Size([10, 34, 2048])

In [187]:
start_pos = 1000

In [188]:
mask: Tensor = torch.full((seq_len, seq_len), -torch.inf)
mask

tensor([[-inf, -inf, -inf,  ..., -inf, -inf, -inf],
        [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
        [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
        ...,
        [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
        [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
        [-inf, -inf, -inf,  ..., -inf, -inf, -inf]])

In [189]:
mask = torch.triu(mask, diagonal=1)
mask

tensor([[0., -inf, -inf,  ..., -inf, -inf, -inf],
        [0., 0., -inf,  ..., -inf, -inf, -inf],
        [0., 0., 0.,  ..., -inf, -inf, -inf],
        ...,
        [0., 0., 0.,  ..., 0., -inf, -inf],
        [0., 0., 0.,  ..., 0., 0., -inf],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [190]:
mask = torch.cat([torch.zeros(seq_len, start_pos), mask], dim=1)
print(mask.shape)
mask

torch.Size([34, 1034])


tensor([[0., 0., 0.,  ..., -inf, -inf, -inf],
        [0., 0., 0.,  ..., -inf, -inf, -inf],
        [0., 0., 0.,  ..., -inf, -inf, -inf],
        ...,
        [0., 0., 0.,  ..., 0., -inf, -inf],
        [0., 0., 0.,  ..., 0., 0., -inf],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [191]:
y: Tensor = attention(x, start_pos=start_pos, mask=mask)
print(y.shape)

torch.Size([10, 34, 2048])
