In [68]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.utils import pad_sequences

In [69]:
class TransformerEncoder(layers.Layer):
    
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads
        
        self.attention = layers.MultiHeadAttention(num_heads = self.num_heads, key_dim = self.embed_dim)
        self.dense_proj = keras.Sequential([
            layers.Dense(self.dense_dim, activation = 'relu'),
            layers.Dense(self.embed_dim)
        ])
        
        self.layernorm1 = layers.LayerNormalization()
        self.layernorm2 = layers.LayerNormalization()
        
    def call(self, inputs, mask = None):
        
        if mask is not None:
            mask = maks[:, tf.newaxis, :]
        attention_output = self.attention(inputs, inputs, attention_mask = mask)
        proj_input = self.layernorm1(attention_output + inputs)
        proj_output = self.dense_proj(proj_input)
        return self.layernorm2(proj_output+ proj_input)
    
    def get_config(self):
        
        config = super().get_config()
        config.update({
            "embed_dim": self.embed_dim,
            "dense_dim": self.dense_dim,
            "num_heads": self.num_heads
        })
        return config
        

In [70]:
vocab_size = 10000
embed_dim = 256
dense_dim = 32
num_heads = 2
max_words = 50

In [71]:
inputs = keras.Input(shape = (None, ), dtype = 'int64')
x = layers.Embedding(vocab_size, embed_dim)(inputs)
x = TransformerEncoder(embed_dim = embed_dim, dense_dim = dense_dim, num_heads = num_heads)(x)
x = layers.GlobalMaxPool1D()(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(1, activation = 'sigmoid')(x)

model = keras.models.Model(inputs, outputs)

In [72]:
model.summary()

Model: "model_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_7 (InputLayer)        [(None, None)]            0         
                                                                 
 embedding_6 (Embedding)     (None, None, 256)         2560000   
                                                                 
 transformer_encoder_5 (Tran  (None, None, 256)        543776    
 sformerEncoder)                                                 
                                                                 
 global_max_pooling1d_2 (Glo  (None, 256)              0         
 balMaxPooling1D)                                                
                                                                 
 dropout_2 (Dropout)         (None, 256)               0         
                                                                 
 dense_14 (Dense)            (None, 1)                 257 

In [73]:
from tensorflow.keras.datasets import imdb

In [74]:

(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words = vocab_size)

In [75]:
x_train = pad_sequences(x_train, maxlen = max_words)
x_test = pad_sequences(x_test, maxlen = max_words)

In [76]:
x_train = tf.convert_to_tensor(x_train, np.int64)
x_test = tf.convert_to_tensor(x_test, np.int64)

In [77]:
y_train = tf.convert_to_tensor(y_train, np.int64)

In [78]:
x_val = x_train[:5000]
y_val = y_train[:5000]

In [79]:
x_train = x_train[5000:]
y_train = y_train[5000:]

In [80]:
callback_list = [
    keras.callbacks.ModelCheckpoint(
        "imdb_transformed_model.h5",
        metrics = 'val_acc',
        save_best_only = True
    )
]

In [82]:
model.compile(
    optimizer = 'rmsprop',
    loss = 'binary_crossentropy',
    metrics = ['acc']
)
history = model.fit(
    x_train,
    y_train,
    epochs = 10,
    batch_size = 128,
    validation_data = (x_val, y_val),
    callbacks = callback_list
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
