In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BertSelfAttention(nn.Module):
    def __init__(self, config):
        super(BertSelfAttention, self).__init__()
        self.hidden_size = config.hidden_size
        self.num_attention_heads = config.num_attention_heads
        self.all_head_size = self.hidden_size // self.num_attention_heads

        # Q K V的投影
        # 输入768 输出768
        self.query = nn.Linear(config.hidden_size, self.all_head_size * self.num_attention_heads)
        self.key = nn.Linear(config.hidden_size, self.all_head_size * self.num_attention_heads)
        self.value = nn.Linear(config.hidden_size, self.all_head_size * self.num_attention_heads)

        # 输出当前的线性投影
        self.output = nn.Linear(self.num_attention_heads * self.all_head_size, config.hidden_size)


    def transpose_for_scores(self, x):
        """
        Reshape 当前的输入张亮进入到多头注意力
        输入的尺寸应该是（尺寸大小， seq的长度， 隐藏层的尺寸）
        输出的尺寸应该是（尺寸大小，多头的数量， seq的长度， 头的尺寸）        
        """
        new_shape = x.size()[:-1] + (self.num_attention_heads, self.all_head_size)
        x = x.view(*new_shape)

        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, attention_mask=None):
            """
            Perform self-attention given the hidden states.

            Parameters:
            - hidden_states: (batch_size, seq_len, hidden_size)
            - attention_mask: (batch_size, seq_len), optional
            """
            # Linear projections for query, key, and value
            query_layer = self.query(hidden_states)
            key_layer = self.key(hidden_states)
            value_layer = self.value(hidden_states)
            
            # Reshape for multi-head attention
            query_layer = self.transpose_for_scores(query_layer)
            key_layer = self.transpose_for_scores(key_layer)
            value_layer = self.transpose_for_scores(value_layer)
            
            # Attention scores: (batch_size, num_heads, seq_len, seq_len)
            attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))  # Q * K^T
            
            # Scale attention scores
            attention_scores = attention_scores / (self.all_head_size ** 0.5)
            
            if attention_mask is not None:
                # Apply attention mask
                attention_scores = attention_scores + attention_mask
            
            # Apply softmax to get attention probabilities
            attention_probs = F.softmax(attention_scores, dim=-1)
            
            # Attention output: (batch_size, num_heads, seq_len, head_size)
            context_layer = torch.matmul(attention_probs, value_layer)
            
            # Reshape back to (batch_size, seq_len, hidden_size)
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
            new_context_layer_shape = context_layer.size()[:-2] + (self.num_attention_heads * self.all_head_size,)
            context_layer = context_layer.view(*new_context_layer_shape)
            
            # Final output projection
            attention_output = self.output(context_layer)
            
            return attention_output
    
    def test_bert_self_attention():
        # 定义一个简单的配置类
        class Config:
            def __init__(self, hidden_size=768, num_attention_heads=12):
                self.hidden_size = hidden_size
                self.num_attention_heads = num_attention_heads
        
        # 实例化配置和BertSelfAttention
        config = Config(hidden_size=768, num_attention_heads=12)
        self_attention = BertSelfAttention(config)
        
        # 创建随机输入：batch_size=2, seq_len=16, hidden_size=768
        batch_size = 2
        seq_len = 16
        hidden_states = torch.randn(batch_size, seq_len, config.hidden_size)
        
        # 调用 self_attention
        output = self_attention(hidden_states)
        
        # 检查输出形状
        print("Output shape:", output.shape)
        assert output.shape == (batch_size, seq_len, config.hidden_size), (
            "输出形状应为 (batch_size, seq_len, hidden_size)，"
            f"但得到 {output.shape}"
        )
        
        print("BertSelfAttention 测试通过！")

# 直接运行测试
if __name__ == "__main__":
    test_bert_self_attention()

Output shape: torch.Size([2, 16, 768])
BertSelfAttention 测试通过！
