# SDPA 缩放点积注意力

In [None]:
import torch
from torch import nn

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, query, key, value, causal_mask = None, padding_mask = None):
        # shape of q/k/v is (batch_size, seq_len, hidden_size)
        # shape of attention_scores is (b, seq_len, seq_len)
        # key needs to transpose
        hidden_size = query.size(-1)
        attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(hidden_size, dtype = torch.float32))

        # shape of causal_mask is (b, seq_len,seq_len), "1" means where mask exits
        if causal_mask is not None: # error use:'if causal_mask:'
            attention_scores += causal_mask * -1e9

        #padding_mask is (b, seq_len) -> (b, 1, 1, seq_len)
        if padding_mask is not None:
            padding_mask = padding_mask.unsqueeze(1).unsqueeze(1)
            # PyTorch 的广播机制会将其视为 (batch_size, 1, seq_len)（自动对齐缺失维度）
            attention_scores += padding_mask * -1e9
        
        attention_probs = torch.softmax(attention_scores, dim = -1) # shape:(b, num_heads, seq_len, seq_len)
        attention_output = torch.matmul(attention_probs, value) # shape:(b, num_heads, seq_len, seq_len)

        return attention_output
    
def test_attn():
    batch_size = 128
    seq_len = 512
    hidden_size = 1024

    query = torch.randn(batch_size, seq_len, hidden_size)
    key = torch.randn(batch_size, seq_len, hidden_size)
    value = torch.randn(batch_size, seq_len, hidden_size)

    sdpa = ScaledDotProductAttention()
    output = sdpa(query, key, value)

    print("Query shape:", query.shape)
    print("Key shape:", key.shape)
    print("Value shape:", value.shape)
    print("Output shape:", output.shape)
    print("Output value:", output)

if __name__ == "__main__":
    test_attn()

# MHA 多头注意力

In [None]:
import torch
from torch import nn

class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, hidden_size)
        self.v_linear = nn.Linear(hidden_size, hidden_size)
        self.o_linear = nn.Linear(hidden_size, hidden_size)

    def forward(self, hidden_state, causal_mask = None, padding_mask = None):
        # shape of hidden_size:(batch_size, seq_len, hidden_size)
        batch_size = hidden_state.size(0)
        # 为输入序列中计算q, k, v, 这是通过将输入词向量与三个权重矩阵相乘实现的
        query = self.q_linear(hidden_state) # shape:(batch_size, seq_len, hidden_size)
        key = self.k_linear(hidden_state)
        value = self.v_linear(hidden_state)

        query = query.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)# shape:-> (b, seq_len, num_heads, head_dim) -> (b, num_heads, seq_len, head_dim)
        key = key.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim, dtype = torch.float32))# shape:(b, num_heads, seq_len, head_dim)

        if causal_mask is not None:
            attention_scores += causal_mask * -1e9

        if padding_mask is not None:
            padding_mask = padding_mask.unsqueeze(1).unsqueeze(1)
            attention_scores += padding_mask * -1e9

        attention_probs = torch.softmax(attention_scores, dim = -1)# shape:(b, num_heads, seq_len, head_dim)
        
        output = torch.matmul(attention_probs, value)# shape:(b, num_heads, seq_len, head_dim)

        # 合并多头前确保内存连续
        output = output.transpose(1, 2).contiguous()

        output = output.view(batch_size, -1, self.head_dim * self.num_heads)# shape:(b, seq_len, hidden_size)

        output = self.o_linear(output)# shape:(b, seq_len, hidden_size)

        return output
    
def test_MHA():
    batch_size = 128
    seq_len = 512
    hidden_size = 1024
    num_heads = 8

    hidden_state = torch.randn(batch_size, seq_len, hidden_size)
    causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    mha = MultiHeadAttention(hidden_size, num_heads)
    output = mha(hidden_state, causal_mask = causal_mask)

    print("Input shape:", hidden_state.shape)
    print("Output shape:", output.shape)

if __name__ == "__main__":
    test_MHA()

    

# MHA with KV

In [None]:
import torch
from torch import nn

class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, hidden_size)
        self.v_linear = nn.Linear(hidden_size, hidden_size)
        self.o_linear = nn.Linear(hidden_size, hidden_size)

    def forward(self, hidden_state, causal_mask = None, past_key_value = None, use_cache = False):
        # shape of hidden_size:(batch_size, seq_len, hidden_size)
        batch_size = hidden_state.size(0)
        # 为输入序列中计算q, k, v, 这是通过将输入词向量与三个权重矩阵相乘实现的
        query = self.q_linear(hidden_state) # shape:(batch_size, 1, hidden_size)
        key = self.k_linear(hidden_state)
        value = self.v_linear(hidden_state)

        query = query.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)# shape:-> (b, 1, num_heads, head_dim) -> (b, num_heads, 1, head_dim)
        key = key.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        if past_key_value is not None:
            past_key, past_value = past_key_value
            key = torch.cat([past_key, key], dim = 2) # (b, num_heads, 1 + past_key.shape, head_dim)
            value = torch.cat([past_value, value], dim = 2)

        new_past_key_value = (key, value) if use_cache else None

        attention_scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim, dtype = torch.float32))# shape:(b, num_heads, seq_len, head_dim)

        if causal_mask is not None:
            attention_scores += causal_mask * -1e9

        attention_probs = torch.softmax(attention_scores, dim = -1)# shape:(b, num_heads, seq_len, head_dim)
        
        output = torch.matmul(attention_probs, value)# shape:(b, num_heads, seq_len, head_dim)

        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.head_dim * self.num_heads)# shape:(b, seq_len, hidden_size)

        output = self.o_linear(output)# shape:(b, seq_len, hidden_size)

        return (output, new_past_key_value) if use_cache else output
    
def test_MHA_with_cache():
    batch_size = 128
    seq_len = 512
    hidden_size = 1024
    num_heads = 8

    hidden_state = torch.randn(batch_size, seq_len, hidden_size)
    causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    mha = MultiHeadAttention(hidden_size, num_heads)

    past_key_value = None
    outputs = []
    for i in range(seq_len):
        current_input = hidden_state[:, i:i+1, :]
        current_causal_mask = causal_mask[i:i+1, :i+1]
        output_step, past_key_value = mha(current_input, causal_mask = current_causal_mask, past_key_value = past_key_value, use_cache = True)
        outputs.append(output_step)

    output = torch.cat(outputs, dim = 1)

    print("Input shape:", hidden_state.shape)
    print("Output shape:", output.shape)

if __name__ == "__main__":
    test_MHA_with_cache()

    

# MQA 多查询注意力

In [None]:
import torch
from torch import nn

class MultiQueryAttention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, self.head_dim)
        self.v_linear = nn.Linear(hidden_size, self.head_dim)
        self.o_linear = nn.Linear(hidden_size, hidden_size)

    def split(self, x, num_heads = None):
        batch_size = x.size(0)
        if num_heads is None:
            num_heads = self.num_heads
        
        return x.reshape(batch_size, -1, num_heads, self.head_dim).tranpose(1, 2)
    
    def forward(self, hidden_state, causal_mask = None, padding_mask = None):
        batch_size = hidden_state.size(0)
        query = self.q_linear(hidden_state)
        key = self.k_linear(hidden_state)
        value = self.v_linear(hidden_state)

        query = self.split(query, self.num_heads)
        key = self.split(key, 1)
        value = self.split(value, 1)

        attention_scores = torch.matmul(query, key.tranpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim, dtype= torch.float32))

        if causal_mask is not None:
            attention_scores += causal_mask * -1e9
        
        if padding_mask is not None:
            padding_mask = padding_mask.unsqueeze(1).unsqueeze(1)
            attention_scores += padding_mask * -1e9
        
        attention_probs = torch.softmax(attention_scores, dim = -1)

        output = torch.matmul(attention_probs, value)
        output = torch.transpose(1, 2).view(batch_size, -1, self.num_heads * self.head_dim)
        output = self.o_linear(output)
        
        return output

# GQA 分组查询注意力

In [None]:
import torch
from torch import nn

class MultiQueryAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, num_groups):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.num_groups = num_groups

        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, self.num_groups * self.head_dim)
        self.v_linear = nn.Linear(hidden_size, self.num_groups * self.head_dim)
        self.o_linear = nn.Linear(hidden_size, hidden_size)

    def split(self, x, num_groups = None):
        batch_size = x.size(0)
        seq_len = x.size(1)
        if num_groups is None:
            return x.view(batch_size, -1, self.num_heads, self.head_dim).tranpose(1, 2)
        else:
            x = x.view(batch_size, -1, num_groups, self.head_dim).transpose(1, 2)
            x = x[:, :, None, :, :].expand(batch_size, num_groups, self.num_heads // num_groups, seq_len, self.head_dim)
            x = x.view(batch_size, -1, self.num_heads, seq_len, self.head_dim)
            return x
    
    def forward(self, hidden_state, causal_mask = None, padding_mask = None):
        batch_size = hidden_state.size(0)
        query = self.q_linear(hidden_state)
        key = self.k_linear(hidden_state)
        value = self.v_linear(hidden_state)

        query = self.split(query, self.num_heads)
        key = self.split(key, self.num_groups)
        value = self.split(value, self.num_groups)

        attention_scores = torch.matmul(query, key.tranpose(-1, -2)) / torch.sqrt(torch.tensor(self.head_dim, dtype= torch.float32))

        if causal_mask is not None:
            attention_scores += causal_mask * -1e9
        
        if padding_mask is not None:
            padding_mask = padding_mask.unsqueeze(1).unsqueeze(1)
            attention_scores += padding_mask * -1e9
        
        attention_probs = torch.softmax(attention_scores, dim = -1)

        output = torch.matmul(attention_probs, value)
        output = torch.transpose(1, 2).view(batch_size, -1, self.num_heads * self.head_dim)
        output = self.o_linear(output)
        
        return output

# SelfAttention 自注意力（单头）

In [None]:
import torch
from torch import nn

class SelfAttention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        
        # 创建Q、K、V的线性变换层
        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, hidden_size)
        self.v_linear = nn.Linear(hidden_size, hidden_size)
        
        # 输出线性层
        self.o_linear = nn.Linear(hidden_size, hidden_size)

    def forward(self, hidden_state, causal_mask=None, padding_mask=None):
        """
        自注意力前向传播
        
        参数:
        hidden_state: 输入张量，形状为(batch_size, seq_len, hidden_size)
        causal_mask: 因果掩码，防止未来信息泄露
        padding_mask: 填充掩码，忽略填充位置
        
        返回:
        输出张量，形状与输入相同
        """
        batch_size, seq_len, _ = hidden_state.size()
        
        # 计算Q、K、V
        query = self.q_linear(hidden_state)  # (b, s, h)
        key = self.k_linear(hidden_state)    # (b, s, h)
        value = self.v_linear(hidden_state)  # (b, s, h)
        
        # 计算注意力分数
        attention_scores = torch.matmul(query, key.transpose(-2, -1))  # (b, s, s)
        attention_scores = attention_scores / torch.sqrt(torch.tensor(self.hidden_size, dtype=torch.float32))
        
        # 应用掩码
        if causal_mask is not None:
            # 确保掩码形状与注意力分数匹配
            if causal_mask.dim() == 2:
                causal_mask = causal_mask.unsqueeze(0)  # 增加批处理维度
            attention_scores += causal_mask * -1e9
        
        if padding_mask is not None:
            # 将填充掩码扩展为与注意力分数相同的形状
            padding_mask = padding_mask.unsqueeze(1)  # (b, 1, s)
            padding_mask = padding_mask.expand(-1, seq_len, -1)  # (b, s, s)
            attention_scores += padding_mask * -1e9
        
        # 计算注意力权重
        attention_probs = torch.softmax(attention_scores, dim=-1)  # (b, s, s)
        
        # 计算上下文向量
        context = torch.matmul(attention_probs, value)  # (b, s, h)
        
        # 输出变换
        output = self.o_linear(context)  # (b, s, h)
        
        return output

def test_self_attention():
    batch_size = 128
    seq_len = 512
    hidden_size = 1024
    
    # 创建输入数据
    hidden_state = torch.randn(batch_size, seq_len, hidden_size)
    
    # 创建因果掩码（防止未来信息泄露）
    causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    
    # 创建填充掩码示例（假设前10个位置是填充的）
    padding_mask = torch.zeros(batch_size, seq_len).bool()
    padding_mask[:, :10] = True
    
    # 初始化自注意力层
    sa = SelfAttention(hidden_size)
    
    # 前向传播
    output = sa(hidden_state, causal_mask=causal_mask, padding_mask=padding_mask)
    
    print("Input shape:", hidden_state.shape)
    print("Output shape:", output.shape)
    print("Output requires_grad:", output.requires_grad)  # 应返回True

if __name__ == "__main__":
    test_self_attention()