## 10.5. Multi-Head Attention
注意力机制联合使用查询、键和值的不同表示子空间。使用 $h$ 个独立学习的线性投影来转换查询、键和值，其输出并行地输入注意力池化，将 $h$ 个注意力池化的输出连接起来，使用另一个线性投影进行转换以产生最终输出。
### 10.5.1. Model
每个注意力的头 $\mathbf{h}_i$ ($i = 1, \ldots, h$) 的计算方法为
$$\mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v},$$
参数 $\mathbf W_o\in\mathbb R^{p_o\times h p_v}$ 将 $h$ 个头连接在一起：
$$\mathbf W_o \begin{bmatrix}\mathbf h_1 \vdots \mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}.$$

In [1]:
import math
from mxnet import autograd, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l

npx.set_np()

### 10.5.2. Implementation
为多头注意力的每个头选择使用缩放的“点－积”注意力。设置了 $p_q = p_k = p_v = p_o / h$。将查询、键和值的线性变换的输出数量设置为 $p_q h = p_k h = p_v h = p_o$，则可以并行计算 $h$ 头。

In [2]:
class MultiHeadAttention(nn.Block):
    def __init__(self, num_hiddens, num_heads, dropout, use_bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
        self.W_k = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
        self.W_v = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
        self.W_o = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)

    def forward(self, queries, keys, values, valid_lens):
        # 'queries' 的形状：('batch_size', 查询或者“键－值”对的个数, 'num_hiddens')
        # 'valid_lens' 的形状：('batch_size',) 或者 ('batch_size', 查询的个数)
        # 变换后，输出的 'queries', 'keys', 'values' 的形状：
        # ('batch_size'*'num_heads', 查询或者“键－值”对的个数, 'num_hiddens'/'num_heads')
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # 在 axis=0，拷贝第一项（标题或者失量）'num_heads' 次；然后拷贝下一项；等等
            valid_lens = valid_lens.repeat(self.num_heads, axis=0)

        # 'output' 的形状：('batch_size'*'num_heads', 查询的个数, 'num_hiddens'/'num_heads')
        output = self.attention(queries, keys, values, valid_lens)
        # 'output_concat' 的形状：('batch_size', 查询的个数, 'num_hiddens')
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

In [4]:
# 两个转置函数
# transpose_output 函数是 transpose_qkv 函数的逆操作
def transpose_qkv(X, num_heads):
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    X = X.transpose(0, 2, 1, 3)
    return X.reshape(-1, X.shape[2], X.shape[3])


def transpose_output(X, num_heads):
    """'transpose_qkv' 的逆操作"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.transpose(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

In [5]:
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
attention.initialize()

batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, np.array([3, 2])
X = np.ones((batch_size, num_queries, num_hiddens))
Y = np.ones((batch_size, num_kvpairs, num_hiddens))
print(attention(X, Y, Y, valid_lens).shape)

(2, 4, 100)
