In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from ignite import RMSNorm, RoPE, SwiGLU
from einops import rearrange, repeat
import math

In [51]:
class GroupedQueryAttn(nn.Module):
    # TODO: add masking
    def __init__(
        self,
        model_dim: int,
        n_query_heads: int,
        n_query_groups: int,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        assert (
            model_dim % n_query_heads == 0
        ), f"{model_dim=} is not divisible by {n_query_heads=}"
        assert (
            n_query_heads % n_query_groups == 0
        ), f"{n_query_heads=} is not divisible by {n_query_groups=}"

        self.model_dim = model_dim
        self.n_query_heads = n_query_heads
        self.n_query_groups = n_query_groups
        self.kv_heads_per_q_head = n_query_heads // n_query_groups

        self.q_head_dim = model_dim // n_query_heads
        self.kv_head_dim = model_dim // self.kv_heads_per_q_head

        self.fused_qkv = nn.Linear(model_dim, model_dim + self.kv_head_dim * 2)
        self.out_proj = nn.Linear(model_dim, model_dim)

        self.reset_parameters()

    def reset_parameters(self) -> None:
        init.xavier_uniform_(self.fused_qkv.weight)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        fused_proj = self.fused_qkv(x)

        q, k, v = torch.split(
            fused_proj, [self.model_dim, self.kv_head_dim, self.kv_head_dim], dim=-1
        )

        # split q, k, v in different heads
        q = rearrange(q, "B S (NH HD) -> B NH S HD", NH=self.n_query_heads)
        k = rearrange(k, "B S (NH HD) -> B NH HD S", NH=self.n_query_groups)
        v = rearrange(v, "B S (NH HD) -> B NH S HD", NH=self.n_query_groups)

        # create dupplicate views of shared k/v heads to align shapes (no additional memory)
        k = repeat(
            k, "B NH HD S -> B (NH repeat) HD S", repeat=self.kv_heads_per_q_head
        )
        v = repeat(
            v, "B NH S HD -> B (NH repeat) S HD", repeat=self.kv_heads_per_q_head
        )

        sim = q @ k / math.sqrt(self.q_head_dim)
        attn_scores = torch.softmax(sim, dim=-1)
        attn_outputs = rearrange(attn_scores @ v, "B NH S HD -> B S (NH HD)")
        return self.out_proj(attn_outputs)


B, S, D = 16, 1024, 2048
x = torch.randn((B, S, D))
n_query_heads = 8
n_query_groups = 4

attn = GroupedQueryAttn(D, n_query_heads, n_query_groups)
a = attn(x)