# GQA

In [None]:
import torch
from torch import nn
from xxx import RoPEEmbedding # 假设的RoPE模块

class GroupedQueryAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads, num_groups, max_seq_len, dropout=0.0):
        super().__init__()
        assert hidden_dim % num_heads == 0 "hidden_dim must be divisible by num_heads"
        assert num_heads % num_groups == 0 "num_heads must be divisible by num_groups"

        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        self.num_groups = num_groups
        self.num_q_per_kv = num_heads // num_groups 
        self.scale = self.head_dim ** -0.5

        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, self.head_dim * num_groups)
        self.v_proj = nn.Linear(hidden_dim, self.head_dim * num_groups)
        self.o_proj = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.rope = RoPEEmbedding(self.head_dim, max_seq_len)


    def forward(self, x, mask=None):
        batch_size = x.shape[0]

        Q = self.q_proj(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(x).view(batch_size, -1, self.num_groups, self.head_dim).transpose(1, 2)
        V = self.v_proj(x).view(batch_size, -1, self.num_groups, self.head_dim).transpose(1, 2)

        # copy K and V for each query in the group
        K = K.reapet_interleave(self.num_q_per_kv, dim=1)
        V = V.repeat_interleave(self.num_q_per_kv, dim=1)

        Q = self.rope(Q)
        K = self.rope(K)

        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
        if mask:
            attn_scores = attn_scores.masked_fill(mask, float('-inf'))
        attn_scores = torch.softmax(attn_scores, dim=-1)
        attn_scores = self.dropout_attn(attn_scores)
       
        output = torch.matmul(attn_scores, V)
        output = output.transpose(1, 2).reshape(batch_size, -1, self.hidden_dim)
        output = self.o_proj(output)

        return output, attn_scores