In [50]:
# GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
# Link to paper -> https://arxiv.org/pdf/2305.13245

In [31]:
import torch
from torch import nn
import math

class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, n_heads, num_kv_heads):
        
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        assert n_heads % num_kv_heads == 0, "n_heads must be divisible by num_kv_heads"

        self.d_model = d_model # sequence length
        self.n_heads = n_heads # number of attention heads, Q heads number
        self.num_kv_heads = num_kv_heads # number of heads of K,V

        self.head_dim = d_model // n_heads
        self.num_queries_per_kv = n_heads // num_kv_heads

        self.q_proj = nn.Linear(d_model, d_model)  # B,S,d_model  //  1,128,512

        self.kv_dim = num_kv_heads * self.head_dim  

        self.k_proj = nn.Linear(d_model, self.kv_dim)  # quarter the size of q_matrix (1,128,128)
        self.v_proj = nn.Linear(d_model, self.kv_dim)  # quarter the size of q_matrix (1,128,128)

        self.out_proj = nn.Linear(d_model, d_model)

        self.scale = math.sqrt(self.head_dim)

    def forward(self, x, mask=None):

        batch_size, seq_len, _ = x.shape

        Q = self.q_proj(x)  # (b, s, d_model)
        K = self.k_proj(x)  # (b, s, kv_dim) <- 80% smaller!
        V = self.v_proj(x)  # (b, s, kv_dim) <- 80% smaller!

        Q = Q.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        # (batch_size, n_heads, seq_len, head_dim)      (b,32,512,16)

        K = K.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        # (batch_size, num_kv_heads, seq_len, head_dim)     (b,8,512,16)

        # Repeat K and V to match the number of query heads (cheap memory operation, but massive compute saving)
        K = K.repeat_interleave(self.num_queries_per_kv, dim=1)
        V = V.repeat_interleave(self.num_queries_per_kv, dim=1)

        
        # now the same a standard Multi-Head Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn_weights = torch.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)
        
        # Concatenate heads
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.d_model)
        
        # Apply output projection
        output = self.out_proj(attn_output)
        
        return output
        

In [32]:
batch_size = 2
seq_len = 10
d_model = 512
n_heads = 32
num_kv_heads = 8


x = torch.randn(batch_size, seq_len, d_model)
x.shape

torch.Size([2, 10, 512])

In [33]:
gqa = GroupedQueryAttention(d_model=d_model, n_heads=n_heads, num_kv_heads=num_kv_heads)

In [46]:
gqa

GroupedQueryAttention(
  (q_proj): Linear(in_features=512, out_features=512, bias=True)
  (k_proj): Linear(in_features=512, out_features=128, bias=True)
  (v_proj): Linear(in_features=512, out_features=128, bias=True)
  (out_proj): Linear(in_features=512, out_features=512, bias=True)
)

In [39]:
print(f"Input shape: {x.shape}")
print(f"Output shape: {gqa(x).shape}")
print(f"Number of parameters in K projection: {sum(p.numel() for p in gqa.k_proj.parameters())}")

Input shape: torch.Size([2, 10, 512])
Output shape: torch.Size([2, 10, 512])
Number of parameters in K projection: 65664


In [40]:
!pip install torchsummary



In [48]:
from torchsummary import summary

# Suppose input has shape (seq_len, d_model)
summary(gqa, input_size=(128, 512))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1             [-1, 128, 512]         262,656
            Linear-2             [-1, 128, 128]          65,664
            Linear-3             [-1, 128, 128]          65,664
            Linear-4             [-1, 128, 512]         262,656
Total params: 656,640
Trainable params: 656,640
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.25
Forward/backward pass size (MB): 1.25
Params size (MB): 2.50
Estimated Total Size (MB): 4.00
----------------------------------------------------------------


In [49]:
# Standard Multi-Head Attention -> 1,050,624 parameters
# Grouped Query Attention -> 656,640

#roughly 40% less parameters