In [124]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.utils import pad_sequences


In [125]:
class TransformerEncoder(layers.Layer):
    
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads
        
        self.attention = layers.MultiHeadAttention(num_heads = self.num_heads, key_dim = self.embed_dim)
        self.dense_proj = keras.Sequential([
            layers.Dense(self.dense_dim, activation = 'relu'),
            layers.Dense(self.embed_dim)
        ])
        
        self.layernorm1 = layers.LayerNormalization()
        self.layernorm2 = layers.LayerNormalization()
        
    def call(self, inputs, mask = None):
        
        if mask is not None:
            mask = mask[:, tf.newaxis, :]
        attention_output = self.attention(inputs, inputs)
        proj_input = self.layernorm1(attention_output + inputs)
        proj_output = self.dense_proj(proj_input)
        return self.layernorm2(proj_output+ proj_input)
    
    def get_config(self):
        
        config = super().get_config()
        config.update({
            "embed_dim": self.embed_dim,
            "dense_dim": self.dense_dim,
            "num_heads": self.num_heads
        })
        return config
        

In [126]:
class PositionalEmbedding(keras.layers.Layer):
    
    def __init__(self, sequence_length, input_dim, output_dim, **kwargs):
        super().__init__(**kwargs)
        self.sequence_length = sequence_length
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        self.token_embedding = keras.layers.Embedding(input_dim, output_dim)
        self.position_embedding = keras.layers.Embedding(sequence_length, output_dim)
        
    def call(self, inputs, value = None):
        length = tf.shape(inputs)[-1]
        position = tf.range(start = 0, limit = length, delta = 1)
        embedded_token = self.token_embedding(inputs)
        embedded_position = self.position_embedding(position)
        return embedded_token + embedded_position

    
    def compute_mask(self, inputs, masks = None):
        return tf.math.not_equal(inputs, 0)

    
    def get_config():
        config = super().get_config()
        config.update({
            "sequence_length": self.sequence_length,
            "input_dim": self.input_dim,
            "output_dim": self.output_dim
        })
        return config
        

In [127]:
vocab_size = 20000
sequence_length = 600
embed_dim = 256
num_heads = 2
dense_dim = 32

In [128]:
inputs = keras.Input(shape = (None,), dtype = "int64")
x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(inputs)
x = TransformerEncoder(embed_dim = embed_dim, dense_dim = dense_dim, num_heads = num_heads)(x)
x = keras.layers.GlobalMaxPool1D()(x)
x = keras.layers.Dropout(0.5)(x)
outputs = keras.layers.Dense(1, activation = 'sigmoid')(x)
model = keras.models.Model(inputs, outputs)


In [129]:
from tensorflow.keras.datasets import imdb
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words = vocab_size)

In [131]:
x_train = pad_sequences(x_train, maxlen = sequence_length)

In [132]:
x_test = pad_sequences(x_test, maxlen = sequence_length)

In [133]:
x_train = tf.convert_to_tensor(x_train, dtype = tf.int64)
y_train = tf.convert_to_tensor(y_train, dtype = tf.int64)

In [134]:
x_val = x_train[:500]
y_val = x_train[:500]
x_train = x_train[500:]
y_train = y_train[500:]

In [135]:
model.compile(
    optimizer = 'rmsprop',
    loss = 'binary_crossentropy',
    metrics = ['acc']
)

In [None]:
history = model.fit(
    x_train,
    y_train,
    epochs = 20,
    batch_size = 128,
    validation_data = (x_val, y_val)
)

Epoch 1/20
