# Part 2: Grouped-Query Attention (GQA)

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple

class PositionEmbeding(nn.Module):
    def __init__(self, max_seq_length):
        super().__init__()
        self.max_seq_length = max_seq_length

    def forward(self, q, k, positions=None):
        return q, k
    
class PositionWiseFeedForward(nn.Module):
    """
    Position-wise feed-forward network module
    """
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionWiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        # x shape: (batch_size, seq_len, d_model)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        # Final output shape: (batch_size, seq_len, d_model)
        return x

## Multi-Head Attention (MHA)

In [7]:
class MultiHeadAttention(nn.Module):
    """
    Multi-head attention mechanism module with KV Cache support
    """
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # Define linear transformation layers for Q, K, V and output
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Calculate attention scores (QK^T)
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)

        # Apply mask (if provided)
        if mask is not None:
            # Set positions where mask is 0 to a very small negative number, so they approach 0 after softmax
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)

        # Calculate attention weights (Softmax)
        attn_probs = torch.softmax(attn_scores, dim=-1)

        # Weighted sum (weights * V)
        output = torch.matmul(attn_probs, V)
        return output

    def split_heads(self, x):
        # Transform input x shape from (batch_size, seq_length, d_model)
        # to (batch_size, num_heads, seq_length, head_dim)
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)

    def combine_heads(self, x):
        # Transform input x shape from (batch_size, num_heads, seq_length, head_dim)
        # back to (batch_size, seq_length, d_model)
        batch_size, num_heads, seq_length, head_dim = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)

    def forward(self, Q, K, V, mask=None,
                past_key_value=None, use_cache=False):
        """
        Args:
            Q: Query tensor
            K: Key tensor (if None, use Q as K)
            V: Value tensor (if None, use Q as V)
            mask: Attention mask
            past_key_value: Tuple of (past_key, past_value) from previous steps
            use_cache: Whether to cache KV for next steps
        Returns:
            output: Attention output
            present_key_value: Tuple of (key, value) for caching (if use_cache=True)
        """
        # Perform linear transformations on Q, K, V
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        # Apply KV Cache if provided
        if past_key_value is not None:
            # past_key_value is a tuple: (past_key, past_value)
            past_key, past_value = past_key_value
            # Concatenate past and current KV along sequence dimension
            K = torch.cat([past_key, K], dim=2)
            V = torch.cat([past_value, V], dim=2)

        # Prepare present KV for caching
        present_key_value = None
        if use_cache:
            present_key_value = (K, V)

        # Calculate scaled dot-product attention
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)

        # Combine multi-head outputs and perform final linear transformation
        output = self.W_o(self.combine_heads(attn_output))
        return output, present_key_value

# Grouped-Query Attention (GQA)

In [8]:
class GroupedQueryAttention(nn.Module):
    """
    Grouped query attention mechanism module with KV Cache support.
    When num_kv_heads == 1, it euqals Multi-query attention (MQA)
    When num_kv_heads == num_heads, it equals Multi-head attention (MHA)
    """
    def __init__(self, d_model, num_heads, num_kv_heads):
        super(GroupedQueryAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = d_model // num_heads
        assert num_heads % num_kv_heads == 0, "num_heads must be divisible by num_kv_heads"
        self.num_groups = num_heads // num_kv_heads

        # Define linear transformation layers for Q, K, V and output
        self.W_q = nn.Linear(d_model, num_heads * self.head_dim, bias=False)
        self.W_k = nn.Linear(d_model, num_kv_heads * self.head_dim, bias=False)
        self.W_v = nn.Linear(d_model, num_kv_heads * self.head_dim, bias=False)
        self.W_o = nn.Linear(num_heads * self.head_dim, d_model, bias=False)
        

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Calculate attention scores (QK^T)
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)

        # Apply mask (if provided)
        if mask is not None:
            # Set positions where mask is 0 to a very small negative number, so they approach 0 after softmax
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)

        # Calculate attention weights (Softmax)
        attn_probs = torch.softmax(attn_scores, dim=-1)

        # Weighted sum (weights * V)
        output = torch.matmul(attn_probs, V)
        return output

    def split_heads(self, x, num_heads):
        '''Transform input x shape from (batch_size, seq_length, d_model)
        to (batch_size, num_heads, seq_length, head_dim)'''
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, num_heads, self.head_dim).transpose(1, 2)

    def combine_heads(self, x):
        '''Transform input x shape from (batch_size, num_heads, seq_length, head_dim)
        back to (batch_size, seq_length, d_model)'''
        batch_size, num_heads, seq_length, head_dim = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
    
    def expand_kv(self, K, V):
        '''Repeat KV for groups (before caching for better efficiency)'''
        # K, V shape: (batch_size, num_kv_heads, seq_length, head_dim)
        if self.num_groups > 1:
            # Method 1: Repeat along head dimension
            K = K.repeat_interleave(self.num_groups, dim=1)
            V = V.repeat_interleave(self.num_groups, dim=1)
            
            # Alternative method 2: Expand (less memory efficient)
            # B, _, T, _ = K.shape
            # K = K.unsqueeze(2).expand(-1, -1, self.num_groups, -1, -1).reshape(B, self.num_heads, T, self.head_dim)
            # V = V.unsqueeze(2).expand(-1, -1, self.num_groups, -1, -1).reshape(B, self.num_heads, T, self.head_dim)
        return K, V

    def forward(self, Q, K, V, mask=None,
                past_key_value=None, use_cache=False):
        """
        Args:
            Q: Query tensor of shape (batch_size, seq_len, d_model)
            K: Key tensor (optional)
            V: Value tensor (optional)
            mask: Attention mask
            past_key_value: Tuple of (past_key, past_value) from previous steps
            use_cache: Whether to cache KV for next steps
        Returns:
            output: Attention output
            present_key_value: Tuple of (key, value) for caching
        """
        Q = self.split_heads(self.W_q(Q), self.num_heads)       # (batch_size, num_heads, seq_len, head_dim)
        K = self.split_heads(self.W_k(K), self.num_kv_heads)    # (batch_size, num_kv_heads, seq_len, head_dim)
        V = self.split_heads(self.W_v(V), self.num_kv_heads)    # (batch_size, num_kv_heads, seq_len, head_dim)

        if past_key_value is not None:
            past_key, past_value = past_key_value
            K = torch.cat([past_key, K], dim=2)
            V = torch.cat([past_value, V], dim=2)

        present_key_value = None
        if use_cache:
            present_key_value = (K, V)

        K, V = self.expand_kv(K, V)
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output, present_key_value

In [9]:
def test_gqa():
    batch_size = 2
    seq_len = 10
    d_model = 512
    num_heads = 8
    num_kv_heads = 2  # GQA: 8 query heads share 2 KV heads
    
    gqa = GroupedQueryAttention(d_model, num_heads, num_kv_heads)
    
    # Test inputs
    x = torch.randn(batch_size, seq_len, d_model)
    
    # Test without cache
    output1, cache = gqa(x, x, x, use_cache=True)
    print(f"Output shape without cache: {output1.shape}")
    print(f"Cache keys shape: {cache[0].shape}")  # Should be (2, 2, 10, 64)
    
    # Test with cache (autoregressive step)
    next_token = torch.randn(batch_size, 1, d_model)
    output2, new_cache = gqa(
        next_token, next_token, next_token,
        past_key_value=cache,
        use_cache=True
    )
    print(f"Output shape with cache: {output2.shape}")
    print(f"New cache keys shape: {new_cache[0].shape}")  # Should be (2, 2, 11, 64)
    
    # Verify KV sharing
    print(f"\nKV sharing ratio: {num_heads}:{num_kv_heads} = {num_heads/num_kv_heads:.0f}x")
    print(f"Memory savings: {(1 - num_kv_heads/num_heads)*100:.1f}%")

if __name__ == "__main__":
    test_gqa()

Output shape without cache: torch.Size([2, 10, 512])
Cache keys shape: torch.Size([2, 2, 10, 64])
Output shape with cache: torch.Size([2, 1, 512])
New cache keys shape: torch.Size([2, 2, 11, 64])

KV sharing ratio: 8:2 = 4x
Memory savings: 75.0%
