# 实现一个 GQA

In [1]:
import torch
import torch.nn as nn
import math

In [None]:
class GroupQueryAttention(nn.Module):
    def __init__(self, hidden_dim: int, head_nums: int, key_value_nums: int, block_size: int = 512 ,dropout=0.1):
        super().__init__()

        assert hidden_dim % head_nums == 0
        assert head_nums % key_value_nums == 0
        #初始化参数
        self.head_dim = hidden_dim // head_nums
        self.head_nums = head_nums
        self.key_value_nums = key_value_nums
        #qkv矩阵
        self.query = nn.Linear(hidden_dim, hidden_dim)
        self.key = nn.Linear(hidden_dim, key_value_nums * self.head_dim)  # 组数
        self.value = nn.Linear(hidden_dim, key_value_nums * self.head_dim)
        
        self.att_dropout = nn.Dropout(dropout)
        #注册 buffer
        self.register_buffer(
            'attention_mask',
            torch.tril(torch.ones(block_size, block_size))
        )
        #投射层
        self.proj = nn.Linear(hidden_dim, hidden_dim)
        
    def forward(self, X, att_mask=None):

        batch, seq_len, hidden_dim = X.size()
        #计算qkv的值
        q = self.query(X)
        k = self.key(X)
        v = self.value(X)

        #分头
        q_state = q.view(batch, seq_len, self.head_nums, self.head_dim).transpose(1, 2)
        k_state = k.view(batch, seq_len, self.key_value_nums, self.head_dim).transpose(1, 2)
        v_state = v.view(batch, seq_len, self.key_value_nums, self.head_dim).transpose(1, 2)

        # kv还要复制一次, repeat 广播
        k_state = k_state.repeat_interleave(self.head_nums // self.key_value_nums, dim=1)
        v_state = v_state.repeat_interleave(self.head_nums // self.key_value_nums, dim=1)

        attention_weight = q_state @ k_state.transpose(-1, -2) / math.sqrt(self.head_dim)
        #加mask
        attention_weight = attention_weight.masked_fill(
            self.attention_mask[:seq_len, :seq_len] == 0,
            float('-inf')
        )
        #softmax
        attention_weight = torch.softmax(attention_weight, dim=-1) 
        attention_weight = self.att_dropout(attention_weight) 

        output = attention_weight @ v_state
        output = output.transpose(1,2).contiguous().view(batch, seq_len, self.head_nums * self.head_dim)

        final_output = self.proj(output)

        return final_output


x = torch.rand(3, 2, 128)
net = GroupQueryAttention(128, 8, 4)
net(x).shape




torch.Size([3, 2, 128])