In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.utils import text_dataset_from_directory
from tensorflow.keras.utils import image_dataset_from_directory

In [2]:
class PositionalEmbedding(layers.Layer):
    def __init__(self, sequence_length, input_dim, output_dim, **kwargs):
        super().__init__(**kwargs)
        
        self.sequence_length = sequence_length
        self.input_dim = input_dim
        self.output_dim
        
        self.token_embeddings = layers.Embedding(input_dim=input_dim, output_dim=output_dim)
        self.position_embeddings = layers.Embedding(input_dim=input_dim, output_dim=output_dim)
        
    def call(self, inputs):
        length = tf.shape(inputs)
        positions = tf.range(start=0, limit=length, delta=1)
        
        embedded_tokens = self.token_embeddings(inputs)
        embedded_positions = self.position_embeddings(positions)
        
        return embedded_tokens + embedded_positions
    
    def compute_mask(self, inputs, mask=None):
        return tf.math.not_equal(inputs, 0)
    
    def get_config(self):
        config = super().get_config()
        
        config.update(
            {
                'sequence_length': self.sequence_length,
                'input_dim': self.input_dim,
                'output_dim': self.output_dim
            }
        )
        return config

In [3]:
class TransformerEndoder(layers.Layer):
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads
        
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        
        self.attention = layers.MultiHeadAttention(key_dim=embed_dim, num_head=num_heads)
        
        self.dense_proj = keras.Sequential(
            [
                layers.Dense(dense_dim, activation='relu'),
                layers.Dense(embed_dim)
            ]
        )
        
    def call(self, inputs, mask=None):
        if mask is not None:
            mask = mask[:, tf.newaxis, :]
            
        attention_output = self.attention(inputs, inputs, attention_mask=mask)
        
        proj_input = self.layernorm_1(inputs + attention_output)
        
        proj_output = self.dense_proj(proj_input)
        
        return self.layernorm_2(proj_input + proj_output)
    
    def get_config(self):
        config = super().get_config()
        
        config.update(
            {
                'embed_dim': self.embed_dim,
                'dense_dim': self.dense_dim,
                'num_heads': self.num_heads
            }
        )
        return config