In [1]:
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense, Dropout, LayerNormalization
import math

In [2]:
batch_size = 1
max_seqlen = 200
d_model = 512
n_heads = 8
n_blocks = 5
hidden = 1024
drop_prob = 0.1

In [3]:
x = tf.random.normal((batch_size, max_seqlen, d_model))

In [4]:
def scaled_dot_product(q, k, v, mask = None):
  scaled_dotproduct = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2])) / math.sqrt(q.shape[-1])
  if mask is not None:
    scaled_dotproduct += mask
  attention_weights = tf.nn.softmax(scaled_dotproduct, axis = -1)
  values = tf.matmul(attention_weights, v)
  return values, attention_weights

class MultiHeadAttention(tf.keras.layers.Layer):
  def __init__(self, d_model, n_heads):
    super().__init__()
    self.n_heads = n_heads
    self.head_dim = d_model // n_heads
    self.qkv_dense = Dense(d_model*3)
    self.out_dense = Dense(d_model)

  def call(self, inputs, mask=None):
    batch_size, max_seqlen, d_model = inputs.shape
    qkv = self.qkv_dense(inputs)
    qkv = tf.reshape(qkv, (batch_size, max_seqlen, self.n_heads, 3*self.head_dim))
    qkv = tf.transpose(qkv, perm=[0, 2, 1, 3])
    q, k, v = tf.split(qkv, 3, axis = -1)
    values, attention_weights = scaled_dot_product(q, k, v, mask)
    values = tf.reshape(values, (batch_size, max_seqlen, self.n_heads*self.head_dim))
    out = self.out_dense(values)
    return out

In [5]:
class PositionalEncoding:
  def __init__(self, max_seqlen, d_model):
    self.max_seqlen = max_seqlen
    self.d_model = d_model

  def call(self):
    even_indices = tf.range(0, self.d_model, 2, dtype=tf.float32)
    den = tf.math.pow(10000, 2*even_indices/self.d_model)
    pos = tf.reshape(tf.range(0, self.max_seqlen, 1, dtype=tf.float32), (self.max_seqlen, 1))
    even_pos = tf.math.sin(pos/den)
    odd_pos = tf.math.cos(pos/den)
    pos = tf.stack([even_pos, odd_pos], axis=2)
    pos = tf.reshape(pos, (self.max_seqlen, self.d_model))
    return pos

In [6]:
class FeedForward(tf.keras.layers.Layer):
  def __init__(self, d_model, hidden, drop_prob=0.1):
    super().__init__()
    self.layer1 = Dense(hidden, activation='relu')
    self.layer2 = Dense(d_model)
    self.dropout = Dropout(drop_prob)
        
  def call(self, x):
    x = self.layer1(x)
    x = self.dropout(x)
    x = self.layer2(x)
    return x

In [7]:
class MultiHeadCrossAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.kv_layer = Dense(2*d_model)
        self.q_layer = Dense(d_model)
        self.linear = Dense(d_model)
        
    def call(self, x, y, mask=None):
        batch_size, max_seqlen, d_model = x.shape
        kv = self.kv_layer(x)
        q = self.q_layer(y)
        kv = tf.reshape(kv, (batch_size, max_seqlen, self.n_heads, 2*self.head_dim))
        q = tf.reshape(q, (batch_size, max_seqlen, self.n_heads, self.head_dim))
        kv = tf.transpose(kv, perm=[0, 2, 1, 3])
        q = tf.transpose(q, perm=[0, 2, 1, 3])
        k, v = tf.split(kv, 2, axis = -1)
        values, attention = scaled_dot_product(q, k, v, mask)
        values = tf.reshape(values, (batch_size, max_seqlen, self.n_heads*self.head_dim))
        out = self.linear(values)
        return out
    
    
class DecoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, ffn_hidden, n_heads, drop_prob):
        super().__init__()
        self.selfmha = MultiHeadAttention(d_model, n_heads)
        self.norm1 = LayerNormalization()
        self.dropout1 = Dropout(drop_prob)
        self.crossmha = MultiHeadCrossAttention(d_model, n_heads)
        self.norm2 = LayerNormalization()
        self.dropout2 = Dropout(drop_prob)
        self.ffn = FeedForward(d_model, ffn_hidden, drop_prob)
        self.norm3 = LayerNormalization()
        self.dropout3 = Dropout(drop_prob)
        
    def call(self, x, y, decoder_mask):
        _y = y
        y = self.selfmha(x)
        y = self.dropout1(y)
        y = self.norm1(y + _y)
        _y = y
        y = self.crossmha(x, y, decoder_mask)
        y = self.dropout2(y)
        y = self.norm2(_y+y)
        _y = y
        y = self.ffn(y)
        y = self.dropout3(y)
        y = self.norm3(y+_y)
        
        return y
    
    
class Decoder(tf.keras.layers.Layer):
    def __init__(self, d_model, ffn_hidden, n_heads, drop_prob, n_blocks):
        super().__init__()
        self.decoders = [DecoderLayer(d_model, ffn_hidden, n_heads, drop_prob)\
                                         for _ in range(n_blocks)]
    def call(self, x, y, mask):
        for decoder in self.decoders:
          y = decoder(x, y, mask)
        return y

In [8]:
d_model = 512
num_heads = 8
drop_prob = 0.1
batch_size = 30
max_sequence_length = 200
ffn_hidden = 2048
num_layers = 5

x = tf.random.normal( (batch_size, max_sequence_length, d_model) )
y = tf.random.normal( (batch_size, max_sequence_length, d_model) )
mask = tf.fill([max_sequence_length, max_sequence_length] , float('-inf'))
mask = tf.experimental.numpy.triu(mask, k=1)
decoder = Decoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers)
out = decoder(x, y, mask)

In [9]:
out

<tf.Tensor: shape=(30, 200, 512), dtype=float32, numpy=
array([[[ 0.13932322,  0.13584025, -0.4594161 , ...,  0.24291672,
         -0.5508995 ,  1.515657  ],
        [ 0.33804357,  0.41118777, -0.9975303 , ..., -0.9210208 ,
          0.29964545,  0.32628384],
        [-0.70708025, -0.21153395, -0.5459612 , ...,  0.27436   ,
         -0.51453066,  1.382289  ],
        ...,
        [ 0.60298365,  0.00809366,  0.40203333, ..., -1.1408042 ,
         -1.3849393 ,  1.0232967 ],
        [ 0.05200108, -0.4694514 , -3.3843205 , ..., -1.2156051 ,
         -0.49520615,  0.03540848],
        [ 0.23965476,  0.40160403, -1.7326179 , ..., -1.0629336 ,
         -1.5942286 ,  0.01689267]],

       [[ 0.75725156, -2.4038734 , -0.80469906, ..., -1.6158617 ,
         -0.16448896,  1.3181194 ],
        [-0.0853722 , -0.7652514 , -0.64315945, ...,  1.2671533 ,
         -0.86594146,  0.204378  ],
        [-2.1040978 ,  0.16972065,  0.1381669 , ...,  0.73474175,
         -1.8635955 ,  0.03357337],
        ...