In [1]:
import math
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, dropout_rate=0.1):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_size = hidden_size // num_heads 
        self.attention_dropout = nn.Dropout(dropout_rate)

        self.q_proj = nn.Linear(hidden_size, hidden_size)  # (hidden_size, head_size * num_heads)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.o_proj = nn.Linear(hidden_size, hidden_size)

    def forward(self, x, attention_mask=None):
        # x: [batch_size, seq_len, hidden_size]
        batch_size, seq_len, hidden_size = x.shape
        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x) # [batch_size, seq_len, hidden_size]

        # 期望[batch_size, num_heads, seq_len, head_size]
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)

        attention_value = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.head_size)
        # attention_value: [batch_size, num_heads, seq_len, seq_len]

        if attention_mask is not None:
            attention_value = attention_value.masked_fill(attention_mask == 0, float("-inf"))
        print(attention_value.shape)

        attention_weight = torch.softmax(attention_value, dim=-1)

        attention_weight = self.attention_dropout(attention_weight)

        attention_result = torch.matmul(attention_weight, V)
        # attention_result: [batch_size, num_heads, seq_len, head_size]
        # 期望concat得到[batch_size, seq_len, hidden_size]
        attention_result = attention_result.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size)

        output = self.o_proj(attention_result)

        return output
        

In [2]:
x = torch.randn(3, 8, 2)
# x : (batch_size, seq_len, hidden_dim)
mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0], [1, 0, 0, 0]])
# mask : (batch_size, seq_len)
# 要和attention_weight的shape保持一致：(batch_size, seq_len, seq_len)
# 扩维
mask = mask.unsqueeze(dim=1).repeat(1, x.size(1), 1)
# unsqueeze在指定的维度（dim=1）上增加一个新的维度，变成(batch_size, 1, seq_len)
# repeat沿着指定的维度重复张量，变成(batch_size, seq_len, seq_len)
print(mask)

net = MultiHeadAttention(2, 2)
net(x)

tensor([[[1, 1, 1, 0],
         [1, 1, 1, 0],
         [1, 1, 1, 0],
         [1, 1, 1, 0],
         [1, 1, 1, 0],
         [1, 1, 1, 0],
         [1, 1, 1, 0],
         [1, 1, 1, 0]],

        [[1, 1, 0, 0],
         [1, 1, 0, 0],
         [1, 1, 0, 0],
         [1, 1, 0, 0],
         [1, 1, 0, 0],
         [1, 1, 0, 0],
         [1, 1, 0, 0],
         [1, 1, 0, 0]],

        [[1, 0, 0, 0],
         [1, 0, 0, 0],
         [1, 0, 0, 0],
         [1, 0, 0, 0],
         [1, 0, 0, 0],
         [1, 0, 0, 0],
         [1, 0, 0, 0],
         [1, 0, 0, 0]]])
torch.Size([3, 2, 8, 8])


tensor([[[-0.2988, -0.3074],
         [-0.3125, -0.5643],
         [-0.3005, -0.3734],
         [-0.2980, -0.3352],
         [-0.2746, -0.5080],
         [-0.2857, -0.6067],
         [-0.3137, -0.5252],
         [-0.3322, -0.4907]],

        [[-0.3547, -0.5050],
         [-0.3829, -0.5032],
         [-0.2951, -0.5374],
         [-0.3427, -0.5328],
         [-0.3521, -0.4849],
         [-0.2906, -0.4808],
         [-0.3897, -0.5004],
         [-0.3585, -0.4875]],

        [[-0.4109, -0.4072],
         [-0.4842, -0.4756],
         [-0.3380, -0.3474],
         [-0.4199, -0.5146],
         [-0.4077, -0.3894],
         [-0.4800, -0.5030],
         [-0.4513, -0.5060],
         [-0.4914, -0.5046]]], grad_fn=<ViewBackward0>)