In [None]:
import tensorflow as tf
from Dive_into_deep_learning.d2l import tensorflow as d2l

In [None]:
class MultiHeadAttention(d2l.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super().__init__()
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
        self.W_k = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
        self.W_v = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
        self.W_o = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
        
    def call(self, queries, keys, values, valid_lens, **kwargs):
        queries = self.transpose_qkv(self.W_q(queries))
        keys = self.transpose_qkv(self.W_k(keys))
        values = self.transpose_qkv(self.W_v(values))
        
        if valid_lens is not None:
            valid_lens = tf.repeat(valid_lens, repeats=self.num_heads, axis=0)
        
        output = self.attention(queries, keys, values, valid_lens, **kwargs)
        
        output_concat = self.transpose_output(output)
        return self.W_o(output_concat)
        
        
    
    def transpose_qkv(self, X):
        """
        :param X: (batch_size, no. of queries or key-value pairs, num_hiddens)
        :return: (batch_size * num_heads, no. of queries of key-value pairs, num_hiddens / num_heads)
        """
        X = tf.reshape(X, (X.shape[0], X.shape[1], self.num_heads, -1))
        X = tf.transpose(X, (0, 2, 1, 3))
        return tf.reshape(X, (-1, X.shape[2], X.shape[3]))
    
    def transpose_output(self, X):
        """
        transpose_qkv方法的逆操作
        :param X: (batch_size * num_heads, no. of queries of key-value pairs, num_hiddens / num_heads)
        :return: (batch_size, no. of queries or key-value pairs, num_hiddens)
        """
        X = tf.reshape(X, (-1, self.num_heads, X.shape[1], X.shape[2]))
        X = tf.transpose(X, (0, 2, 1, 3))
        return tf.reshape(X, (X.shape[0], X.shape[1], -1))

In [None]:
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                               num_hiddens, num_heads, 0.5)

batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = tf.constant([3, 2])
X = tf.ones((batch_size, num_queries, num_hiddens))
Y = tf.ones((batch_size, num_kvpairs, num_hiddens))
d2l.check_shape(attention(X, Y, Y, valid_lens, training=False),
                (batch_size, num_queries, num_hiddens))