In [1]:
import math
import torch
from torch import nn
from d2l import torch as d2l

# 实现

In [2]:
class MultiHeadAttention(nn.Module):
    """多头注意力（Multi-Head Attention）"""
    def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super().__init__(**kwargs)
        self.num_heads = num_heads  # 头数
        # 每个头都用点积注意力（可换成别的attention机制）
        self.attention = d2l.DotProductAttention(dropout)
        # 输入全部映射到num_hiddens维（等于num_heads*每头宽度）
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        # 多头输出再映射回num_hiddens
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # 输入:
        # queries, keys, values: (batch_size, 查询数/键值对数, num_hiddens)
        # valid_lens: (batch_size,) 或 (batch_size, 查询数)
        # 1. 投影Q/K/V到num_hiddens维，然后拆分成num_heads个头
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys   = transpose_qkv(self.W_k(keys),   self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # 每个头都复用一份mask
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        # 2. 每个头独立做attention计算（并行处理）
        # output: (batch_size*num_heads, 查询数, num_hiddens/num_heads)
        output = self.attention(queries, keys, values, valid_lens)

        # 3. 把多头的输出合并回去
        # output_concat: (batch_size, 查询数, num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        # 4. 再通过线性层融合不同头的信息
        return self.W_o(output_concat)


In [3]:
def transpose_qkv(X, num_heads):
    """把Q/K/V变成多头结构，便于并行计算"""
    # X: (batch_size, seq_len, num_hiddens)
    # 1. 先把num_hiddens分成num_heads份：(batch, seq, num_heads, head_dim)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    # 2. 交换num_heads和seq_len轴：(batch, num_heads, seq, head_dim)
    X = X.permute(0, 2, 1, 3)
    # 3. 合并batch和num_heads：方便用batch维并行算attention
    # (batch*num_heads, seq, head_dim)
    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_output(X, num_heads):
    """把多头的输出合并回原始结构"""
    # X: (batch*num_heads, seq, head_dim)
    # 1. 拆出num_heads: (batch, num_heads, seq, head_dim)
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    # 2. 交换回(seq, num_heads)的位置: (batch, seq, num_heads, head_dim)
    X = X.permute(0, 2, 1, 3)
    # 3. 最后合并num_heads和head_dim：(batch, seq, num_heads*head_dim)
    return X.reshape(X.shape[0], X.shape[1], -1)

In [4]:
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                               num_hiddens, num_heads, 0.5)
attention.eval()

MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)

In [5]:
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens =  6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape

torch.Size([2, 4, 100])