# GQA的手撕
话不多说，直接开始



In [2]:
# 每日一导
import torch
import torch.nn as nn
import torch.nn.functional as F
import math 

In [9]:
class GroupedQueryAttention(nn.Module):
    def __init__(self,d_model,n_heads,n_kv_heads):
        super().__init__()
        assert d_model % n_heads == 0
        assert n_heads % n_kv_heads == 0

        self.d_model = d_model
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.head_dim = d_model //n_heads
        # 计算每个分组的大小
        self.group_size = n_heads // n_kv_heads 

        # 初始化4个Linear
        self.w_q = nn.Linear(d_model,n_heads * self.head_dim)
        self.w_k = nn.Linear(d_model,n_kv_heads*self.head_dim)
        self.w_v = nn.Linear(d_model,n_kv_heads*self.head_dim)
        self.w_o = nn.Linear(n_heads * self.head_dim, d_model)
    
    def repeat_kv(self, x, n_rep):
        if n_rep == 1:
            return x
        batch_size, n_kv_heads, seq_len, head_dim = x.shape
        x = x[:, :, None, :, :]  # (B, n_kv_heads, 1, T, D)
        x = x.expand(batch_size, n_kv_heads, n_rep, seq_len, head_dim)
        return x.reshape(batch_size, n_kv_heads * n_rep, seq_len, head_dim)
        
    def forward(self,x,mask=None):
        batch_size,seq_len,_ =x.shape
        # 1 projection
        q = self.w_q(x).view(batch_size,seq_len,self.n_heads,self.head_dim).transpose(1,2)
        k = self.w_k(x).view(batch_size,seq_len,self.n_kv_heads,self.head_dim).transpose(1,2)
        v = self.w_v(x).view(batch_size,seq_len,self.n_kv_heads,self.head_dim).transpose(1,2)

        # 2 lift dim
        k = self.repeat_kv(k,self.group_size)
        v = self.repeat_kv(v,self.group_size)

        # 3 compute attn scroes
        attn_scores = torch.matmul(q,k.transpose(-2,-1)) / math.sqrt(self.head_dim)
        
        # 4 check if mask
        if mask is not None:
            attn_scores = attn_score.masked_fill(mask==0,-1e9)

        # 5 Normalization
        attn_prob = F.softmax(attn_scores,dim=-1)

        # 6 weighed sum
        output = torch.matmul(attn_prob,v)

        # 7 concat multi heads
        output =output.transpose(1,2).contiguous().view(batch_size,seq_len,-1)

        # 8 last linear
        return self.w_o(output)
    




In [10]:
# eval
model = GroupedQueryAttention(d_model=128,n_heads=8,n_kv_heads=2)
x = torch.randn(1,4,128)
_ = model(x)
print(model)
print(x.shape)
print(model(x).shape)

GroupedQueryAttention(
  (w_q): Linear(in_features=128, out_features=128, bias=True)
  (w_k): Linear(in_features=128, out_features=32, bias=True)
  (w_v): Linear(in_features=128, out_features=32, bias=True)
  (w_o): Linear(in_features=128, out_features=128, bias=True)
)
torch.Size([1, 4, 128])
torch.Size([1, 4, 128])
