<h5/>In this script we will built custom transformer for sentiment classifer</h5>

In [None]:
import os

os.environ["KERAS_BACKEND"] = "jax"  # or "tensorflow" or "torch"

import keras_nlp
import keras
import tensorflow as tf

# Use mixed precision to speed up all training in this guide.
keras.mixed_precision.set_global_policy("mixed_float16")

In [15]:
BATCH_SIZE = 16
imdb_train = keras.utils.text_dataset_from_directory(
    "aclImdb/train",
    batch_size=BATCH_SIZE,
)
imdb_test = keras.utils.text_dataset_from_directory(
    "aclImdb/test",
    batch_size=BATCH_SIZE,
)



Found 25000 files belonging to 2 classes.
Found 25000 files belonging to 2 classes.


### Train custom vocabulary from IMDB data

In [16]:
vocab = keras_nlp.tokenizers.compute_word_piece_vocabulary(
    imdb_train.map(lambda x, y: x),
    vocabulary_size=20_000,
    lowercase=True,
    strip_accents=True,
    reserved_tokens=["[PAD]", "[START]", "[END]", "[MASK]", "[UNK]"],
)
tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(
    vocabulary=vocab,
    lowercase=True,
    strip_accents=True,
    oov_token="[UNK]",
)

### Preprocess data with a custom tokenizer

In [17]:
packer = keras_nlp.layers.StartEndPacker(
    start_value=tokenizer.token_to_id("[START]"),
    end_value=tokenizer.token_to_id("[END]"),
    pad_value=tokenizer.token_to_id("[PAD]"),
    sequence_length=512,
)


def preprocess(x, y):
    token_ids = packer(tokenizer(x))
    return token_ids, y


imdb_preproc_train_ds = imdb_train.map(
    preprocess, num_parallel_calls=tf.data.AUTOTUNE
).prefetch(tf.data.AUTOTUNE)
imdb_preproc_val_ds = imdb_test.map(
    preprocess, num_parallel_calls=tf.data.AUTOTUNE
).prefetch(tf.data.AUTOTUNE)

# print(imdb_preproc_train_ds.unbatch().take(1).get_single_element())

### Design a tiny transformer

In [27]:
token_id_input = keras.Input(
    shape=(None,),
    dtype="int32",
    name="token_ids",
)
outputs = keras_nlp.layers.TokenAndPositionEmbedding(
    vocabulary_size=len(vocab),
    sequence_length=packer.sequence_length,
    embedding_dim=64,
)(token_id_input)
outputs = keras_nlp.layers.TransformerEncoder(
    num_heads=4,
    intermediate_dim=128,
    dropout=0.1,
)(outputs)
# Use "[START]" token to classify
outputs = keras.layers.Dense(2)(outputs[:, 0, :])
model = keras.Model(
    inputs=token_id_input,
    outputs=outputs,
)

model.summary()

### Train the transformer directly on the classification objective

In [28]:
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.AdamW(5e-5),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    jit_compile=True,
)
model.fit(
    imdb_preproc_train_ds,
    validation_data=imdb_preproc_val_ds,
    epochs=15,
)

Epoch 1/15
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m35s[0m 20ms/step - loss: 0.6998 - sparse_categorical_accuracy: 0.5317 - val_loss: 0.4893 - val_sparse_categorical_accuracy: 0.7834
Epoch 2/15
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 19ms/step - loss: 0.3964 - sparse_categorical_accuracy: 0.8275 - val_loss: 0.3191 - val_sparse_categorical_accuracy: 0.8700
Epoch 3/15
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 15ms/step - loss: 0.2619 - sparse_categorical_accuracy: 0.8997 - val_loss: 0.3055 - val_sparse_categorical_accuracy: 0.8741
Epoch 4/15
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 17ms/step - loss: 0.2167 - sparse_categorical_accuracy: 0.9168 - val_loss: 0.2987 - val_sparse_categorical_accuracy: 0.8789
Epoch 5/15
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 17ms/step - loss: 0.1799 - sparse_categorical_accuracy: 0.9313 - val_loss: 0.3331 - val_sparse_categoric

<keras.src.callbacks.history.History at 0x7e0d850bee00>

In [29]:
test_loss, test_accuracy = model.evaluate(imdb_preproc_val_ds)
print(f"Test Loss: {test_loss}, Test Accuracy: {test_accuracy}")

[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 7ms/step - loss: 0.6493 - sparse_categorical_accuracy: 0.8572
Test Loss: 0.6534755825996399, Test Accuracy: 0.8560400009155273


Final test accuracy is 85.605% without any cleaning of data