In [14]:
import torch
import torch.nn as nn
import math
import os
from thop import profile
from contextlib import redirect_stdout

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, dropout, bias, device):
        """
        多头注意力机制的实现。 
        Args:
        hidden_size (int): 输入特征的维度，也即 hidden_state 的最后一维。
        num_heads (int): 注意力头的数量。
        dropout (float): dropout 的概率，默认为 0.0。 
        """
        super(MultiHeadAttention, self).__init__()
        assert hidden_size % num_heads == 0, "hidden_size 必须能被 num_heads 整除"
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads # 每个头的维度
        
        # 定义线性变换层，用于生成Q、K、V
        self.query = nn.Linear(hidden_size, hidden_size, bias, device)
        self.key = nn.Linear(hidden_size, hidden_size, bias, device)
        self.value = nn.Linear(hidden_size, hidden_size, bias, device)
        
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(hidden_size, hidden_size, bias, device)
        self.out_projection = nn.Linear(hidden_size, hidden_size, bias, device)
        
    def forward(self, hidden_state, attention_mask):
        """
        前向传播函数。
        Args:
            hidden_state (torch.Tensor): 输入的 hidden_state，形状为 [batch_size, seq_len, hidden_size]。
            attention_mask (torch.Tensor, optional): 注意力掩码，用于屏蔽某些位置，形状为 [batch_size, seq_len]。默认为 None。
        Returns:
             torch.Tensor: 注意力输出，形状为 [batch_size, seq_len, hidden_size]。
        """
        batch_size, seq_len, _ = hidden_state.size()
        
        # 1. 通过线性层得到 Q, K, V
        query = self.query(hidden_state) # [batch_size, seq_len, hidden_size]
        key = self.key(hidden_state) # [batch_size, seq_len, hidden_size]
        value = self.value(hidden_state) # [batch_size, seq_len, hidden_size]
        
        # 2. 将 Q, K, V 拆分成多头
        query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # [batch_size, num_heads, seq_len, head_dim]
        key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # [batch_size, num_heads, seq_len, head_dim]
        value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # [batch_size, num_heads, seq_len, head_dim]
        
        # 3. 计算注意力权重
        attention_weights = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5) # [batch_size, num_heads， seq_len, seq_len]
        
        # 应用 attention mask
        if attention_mask is not None:
            attention_weights = attention_weights.masked_fill(attention_mask[:, None, None, :] == 0, float('-inf')) # attention_mask[:, None, None, :] 将掩码从 [batch_size, seq_len] 扩展为 [batch_size, 1, 1, seq_len]
        attention_weights = torch.softmax(attention_weights, dim=-1)  # [batch_size, num_heads, seq_len, seq_len]
        attention_weights = self.dropout(attention_weights)
        
        # 4. 计算上下文向量
        context = torch.matmul(attention_weights, value)  # [batch_size, num_heads, seq_len, head_dim]
        
        # 5. 将多头合并
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)  # [batch_size, seq_len, hidden_size]，contiguous()确保内存布局是连续的，为后续的view操作做准备
        
        # 6. 通过输出线性层
        output = self.out_projection(context)  # [batch_size, seq_len, hidden_size]
        return output

In [19]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, group_size, dropout, bias, device):
        """
        Grouped Query Attention 实现。
        Args:
            hidden_size (int): 输入特征的维度
            num_heads (int): 查询头的数量。
            group_size (int): 每个组中包含的查询头数量。
            dropout (float): dropout 的概率。
        """
        super(GroupedQueryAttention, self).__init__()
        
        assert hidden_size % num_heads == 0, "hidden_size 必须能被 num_heads 整除"
        assert num_heads % group_size == 0, "num_heads 必须能被 group_size 整除"
        
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.group_size = group_size
        self.group_num = num_heads // group_size
        self.head_dim = hidden_size // num_heads
        
        # 查询头
        self.query = nn.Linear(hidden_size, hidden_size, bias, device)
        
        # 键和值头（分组共享）
        self.key = nn.Linear(hidden_size, self.group_num * self.head_dim, bias, device)
        self.value = nn.Linear(hidden_size, self.group_num * self.head_dim, bias, device)
        
        self.dropout = nn.Dropout(dropout)
        self.out_projection = nn.Linear(hidden_size, hidden_size, bias, device)
        
    def forward(self, hidden_state, attention_mask=None):
        """
        前向传播函数。
        Args:
            hidden_state (torch.Tensor): 输入张量，形状为 [batch_size, seq_len, hidden_size]。
            attention_mask (torch.Tensor, optional): 注意力掩码，形状为 [batch_size, seq_len]。
        Returns:
            torch.Tensor: 注意力输出，形状为 [batch_size, seq_len, hidden_size]。
        """
        batch_size, seq_len, _ = hidden_state.size()
        
        # 1. 通过线性层得到 Q, K, V
        query = self.query(hidden_state) # [batch_size, seq_len, hidden_size]
        key = self.key(hidden_state) # [batch_size, seq_len, group_num * head_dim]
        value = self.value(hidden_state) # [batch_size, seq_len, group_num * head_dim]
        
        # 2. 将 Q, K, V 拆分成多头
        query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # [batch_size, num_heads, seq_len, head_dim]
        
        # 3. K 和 V 扩展到 num_heads 个头
        key = key.view(batch_size, seq_len, self.group_num, self.head_dim).transpose(1, 2) # [batch_size, group_num, seq_len, head_dim]
        key = key.unsqueeze(2).expand(-1, -1, self.group_size, -1, -1).contiguous().view(batch_size, -1, seq_len, self.head_dim) # [batch_size, num_heads, seq_len, head_dim]
        value = value.view(batch_size, seq_len, self.group_num, self.head_dim).transpose(1, 2) # [batch_size, group_num, seq_len, head_dim]
        value = value.unsqueeze(2).expand(-1, -1, self.group_size, -1, -1).contiguous().view(batch_size, -1, seq_len, self.head_dim) # [batch_size, num_heads, seq_len, head_dim]
        
        # 4. 计算注意力权重
        attention_weights = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5) # [batch_size, num_heads, seq_len, seq_len]
        
        # 5. attention mask
        if attention_mask is not None:
            attention_weights = attention_weights.masked_fill(attention_mask[:, None, None, :] == 0, float('-inf'))
            
        attention_weights = torch.softmax(attention_weights, dim=-1) # [batch_size, num_heads, seq_len, seq_len]
        
        attention_weights = self.dropout(attention_weights)
        
        # 6. 计算上下文向量
        context = torch.matmul(attention_weights, value) # [batch_size, num_heads, seq_len, head_dim]
        
        # 7. 将多头合并
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size) # [batch_size, seq_len, hidden_size]
         
        # 8. 通过输出线性层
        output = self.out_projection(context) # [batch_size, seq_len, hidden_size]
        
        return output   

In [7]:
class RotaryEmbedding(nn.Module):
    def __init__(self, hidden_size, num_heads, base, max_len):
        """
        RoPE位置编码模块。
        Args:
            hidden_size (int): 模型维度
            num_heads (int): 注意力头数量
            base (int): 频率基值
            max_len (int): 最大序列长度
        """
        super(RotaryEmbedding, self).__init__()
        
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.base = base
        self.max_len = max_len
        self.head_dim = hidden_size // num_heads
        self.cos_pos_cache, self.sin_pos_cache = self._compute_pos_emb()
    
    def _compute_pos_emb(self):
        """
        计算位置编码的余弦和正弦值。
        Returns:
            cos_pos (Tensor): 余弦位置编码
            sin_pos (Tensor): 正弦位置编码
        """
        theta_i = 1. / (self.base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
        positions = torch.arange(self.max_len)
        pos_emb = positions.unsqueeze(1) * theta_i.unsqueeze(0)
        
        cos_pos = pos_emb.sin().repeat_interleave(2, dim=-1)
        sin_pos = pos_emb.cos().repeat_interleave(2, dim=-1)
        
        return cos_pos, sin_pos
     
    def forward(self, q):
        """
        RoPE位置编码应用。
        Args:
            q (torch.Tensor): 输入张量 [bs, num_heads, seq_len, head_dim]
        Returns:
            torch.Tensor: 应用位置编码后的张量
        """
        bs, seq_len = q.shape[0], q.shape[2]
        cos_pos = self.cos_pos_cache[:seq_len].to(q.device) # [seq_len, head_dim]
        sin_pos = self.sin_pos_cache[:seq_len].to(q.device) # [seq_len, head_dim]
        
        # 扩展维度以匹配batch和head维度
        cos_pos = cos_pos.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, head_dim]
        sin_pos = sin_pos.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, head_dim]
        
        # RoPE变换
        q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1) # 奇偶交替
        q2 = q2.reshape(q.shape).contiguous()
        
        return q * cos_pos + q2 * sin_pos
    

class MultiHeadLatentAttention(nn.Module):
    def __init__(self, hidden_size, down_dim, up_dim, num_heads, rope_head_dim, base, max_len, dropout, bias, device):      
        """
        Multi-Head Latent Attention 实现。
        Args:
            hidden_size (int): 输入特征维度。
            down_dim (int): 降维后的维度。
            up_dim (int): 升维后的维度。
            num_heads (int): 注意力头数量。
            rope_head_dim (int): RoPE编码的头维度。
            dropout (float): ddropout概率。
            bias (bool): 是否使用偏置。
            device (str): 设备类型（'cpu'或'cuda'）。
        """
        super(MultiHeadLatentAttention, self).__init__()
        
        self.hidden_size = hidden_size
        self.down_dim = down_dim
        self.up_dim = up_dim
        self.num_heads = num_heads
        self.rope_head_dim = rope_head_dim
        self.head_dim = hidden_size // num_heads
        self.v_head_dim = up_dim // num_heads
        
        # 降维投影
        self.down_proj_kv = nn.Linear(hidden_size, down_dim, bias, device)
        self.down_proj_q = nn.Linear(hidden_size, down_dim, bias, device)
        
        # 升维投影
        self.up_proj_k = nn.Linear(down_dim, up_dim, bias, device)
        self.up_proj_v = nn.Linear(down_dim, up_dim, bias, device)
        self.up_proj_q = nn.Linear(down_dim, up_dim, bias, device)
        
        # 解耦Q/K投影
        self.proj_qr = nn.Linear(down_dim, rope_head_dim * num_heads, bias, device)
        self.proj_kr = nn.Linear(hidden_size, rope_head_dim, bias, device)
        
        # RoPE位置编码
        self.rope_q = RotaryEmbedding(rope_head_dim * num_heads, num_heads, base, max_len)
        self.rope_k = RotaryEmbedding(rope_head_dim, 1, base, max_len)
        
        # 输出层
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(num_heads * self.v_head_dim, hidden_size, bias, device)
        self.res_dropout = nn.Dropout(dropout)
        
    def forward(self, hidden_state, mask=None):
        """
        前向传播。
        Args:
            h (torch.Tensor): 输入张量 [batch_size, seq_len, hidden_size]
            mask (torch.Tensor): 注意力掩码 [batch_size, seq_len]
        Returns:
            torch.Tensor: 输出张量 [bs, seq_len, d_model]
        """
        batch_size, seq_len, _ = hidden_state.size()
        
        # 1. 低秩转换
        c_t_kv = self.down_proj_kv(hidden_state) # [batch_size, seq_len, down_dim]
        k_t_c = self.up_proj_k(c_t_kv) # [batch_size, seq_len, up_dim]
        v_t_c = self.up_proj_v(c_t_kv) # [batch_size, seq_len, up_dim]
        c_t_q = self.down_proj_q(hidden_state) # [batch_size, seq_len, down_dim]
        q_t_c = self.up_proj_q(c_t_q) # [batch_size, seq_len, up_dim]
        
        # 2. 解耦Q/K处理
        # RoPE投影处理
        q_t_r = self.proj_qr(c_t_q) # [batch_size, seq_len, rope_head_dim * num_heads]
        q_t_r = q_t_r.view(batch_size, seq_len, self.num_heads, self.rope_head_dim).transpose(1, 2) # [batch_size, num_heads, seq_len, rope_head_dim]
        q_t_r = self.rope_q(q_t_r) # 应用RoPE编码
        
        k_t_r = self.proj_kr(hidden_state) # [batch_size, seq_len, rope_head_dim]
        k_t_r = k_t_r.unsqueeze(1) # [batch_size, 1, seq_len, rope_head_dim]
        k_t_r = self.rope_k(k_t_r) # 应用RoPE编码
        
        # 3. 注意力计算
        # Q/K/V维度调整
        q_t_c = q_t_c.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2) # [batch_size, num_heads, seq_len, v_head_dim]
        q = torch.cat([q_t_c, q_t_r], dim=-1) # [batch_size, num_heads, seq_len, (up_dim+rope_head_dim)/num_heads]
        
        k_t_c = k_t_c.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2) # [batch_size, num_heads, seq_len, v_head_dim]
        k_t_r = k_t_r.expand(batch_size, self.num_heads, seq_len, -1) # [batch_size, num_heads, seq_len, rope_head_dim]
        k = torch.cat([k_t_c, k_t_r], dim=-1) # [batch_size, num_heads, seq_len, (up_dim+rope_head_dim)/num_heads]
        
        # 4. 计算注意力权重
        scores = torch.matmul(q, k.transpose(-1, -2)) # [batch_size, num_heads, seq_len, seq_len]
        scores = scores / (math.sqrt(self.head_dim) + math.sqrt(self.rope_head_dim))
        
        # 5. attention mask
        if attention_mask is not None:
            scores = scores.masked_fill(mask[:, None, None, :] == 0, float('-inf')) # [batch_size, num_heads, seq_len, seq_len]
        
        attention_weights = torch.softmax(scores, dim=-1) # [batch_size, num_heads, seq_len, seq_len]
        attention_weights = self.dropout(attention_weights)
        
        # 6. V维度调整
        v_t_c = v_t_c.view(batch_size, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2) # [batch_size, num_heads, seq_len, v_head_dim]
        
        # 7. 计算上下文向量
        context = torch.matmul(attention_weights, v_t_c) # [batch_size, num_heads, seq_len, v_head_dim]
        
        # 8. 合并多头
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) # [batch_size, seq_len, num_heads * v_head_dim]
        
        # 9. 输出投影
        output = self.fc(context) # [batch_size, seq_len, hidden_size]
        output = self.res_dropout(output)
            
        return output

In [12]:
def count_params_and_flops(module: nn.Module, input_shape: tuple, attention_mask: bool = False, device: str = 'cuda'):
    """
    统计指定模型模块的参数量和计算量(FLOPs)。
    Args:
        module: PyTorch 模块对象。
        input_shape: 输入张量的形状 (元组形式, 不包含 batch 维度)。
    Returns:
        params_total: 总参数量。
        flops_total: 总计算量。
    """
    # 构造示例输入
    dummy_input = torch.randn(2, *input_shape, device=device) # 添加 batch 维度
    
    # 计算参数量（单位：k）
    params_total = sum(p.numel() for p in module.parameters())
    
    # 计算计算量（单位：GFLOPs）
    with redirect_stdout(open(os.devnull, "w")):
        flops_total, _ = profile(module, inputs=(dummy_input, attention_mask))
        
    return params_total, flops_total

In [None]:
if __name__ == '__main__':
    # 示例
    batch_size = 2
    seq_len = 10
    hidden_size = 256
    num_heads = 8
    group_size = 2
    base=10000
    max_len=512
    down_dim=64
    up_dim=128
    rope_head_dim=26
    dropout = 0.1
    bias = False
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # 创建一个随机的 hidden_state
    hidden_state = torch.randn(batch_size, seq_len, hidden_size, device=device)
    
    # 创建一个 attention mask (可选)
    attention_mask = torch.ones(batch_size, seq_len, device=device)
    attention_mask[:, 5:] = 0
    
    print("==" * 5, " Attention  Test ", "==" * 5)
    
    # 创建一个 MHA 实例
    mha = MultiHeadAttention(hidden_size, num_heads, dropout, bias, device)
    
    # 通过 MHA 层
    output = mha(hidden_state, attention_mask)
    
    # 打印输出形状
    print("MHA Output Shape:", output.shape)
    
    # 统计参数量和计算量
    mha_params, mha_flops = count_params_and_flops(mha, (seq_len, hidden_size), attention_mask, device)
    print(f"MHA Params: {mha_params}, FLOPs: {mha_flops}")
    
    print("===" * 13)

MHA Output Shape: torch.Size([2, 10, 256])
MHA Params: 327680, FLOPs: 5242880.0


In [20]:
if __name__ == '__main__':
    # 示例
    batch_size = 2
    seq_len = 10
    hidden_size = 256
    num_heads = 8
    group_size = 2
    dropout = 0.1
    bias = False
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # 创建一个随机的 hidden_state
    hidden_state = torch.randn(batch_size, seq_len, hidden_size, device=device)
    
    # 创建一个 attention mask (可选)
    attention_mask = torch.ones(batch_size, seq_len, device=device)
    attention_mask[:, 5:] = 0
    
    print("==" * 5, " Attention  Test ", "==" * 5)
    
    # 创建一个 GQA 实例
    gqa = GroupedQueryAttention(hidden_size, num_heads, group_size, dropout, bias, device)
    
    # 通过 GQA 层
    output = gqa(hidden_state, attention_mask)
    
    # 打印输出形状
    print("GQA Output Shape:", output.shape)
    
    # 统计参数量和计算量
    gqa_params, gqa_flops = count_params_and_flops(gqa, (seq_len, hidden_size), attention_mask, device)
    print(f"GQA Params: {gqa_params}, FLOPs: {gqa_flops}")
    
    print("===" * 13)

GQA Output Shape: torch.Size([2, 10, 256])
GQA Params: 196608, FLOPs: 3932160.0


In [None]:
if __name__ == '__main__':
    # 示例
    batch_size = 2
    seq_len = 10
    hidden_size = 256
    num_heads = 8
    base=10000
    max_len=512
    down_dim=64
    up_dim=128
    rope_head_dim=26
    dropout = 0.1
    bias = False
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # 创建一个随机的 hidden_state
    hidden_state = torch.randn(batch_size, seq_len, hidden_size, device=device)
    
    # 创建一个 attention mask (可选)
    attention_mask = torch.ones(batch_size, seq_len, device=device)
    attention_mask[:, 5:] = 0
    
    print("==" * 5, " Attention  Test ", "==" * 5)
    
    # 创建一个 MLA 实例
    mla = MultiHeadLatentAttention(hidden_size, down_dim, up_dim, num_heads, rope_head_dim, base, max_len, dropout, bias, device)
    
    # 通过 MLA 层
    output = mla(hidden_state, attention_mask)
    
    # 打印输出形状
    print("MLA Output Shape:", output.shape)
    
    # 统计参数量和计算量
    mla_params, mla_flops = count_params_and_flops(mla, (seq_len, hidden_size), attention_mask, device)
    print(f"MlA Params: {mla_params}, FLOPs: {mla_flops}")
    
    print("===" * 13)

MLA Output Shape: torch.Size([2, 10, 256])
MlA Params: 110080, FLOPs: 2201600.0
