In [1]:
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers

In [2]:
class PositionalEmbedding(layers.Layer):
    def __init__(self, sequence_length, input_dim, output_dim, **kwargs):
        super().__init__(**kwargs)
        
        self.sequence_length = sequence_length
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        self.token_embeddings = layers.Embedding(input_dim=input_dim, output_dim=output_dim)
        self.position_embeddings = layers.Embedding(input_dim=input_dim, output_dim=output_dim)
        
    def call(self, inputs):
        length = tf.shape(inputs)[-1]
        positions = tf.range(start=0, limit=length, delta=1)
        
        embedded_tokens = self.token_embeddings(inputs)
        embedded_positions = self.position_embeddings(positions)
        
        return embedded_tokens + embedded_positions
    
    def compute_mask(self, inputs, mask=None):
        return tf.math.not_equal(inputs, 0)
    
    def get_config(self):
        config = super().get_config()
        
        config.update(
            {
                '': self.sequence_length,
                '': self.input_dim,
                '': self.output_dim
            }
        )
        return config

In [3]:
class TransformerEncoder(layers.Layer):
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads
        
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        
        self.attention = layers.MultiHeadAttention(key_dim=embed_dim, num_heads=nun_heads)
        
        self.dense_proj = keras.Sequential(
            [
                layers.Dense(dense_dim),
                layers.Dense(embed_dim)
            ]
        )
        
    def call(self, inputs, mask=None):
        if mask is not None:
            mask = mask[:, tf.newaxis, :]
            
        attention_output = self.attention(inputs, inputs, attention_mask=mask)
        
        proj_input = self.layernorm_1(inputs + attention_output)
        
        proj_output = self.dense_proj(proj_input)
        
        return self.layernorm_2(proj_intput + proj_output)
    
    def get_config(self):
        config = super().get_config()
        
        config.update(
            {
                '': self.embed_dim,
                '': self.dense_dim,
                '': self.num_heads
            }
        )
        return config

In [5]:
class TransformerDecoder(layers.Layer):
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads
        
        self.supports_masking = True
        
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.layernorm_3 = layers.LayerNormalization()
        
        self.attention_1 = layers.MultiHeadAttention(key_dim=embed_dim, num_heads=num_heads)
        self.attention_2 = layers.MultiHeadAttention(key_dim=embed_dim, num_heads=num_heads)
        
        self.dense_proj = keras.Sequential(
            [
                layers.Dense(dense_dim),
                layers.Dense(embed_dim)
            ]
        )
        
    def get_causal_attention_mask(self, inputs):
        input_shape = tf.shape(inputs)
        batch_size, sequence_length = input_shape[0], input_shape[1]
        
        i = tf.range(sequence_length)[:, tf.newaxis]
        j = tf.range(sequence_length)
        
        mask = tf.cast(i >= j, dtype='int32')
        mask = tf.reshape(mask, (1, sequence_length, sequence_length))
        
        mult = tf.concat(
            [
                tf.expand_dims(batch_size, -1),
                tf.constant([1,1], dtype=tf.int32)
            ], axis=0
        )
        
        return tf.tile(mask, mult)
    
    def call(self, inputs, encoder_output, mask=None):
        causal_mask = self.get_causal_attention_mask(inputs)
        
        if mask is not None:
            paddnig_mask = tf.cast(mask[:, tf.newaxis, :], dtype='int32')
            paddnig_mask = tf.minimum(padding_mask, causal_mask)
        else:
            padding_mask = mask
            
        attention_output_1 = self.attention_1(
            query=inputs,
            value=inputs,
            key=inputs,
            attention_mask=causal_mask
        )
        
        attention_output_1 = self.layernorm_1(inputs + attention_output_1)
        
        attention_output_2 = self.attention(
            query=attention_output_1,
            value=encoder_output,
            key=encoder_output,
            attention_mask=padding_mask
        )
        
        attention_output_2 = self.layernorm_2(attention_output_1 + attention_output_2)
        
        proj_output = self.dense_proj(attention_output_2)
        
        return self.layernorm_3(attention_output_2 + proj_output)
    
    def get_config(self):
        config = super().get_config()
        
        config.update(
            {
                '': self.embed_dim,
                '': self.dense_dim,
                '': self.num_heads
            }
        )
        return config