In [19]:
#GQA的目的，为了减少缓存的访问，GQA是介于MultiHead Attention 和 MultiQuery Attention之间的一种平衡
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [20]:
class GroupQueryAttention(nn.Module):
    def __init__(self,hidden_dim,num_heads,nums_key_value_head,drop_rate=0.1,bias=False,reflect_matrix=True):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads # 注意力的头数量
        self.nums_key_value_head = nums_key_value_head# key value 的头数量
        self.drop_rate = drop_rate
        self.heads_dim = hidden_dim//num_heads # 每个注意力的头的维度
        self.bias = bias
        
        assert self.hidden_dim % self.num_heads == 0 #注意力的头数量得可以整除
        assert self.num_heads % self.nums_key_value_head == 0 #key value的头数量得可以整除,N 个 query 为一组
        
        self.query = nn.Linear(hidden_dim,self.num_heads*self.heads_dim,bias=self.bias)
        self.key = nn.Linear(hidden_dim,self.nums_key_value_head*self.heads_dim,bias=self.bias)
        self.value = nn.Linear(hidden_dim,self.nums_key_value_head*self.heads_dim,bias=self.bias)
        
        self.att_dropout = nn.Dropout(self.drop_rate)
        self.reflect_matrix=reflect_matrix
        self.outputs = nn.Linear(hidden_dim,hidden_dim) if reflect_matrix else None
    def forward(self,x,mask=None):
        b,s,hidden_dim = x.size()
        
        # q(b,s,hidden_dim)-> (b,s,num_heads,heads_dim)-> (b,num_heads,s,heads_dim)
        q = self.query(x).view(b,s,self.num_heads,self.heads_dim).transpose(1,2)
        # k(b,s,hidden_dim)-> (b,s,nums_key_value_head,heads_dim)-> (b,nums_key_value_head,s,heads_dim)
        k = self.key(x).view(b,s,self.nums_key_value_head,self.heads_dim).transpose(1,2)
        v = self.value(x).view(b,s,self.nums_key_value_head,self.heads_dim).transpose(1,2)
        
        #(b,nums_key_value_head,s,heads_dim)->(b,num_heads,s,heads_dim) 为了后续能够进行矩阵乘法计算
        k=k.repeat_interleave(self.num_heads//self.nums_key_value_head,dim=1)
        v=v.repeat_interleave(self.num_heads//self.nums_key_value_head,dim=1)
        #q@k.transpose(-1,-2) = (b,num_heads,s,s)
        attention_score = q@k.transpose(-1,-2)/math.sqrt(self.heads_dim)
        if mask is not None:
            attention_score = attention_score.masked_fill(mask==0,float('-inf'))
            
        attention_weight = torch.softmax(attention_score,dim=-1)
        attention_weight = self.att_dropout(attention_weight)
        
        #(b,num_heads,s,s)@(b,num_heads,s,heads_dim) = (b,num_heads,s,heads_dim)
        attention_weight = attention_weight@v
        
        outputs_mid = attention_weight.transpose(1,2).contiguous()
        outputs_mid = outputs_mid.view(b,s,-1)
        if self.reflect_matrix:
            result = self.outputs(outputs_mid)
        else:
            result = outputs_mid
        return result
        

In [21]:
if __name__ == "__main__":
    attention_mask = (
        torch.tensor(
            [
                [0, 1],
                [0, 0],
                [1, 0],
            ]
        )
        .unsqueeze(1)
        .unsqueeze(2)
        .expand(3, 8, 2, 2)
    )

    x = torch.rand(3, 2, 128)
    net = GroupQueryAttention(hidden_dim=128, num_heads = 8,nums_key_value_head=4)
    print(net(x, attention_mask).shape)

torch.Size([3, 2, 128])
