In [5]:
import tensorflow as tf

In [6]:
def scaled_dot_product(q, k, v, mask):
    d_k = tf.cast(tf.shape(k)[-1], tf.float32)
    scaled_qk = tf.matmul(q, k, transpose_b=True) / tf.math.sqrt(d_k)

    if mask is not None:
        scaled_qk += mask

    attention_weights = tf.nn.softmax(scaled_qk)
    output = tf.matmul(attention_weights, v)
    return output, attention_weights

In [7]:
class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv_layer = tf.keras.layers.Dense(3 * d_model, use_bias=False)
        self.linear_layer = tf.keras.layers.Dense(d_model, activation='relu')

    def split_heads(self, x, batch_size):
        if len(x.shape) == 2:
            x = tf.expand_dims(tf.expand_dims(x, axis=0), axis=1)
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.head_dim))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, x, mask):
        batch_size, _, _ = x.shape

        qkv = self.qkv_layer(x)
        q, k, v = tf.split(qkv, 3, axis=-1)
        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)
        values, attention = scaled_dot_product(q, k, v, mask)

        values = tf.transpose(values, perm=[0, 2, 1, 3])
        values = tf.reshape(values, (batch_size, -1, self.num_heads * self.head_dim))
        out = self.linear_layer(values)
        print("MultiHeadAttention output shape is ",out.shape)
        return out

# Example usage:
# Assuming you have a query tensor q with shape (batch_size, seq_len, d_model)
# where seq_len = 30 and d_model = 64
q = tf.random.normal((32, 30, 64))  # Example with batch_size = 32

# Create an instance of MultiHeadAttention
multihead_attention = MultiHeadAttention(d_model=64, num_heads=8)

# Call the MultiHeadAttention layer
output = multihead_attention(q, mask=None)

# The output will be the result of the attention mechanism after splitting into heads.
print("Output shape:", output.shape)


MultiHeadAttention output shape is  (32, 30, 64)
Output shape: (32, 30, 64)
