# Step 1. Transformer

对这个任务，我有一处疑问：为什么 $\mathrm{embed\_size}$ 必须是 $\mathrm{num\_heads}$ 的整数倍？按照我的理解，$\mathrm{embed\_size}$、$\mathrm{dim\_qk}$、$\mathrm{dim\_v}$ 和 $\mathrm{num\_heads}$ 应该都是可自由调整的超参数，原文取 $\mathrm{dim\_qk = dim\_v = embed\_size / num\_heads}$ 只是出于方便。

In [1]:
import numpy as np

np.random.seed(114514)

def softmax(x, axis=-1):
    # 采用最直接的实现方式；实际应用需要考虑数值稳定性
    exp_x = np.exp(x)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

def scaled_dot_product_attention(Q, K, V, mask=None, verbose=False):
    assert mask is None
    if verbose:
        print(f'QKV.shape={Q.shape}')
    dim_qkv = Q.shape[-1]
    attn_score = (Q @ K.swapaxes(-1, -2)) / np.sqrt(dim_qkv)
    attn_weights = softmax(attn_score)
    output = attn_weights @ V
    if verbose:
        print(f'{output.shape=} {attn_weights.shape=}')
    return output, attn_weights

def multi_head_attention(embed_size, num_heads, input, mask=None, verbose=False):
    # 假设 dim_qk = dim_v = embed_size / num_heads
    # 本实现不考虑 QKV 的 bias 项
    assert embed_size % num_heads == 0
    dim_qkv = embed_size // num_heads
    Wq = np.random.randn(num_heads, embed_size, dim_qkv)
    Wk = np.random.randn(num_heads, embed_size, dim_qkv)
    Wv = np.random.randn(num_heads, embed_size, dim_qkv)
    Wo = np.random.randn(embed_size, embed_size)
    Bo = np.zeros(embed_size)
    if verbose:
        print(f'W_qkv.shape={Wq.shape}')
        print(f'{Wo.shape=} {Bo.shape=}')
    input = input[:, None, :, :]
    if verbose:
        print(f'{input.shape=}')
    output, weights = scaled_dot_product_attention(
        input @ Wq, input @ Wk, input @ Wv, mask, verbose)
    output = output.swapaxes(-2, -3)
    output = output.reshape([*output.shape[:-2], -1])
    output = output @ Wo + Bo
    if verbose:
        print(f'{output.shape=} {weights.shape=}')
    return output, weights

测试 `multi_head_attention` 的功能。输入的 batch 大小为 3，序列长度为 11；模型共 5 个 head，每个 head 的 qkv 大小为 7。选取这几个数是为了方便观察 shape 的变化规律。

In [2]:
embed_size = 5*7
num_heads = 5
input = np.random.randn(3, 11, embed_size)
output, weights = multi_head_attention(embed_size, num_heads, input, verbose=True)
print(output[0, 0])
print(weights[0, 0, 0])

W_qkv.shape=(5, 35, 7)
Wo.shape=(35, 35) Bo.shape=(35,)
input.shape=(3, 1, 11, 35)
QKV.shape=(3, 5, 11, 7)
output.shape=(3, 5, 11, 7) attn_weights.shape=(3, 5, 11, 11)
output.shape=(3, 11, 35) weights.shape=(3, 5, 11, 11)
[ -1.00275258 -25.66227608  42.57650594   7.97341477  -2.09239899
  22.53574569 -32.31421119 -19.31954746  35.94738272   5.09795971
 -34.47604002   0.86513501  50.51554347  21.8124433   35.35536458
 -30.79651531   0.38839876   6.82163086 -14.5239423  -50.32858852
  20.92636831 -11.40505511  34.35585814  -8.64440007  17.03970826
 -46.23846407   0.86446847  27.91816735  -6.19561116 -11.2085796
  -0.52242257 -86.61101946 -23.54598171 -26.04331552 -26.03110728]
[1.29420131e-12 1.81028363e-33 4.99676145e-31 5.48498138e-21
 3.03060036e-26 1.09915871e-16 3.71961110e-10 1.56721677e-26
 1.97962592e-25 1.00000000e+00 2.35854129e-25]
