In [1]:
import torch
import torch.nn as nn
import math

In [2]:
class GQA(nn.Module):
    def __init__(self, query_heads, groups, head_dim, embed_dim):
        super().__init__()

        assert query_heads%groups == 0, "query heads must be multiple of groups"

        self.query_heads = query_heads
        self.groups = groups
        self.head_dim = head_dim
        self.embed_dim = embed_dim
        self.query_per_group = self.query_heads//self.groups

        self.q_proj = nn.Linear(self.embed_dim, self.query_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.embed_dim, self.groups * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.embed_dim, self.groups * self.head_dim, bias=False)

        self.output_proj = nn.Linear(self.query_heads * self.head_dim, self.embed_dim, bias=False)

    def forward(self, x, kv_x = None, attention_mask = None):
        batch, seq_len, _ = x.size()

        if kv_x is None:
            kv_x = x

        kv_batch, kv_seq_len, _ = kv_x.size()

        q = self.q_proj(x)
        k = self.k_proj(kv_x)
        v = self.v_proj(kv_x)

        q = q.view(batch, seq_len, self.query_heads, self.head_dim).permute(0, 2, 1, 3)
        k = k.view(batch, kv_seq_len, self.groups, self.head_dim).permute(0, 2, 1, 3)
        v = v.view(batch, kv_seq_len, self.groups, self.head_dim).permute(0, 2, 1, 3)

        # GROUPING
        # for k
        k_expanded = k.unsqueeze(2)
        k_expanded = k_expanded.repeat(1, 1, self.query_per_group, 1, 1)
        k_expanded = k_expanded.view(batch, self.query_heads, kv_seq_len, self.head_dim)

        # for v
        v_expanded = v.unsqueeze(2)
        v_expanded = v_expanded.repeat(1, 1, self.query_per_group, 1, 1)
        v_expanded = v_expanded.view(batch, self.query_heads, kv_seq_len, self.head_dim)

        att_scores = torch.matmul(q, k_expanded.transpose(-1, -2))
        att_scores = att_scores/math.sqrt(self.head_dim)

        if attention_mask is not None:
            att_scores = att_scores + attention_mask


        att_probs = torch.softmax(att_scores, dim=-1)

        context_layer = torch.matmul(att_probs, v_expanded)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

        context_layer = context_layer.view(batch, seq_len, self.query_heads*self.head_dim)

        result = self.output_proj(context_layer)

        return result
        

In [4]:
batch_size = 2
seq_len = 16
embed_dim = 64
head_dim = 8
query_heads = 8
groups = 4

x = torch.randn(batch_size, seq_len, embed_dim)

gqa = GQA(query_heads=query_heads, groups=groups, head_dim=head_dim, embed_dim=embed_dim)

output = gqa(x)

print("Input shape:", x.shape)
print("Output shape:", output.shape)

Input shape: torch.Size([2, 16, 64])
Output shape: torch.Size([2, 16, 64])
