# MHA

In [None]:
import torch
from torch import nn
from xxx import RoPEEmbedding # 假设的RoPE模块


class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads, max_seq_len, dropout=0.1):
        super().__init__()
        assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"

        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        self.scale=self.head_dim ** -0.5
        self.max_seq_len = max_seq_len

        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)
        self.o_proj = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.rope = RoPEEmbedding(self.head_dim, max_seq_len)
    def forwward(self, x, mask=None):
        batch_size = x.shape[0]

        Q = self.q_proj(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)  # (batch, num_heads, seq_len, head_dim)
        K = self.k_proj(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        Q = self.rope(Q)
        K = self.rope(K)

        attn_scores = Q @ K.transpose(-2, -1) * self.scale
        if mask:
            attn_scores = attn_scores.masked_fill(mask, float('-inf'))
        attn_scores = torch.softmax(attn_scores, dim=-1)
        attn_scores = self.dropout(attn_scores)

        output = (attn_scores @ V).transpose(1, 2).reshape(batch_size, -1, self.hidden_dim)
        output = self.o_proj(output)
        return output, attn_scores