In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time
from typing import Optional, Tuple


In [None]:


# ============================================================================
# 1. BASIC GROUPED QUERY ATTENTION
# ============================================================================

class GroupedQueryAttention(nn.Module):
    """
    Grouped Query Attention (GQA) implementation.
    
    GQA shares keys and values across groups of query heads,
    reducing memory usage while maintaining quality.
    """
    
    def __init__(
        self,
        embed_dim: int,
        num_query_heads: int,
        num_kv_heads: int,
        dropout: float = 0.1
    ):
        """
        Args:
            embed_dim: Model dimension
            num_query_heads: Number of query heads
            num_kv_heads: Number of key-value heads
            dropout: Dropout probability
        """
        super(GroupedQueryAttention, self).__init__()
        
        assert embed_dim % num_query_heads == 0, \
            "embed_dim must be divisible by num_query_heads"
        assert num_query_heads % num_kv_heads == 0, \
            "num_query_heads must be divisible by num_kv_heads"
        
        self.embed_dim = embed_dim
        self.num_query_heads = num_query_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = embed_dim // num_query_heads
        self.group_size = num_query_heads // num_kv_heads
        
        # Query projection (one per query head)
        self.q_proj = nn.Linear(embed_dim, num_query_heads * self.head_dim)
        
        # Key-Value projections (one per KV head)
        self.k_proj = nn.Linear(embed_dim, num_kv_heads * self.head_dim)
        self.v_proj = nn.Linear(embed_dim, num_kv_heads * self.head_dim)
        
        # Output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = math.sqrt(self.head_dim)
    
    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        return_attention: bool = False
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Args:
            x: Input tensor [batch, seq_len, embed_dim]
            mask: Optional attention mask
            return_attention: Whether to return attention weights
        
        Returns:
            output: [batch, seq_len, embed_dim]
            attention_weights: [batch, num_query_heads, seq_len, seq_len] (optional)
        """
        batch_size, seq_len, embed_dim = x.size()
        
        # Project to Q, K, V
        # Q: [batch, seq_len, num_query_heads * head_dim]
        Q = self.q_proj(x)
        # K, V: [batch, seq_len, num_kv_heads * head_dim]
        K = self.k_proj(x)
        V = self.v_proj(x)
        
        # Reshape for multi-head attention
        # Q: [batch, seq_len, num_query_heads, head_dim]
        # -> [batch, num_query_heads, seq_len, head_dim]
        Q = Q.view(batch_size, seq_len, self.num_query_heads, self.head_dim).transpose(1, 2)
        
        # K, V: [batch, seq_len, num_kv_heads, head_dim]
        # -> [batch, num_kv_heads, seq_len, head_dim]
        K = K.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        
        # Expand K, V to match number of query heads
        # Method: Repeat each KV head group_size times
        # [batch, num_kv_heads, seq_len, head_dim]
        # -> [batch, num_kv_heads, 1, seq_len, head_dim]
        # -> [batch, num_kv_heads, group_size, seq_len, head_dim]
        # -> [batch, num_query_heads, seq_len, head_dim]
        K = K.unsqueeze(2).expand(
            batch_size, self.num_kv_heads, self.group_size, seq_len, self.head_dim
        ).reshape(batch_size, self.num_query_heads, seq_len, self.head_dim)
        
        V = V.unsqueeze(2).expand(
            batch_size, self.num_kv_heads, self.group_size, seq_len, self.head_dim
        ).reshape(batch_size, self.num_query_heads, seq_len, self.head_dim)
        
        # Compute attention scores
        # [batch, num_query_heads, seq_len, head_dim] @ [batch, num_query_heads, head_dim, seq_len]
        # -> [batch, num_query_heads, seq_len, seq_len]
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        
        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Softmax
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Apply attention to values
        # [batch, num_query_heads, seq_len, seq_len] @ [batch, num_query_heads, seq_len, head_dim]
        # -> [batch, num_query_heads, seq_len, head_dim]
        attn_output = torch.matmul(attn_weights, V)
        
        # Reshape back
        # [batch, num_query_heads, seq_len, head_dim]
        # -> [batch, seq_len, num_query_heads, head_dim]
        # -> [batch, seq_len, embed_dim]
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.embed_dim
        )
        
        # Output projection
        output = self.out_proj(attn_output)
        
        if return_attention:
            return output, attn_weights
        return output, None


# ============================================================================
# 2. OPTIMIZED GQA WITH KV CACHE
# ============================================================================

class CachedGroupedQueryAttention(nn.Module):
    """
    Grouped Query Attention with KV caching for efficient autoregressive generation.
    
    During generation, we cache computed keys and values to avoid redundant
    computation. GQA makes this cache much smaller than standard MHA.
    """
    
    def __init__(
        self,
        embed_dim: int,
        num_query_heads: int,
        num_kv_heads: int,
        dropout: float = 0.1
    ):
        super(CachedGroupedQueryAttention, self).__init__()
        
        assert embed_dim % num_query_heads == 0
        assert num_query_heads % num_kv_heads == 0
        
        self.embed_dim = embed_dim
        self.num_query_heads = num_query_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = embed_dim // num_query_heads
        self.group_size = num_query_heads // num_kv_heads
        
        self.q_proj = nn.Linear(embed_dim, num_query_heads * self.head_dim)
        self.k_proj = nn.Linear(embed_dim, num_kv_heads * self.head_dim)
        self.v_proj = nn.Linear(embed_dim, num_kv_heads * self.head_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.head_dim)
    
    def forward(
        self,
        x: torch.Tensor,
        past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        use_cache: bool = False,
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        """
        Args:
            x: Input [batch, seq_len, embed_dim]
            past_kv: Cached (K, V) from previous steps
            use_cache: Whether to return updated cache
            mask: Optional attention mask
        
        Returns:
            output: [batch, seq_len, embed_dim]
            present_kv: Updated cache (if use_cache=True)
        """
        batch_size, seq_len, _ = x.size()
        
        # Project queries
        Q = self.q_proj(x)
        Q = Q.view(batch_size, seq_len, self.num_query_heads, self.head_dim).transpose(1, 2)
        
        # Project keys and values
        K = self.k_proj(x)
        V = self.v_proj(x)
        K = K.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        
        # Concatenate with past KV if provided
        if past_kv is not None:
            past_K, past_V = past_kv
            K = torch.cat([past_K, K], dim=2)  # Concatenate along sequence dimension
            V = torch.cat([past_V, V], dim=2)
        
        # Store for next step if caching
        present_kv = (K, V) if use_cache else None
        
        # Expand K, V for grouped attention
        total_seq_len = K.size(2)
        K_expanded = K.unsqueeze(2).expand(
            batch_size, self.num_kv_heads, self.group_size, total_seq_len, self.head_dim
        ).reshape(batch_size, self.num_query_heads, total_seq_len, self.head_dim)
        
        V_expanded = V.unsqueeze(2).expand(
            batch_size, self.num_kv_heads, self.group_size, total_seq_len, self.head_dim
        ).reshape(batch_size, self.num_query_heads, total_seq_len, self.head_dim)
        
        # Compute attention
        scores = torch.matmul(Q, K_expanded.transpose(-2, -1)) / self.scale
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        attn_output = torch.matmul(attn_weights, V_expanded)
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.embed_dim
        )
        
        output = self.out_proj(attn_output)
        
        return output, present_kv


# ============================================================================
# 3. CONVERSION UTILITIES: MHA → GQA
# ============================================================================

class MultiHeadAttention(nn.Module):
    """Standard Multi-Head Attention for comparison."""
    
    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.1):
        super(MultiHeadAttention, self).__init__()
        
        assert embed_dim % num_heads == 0
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.head_dim)
    
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.size()
        
        Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        
        output = torch.matmul(attn, V)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        output = self.out_proj(output)
        
        return output


def convert_mha_to_gqa(
    mha_module: MultiHeadAttention,
    num_kv_heads: int,
    strategy: str = 'mean'
) -> GroupedQueryAttention:
    """
    Convert a trained Multi-Head Attention to Grouped Query Attention.
    
    Args:
        mha_module: Trained MHA module
        num_kv_heads: Number of KV heads in GQA
        strategy: 'mean' or 'first' - how to combine heads
    
    Returns:
        gqa_module: Converted GQA module
    """
    assert mha_module.num_heads % num_kv_heads == 0, \
        "num_heads must be divisible by num_kv_heads"
    
    group_size = mha_module.num_heads // num_kv_heads
    
    # Create GQA module
    gqa = GroupedQueryAttention(
        embed_dim=mha_module.embed_dim,
        num_query_heads=mha_module.num_heads,
        num_kv_heads=num_kv_heads,
        dropout=0.0  # Will copy from MHA
    )
    
    # Copy query projection (unchanged)
    gqa.q_proj.weight.data = mha_module.q_proj.weight.data.clone()
    gqa.q_proj.bias.data = mha_module.q_proj.bias.data.clone()
    
    # Copy output projection (unchanged)
    gqa.out_proj.weight.data = mha_module.out_proj.weight.data.clone()
    gqa.out_proj.bias.data = mha_module.out_proj.bias.data.clone()
    
    # Convert K, V projections
    head_dim = mha_module.head_dim
    
    # Reshape MHA K, V weights to [num_heads, head_dim, embed_dim]
    k_weight = mha_module.k_proj.weight.data.view(
        mha_module.num_heads, head_dim, mha_module.embed_dim
    )
    v_weight = mha_module.v_proj.weight.data.view(
        mha_module.num_heads, head_dim, mha_module.embed_dim
    )
    
    k_bias = mha_module.k_proj.bias.data.view(mha_module.num_heads, head_dim)
    v_bias = mha_module.v_proj.bias.data.view(mha_module.num_heads, head_dim)
    
    # Group and combine
    new_k_weight = []
    new_v_weight = []
    new_k_bias = []
    new_v_bias = []
    
    for g in range(num_kv_heads):
        start_idx = g * group_size
        end_idx = start_idx + group_size
        
        if strategy == 'mean':
            # Average weights across group
            k_g = k_weight[start_idx:end_idx].mean(dim=0)
            v_g = v_weight[start_idx:end_idx].mean(dim=0)
            kb_g = k_bias[start_idx:end_idx].mean(dim=0)
            vb_g = v_bias[start_idx:end_idx].mean(dim=0)
        elif strategy == 'first':
            # Take first head in group
            k_g = k_weight[start_idx]
            v_g = v_weight[start_idx]
            kb_g = k_bias[start_idx]
            vb_g = v_bias[start_idx]
        else:
            raise ValueError(f"Unknown strategy: {strategy}")
        
        new_k_weight.append(k_g)
        new_v_weight.append(v_g)
        new_k_bias.append(kb_g)
        new_v_bias.append(vb_g)
    
    # Stack and reshape
    new_k_weight = torch.stack(new_k_weight).view(num_kv_heads * head_dim, mha_module.embed_dim)
    new_v_weight = torch.stack(new_v_weight).view(num_kv_heads * head_dim, mha_module.embed_dim)
    new_k_bias = torch.stack(new_k_bias).view(num_kv_heads * head_dim)
    new_v_bias = torch.stack(new_v_bias).view(num_kv_heads * head_dim)
    
    # Assign to GQA
    gqa.k_proj.weight.data = new_k_weight
    gqa.v_proj.weight.data = new_v_weight
    gqa.k_proj.bias.data = new_k_bias
    gqa.v_proj.bias.data = new_v_bias
    
    return gqa


# ============================================================================
# 4. COMPLETE MODEL EXAMPLE
# ============================================================================

class GQATransformerBlock(nn.Module):
    """
    Complete transformer block with GQA.
    Similar to GPT architecture but with grouped query attention.
    """
    
    def __init__(
        self,
        embed_dim: int,
        num_query_heads: int,
        num_kv_heads: int,
        ff_dim: Optional[int] = None,
        dropout: float = 0.1
    ):
        super(GQATransformerBlock, self).__init__()
        
        if ff_dim is None:
            ff_dim = 4 * embed_dim
        
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)
        
        self.attn = GroupedQueryAttention(
            embed_dim=embed_dim,
            num_query_heads=num_query_heads,
            num_kv_heads=num_kv_heads,
            dropout=dropout
        )
        
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.GELU(),
            nn.Linear(ff_dim, embed_dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x, mask=None):
        # Attention with residual
        attn_out, _ = self.attn(self.ln1(x), mask=mask)
        x = x + attn_out
        
        # FFN with residual
        x = x + self.ffn(self.ln2(x))
        
        return x


# ============================================================================
# 5. PERFORMANCE BENCHMARKING
# ============================================================================

class AttentionBenchmark:
    """Benchmark different attention mechanisms."""
    
    @staticmethod
    def benchmark_memory(
        batch_size: int,
        seq_len: int,
        embed_dim: int,
        num_heads: int,
        num_kv_heads: int
    ):
        """
        Measure memory usage of MHA vs GQA.
        
        Returns:
            dict with memory statistics
        """
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Create input
        x = torch.randn(batch_size, seq_len, embed_dim, device=device)
        
        # Multi-Head Attention
        mha = MultiHeadAttention(embed_dim, num_heads).to(device)
        
        if device.type == 'cuda':
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.synchronize()
        
        _ = mha(x)
        
        if device.type == 'cuda':
            torch.cuda.synchronize()
            mha_memory = torch.cuda.max_memory_allocated() / 1024**2  # MB
        else:
            mha_memory = 0
        
        # Grouped Query Attention
        gqa = GroupedQueryAttention(embed_dim, num_heads, num_kv_heads).to(device)
        
        if device.type == 'cuda':
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.synchronize()
        
        _ = gqa(x)
        
        if device.type == 'cuda':
            torch.cuda.synchronize()
            gqa_memory = torch.cuda.max_memory_allocated() / 1024**2  # MB
        else:
            gqa_memory = 0
        
        return {
            'mha_memory_mb': mha_memory,
            'gqa_memory_mb': gqa_memory,
            'reduction': (mha_memory - gqa_memory) / mha_memory if mha_memory > 0 else 0
        }
    
    @staticmethod
    def benchmark_speed(
        batch_size: int,
        seq_len: int,
        embed_dim: int,
        num_heads: int,
        num_kv_heads: int,
        num_iterations: int = 100
    ):
        """
        Measure inference speed of MHA vs GQA.
        
        Returns:
            dict with timing statistics
        """
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        x = torch.randn(batch_size, seq_len, embed_dim, device=device)
        
        # MHA
        mha = MultiHeadAttention(embed_dim, num_heads).to(device)
        mha.eval()
        
        # Warmup
        with torch.no_grad():
            for _ in range(10):
                _ = mha(x)
        
        if device.type == 'cuda':
            torch.cuda.synchronize()
        
        start = time.time()
        with torch.no_grad():
            for _ in range(num_iterations):
                _ = mha(x)
        
        if device.type == 'cuda':
            torch.cuda.synchronize()
        
        mha_time = (time.time() - start) / num_iterations
        
        # GQA
        gqa = GroupedQueryAttention(embed_dim, num_heads, num_kv_heads).to(device)
        gqa.eval()
        
        # Warmup
        with torch.no_grad():
            for _ in range(10):
                _ = gqa(x)
        
        if device.type == 'cuda':
            torch.cuda.synchronize()
        
        start = time.time()
        with torch.no_grad():
            for _ in range(num_iterations):
                _ = gqa(x)
        
        if device.type == 'cuda':
            torch.cuda.synchronize()
        
        gqa_time = (time.time() - start) / num_iterations
        
        return {
            'mha_time_ms': mha_time * 1000,
            'gqa_time_ms': gqa_time * 1000,
            'speedup': mha_time / gqa_time
        }
    
    @staticmethod
    def benchmark_kv_cache_size(
        seq_len: int,
        embed_dim: int,
        num_heads: int,
        num_kv_heads: int,
        num_layers: int = 32
    ):
        """
        Calculate KV cache size for different attention mechanisms.
        
        Returns:
            dict with cache size information
        """
        head_dim = embed_dim // num_heads
        
        # MHA cache size
        mha_cache = 2 * num_layers * num_heads * head_dim * seq_len * 2  # 2 bytes for FP16
        
        # GQA cache size
        gqa_cache = 2 * num_layers * num_kv_heads * head_dim * seq_len * 2
        
        return {
            'mha_cache_mb': mha_cache / (1024**2),
            'gqa_cache_mb': gqa_cache / (1024**2),
            'reduction_factor': num_heads / num_kv_heads,
            'memory_saved_mb': (mha_cache - gqa_cache) / (1024**2)
        }


# ============================================================================
# 6. DEMONSTRATION EXAMPLES
# ============================================================================

def demo_basic_gqa():
    """Demonstrate basic GQA usage."""
    print("="*80)
    print("DEMO 1: Basic Grouped Query Attention")
    print("="*80)
    
    batch_size = 2
    seq_len = 10
    embed_dim = 256
    num_query_heads = 16
    num_kv_heads = 4
    
    # Create GQA layer
    gqa = GroupedQueryAttention(embed_dim, num_query_heads, num_kv_heads)
    
    # Create input
    x = torch.randn(batch_size, seq_len, embed_dim)
    
    # Forward pass
    output, attn_weights = gqa(x, return_attention=True)
    
    print(f"\nConfiguration:")
    print(f"  Input shape: {x.shape}")
    print(f"  Query heads: {num_query_heads}")
    print(f"  KV heads: {num_kv_heads}")
    print(f"  Group size: {num_query_heads // num_kv_heads}")
    print(f"  Head dimension: {embed_dim // num_query_heads}")
    
    print(f"\nOutput:")
    print(f"  Shape: {output.shape}")
    print(f"  Attention weights shape: {attn_weights.shape}")
    
    # Count parameters
    mha_params = 4 * embed_dim * embed_dim  # Q, K, V, O for MHA
    gqa_params = sum(p.numel() for p in gqa.parameters())
    
    print(f"\nParameters:")
    print(f"  GQA: {gqa_params:,}")
    print(f"  MHA (equivalent): {mha_params:,}")
    print(f"  Reduction: {(1 - gqa_params/mha_params)*100:.1f}%")


def demo_mha_to_gqa_conversion():
    """Demonstrate converting MHA to GQA."""
    print("\n" + "="*80)
    print("DEMO 2: MHA to GQA Conversion")
    print("="*80)
    
    embed_dim = 128
    num_heads = 8
    num_kv_heads = 4
    
    # Create and "train" MHA
    mha = MultiHeadAttention(embed_dim, num_heads)
    
    # Simulate some training (just for demo)
    x = torch.randn(4, 20, embed_dim)
    _ = mha(x)
    
    print(f"\nOriginal MHA:")
    print(f"  Num heads: {num_heads}")
    mha_params = sum(p.numel() for p in mha.parameters())
    print(f"  Parameters: {mha_params:,}")
    
    # Convert to GQA
    gqa = convert_mha_to_gqa(mha, num_kv_heads, strategy='mean')
    
    print(f"\nConverted GQA:")
    print(f"  Query heads: {gqa.num_query_heads}")
    print(f"  KV heads: {gqa.num_kv_heads}")
    gqa_params = sum(p.numel() for p in gqa.parameters())
    print(f"  Parameters: {gqa_params:,}")
    print(f"  Reduction: {(1 - gqa_params/mha_params)*100:.1f}%")
    
    # Test that conversion works
    with torch.no_grad():
        mha_out = mha(x)
        gqa_out, _ = gqa(x)
        
        # They won't be identical due to grouping, but should be similar
        diff = (mha_out - gqa_out).abs().mean()
        print(f"\nOutput difference: {diff:.6f}")
        print(f"  (Non-zero expected due to weight averaging)")


def demo_kv_cache():
    """Demonstrate KV caching with GQA."""
    print("\n" + "="*80)
    print("DEMO 3: KV Caching with GQA")
    print("="*80)
    
    embed_dim = 128
    num_query_heads = 8
    num_kv_heads = 4
    
    gqa = CachedGroupedQueryAttention(embed_dim, num_query_heads, num_kv_heads)
    gqa.eval()
    
    # Simulate autoregressive generation
    print("\nSimulating token-by-token generation:")
    
    # Initial token
    x = torch.randn(1, 1, embed_dim)
    print(f"\nStep 1: Process token 1")
    print(f"  Input shape: {x.shape}")
    
    output, kv_cache = gqa(x, past_kv=None, use_cache=True)
    print(f"  KV cache created: K shape {kv_cache[0].shape}, V shape {kv_cache[1].shape}")
    
    # Next tokens using cache
    for step in range(2, 5):
        x = torch.randn(1, 1, embed_dim)
        output, kv_cache = gqa(x, past_kv=kv_cache, use_cache=True)
        print(f"\nStep {step}: Process token {step}")
        print(f"  Input shape: {x.shape}")
        print(f"  KV cache grows: K shape {kv_cache[0].shape}")
    
    print(f"\nFinal cache contains {kv_cache[0].size(2)} tokens")
    print(f"Only computed new K, V for each new token (not recomputed past)")


def demo_performance_comparison():
    """Compare performance of different attention mechanisms."""
    print("\n" + "="*80)
    print("DEMO 4: Performance Comparison")
    print("="*80)
    
    configs = [
        (4, 512, 256, 16, 16, "MHA (baseline)"),
        (4, 512, 256, 16, 8, "GQA (2:1)"),
        (4, 512, 256, 16, 4, "GQA (4:1)"),
        (4, 512, 256, 16, 2, "GQA (8:1)"),
    ]
    
    print("\nSpeed Benchmark:")
    print(f"{'Config':<20} {'Time (ms)':<12} {'Speedup':<10}")
    print("-" * 45)
    
    baseline_time = None
    for batch, seq, dim, qh, kvh, name in configs:
        result = AttentionBenchmark.benchmark_speed(
            batch, seq, dim, qh, kvh, num_iterations=50
        )
        time_ms = result['gqa_time_ms'] if kvh < qh else result['mha_time_ms']
        
        if baseline_time is None:
            baseline_time = time_ms
            speedup = 1.0
        else:
            speedup = baseline_time / time_ms
        
        print(f"{name:<20} {time_ms:>10.2f}  {speedup:>8.2f}x")
    
    print("\nKV Cache Size Comparison:")
    print(f"{'Config':<20} {'Cache (MB)':<12} {'Reduction':<12}")
    print("-" * 50)
    
    for batch, seq, dim, qh, kvh, name in configs:
        result = AttentionBenchmark.benchmark_kv_cache_size(
            seq_len=seq, embed_dim=dim, num_heads=qh, 
            num_kv_heads=kvh, num_layers=32
        )
        cache_mb = result['gqa_cache_mb'] if kvh < qh else result['mha_cache_mb']
        reduction = f"{qh/kvh:.0f}x" if kvh < qh else "-"
        
        print(f"{name:<20} {cache_mb:>10.1f}  {reduction:>10}")


# ============================================================================
# MAIN EXECUTION
# ============================================================================

if __name__ == "__main__":
    print("\n" + "="*80)
    print("GROUPED QUERY ATTENTION: COMPLETE IMPLEMENTATION")
    print("="*80)
    
    # Set seed for reproducibility
    torch.manual_seed(42)
    
    # Run all demos
    demo_basic_gqa()
    demo_mha_to_gqa_conversion()
    demo_kv_cache()
    demo_performance_comparison()
    
    print("\n" + "="*80)
    print("All demonstrations completed successfully!")
    print("="*80)


# ============================================================================
# ADDITIONAL UTILITIES
# ============================================================================

def visualize_grouping_pattern(num_query_heads: int, num_kv_heads: int):
    """
    Visualize which query heads share which KV heads.
    """
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches
    
    group_size = num_query_heads // num_kv_heads
    
    fig, ax = plt.subplots(figsize=(12, 6))
    
    # Draw query heads
    for i in range(num_query_heads):
        color_idx = i // group_size
        color = plt.cm.tab10(color_idx % 10)
        rect = patches.Rectangle((i, 2), 0.8, 0.8, 
                                 linewidth=2, edgecolor='black', 
                                 facecolor=color, alpha=0.7)
        ax.add_patch(rect)
        ax.text(i + 0.4, 2.4, f'Q{i}', ha='center', va='center', fontsize=8)
    
    # Draw KV heads
    for j in range(num_kv_heads):
        color = plt.cm.tab10(j % 10)
        rect = patches.Rectangle((j * group_size + group_size/2 - 0.4, 0.5), 
                                 0.8, 0.8,
                                 linewidth=2, edgecolor='black', 
                                 facecolor=color, alpha=0.7)
        ax.add_patch(rect)
        ax.text(j * group_size + group_size/2, 0.9, f'KV{j}', 
               ha='center', va='center', fontsize=10, fontweight='bold')
    
    # Draw connections
    for i in range(num_query_heads):
        kv_idx = i // group_size
        x_start = i + 0.4
        x_end = kv_idx * group_size + group_size/2
        ax.plot([x_start, x_end], [2, 1.3], 'k--', alpha=0.3, linewidth=1)
    
    ax.set_xlim(-0.5, num_query_heads + 0.5)
    ax.set_ylim(0, 3.5)
    ax.set_aspect('equal')
    ax.axis('off')
    
    ax.text(num_query_heads/2, 3.2, 
           f'Grouped Query Attention: {num_query_heads} Query Heads, {num_kv_heads} KV Heads',
           ha='center', fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig('/tmp/gqa_grouping.png', dpi=150, bbox_inches='tight')
    print("Grouping visualization saved to /tmp/gqa_grouping.png")
    plt.close()


if __name__ == "__main__":
    visualize_grouping_pattern(num_query_heads=8, num_kv_heads=4)

In [None]:
# Create GQA layer
gqa = GroupedQueryAttention(
    embed_dim=512,
    num_query_heads=32,
    num_kv_heads=8  # 4× memory reduction
)

# Forward pass
x = torch.randn(16, 100, 512)  # [batch, seq, dim]
output, _ = gqa(x)

With KV Caching

In [None]:
gqa = CachedGroupedQueryAttention(
    embed_dim=512,
    num_query_heads=32,
    num_kv_heads=8
)

# First token
output, cache = gqa(x[:, 0:1], use_cache=True)

# Subsequent tokens (much faster!)
for i in range(1, seq_len):
    output, cache = gqa(x[:, i:i+1], past_kv=cache, use_cache=True)

Convert Existing MHA

In [None]:
# Efficient expansion using broadcasting
K = K.unsqueeze(2).expand(
    batch, num_kv_heads, group_size, seq_len, head_dim
).reshape(batch, num_query_heads, seq_len, head_dim)

 K, V Expansion

In [None]:
# Efficient expansion using broadcasting
K = K.unsqueeze(2).expand(
    batch, num_kv_heads, group_size, seq_len, head_dim
).reshape(batch, num_query_heads, seq_len, head_dim)

Memory Layout

In [None]:
# Store KV cache efficiently
# Shape: [batch, num_kv_heads, seq_len, head_dim]
# Not: [batch, num_query_heads, seq_len, head_dim]
# Saves (num_query_heads / num_kv_heads)× memory

Group Assignment

In [None]:
# Query head i uses KV head j
j = i // group_size

# Example: 8 query heads, 4 KV heads
# Q0, Q1 → KV0
# Q2, Q3 → KV1
# Q4, Q5 → KV2
# Q6, Q7 → KV3