# Text Classification with Transformers

In [1]:
import os
os.environ["KERAS_BACKEND"] = "jax"

In [2]:
import keras

## Download and prepare dataset

In [3]:
vocab_size = 20000  # Only consider the top 20k words
maxlen = 200  # Only consider the first 200 words of each movie review
(x_train, y_train), (x_val, y_val) = keras.datasets.imdb.load_data(num_words=vocab_size)
print(len(x_train), "Training sequences")
print(len(x_val), "Validation sequences")
x_train = keras.utils.pad_sequences(x_train, maxlen=maxlen)
x_val = keras.utils.pad_sequences(x_val, maxlen=maxlen)

25000 Training sequences
25000 Validation sequences


## Create classifier model using transformer layer

In [4]:
import keras_mml

In [5]:
embed_dim = 32  # Embedding size for each token
num_heads = 2  # Number of attention heads
ff_dim = 32  # Hidden layer size in feed forward network inside transformer

model = keras.models.Sequential(
    layers=[
        keras.layers.Input(shape=(maxlen,)),
        keras_mml.layers.TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim),
        keras_mml.layers.TransformerBlockMML(embed_dim, num_heads, ff_dim),
        keras.layers.GlobalAveragePooling1D(),
        keras.layers.Dropout(0.1),
        keras.layers.Dense(20, activation="relu"),
        keras.layers.Dropout(0.1),
        keras.layers.Dense(2, activation="softmax")
    ]
)
model.summary()

In [6]:
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
history = model.fit(
    x_train, y_train, batch_size=32, epochs=2, validation_data=(x_val, y_val)
)

Epoch 1/2
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m57s[0m 70ms/step - accuracy: 0.7251 - loss: 0.5067 - val_accuracy: 0.8397 - val_loss: 0.3628
Epoch 2/2
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m51s[0m 63ms/step - accuracy: 0.9394 - loss: 0.1715 - val_accuracy: 0.8597 - val_loss: 0.3520
