In [1]:
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense, Dropout, LayerNormalization
import math

In [2]:
batch_size = 1
max_seqlen = 200
d_model = 512
n_heads = 8
n_blocks = 5
hidden = 1024
drop_prob = 0.1

In [3]:
x = tf.random.normal((batch_size, max_seqlen, d_model))

In [4]:
def scaled_dot_product(q, k, v, mask = None):
  scaled_dotproduct = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2])) / math.sqrt(q.shape[-1])
  attention_weights = tf.nn.softmax(scaled_dotproduct, axis = -1)
  if mask is not None:
    attention_weights += mask
  values = tf.matmul(attention_weights, v)
  return values, attention_weights

class MultiHeadAttention(tf.keras.Model):
  def __init__(self, d_model, n_heads):
    super().__init__()
    self.n_heads = n_heads
    self.head_dim = d_model // n_heads
    self.qkv_dense = Dense(d_model*3)
    self.out_dense = Dense(d_model)

  def call(self, inputs, mask=None):
    batch_size, max_seqlen, d_model = inputs.shape
    qkv = self.qkv_dense(inputs)
    qkv = tf.reshape(qkv, (batch_size, max_seqlen, self.n_heads, 3*self.head_dim))
    qkv = tf.transpose(qkv, perm=[0, 2, 1, 3])
    q, k, v = tf.split(qkv, 3, axis = -1)
    values, attention_weights = scaled_dot_product(q, k, v, mask)
    values = tf.reshape(values, (batch_size, max_seqlen, self.n_heads*self.head_dim))
    out = self.out_dense(values)
    return out

In [5]:
class PositionalEncoding:
  def __init__(self, max_seqlen, d_model):
    self.max_seqlen = max_seqlen
    self.d_model = d_model

  def call(self):
    even_indices = tf.range(0, self.d_model, 2, dtype=tf.float32)
    den = tf.math.pow(10000, 2*even_indices/self.d_model)
    pos = tf.reshape(tf.range(0, self.max_seqlen, 1, dtype=tf.float32), (self.max_seqlen, 1))
    even_pos = tf.math.sin(pos/den)
    odd_pos = tf.math.cos(pos/den)
    pos = tf.stack([even_pos, odd_pos], axis=2)
    pos = tf.reshape(pos, (self.max_seqlen, self.d_model))
    return pos

In [6]:
class FeedForward(tf.keras.Model):
  def __init__(self, d_model, hidden, drop_prob=0.1):
    super().__init__()
    self.layer1 = Dense(hidden, activation='relu')
    self.layer2 = Dense(d_model)
    self.dropout = Dropout(drop_prob)
        
  def call(self, x):
    x = self.layer1(x)
    x = self.dropout(x)
    x = self.layer2(x)
    return x

In [7]:
class EncoderLayer(tf.keras.Model):
  def __init__(self, d_model, n_heads, hidden, drop_prob):
    super().__init__()
    self.attention = MultiHeadAttention(d_model, n_heads)
    self.norm1 = LayerNormalization()
    self.dropout1 = Dropout(drop_prob)
    self.ffn = FeedForward(d_model, hidden, drop_prob)
    self.norm2 = LayerNormalization()
    self.dropout2 = Dropout(drop_prob)
        
  def call(self, x):
    res = x
    x = self.attention(x, mask=None)
    x = self.dropout1(x)
    x = self.norm1(x+res)
    res = x
    x = self.ffn(x)
    x = self.dropout2(x)
    x = self.norm2(x + res)
    return x

In [8]:
class Encoder(tf.keras.Model):
    def __init__(self, n_blocks, d_model, n_heads, hidden, drop_prob):
        super().__init__()
        self.n_blocks = n_blocks
        self.encoders = Sequential([EncoderLayer(d_model, n_heads, hidden, drop_prob) \
                                     for _ in range(n_blocks)])
        
    def call(self, inputs):
        x = self.encoders(inputs)
        return x

In [9]:
en = Encoder(n_blocks, d_model, n_heads, hidden, drop_prob)

In [10]:
en(x)

<tf.Tensor: shape=(1, 200, 512), dtype=float32, numpy=
array([[[ 8.2368737e-01,  1.7308272e-01, -4.9270141e-01, ...,
          1.3057206e+00, -1.0925290e-01,  2.8918761e-01],
        [-4.2751050e-01, -3.4096795e-01, -4.3196020e-01, ...,
          1.0392998e+00,  1.0989662e+00,  8.3607346e-01],
        [ 4.4611350e-01,  5.9474868e-01, -3.6346868e-01, ...,
          2.7086868e+00,  1.4857063e+00,  1.1411437e+00],
        ...,
        [ 9.0194619e-01,  1.2233398e+00, -7.3240137e-01, ...,
         -3.7522590e-01,  4.2617995e-01, -2.2641669e-03],
        [-3.9516670e-01,  7.2224128e-01, -1.2510946e+00, ...,
          1.2202094e+00,  1.1120749e+00, -2.4744701e-01],
        [ 2.4680458e-01,  4.2367247e-01, -4.4969967e-01, ...,
          4.3569800e-01,  9.1172850e-01, -8.6346075e-02]]], dtype=float32)>