In [17]:
import numpy as np
np.random.seed(114514)

def scaled_dot_product_attention(Q, K, V, mask=None):
  '''
  1. 需要完成调整 K 的转置来匹配 Q 的最后一个维度，
  2. 计算attn_score并缩放，
  3. softmax 应用于最后一个轴计算attn_weight，
  4. 应用attn_weights输出output
  5. 带掩码mask的的注意力可以不用实现,但请记住encoder和decoder的transformer块是不一样的，很大一部分都在就在mask上
  '''
  K_T = np.swapaxes(K, -1, -2)
  attention_weights = np.matmul(Q, K_T) / np.sqrt(K.shape[-1])
  attention_weights = np.clip(attention_weights, -500, 500)
  attention_weights = np.exp(attention_weights) / np.sum(np.exp(attention_weights), axis=-1, keepdims=True)
  output = np.matmul(attention_weights, V)
  return output, attention_weights

def multi_head_attention(embed_size, num_heads, input, mask=None):
  '''
  1. embed_size 确保可以等分 num_heads 份， 否则输出错误
  2. 随机初始化Wq,Wk,Wv,Wo矩阵，并对input做线性变换
  3. 利用scaled_dot_product_attention()输出的attn_output计算O
  4. 返回output, attN_weights
  '''
  assert embed_size % num_heads == 0
  dk = int(embed_size / num_heads)
  Wq = np.random.normal(0, 1, (num_heads, embed_size, dk))
  Wk = np.random.normal(0, 1, (num_heads, embed_size, dk))
  Wv = np.random.normal(0, 1, (num_heads, embed_size, dk))
  Wo = np.random.normal(0, 1, (embed_size, embed_size))
  q = np.matmul(input[:, np.newaxis, :, :], Wq)
  k = np.matmul(input[:, np.newaxis, :, :], Wk)
  v = np.matmul(input[:, np.newaxis, :, :], Wv)
  output, weights = scaled_dot_product_attention(q, k, v, mask)
  output = tuple(np.moveaxis(output, 1, 0))
  output = np.concatenate(output, axis=-1)
  output = np.matmul(output, Wo)
  
  return output, weights

# test e.g.
embed_size = 128
num_heads = 8
input = np.random.randn(10, 20, embed_size)
output, weights = multi_head_attention(embed_size, num_heads, input)
print(output.shape, weights.shape)
output[0][0][:10], weights[0][0][0][:10]

(10, 20, 128) (10, 8, 20, 20)


(array([-143.92030352,  -30.61859787,  -75.5384955 ,  154.50956891,
         -24.23250515,   42.76191746,   58.35442329,  168.83597452,
         215.34228635,  321.11911673]),
 array([8.87333531e-059, 1.01785873e-132, 4.91141969e-049, 1.63399334e-103,
        3.84843402e-084, 1.00000000e+000, 4.67253028e-114, 2.80462952e-085,
        1.69103767e-140, 2.67974303e-186]))

```python
   (10, 20, 128) (10, 8, 20, 20)
   (array([-91.96555916, -19.40983534, -32.99740866, 113.35786088,
           138.22610441,  81.21040905, -30.81003178,  90.7098463 ,
           162.38724319, -40.72173619]),
    array([1.94810489e-189, 3.21476597e-151, 3.61314239e-103, 4.96644350e-219,
           3.90604112e-173, 3.46437823e-131, 4.72245009e-077, 2.66307289e-194,
           1.00000000e+000, 5.17103825e-098]))
```