总结一下就是：
* 参数多了num_key_value_heads，代表有多少kv头
* k_proj和v_proj的输出维度不再是hidden_size（num_heads * head_size），而是num_key_value_heads * head_size
* k，v做.view()的时候也要注意，需要用num_key_value_heads而不是num_heads
* 因为维度不匹配，k，v需要在dim=1（即num_key_value_heads维度）上做repeat_interleave，重复张量中的元素
* 最后再view时需要先contiguous使内存连续化，但reshape不用

In [1]:
import torch
import torch.nn as nn
import math

In [10]:
# 当num_key_value_heads == 1时，为MQA
# 忽略attention_mask，attention_dropout
class GroupQueryAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, num_key_value_heads):
        super().__init__()
        assert hidden_size % num_heads == 0
        assert num_heads % num_key_value_heads == 0
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_size = hidden_size // num_heads
        self.num_key_value_heads = num_key_value_heads

        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_size)
        self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_size)
        self.o_proj = nn.Linear(hidden_size, hidden_size)

    def forward(self, hidden_states, attention_mask=None, head_mask=None):
        batch_size, seq_len, _ = hidden_states.size()

        q_states = self.q_proj(hidden_states)
        k_states = self.k_proj(hidden_states)
        v_states = self.v_proj(hidden_states)

        q_states = q_states.view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)
        k_states = k_states.view(batch_size, seq_len, self.num_key_value_heads, self.head_size).transpose(1, 2)
        v_states = v_states.view(batch_size, seq_len, self.num_key_value_heads, self.head_size).transpose(1, 2)

        # k,v repeat for each group
        # torch will not broadcast the repeat dim automatically
        # repeat_interleave 沿着指定维度重复张量中的元素n次
        k_states = k_states.repeat_interleave(self.num_heads // self.num_key_value_heads, dim=1)
        v_states = v_states.repeat_interleave(self.num_heads // self.num_key_value_heads, dim=1)

        # compute attention scores
        attention_scores = torch.matmul(q_states, k_states.transpose(-1, -2)) / math.sqrt(self.head_size)
        if attention_mask is not None:
            attention_scores = attention_scores.masked_fill(attention_mask == 0, float("-inf"))
        attention_scores = torch.softmax(attention_scores, dim=-1)
        attention_scores = torch.matmul(attention_scores, v_states)

        output = self.o_proj(attention_scores.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_size))
        # output = self.o_proj(attention_scores.transpose(1, 2).view(batch_size, seq_len, self.hidden_size))

        return output

In [11]:
# 测试
x = torch.randn(3, 2, 128)
net = GroupQueryAttention(128, 8, 4)
net(x).shape

torch.Size([3, 2, 128])