In [None]:
from dataclasses import dataclass 
from torch import nn 
import torch.nn.functional as F 
import math, torch

In [None]:
@dataclass 
class ModelArgs:
    n_heads: int = 32
    dim: int = 4096
    n_kv_heads:int = 8
    

In [None]:
def repeat_kv(x, reps):
    bsz, seqlen, n_heads, dims = x.shape
    if reps == 1:
        return x 
    return (x[:,:,:,None,:]
            .expand(bsz, seqlen, n_heads, reps, dims)
            .reshape(bsz, seqlen, n_heads*reps, dims))

def apply_rope(k,v):
    return k,v

In [None]:
class GQA(nn.Module):
    def __init__(self, args:ModelArgs):
        super().__init__()
        self.args = args
        head_dim = args.dim // args.n_heads
        self.head_dim = head_dim
        self.wq = nn.Parameter(torch.FloatTensor((args.dim, args.n_heads*head_dim)))
        self.wk = nn.Parameter(torch.FloatTensor((args.dim, args.n_kv_heads*head_dim)))
        self.wv = nn.Parameter(torch.FloatTensor((args.dim, args.n_kv_heads*head_dim)))
        self.wo = nn.Parameter(torch.FloatTensor((args.n_kv_heads*head_dim, args.dim)))
        self.reps = args.n_heads // args.n_kv_heads


    def forward(self, x, mask:None):
        bsz, seqlen, _ = x.shape
        xq = self.wq(x).reshape(bsz, seqlen, self.args.n_heads, self.head_dim).transpose(1,2)
        xk = self.wk(x).reshape(bsz, seqlen, self.args.n_kv_heads, self.head_dim).transpose(1,2)
        xv = self.wv(x).reshape(bsz, seqlen, self.args.n_kv_heads, self.head_dim).transpose(1,2)
        keys, values = apply_rope(keys, values)
        keys, values = repeat_kv(xk, self.reps), repeat_kv(xv, self.reps)

        #计算
        score = torch.matmul(xq, keys.transpose(2,3))
        if mask:
            score = score + mask
        score = F.softmax(score, dim=-1) / math.sqrt(self.head_dim)

        output = torch.matmul(score, values)
        output = output.transpose(1,2).contiguous().view(bsz, seqlen, self.args.n_kv_heads*self.head_dim)
        return self.wo(output)


