## 注意力机制
***
***
Time: 2020-09-16<br>
Author: dsy<br>
Notes: 《神经网络与深度学习》
***

自注意力

假设输入序列为$X=[x_1,\cdots,x_N] \in \mathbb{R}^{d_1 x N}$，输出序列为$H=[h_1,\cdots,h_N] \in \mathbb{R}^{d_2 x N}$,首先我们可以通过线性变换得到三组向量序列：
$$
\begin{aligned}
Q & = W_Q X \in \mathbb{R}^{d_3 x N},\\
K & = W_K X \in \mathbb{R}^{d_3 x N},\\
V & = W_V X \in \mathbb{R}^{d_3 x N},
\end{aligned}
$$
其中$Q,K,V$分别为查询向量序列，键向量序列和值向量序列，$W_Q \in \mathbb{R}^{d_3 x d1},W_K \in \mathbb{R}^{d_3 x d1},W_V \in \mathbb{R}^{d_2 x d_1}$分别为可学习的参数矩阵。

利用公式可以得出输出向量$h_i$:
$$
\begin{aligned}
h_i & = att \Big( (K,V),q_i\Big) \\
    & = \sum_{j=1}^{N} \alpha_{ij}v_j\\
    & = \sum_{j=1}^{N} softmax \Big(s(k_j,q_i) \Big)v_j
\end{aligned}
$$


常用的注意力打分函数：
$$\begin{array}{cll}
\text{加性模型} & s(x_i,q) & = v^T \tanh(W x_i + U q) \\
\text{点积模型} & s(x_i,q) & = x_i^Tq \\
\text{缩放点积模型} & s(x_i,q) & = \frac{x_i^T q}{\sqrt{d}} \\
\text{双线性模型} & s(x_i,q) & = x_i^T W q 
\end{array}$$

In [1]:
import torch
import torch.nn as nn
seed = 0
torch.manual_seed(seed)
torch.random.manual_seed(seed)

<torch._C.Generator at 0x19c753f12d0>

In [2]:
L,N,E,S = 3,4,10,6

In [3]:
multiheadAttention = torch.nn.MultiheadAttention(
    embed_dim = E
    , num_heads=10
    , dropout=0.0
    , bias=True
    , add_bias_kv=False
    , add_zero_attn=False
    , kdim=None
    , vdim=None
)

In [4]:
query = torch.rand((L,N,E))
query

tensor([[[0.2159, 0.4216, 0.9246, 0.5207, 0.1464, 0.3329, 0.3643, 0.4035,
          0.5479, 0.9624],
         [0.5268, 0.1913, 0.5256, 0.7397, 0.7480, 0.0430, 0.4105, 0.1284,
          0.2867, 0.6801],
         [0.1449, 0.6859, 0.9244, 0.5328, 0.1668, 0.3209, 0.6092, 0.1188,
          0.7484, 0.0461],
         [0.0194, 0.0142, 0.3986, 0.8362, 0.0268, 0.9156, 0.3000, 0.6464,
          0.5228, 0.0491]],

        [[0.9147, 0.7692, 0.9970, 0.7526, 0.1700, 0.9173, 0.5269, 0.7371,
          0.0991, 0.3562],
         [0.0091, 0.3053, 0.6079, 0.1074, 0.6594, 0.7684, 0.5697, 0.1655,
          0.1123, 0.3457],
         [0.7195, 0.9932, 0.7875, 0.4437, 0.6753, 0.0095, 0.0729, 0.7333,
          0.2168, 0.7405],
         [0.1470, 0.2523, 0.0882, 0.7609, 0.4491, 0.8848, 0.8094, 0.7767,
          0.5161, 0.3454]],

        [[0.3913, 0.5665, 0.7479, 0.1497, 0.9196, 0.4456, 0.0810, 0.2295,
          0.9424, 0.9573],
         [0.0369, 0.8526, 0.7506, 0.7960, 0.9233, 0.2305, 0.6579, 0.7046,
          0.3

In [5]:
key = torch.rand((S,N,E))
key

tensor([[[0.9991, 0.9883, 0.1229, 0.0947, 0.1210, 0.4976, 0.3725, 0.1727,
          0.3207, 0.5945],
         [0.2388, 0.6108, 0.3853, 0.2577, 0.5687, 0.9111, 0.1620, 0.5232,
          0.3156, 0.9907],
         [0.0256, 0.0207, 0.9927, 0.1837, 0.5959, 0.4568, 0.3947, 0.3883,
          0.8177, 0.5239],
         [0.0132, 0.2048, 0.3295, 0.7516, 0.1764, 0.9715, 0.3886, 0.4102,
          0.8918, 0.7513]],

        [[0.9241, 0.7892, 0.3483, 0.1683, 0.4628, 0.9138, 0.3322, 0.0363,
          0.7050, 0.9867],
         [0.3577, 0.0860, 0.0465, 0.6253, 0.4621, 0.2475, 0.6011, 0.6899,
          0.8977, 0.8882],
         [0.4252, 0.0591, 0.0482, 0.9668, 0.7210, 0.7180, 0.0674, 0.9630,
          0.9737, 0.9514],
         [0.0782, 0.3113, 0.1561, 0.9735, 0.2852, 0.2717, 0.7620, 0.2687,
          0.2537, 0.4563]],

        [[0.4519, 0.1105, 0.9168, 0.2794, 0.6774, 0.9349, 0.7522, 0.5708,
          0.9254, 0.5672],
         [0.2687, 0.9730, 0.6183, 0.0122, 0.3577, 0.1594, 0.9384, 0.4174,
          0.0

In [6]:
value = torch.rand((S,N,E))
value

tensor([[[6.6399e-01, 6.1957e-02, 7.7410e-01, 7.6027e-01, 8.1010e-01,
          1.8123e-01, 9.9800e-01, 2.0362e-01, 9.9917e-01, 2.0155e-02],
         [5.4515e-02, 8.0710e-01, 5.5226e-01, 5.2884e-01, 2.2312e-01,
          2.9026e-01, 3.5381e-01, 1.2922e-02, 5.2598e-01, 5.8843e-01],
         [4.9958e-01, 6.6147e-01, 9.7443e-01, 6.3294e-01, 3.1696e-01,
          2.9423e-01, 1.8010e-01, 1.5339e-01, 4.1948e-01, 4.1157e-01],
         [7.2243e-01, 2.8628e-01, 8.9860e-01, 1.4916e-01, 5.0142e-01,
          9.4946e-01, 9.9719e-01, 2.1037e-01, 5.8906e-01, 5.5906e-01]],

        [[2.6557e-01, 3.2725e-01, 6.3543e-01, 1.5232e-01, 5.8250e-01,
          7.1636e-01, 3.0296e-01, 9.1532e-01, 4.6709e-01, 7.2686e-01],
         [9.9515e-01, 3.4717e-01, 7.7170e-01, 3.5699e-01, 4.2696e-01,
          4.1526e-01, 4.9689e-01, 3.1112e-01, 6.1719e-01, 5.1884e-01],
         [8.1694e-01, 3.9880e-01, 5.5014e-01, 3.1400e-01, 8.1273e-02,
          7.0233e-01, 5.6398e-01, 2.9976e-01, 3.3095e-01, 6.3076e-01],
         [4

In [7]:
attn_output,attn_output_weights = multiheadAttention(
    query
    , key
    , value
    , key_padding_mask=None
    , need_weights=True
    , attn_mask=None)

In [8]:
attn_output.shape # L,N,E

torch.Size([3, 4, 10])

In [9]:
attn_output_weights.shape # N,L,S

torch.Size([4, 3, 6])