In [1]:
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
import math

In [2]:
batch_size = 1
max_seqlen = 200
d_model = 512
n_heads = 8
x = tf.random.normal((batch_size, max_seqlen, d_model))

In [3]:
def scaled_dot_product(q, k, v, mask = None):
  scaled_dotproduct = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2])) / math.sqrt(q.shape[-1])
  if mask is not None:
    scaled_dotproduct += mask
  attention_weights = tf.nn.softmax(scaled_dotproduct, axis = -1)
  values = tf.matmul(attention_weights, v)
  return values, attention_weights

class MultiHeadAttention(tf.keras.layers.Layer):
  def __init__(self, d_model, n_heads):
    super().__init__()
    self.n_heads = n_heads
    self.head_dim = d_model // n_heads
    self.qkv_dense = Dense(d_model*3)
    self.out_dense = Dense(d_model)

  def call(self, inputs, mask=None):
    batch_size, max_seqlen, d_model = inputs.shape
    qkv = self.qkv_dense(inputs)
    qkv = tf.reshape(qkv, (batch_size, max_seqlen, self.n_heads, 3*self.head_dim))
    qkv = tf.transpose(qkv, perm=[0, 2, 1, 3])
    q, k, v = tf.split(qkv, 3, axis = -1)
    values, attention_weights = scaled_dot_product(q, k, v, mask)
    values = tf.reshape(values, (batch_size, max_seqlen, self.n_heads*self.head_dim))
    out = self.out_dense(values)
    return out

In [4]:
mha = MultiHeadAttention(d_model, n_heads)
mha(x)

<tf.Tensor: shape=(1, 200, 512), dtype=float32, numpy=
array([[[ 0.02877445, -0.01475262,  0.10030845, ...,  0.00272664,
          0.00338816, -0.02589648],
        [-0.00837563,  0.06653754,  0.09807103, ..., -0.05678674,
         -0.03421912,  0.01303486],
        [-0.07058281,  0.0352192 ,  0.03819446, ..., -0.06132259,
         -0.01706199,  0.03019537],
        ...,
        [ 0.00886711,  0.00980199, -0.03666598, ..., -0.0303993 ,
         -0.03986033, -0.0179459 ],
        [ 0.00486959, -0.08574107,  0.00299693, ..., -0.04498035,
         -0.02993053,  0.01326033],
        [-0.00330408, -0.02367062, -0.01247103, ..., -0.00755949,
         -0.0645581 ,  0.03299144]]], dtype=float32)>