# PART 1: The Memory Path

|Configuration|KV heads|Memory usage (MB)|Reduction Factor|
|-----------|-----------|-----------|-----------|
|MHA|32|1 batch * 32 heads * 128 head dimension * 4096 tokens * 2 bytes per element * 2 for K and V = 67,108,864 bytes = 67.1MB|1x|
|MQA|1|1 batch * 1 head * 128 head dimensions * 4096 tokens * 2 byte per element * 2 for K and V = 2,097,152 bytes = 2.1MB|32x|
|GQA|8|1 batch * 8 head * 128 head dimensions * 4096 tokens * 2 byte per element * 2 for K and V = 16,777,216 bytes = 16.8MB|4x|

# PART 2: the `repeat_kv` Utility

In [14]:
import torch

def repeat_kv(x: torch.Tensor, num_repeats: int) -> torch.Tensor:
    """
    Repeats the KV heads to match the number of query heads.
    
    Args:
        x: Tensor of shape (B, n_kv_heads, T, head_dim)
        num_repeats: How many times to repeat each KV head
        
    Returns:
        Tensor of shape (B, n_kv_heads * num_repeats, T, head_dim)
    """
    if num_repeats == 1:
        return x
    
    # Use repeat_interleave on the head dimension (dim=1)
    # This repeats each head 'num_repeats' times: [H1, H1, H2, H2...]
    return x.repeat_interleave(num_repeats, dim=1)



In [15]:
# --- Test Case ---
# Input shape: [Batch, KV_Heads, Seq_Len, Head_Dim]
input_tensor = torch.randn(1, 2, 4, 8)
num_repeats = 2

output_tensor = repeat_kv(input_tensor, num_repeats)

print(f"Input shape:  {list(input_tensor.shape)}")   # [1, 2, 4, 8]
print(f"Num repeats:  {num_repeats}")
print(f"Output shape: {list(output_tensor.shape)}")  # [1, 4, 4, 8]

# Verification
assert output_tensor.shape == (1, 4, 4, 8), "Shape mismatch!"

Input shape:  [1, 2, 4, 8]
Num repeats:  2
Output shape: [1, 4, 4, 8]


# PART 3: Implementing the GQA module

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class GroupedQueryAttention(nn.Module):
    def __init__(self, d_in, d_out, num_heads, num_kv_groups):
        super().__init__()
        
        # Ensure num_heads is divisible by num_kv_groups
        assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
        
        self.d_in = d_in
        self.d_out = d_out
        self.num_heads = num_heads
        self.num_kv_groups = num_kv_groups
        self.head_dim = d_out // num_heads
        self.num_repeats = num_heads // num_kv_groups
        
        # Linear layers for Q, K, and V projections 
        self.W_q = nn.Linear(d_in, d_out, bias=False)
        kv_dim = self.head_dim * num_kv_groups
        self.W_k = nn.Linear(d_in, kv_dim, bias=False)
        self.W_v = nn.Linear(d_in, kv_dim, bias=False)
        
        self.out_proj = nn.Linear(d_out, d_out, bias=False)

    def forward(self, x):
        # batch size, sequence length, and d_model (num_heads * head_dimensions)
        B, T, C = x.shape
        
        # Project Input â†’ Q, K, V
        # dot product with W_q, W_k, and W_v
        q = self.W_q(x)
        k = self.W_k(x)
        v = self.W_v(x)
        
        # (B,T, H, d_h) to (B, H, T, d_h) 
        # where H in q  is H_q and H in k and v is H_kv
        q = q.view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = k.view(B, T, self.num_kv_groups, self.head_dim).permute(0, 2, 1, 3)
        v = v.view(B, T, self.num_kv_groups, self.head_dim).permute(0, 2, 1, 3)
        
        # Using the repeat_kv function to broadcast the head (now in dim = 1)
        k = repeat_kv(k, self.num_repeats)
        v = repeat_kv(v, self.num_repeats)
        
        # Scaled Dot-Product Attention
        attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn_weights = F.softmax(attn_scores, dim=-1)
    
        context = attn_weights @ v
        
        # Concatenate and project output
        # Permute back to (B, T, H, d_h) before reshaping to (B, T, d_out)
        context = context.permute(0, 2, 1, 3).contiguous().view(B, T, self.d_out)
        return self.out_proj(context)

In [17]:
# --- Test Case ---
d_in = 512
num_heads = 8
num_kv_groups = 2
gqa_module = GroupedQueryAttention(d_in, d_in, num_heads, num_kv_groups)

# Pass random tensor (Batch=1, Seq=16, Dim=512)
x = torch.randn(1, 16, d_in)

# calls the  forward function
output = gqa_module(x)

print(f"Input shape:  {list(x.shape)}")
print(f"Output shape: {list(output.shape)}") # Should match

Input shape:  [1, 16, 512]
Output shape: [1, 16, 512]


# PART 4: The Model Surgery

In [18]:
import torch

def convert_mha_to_gqa(mha_weight, groups):
    """
    Converts an MHA Key/Value weight matrix to GQA using mean-pooling.
    
    Args:
        mha_weight: Tensor of shape (D_in, D_out)
        groups: Number of MHA heads to average into a single GQA head
        
    Returns:
        Compressed tensor of shape (D_in, D_out / groups)
    """
    d_in, d_out = mha_weight.shape
    
    
    # We reshape to: (D_in, Num_GQA_Groups, Heads_per_Group, D_head)
    # Reshape to identify individual heads and group them
    # -1 automatically calculates the number of GQA groups based on the input
    # this becomes (1, 2, 2) with input of (1,4)
    x = mha_weight.view(d_in, -1, groups)
    
    # Calculate the mean within each group (dim=2)
    # the dims become (d_in, num_groups)
    gqa_weight = x.mean(dim=2)
    
    # Reshape back to a 2D matrix (D_in, D_out_compressed)
    # -1 helps squash all the other dimensions to 1, in the case that d_head is not 1
    # The new D_out is (original D_out / groups)
    return gqa_weight.view(d_in, -1)



In [19]:
# --- Test Case Verification ---
# Create a dummy weight matrix for H=4 heads, D_head=1 (Total D_out = 4)
# Head 1 = 1s, Head 2 = 3s, Head 3 = 5s, Head 4 = 7s
mha_weight = torch.tensor([[1.0, 3.0, 5.0, 7.0]])
groups = 2 

# Run conversion
gqa_weight = convert_mha_to_gqa(mha_weight, groups)

print(f"Original MHA weights: {mha_weight}")
print(f"Compressed GQA weights: {gqa_weight}")

# Proof of correctness:
# Group 1 (Heads 1 & 2): Mean of 1 and 3 is 2.0
# Group 2 (Heads 3 & 4): Mean of 5 and 7 is 6.0
expected = torch.tensor([[2.0, 6.0]])
assert torch.allclose(gqa_weight, expected), "Mean-pooling calculation failed!"
print("Verification Success: The dummy weights averaged correctly.")

Original MHA weights: tensor([[1., 3., 5., 7.]])
Compressed GQA weights: tensor([[2., 6.]])
Verification Success: The dummy weights averaged correctly.
