In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Layer

Dense是对全连阶层专门对实现（其他的还有LSTM、Conv等），继承自Layer
dense表示全连阶层的权重矩阵是“密集的”——输入层、输出层的每个神经元相连（而卷基层是稀疏的，输出层的神经元只和“感受野”内的输入层神经元相连）
另外全连阶层在学术上命名为fully connected layer/ dense layer，用dense简洁命名了全连阶层

In [3]:
batch_size = 32
q_seq_len = 10
kv_seq_len = 20
d_model = 512 # 这里q、k、v的维度一样
num_heads = 8

q = tf.random.normal([batch_size, q_seq_len, d_model])
k = tf.random.normal([batch_size, kv_seq_len, d_model])
v = tf.random.normal([batch_size, kv_seq_len, d_model])
q.shape, k.shape, v.shape

(TensorShape([32, 10, 512]),
 TensorShape([32, 20, 512]),
 TensorShape([32, 20, 512]))

scaled_attention_logits命名原理：
scaled_attention表示 QK/sqrt(dk)这一数学过程
在机器学习中，未经过softmax归一化的分数是logits，softmax将logits转为概率

In [12]:
def scaled_dot_product_attention(q, k, v):
    matmul_qk = tf.matmul(q, k, transpose_b=True) 
    # 转置K，对齐q_dim和k_dim， matmul_qk的每行是每个queyr和「序列中」每个key的点积
    sqrt_dk = tf.math.sqrt(tf.cast(tf.shape(k)[-1], tf.float32))
    scaled_attention_logits = matmul_qk / sqrt_dk
    weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
    output = tf.matmul(weights, v)
    return output
    

In [13]:
'''
class 子类名(父类名):
    def __init__(self):
        super(子类名, self).__init__()
'''
class MultiHeadAttention(Layer):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        assert d_model % num_heads == 0
        self.depth = self.d_model // self.num_heads
        self.wq, self.wk, self.wv, self.w = Dense(d_model), Dense(d_model), Dense(d_model), Dense(d_model)

    def split_heads(self, x):
        # 将最后一维分割为num_heads * self.depth
        # [bs, seq, dim] -> [bs, seq, heads, subdim] -> [bs, heads, seq, subdim]
        batch_size = x.shape[0]
        x = tf.reshape(x, [batch_size, -1, self.num_heads, self.depth])
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, q, k, v):
        self.q = self.wq(q)
        self.k = self.wk(k)
        self.v = self.wv(v)
        self.q = self.split_heads(self.q)
        self.k = self.split_heads(self.k)
        self.v = self.split_heads(self.v)
        splited_output = scaled_dot_product_attention(self.q, self.k, self.v)
        batch_size = q.shape[0]
        multi_head_output = tf.transpose(splited_output, [0, 2, 1, 3]) # [bs, head, seq, subdim] -> [bs, seq, head, subdim]
        concated_output = tf.reshape(multi_head_output, [batch_size, -1, self.d_model])
        output = self.w(concated_output)
        return output



In [14]:
mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
mha(q,k,v)

<tf.Tensor: shape=(32, 10, 512), dtype=float32, numpy=
array([[[-0.00364481, -0.18792947,  0.24575943, ..., -0.3051786 ,
          0.05009094,  0.49072775],
        [-0.3751133 , -0.16968033,  0.63417685, ..., -0.42369777,
         -0.28306895,  0.23018757],
        [ 0.02423792, -0.34091505,  0.26673567, ..., -0.394813  ,
          0.3934161 ,  0.4681937 ],
        ...,
        [ 0.51455647, -0.3019976 ,  0.1784493 , ...,  0.05739626,
          0.11175066,  0.2889152 ],
        [-0.5169923 , -0.14276664, -0.5098453 , ..., -0.6805406 ,
         -0.2078146 , -0.01302366],
        [-0.26270923, -0.25190753,  0.17749542, ..., -0.2725378 ,
         -0.27025217,  0.3742838 ]],

       [[ 0.12412558,  0.26256436, -0.33459312, ..., -0.4053634 ,
         -0.04337247,  0.7351228 ],
        [-0.1413969 ,  0.05620397,  0.17358118, ..., -0.70191425,
          0.270257  ,  0.62721497],
        [ 0.32794735,  0.27703953,  0.39264506, ..., -0.35090202,
         -0.14802665,  0.248271  ],
        ...,