In [1]:
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense, Dropout, LayerNormalization
import math

In [2]:
batch_size = 1
max_seqlen = 200
d_model = 512
n_heads = 8
n_blocks = 5
hidden = 1024
drop_prob = 0.1

In [3]:
x = tf.random.normal((batch_size, max_seqlen, d_model))

In [4]:
def scaled_dot_product(q, k, v, mask = None):
  scaled_dotproduct = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2])) / math.sqrt(q.shape[-1])
  if mask is not None:
    scaled_dotproduct += mask
  attention_weights = tf.nn.softmax(scaled_dotproduct, axis = -1)
  values = tf.matmul(attention_weights, v)
  return values, attention_weights

class MultiHeadAttention(tf.keras.layers.Layer):
  def __init__(self, d_model, n_heads):
    super().__init__()
    self.n_heads = n_heads
    self.head_dim = d_model // n_heads
    self.qkv_dense = Dense(d_model*3)
    self.out_dense = Dense(d_model)

  def call(self, inputs, mask=None):
    batch_size, max_seqlen, d_model = inputs.shape
    qkv = self.qkv_dense(inputs)
    qkv = tf.reshape(qkv, (batch_size, max_seqlen, self.n_heads, 3*self.head_dim))
    qkv = tf.transpose(qkv, perm=[0, 2, 1, 3])
    q, k, v = tf.split(qkv, 3, axis = -1)
    values, attention_weights = scaled_dot_product(q, k, v, mask)
    values = tf.reshape(values, (batch_size, max_seqlen, self.n_heads*self.head_dim))
    out = self.out_dense(values)
    return out

In [5]:
class PositionalEncoding:
  def __init__(self, max_seqlen, d_model):
    self.max_seqlen = max_seqlen
    self.d_model = d_model

  def call(self):
    even_indices = tf.range(0, self.d_model, 2, dtype=tf.float32)
    den = tf.math.pow(10000, 2*even_indices/self.d_model)
    pos = tf.reshape(tf.range(0, self.max_seqlen, 1, dtype=tf.float32), (self.max_seqlen, 1))
    even_pos = tf.math.sin(pos/den)
    odd_pos = tf.math.cos(pos/den)
    pos = tf.stack([even_pos, odd_pos], axis=2)
    pos = tf.reshape(pos, (self.max_seqlen, self.d_model))
    return pos

In [6]:
class FeedForward(tf.keras.layers.Layer):
  def __init__(self, d_model, hidden, drop_prob=0.1):
    super().__init__()
    self.layer1 = Dense(hidden, activation='relu')
    self.layer2 = Dense(d_model)
    self.dropout = Dropout(drop_prob)
        
  def call(self, x):
    x = self.layer1(x)
    x = self.dropout(x)
    x = self.layer2(x)
    return x

In [7]:
class EncoderLayer(tf.keras.layers.Layer):
  def __init__(self, d_model, n_heads, hidden, drop_prob):
    super().__init__()
    self.attention = MultiHeadAttention(d_model, n_heads)
    self.norm1 = LayerNormalization()
    self.dropout1 = Dropout(drop_prob)
    self.ffn = FeedForward(d_model, hidden, drop_prob)
    self.norm2 = LayerNormalization()
    self.dropout2 = Dropout(drop_prob)
        
  def call(self, x):
    res = x
    x = self.attention(x, mask=None)
    x = self.dropout1(x)
    x = self.norm1(x+res)
    res = x
    x = self.ffn(x)
    x = self.dropout2(x)
    x = self.norm2(x + res)
    return x

In [8]:
class Encoder(tf.keras.layers.Layer):
  def __init__(self, n_blocks, d_model, n_heads, hidden, drop_prob):
    super().__init__()
    self.n_blocks = n_blocks
    self.encoders = Sequential([EncoderLayer(d_model, n_heads, hidden, drop_prob) \
                                  for _ in range(n_blocks)])
        
  def call(self, inputs):
    x = self.encoders(inputs)
    return x

In [9]:
en = Encoder(n_blocks, d_model, n_heads, hidden, drop_prob)
en(x)

<tf.Tensor: shape=(1, 200, 512), dtype=float32, numpy=
array([[[ 0.8720597 ,  0.72465485,  1.0577979 , ...,  2.0518732 ,
         -0.11343463, -0.4436577 ],
        [ 1.4336817 , -0.54340327,  1.1054131 , ...,  0.2823553 ,
         -0.9729718 , -0.12151792],
        [ 0.0448431 ,  1.3235494 ,  1.6045357 , ...,  1.3869084 ,
          0.45061648, -0.51091754],
        ...,
        [-0.63664484,  1.7898706 ,  1.0812165 , ..., -0.05531621,
         -0.4358906 , -0.18574098],
        [ 0.9070335 , -0.18310636,  2.0024514 , ..., -0.6505261 ,
         -1.6465125 , -0.5812339 ],
        [-1.4111273 ,  1.0861882 ,  1.5656976 , ...,  1.3039907 ,
         -1.1497822 ,  0.87188476]]], dtype=float32)>