In [1]:
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
import math

In [2]:
batch_size = 1
max_seqlen = 200
d_model = 512
n_heads = 8
x = tf.random.normal((batch_size, max_seqlen, d_model))

In [3]:
def scaled_dot_product(q, k, v, mask = None):
  scaled_dotproduct = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2])) / math.sqrt(q.shape[-1])
  attention_weights = tf.nn.softmax(scaled_dotproduct, axis = -1)
  if mask is not None:
    attention_weights += mask
  values = tf.matmul(attention_weights, v)
  return values, attention_weights

class MultiHeadAttention(tf.keras.Model):
  def __init__(self, d_model, n_heads):
    super().__init__()
    self.n_heads = n_heads
    self.head_dim = d_model // n_heads
    self.qkv_dense = Dense(d_model*3)
    self.out_dense = Dense(d_model)

  def call(self, inputs, mask=None):
    batch_size, max_seqlen, d_model = inputs.shape
    qkv = self.qkv_dense(inputs)
    qkv = tf.reshape(qkv, (batch_size, max_seqlen, self.n_heads, 3*self.head_dim))
    qkv = tf.transpose(qkv, perm=[0, 2, 1, 3])
    q, k, v = tf.split(qkv, 3, axis = -1)
    values, attention_weights = scaled_dot_product(q, k, v, mask)
    out = self.out_dense(values)
    return out

In [4]:
mha = MultiHeadAttention(d_model, n_heads)
mha(x)

<tf.Tensor: shape=(1, 8, 200, 512), dtype=float32, numpy=
array([[[[-0.03157103,  0.00248095,  0.03175261, ...,  0.01633478,
           0.03105789,  0.00479053],
         [-0.00023916,  0.00502656,  0.0308854 , ...,  0.00684066,
           0.0341939 ,  0.03702442],
         [ 0.00504123, -0.01155712,  0.02114235, ...,  0.01511279,
           0.00699577,  0.00282884],
         ...,
         [-0.01668626, -0.02183168,  0.04970204, ...,  0.0194302 ,
           0.01919152, -0.01221314],
         [-0.01511907, -0.00300023,  0.01551432, ...,  0.01366398,
           0.03087438,  0.01827713],
         [-0.03672407, -0.00609656,  0.03056387, ...,  0.01352291,
           0.0137173 , -0.0319193 ]],

        [[-0.02130445, -0.01039708,  0.01162839, ...,  0.00250054,
          -0.01752175,  0.01471678],
         [-0.0239347 , -0.01950839,  0.02158425, ...,  0.01301685,
          -0.01967132,  0.02036212],
         [-0.02280619, -0.01825266,  0.01244256, ...,  0.03562924,
          -0.00111773,  0.0