# MHA

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim: float, num_heads: int, dropout: float= 0.2, bias: bool=False):
        super().__init__()
        assert embed_dim% num_heads== 0
        self.head_dim = embed_dim// num_heads
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.dropout = dropout

        self.m_attn = nn.Linear(embed_dim, 3* embed_dim,
                                bias= bias)
        self.m_proj = nn.Linear(embed_dim, embed_dim, bias= bias)
        
        self.atten_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")

    def forward(self, x: torch.tensor, attention_mask: torch.tensor=None):
        B, T, C = x.shape
        qkv = self.m_attn(x)
        qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        if attention_mask is not None:
            # Expand to shape: (B, 1, 1, T) to broadcast
            attention_mask = attention_mask[:, None, None, :]  # (B, 1, 1, T)
            # Convert mask to float with -inf where masked
            attention_mask = attention_mask.masked_fill(attention_mask == 0, float("-inf"))
            attention_mask = attention_mask.masked_fill(attention_mask == 1, 0.0)
        if self.flash:
            try:
                from flash_attn import flash_attn_func
                y = flash_attn_func(q, k, v, 
                                    dropout_p= self.dropout if self.training else 0,
                                    causal= True)
            except Exception:
                y = F.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask,
                                                   dropout_p= self.dropout if self.training else 0,
                                                   is_causal= False)
        else:
            att = (q@ k.transpose(-2, -1))* (1/ torch.sqrt(k.size(-1)))
            if attention_mask is not None:
                att = att.masked_fill(attention_mask == 0, float('-inf'))
            att = self.atten_dropout(F.softmax(att, dim=-1))
            y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.m_proj(y))
        return y
    
if __name__ == '__main__':
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    B, T, C = 2, 5, 16
    num_heads = 4
    x = torch.randn(B, T, C).to(device)
    attention_mask = torch.tensor([
        [1, 1, 1, 1, 1],
        [1, 1, 1, 0, 0]
    ], dtype=torch.float32).to(device)  # shape: (B, T)
    mha = MultiHeadAttention(embed_dim=C, num_heads=num_heads, dropout=0.1)
    mha = mha.to(device)
    output = mha(x, attention_mask)
    print("输出形状:", output.shape)

输出形状: torch.Size([2, 5, 16])


# MQA


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiQueryAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.2, bias: bool = False):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.head_dim = embed_dim // num_heads
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.dropout = dropout

        # MQA: 单独定义 Q 的线性层，K 和 V 共享单头
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)  # 多头 Q
        self.k_proj = nn.Linear(embed_dim, self.head_dim, bias=bias)  # 单头 K
        self.v_proj = nn.Linear(embed_dim, self.head_dim, bias=bias)  # 单头 V
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

        self.atten_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")

    def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None):
        B, T, C = x.shape

        # 计算 Q, K, V
        q = self.q_proj(x)  # (B, T, embed_dim)
        k = self.k_proj(x)  # (B, T, head_dim)
        v = self.v_proj(x)  # (B, T, head_dim)

        # 重塑 Q 为多头形式，K 和 V 保持单头
        q = q.reshape(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # (B, num_heads, T, head_dim)
        k = k.unsqueeze(1)  # (B, 1, T, head_dim)
        v = v.unsqueeze(1)  # (B, 1, T, head_dim)

        if attention_mask is not None:
            # 扩展为 (B, 1, 1, T) 以广播
            attention_mask = attention_mask[:, None, None, :]  # (B, 1, 1, T)
            attention_mask = attention_mask.masked_fill(attention_mask == 0, float("-inf"))
            attention_mask = attention_mask.masked_fill(attention_mask == 1, 0.0)

        if self.flash:
            try:
                from flash_attn import flash_attn_func
                y = flash_attn_func(q, k, v, 
                                  dropout_p=self.dropout if self.training else 0,
                                  causal=True)
            except Exception:
                y = F.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask,
                                                 dropout_p=self.dropout if self.training else 0,
                                                 is_causal=False)
        else:
            # 手动计算注意力
            att = (q @ k.transpose(-2, -1)) * (1.0 / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)))
            if attention_mask is not None:
                att = att + attention_mask  # 直接加，因为 mask 已转为 -inf/0
            att = F.softmax(att, dim=-1)
            att = self.atten_dropout(att)
            y = att @ v  # (B, num_heads, T, head_dim)

        # 重塑输出并投影
        y = y.transpose(1, 2).contiguous().view(B, T, self.embed_dim)  # (B, T, embed_dim)
        y = self.resid_dropout(self.out_proj(y))
        return y

if __name__ == '__main__':
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    B, T, C = 2, 5, 16
    num_heads = 4
    x = torch.randn(B, T, C).to(device)
    attention_mask = torch.tensor([
        [1, 1, 1, 1, 1],
        [1, 1, 1, 0, 0]
    ], dtype=torch.float32).to(device)  # shape: (B, T)
    mqa = MultiQueryAttention(embed_dim=C, num_heads=num_heads, dropout=0.1)
    mqa = mqa.to(device)
    output = mqa(x, attention_mask)
    print("输出形状:", output.shape)

输出形状: torch.Size([2, 5, 16])


# GQA

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GroupedQueryAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, num_kv_heads: int, dropout: float = 0.1, bias: bool = False):
        super().__init__()
        assert embed_dim % num_heads == 0
        assert num_heads % num_kv_heads == 0

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = embed_dim // num_heads
        self.group_size = num_heads // num_kv_heads

        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)                                # (B, T, H * D)
        self.kv_proj = nn.Linear(embed_dim, 2 * self.head_dim * num_kv_heads, bias=bias)         # (B, T, 2 * h_kv * D)

        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

        self.atten_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)
        self.dropout = dropout
        self.flash = hasattr(F, 'scaled_dot_product_attention')

    def forward(self, x, attention_mask=None):
        B, T, C = x.shape

        q = self.q_proj(x).reshape(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # (B, nh, T, hd)
        kv = self.kv_proj(x).reshape(B, T, self.num_kv_heads, 2, self.head_dim).permute(3, 0, 2, 1, 4)
        k, v = kv[0], kv[1]  # (B, n_kv, T, hd)

        # 将 KV 扩展为每组 query 使用
        k = k.repeat_interleave(self.group_size, dim=1)  # (B, nh, T, hd)
        v = v.repeat_interleave(self.group_size, dim=1)

        if attention_mask is not None:
            attention_mask = attention_mask[:, None, None, :]  # (B, 1, 1, T)
            attention_mask = attention_mask.masked_fill(attention_mask == 0, float("-inf"))
            attention_mask = attention_mask.masked_fill(attention_mask == 1, 0.0)

        if self.flash:
            try:
                from flash_attn import flash_attn_func
                y = flash_attn_func(q, k, v, dropout_p=self.dropout if self.training else 0.0, causal=True)
            except Exception:
                y = F.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask,
                                                   dropout_p=self.dropout if self.training else 0.0,
                                                   is_causal=False)
        else:
            att = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
            if attention_mask is not None:
                att = att.masked_fill(attention_mask == float('-inf'), float('-inf'))
            att = self.atten_dropout(F.softmax(att, dim=-1))
            y = att @ v

        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.out_proj(y))
        return y
    
if __name__ == '__main__':
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    B, T, C = 2, 5, 16
    num_heads = 4
    num_kv_heads = 2
    x = torch.randn(B, T, C).to(device)
    attention_mask = torch.tensor([
        [1, 1, 1, 1, 1],
        [1, 1, 1, 0, 0]
    ], dtype=torch.float32).to(device)  # shape: (B, T)
    mha = GroupedQueryAttention(embed_dim=C, num_heads=num_heads, num_kv_heads= 2, dropout=0.1)
    mha = mha.to(device)
    output = mha(x, attention_mask)
    print("输出形状:", output.shape)

输出形状: torch.Size([2, 5, 16])


# Sparse Attention


In [None]:
class SparseAttention(nn.Module):
    def __init__(self, d_model=512, num_heads=8, window_size=256):
        super().__init__()
        self.mha = MultiHeadAttention(d_model, num_heads) 
        self.window_size = window_size

    def forward(self, x, mask=None):
        B, N, _ = x.shape
        # 创建局部mask（对角线窗口）
        local_mask = torch.ones(B, N, N, device=x.device)
        for i in range(N):
            local_mask[:, i, max(0, i-self.window_size):min(N, i+self.window_size+1)] = 0
        if mask is not None:
            local_mask = local_mask | (mask == 0)
        return self.mha(x, local_mask == 0)   # 1表示可attn

# MLA

In [3]:
import torch
import torch.nn as nn
import math

class MultiHeadLatentAttention(nn.Module):
    def __init__(self, d_model=4096, num_heads=32, q_latent_dim=512, kv_latent_dim=128):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.q_latent_dim = q_latent_dim
        self.kv_latent_dim = kv_latent_dim

        # 1. 输入 → 低维潜空间
        self.Wq_d = nn.Linear(d_model, q_latent_dim)      # Query latent
        self.Wkv_d = nn.Linear(d_model, kv_latent_dim)    # KV latent（共享！）

        # 2. 潜空间中做注意力
        self.W_qk = nn.Linear(q_latent_dim, num_heads * kv_latent_dim)  # 多头展开
        self.Wv_u = nn.Linear(kv_latent_dim, num_heads * self.head_dim) # 恢复V

        self.out = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, N, _ = x.shape

        C_q = self.Wq_d(x)                                      # (B, N, q_latent)
        C_kv = self.Wkv_d(x)                                    # (B, N, kv_latent) ← 关键压缩

        # 多头分数计算（在潜空间）
        scores = torch.matmul(
            self.W_qk(C_q).view(B, N, self.num_heads, self.kv_latent_dim).transpose(1, 2),
            C_kv.transpose(-2, -1)[:, None, ...]
        ) / math.sqrt(self.kv_latent_dim)

        attn = F.softmax(scores, dim=-1)

        # 恢复V并计算输出
        V = self.Wv_u(C_kv).view(B, N, self.num_heads, self.head_dim)
        out = torch.matmul(attn, V.transpose(1, 2)).transpose(1, 2).reshape(B, N, -1)
        return self.out(out)