GQA相比于MHA，是减少了KV的生成，用少量的KV去计算不同Linear下的Q。  
而代码的本质是生成少量的KV，然后将其复制与Q等量，进行计算。  
也就是要保证kv_head_num可以被head_num整除 

MQA就是GQA的kv_head_num =1的特殊情况

In [13]:
import torch
from  torch import nn
from torch import functional as F
import math
# 这个代码省略了mask和dropout
class GruopAttention(nn.Module):
    def __init__(self,hidden_dim,head_num,kv_head_num):
        super().__init__()
        # 保证整除的关系
        assert hidden_dim % head_num == 0
        assert head_num % kv_head_num ==0

        self.hidden_dim = hidden_dim
        self.head_num = head_num
        self.kv_head_num = kv_head_num
        self.head_dim = hidden_dim // head_num

        # Q(batch,seq,hidden_dim) KV(batch,seq,kv_head_num * head_dim) 
        self.q_proj = nn.Linear(hidden_dim,hidden_dim)
        self.k_proj = nn.Linear(hidden_dim,kv_head_num * self.head_dim)
        self.v_proj = nn.Linear(hidden_dim, kv_head_num * self.head_dim)
        # 最终还是要输出相同的格式
        self.output = nn.Linear(hidden_dim,hidden_dim)

    def forward(self,X):
        batch,seq,_ = X.size()
        Q = self.q_proj(X)
        K = self.k_proj(X)
        V = self.v_proj(X)

        # (batch,seq,head_num,head_dim) -> (batch,head_num,seq,head_dim)
        q = Q.view(batch,seq,self.head_num,self.head_dim).transpose(1,2)
        
        k = K.view(batch,seq,self.kv_head_num,self.head_dim).transpose(1,2)
        v = V.view(batch,seq,self.kv_head_num,self.head_dim).transpose(1,2)

        # 对KV的维度进行repeat
        # (batch,kv_head_num,seq,head_dim) -> (batch,head_num,seq,head_dim)
        # repeat次数为 head_num // kv_head_num
        # 在维度1上进行repeat
        k = k.repeat_interleave(self.head_num // self.kv_head_num,dim = 1)
        v = v.repeat_interleave(self.head_num // self.kv_head_num,dim  =1)

        # 之后就是与MHA相同的操作
        # 注意指定softmax的维度
        attention_mid = torch.softmax((q @ k.transpose(-1,-2)) / math.sqrt(self.head_num),dim=-1)

        #中间的dropout省略

        #(batch,head_num,seq,head_dim) -> (batch,seq,head_num * head_dim)
        output_mid = ((attention_mid @ v)
            .transpose(1,2)
            .contiguous())
        output_mid = output_mid.view(batch,seq,self.hidden_dim)

        output = self.output(output_mid)
        return output

X = torch.rand(3,4,128)
net = GruopAttention(128,8,4)
output = net(X)
print(output.shape)
output



torch.Size([3, 4, 128])


tensor([[[-0.1449,  0.0688, -0.0461,  ...,  0.0805, -0.3760,  0.0840],
         [-0.1447,  0.0678, -0.0487,  ...,  0.0787, -0.3763,  0.0834],
         [-0.1440,  0.0697, -0.0493,  ...,  0.0791, -0.3751,  0.0828],
         [-0.1427,  0.0668, -0.0483,  ...,  0.0785, -0.3745,  0.0823]],

        [[-0.1132,  0.1124,  0.0269,  ...,  0.0236, -0.3805,  0.0164],
         [-0.1157,  0.1130,  0.0260,  ...,  0.0229, -0.3821,  0.0152],
         [-0.1129,  0.1101,  0.0265,  ...,  0.0241, -0.3835,  0.0160],
         [-0.1141,  0.1158,  0.0227,  ...,  0.0215, -0.3825,  0.0189]],

        [[-0.0658,  0.1407,  0.0395,  ...,  0.1017, -0.4091,  0.0375],
         [-0.0658,  0.1381,  0.0383,  ...,  0.1015, -0.4099,  0.0401],
         [-0.0691,  0.1389,  0.0382,  ...,  0.1002, -0.4114,  0.0376],
         [-0.0638,  0.1386,  0.0380,  ...,  0.1002, -0.4097,  0.0365]]],
       grad_fn=<ViewBackward0>)