## Scaled dot product attention
![](https://www.tensorflow.org/images/tutorials/transformer/scaled_attention.png)

输入是 Q (query), K (key), V (value)， 输出是：
$$Attention(Q,K,V)=softmax_k(\dfrac{QK^T}{\sqrt{(d_k)}})V$$

- Q: [batch, q_len, d_model]
- K: [batch, kv_len, d_model]
- V: [batch, kv_len, d_model]

> The dot-product attention is scaled by a factor of square root of the depth. This is done because for large values of depth, the dot product grows large in magnitude pushing the softmax function where it has small gradients resulting in a very hard softmax

这里除以 d_model 是避免向量内积之后数量级太大。softmax 之前各个值的大小差距太大(方差变大)，而通过 softmax 之后得到的 probability 变化不大，导致梯度变化很小，以至于 softmax 保持不变(hard softmax).

> For example, consider that Q and K have a mean of 0 and variance of 1. Their matrix multiplication will have a mean of 0 and variance of dk. Hence, square root of dk is used for scaling (and not any other number) because the matmul of Q and K should have a mean of 0 and variance of 1, so that we get a gentler softmax.

$$q\cdot k=\sum_{i=1}^{d_k}q_ik_i$$

> The mask is multiplied with -1e9 (close to negative infinity). This is done because the mask is summed with the scaled matrix multiplication of Q and K and is applied immediately before a softmax. The goal is to zero out these cells, and large negative inputs to softmax are near zero in the output.

In [3]:
import tensorflow as tf
import numpy as np

def scaled_dot_product_attention(q, k, v, mask):
    """ caculate the attention weights.
    q, k, v must have matching leading dimensions.
    The mask has different shapes depending on its type(padding or look ahead)
    but it must be broadcastable for addition.

    :param q: query shape == [..., q_len, d_model]
    :param k: key shape == [..., kv_len, d_model]
    :param v: value shape == [..., kv_len, d_model]
    :param mask: Float tensor with shape broadcastable to [..., q_len, kv_len]
    :return:
        output, attention_weights.
    """
    matmul_qk = tf.matmul(q, k, transpose_b=True)  # [..., q_len, kv_len]

    # scale matmul_qk
    dk = tf.cast(tf.shape(k)[-1], tf.float32)
    scaled_attention_logits = matmul_qk / tf.sqrt(dk)

    # add the mask to the scaled tensor
    if mask is not None:
        scaled_attention_logits += (mask * 1e-9)

    # softmax is normalized on the last axis (seq_len_k) so that the scores add up to 1.

    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., q_len, kv_len)

    output = tf.matmul(attention_weights, v)  # [.., q_len, d_model] ? [.., k_len, d_model]

    return output, attention_weights

这里比较疑惑的是最后的输出 output shape 感觉应该是 `[..., kv_len, d_model]`, 而不是 `[..., q_len, d_model]`

In [5]:
def print_out(q, k, v):
    temp_out, temp_attn = scaled_dot_product_attention(
      q, k, v, None)
    print ('Attention weights are:')
    print (temp_attn)     # [..., len_q, len_kv]
    print ('Output is:')
    print (temp_out)      # [..., len_q, d_model]


np.set_printoptions(suppress=True)
temp_k = tf.constant([[10, 0, 0],
                      [0, 10, 0],
                      [0, 0, 10],
                      [0, 0, 10]], dtype=tf.float32)  # (4, 3)
temp_v = tf.constant([[1, 0],
                      [10, 0],
                      [100, 5],
                      [1000, 6]], dtype=tf.float32)  # (4, 2)

In [7]:
# This query aligns with a repeated key (third and fourth),
# so all associated values get averaged.
temp_q = tf.constant([[0, 0, 10]], dtype=tf.float32)  # (1, 3)
print_out(temp_q, temp_k, temp_v)

Attention weights are:
tf.Tensor([[0.  0.  0.5 0.5]], shape=(1, 4), dtype=float32)
Output is:
tf.Tensor([[550.    5.5]], shape=(1, 2), dtype=float32)


In [6]:
# This `query` aligns with the second `key`,
# so the second `value` is returned.
temp_q = tf.constant([[0, 10, 0]], dtype=tf.float32)  # (1, 3)
print_out(temp_q, temp_k, temp_v)

Attention weights are:
tf.Tensor([[0. 1. 0. 0.]], shape=(1, 4), dtype=float32)
Output is:
tf.Tensor([[10.  0.]], shape=(1, 2), dtype=float32)


In [8]:
# This query aligns equally with the first and second key, 
# so their values get averaged.
temp_q = tf.constant([[10, 10, 0]], dtype=tf.float32)  # (1, 3)
print_out(temp_q, temp_k, temp_v)

Attention weights are:
tf.Tensor([[0.5 0.5 0.  0. ]], shape=(1, 4), dtype=float32)
Output is:
tf.Tensor([[5.5 0. ]], shape=(1, 2), dtype=float32)


In [9]:
temp_q = tf.constant([[0, 0, 10], [0, 10, 0], [10, 10, 0]], dtype=tf.float32)  # (3, 3)
print_out(temp_q, temp_k, temp_v)

Attention weights are:
tf.Tensor(
[[0.  0.  0.5 0.5]
 [0.  1.  0.  0. ]
 [0.5 0.5 0.  0. ]], shape=(3, 4), dtype=float32)
Output is:
tf.Tensor(
[[550.    5.5]
 [ 10.    0. ]
 [  5.5   0. ]], shape=(3, 2), dtype=float32)


### Multi-head attention

![](https://www.tensorflow.org/images/tutorials/transformer/multi_head_attention.png)

**multi-head attention 分成四步：**

- Linear layers and split into heads.  
- Scaled dot-product attention.  
- Concatenation of heads.  
- Final linear layer.

**为什么要使用 multi-head？**  

Instead of one single attention head, Q, K, and V are split into multiple heads because it allows the model to jointly attend to information at different positions from different representational spaces. After the split each head has a reduced dimensionality, so the total computation cost is the same as a single head attention with full dimensionality.

In [10]:
class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model

        assert d_model % num_heads == 0

        self.depth = self.d_model // self.num_heads

        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)

        self.dense = tf.keras.layers.Dense(d_model)

    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth).
        Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)

        :param x: [batch_size, seq_len, d_model]
        :param batch_size:
        :return:
        """
        x= tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])   # [batch, num_heads, -1, depth]

    def call(self, q, k, v, mask):
        """

        :param v:  [batch, q_len, d_model]
        :param k:  [batch, kv_len, d_model]
        :param q:  [batch, kv_len, d_model]
        :param mask: padding mask or look ahead mask. [..., q_len, kv_len]
        :return:
        """
        batch_size = tf.shape(q)[0]
        
        # Linear layers and split into heads
        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)

        split_q = self.split_heads(q, batch_size)
        split_k = self.split_heads(k, batch_size)
        split_v = self.split_heads(v, batch_size)

        # Scaled dot-product attention
        # scaled_attention.shape == [batch, num_heads, q_len, depth]
        # attention_weights.shape == [batch, num_heads, q_len, kv_len]
        scaled_attention, attention_weights = scaled_dot_product_attention(
            split_q, split_k, split_v, mask)
        scaled_attention = tf.transpose(scaled_attention, [0, 2, 1, 3]) # [batch, q_len, num_heads, depth]
        
        # concatenation
        concat_attention = tf.reshape(scaled_attention,
                                      shape=(batch_size, -1, self.d_model)) # [batch, q_len, d_model]
        
        #  Final linear layer
        output = self.dense(concat_attention) # [batch, q_len, d_model]
        return output, attention_weights

In [11]:
temp_mha = MultiHeadAttention(d_model=512, num_heads=8)
q = tf.random.uniform((1, 62, 512))
k = v = tf.random.uniform((1, 60, 512))
out, attn = temp_mha(q, k, v, mask=None)
print(out.shape, attn.shape)

(1, 62, 512) (1, 8, 62, 60)
