实现多头注意力机制（MHA）、多查询注意力机制（MQA）、分组查询注意力机制（GQA）和内存高效注意力机制（MLA），并对比它们的KV Cache内存占用

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        assert self.head_dim * num_heads == d_model, "d_model must be divisible by num_heads"
        
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.wo = nn.Linear(d_model, d_model)
        
    def forward(self, q, k, v, mask=None):
        batch_size, seq_len, d_model = q.size()
        
        # Linear projections
        q = self.wq(q).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.wk(k).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.wv(v).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Scaled Dot-Product Attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn = F.softmax(scores, dim=-1)
        output = torch.matmul(attn, v)
        
        # Concatenate heads
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        
        # Final linear projection
        output = self.wo(output)
        
        return output

In [4]:
class MultiQueryAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiQueryAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        assert self.head_dim * num_heads == d_model, "d_model must be divisible by num_heads"
        
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, self.head_dim)
        self.wv = nn.Linear(d_model, self.head_dim)
        self.wo = nn.Linear(d_model, d_model)
        
    def forward(self, q, k, v, mask=None):
        batch_size, seq_len, d_model = q.size()
        
        # Linear projections
        q = self.wq(q).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.wk(k).view(batch_size, seq_len, self.head_dim).unsqueeze(1)
        v = self.wv(v).view(batch_size, seq_len, self.head_dim).unsqueeze(1)
        
        # Scaled Dot-Product Attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn = F.softmax(scores, dim=-1)
        output = torch.matmul(attn, v)
        
        # Concatenate heads
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        
        # Final linear projection
        output = self.wo(output)
        
        return output

In [5]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, num_heads, num_groups):
        super(GroupedQueryAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.num_groups = num_groups
        self.head_dim = d_model // num_heads
        
        assert self.head_dim * num_heads == d_model, "d_model must be divisible by num_heads"
        assert num_heads % num_groups == 0, "num_heads must be divisible by num_groups"
        
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model // num_groups)
        self.wv = nn.Linear(d_model, d_model // num_groups)
        self.wo = nn.Linear(d_model, d_model)
        
    def forward(self, q, k, v, mask=None):
        batch_size, seq_len, d_model = q.size()
        
        # Linear projections
        q = self.wq(q).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.wk(k).view(batch_size, seq_len, self.num_groups, self.head_dim).unsqueeze(2)
        v = self.wv(v).view(batch_size, seq_len, self.num_groups, self.head_dim).unsqueeze(2)
        
        # Scaled Dot-Product Attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn = F.softmax(scores, dim=-1)
        output = torch.matmul(attn, v)
        
        # Concatenate heads
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        
        # Final linear projection
        output = self.wo(output)
        
        return output

In [6]:
def compare_kv_cache_memory(d_model, num_heads, seq_len, batch_size):
    # Initialize models
    mha = MultiHeadAttention(d_model, num_heads)
    mqa = MultiQueryAttention(d_model, num_heads)
    gqa = GroupedQueryAttention(d_model, num_heads, num_groups=2)
    
    # Generate random input
    q = torch.randn(batch_size, seq_len, d_model)
    k = torch.randn(batch_size, seq_len, d_model)
    v = torch.randn(batch_size, seq_len, d_model)
    
    # Calculate KV Cache memory for MHA
    mha_kv_cache = (k.size(1) * k.size(2) + v.size(1) * v.size(2)) * batch_size * 4  # 4 bytes per float32
    print(f"MHA KV Cache Memory: {mha_kv_cache / 1e6:.2f} MB")
    
    # Calculate KV Cache memory for MQA
    mqa_kv_cache = (k.size(1) * (k.size(2) // num_heads) + v.size(1) * (v.size(2) // num_heads)) * batch_size * 4
    print(f"MQA KV Cache Memory: {mqa_kv_cache / 1e6:.2f} MB")
    
    # Calculate KV Cache memory for GQA
    gqa_kv_cache = (k.size(1) * (k.size(2) // 2) + v.size(1) * (v.size(2) // 2)) * batch_size * 4
    print(f"GQA KV Cache Memory: {gqa_kv_cache / 1e6:.2f} MB")
    


# Example usage
d_model = 512
num_heads = 8
seq_len = 1024
batch_size = 32

compare_kv_cache_memory(d_model, num_heads, seq_len, batch_size)

MHA KV Cache Memory: 134.22 MB
MQA KV Cache Memory: 16.78 MB
GQA KV Cache Memory: 67.11 MB
