# 手写多头注意力机制(MHA)

In [None]:
import math
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self,hidden_dim,head_num,dropout_rate:float =0.1):
        super().__init__()
        
        self.head_num = head_num
        self.head_dim = hidden_dim // head_num

        self.q_proj = nn.Linear(hidden_dim,hidden_dim)
        self.k_proj = nn.Linear(hidden_dim,hidden_dim)
        self.v_proj = nn.Linear(hidden_dim,hidden_dim)
        self.o_proj = nn.Linear(hidden_dim,hidden_dim)

        self.dropout = nn.Dropout(dropout_rate)
    
    def forward(self,x,attention_mask=None):

        # x shape [batch_size,seq_len,hidden_dim]
        batch_size,seq_len,hidden_dim = x.size()

        # QKV:[batch_size,head_num,seq_len,head_dim]
        Q = self.q_proj(x).view(batch_size,seq_len,self.head_num,self.head_dim).transpose(1,2)
        K = self.k_proj(x).view(batch_size,seq_len,self.head_num,self.head_dim).transpose(1,2)
        V = self.v_proj(x).view(batch_size,seq_len,self.head_num,self.head_dim).transpose(1,2)

        attention_score = Q @ K.transpose(-1,-2) / math.sqrt(self.head_dim)
        if attention_mask is not None:
            attention_score = attention_score.masked_fill(
                attention_mask==0,float('-inf')
            )
        # attention_weight [batch_size,head_num,seq_len,seq_len]
        attention_weight = torch.softmax(attention_score,dim=-1)
        attention_weight = self.dropout(attention_weight)

        # output: [batch_size,head_num,seq_len,head_dim]
        output = attention_weight @ V
        output = output.transpose(1,2).contiguous()
        output = output.view(batch_size,seq_len,-1)
        
        output = self.o_proj(output)

        return output

attention_mask = (
    torch.tensor(
        [
            [0,1],
            [0,0],
            [1,0]
        ]
    )
    .unsqueeze(1)
    .unsqueeze(2)
    .expand(3,8,2,2) # [batch_size, head_num, seq_len, seq_len]
)

x = torch.rand(3,2,128)
net = MultiHeadAttention(hidden_dim=128,head_num=8)
net(x,attention_mask)
print()




In [None]:
a = torch.tensor(
        [
            [0,1],
            [0,0],
            [1,0]
        ]
    ).unsqueeze(1).unsqueeze(2)

# [3, 1, 1, 2])
a.shape

# [3, 8, 1, 2]
a.expand(3,8,1,2).shape

torch.Size([3, 8, 1, 2])