<a href="https://colab.research.google.com/github/akaliutau/tensorflow-grimoire/blob/main/notebooks/transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## The Transformer architecture

In [None]:
import os, pathlib, shutil, random
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

### The Transformer encoder

**Use the sample data (small imdb reviews archive)**

In [None]:
!wget -nv -O aclImdb_small.zip https://github.com/akaliutau/tensorflow-grimoire/blob/main/data/aclImdb_small.zip?raw=true
!unzip -oq aclImdb_small.zip

2025-01-05 23:53:01 URL:https://raw.githubusercontent.com/akaliutau/tensorflow-grimoire/refs/heads/main/data/aclImdb_small.zip [1632806/1632806] -> "aclImdb_small.zip" [1]


**Preparing the data**

In [None]:
batch_size = 32
base_dir = pathlib.Path("aclImdb")
val_dir = base_dir / "val"
train_dir = base_dir / "train"
for category in ("neg", "pos"):
    os.makedirs(val_dir / category, exist_ok=True)
    files = os.listdir(train_dir / category)
    random.Random(1337).shuffle(files)
    num_val_samples = int(0.2 * len(files))
    val_files = files[-num_val_samples:]
    for fname in val_files:
        shutil.move(train_dir / category / fname,
                    val_dir / category / fname)

train_ds = keras.utils.text_dataset_from_directory(
    "aclImdb/train", batch_size=batch_size
)
val_ds = keras.utils.text_dataset_from_directory(
    "aclImdb/val", batch_size=batch_size
)
test_ds = keras.utils.text_dataset_from_directory(
    "aclImdb/test", batch_size=batch_size
)
text_only_train_ds = train_ds.map(lambda x, y: x)

Found 800 files belonging to 2 classes.
Found 200 files belonging to 2 classes.
Found 1000 files belonging to 2 classes.


**Vectorizing the data**

In [None]:
max_length = 300
max_tokens = 10000
text_vectorization = layers.TextVectorization(
    max_tokens=max_tokens,
    output_mode="int",
    output_sequence_length=max_length,
)
text_vectorization.adapt(text_only_train_ds)

int_train_ds = train_ds.map(
    lambda x, y: (text_vectorization(x), y),
    num_parallel_calls=4)
int_val_ds = val_ds.map(
    lambda x, y: (text_vectorization(x), y),
    num_parallel_calls=4)
int_test_ds = test_ds.map(
    lambda x, y: (text_vectorization(x), y),
    num_parallel_calls=4)

**Transformer encoder implemented as a subclassed `Layer`**

Note, in line `self.layernorm_1(inputs + attention_output)` the tensors inputs and attention_output are added element-wise. They are NOT concatenated. This is a classic residual connection where the original input is added to the output of a sub-layer. For this to work, both input tensors must be of compatible shapes

To summarize: the addition used here is not for merging different types of data. Instead, it's a core component of residual connections. It allows the gradient to flow more easily through the network during training, preventing vanishing gradients and enabling the training of deeper networks.

The input (inputs) acts as the query, key, and value for self-attention.

`attention_mask` is used to ignore padding tokens in the attention process.
`attention_output` is the result of the self-attention. It has the same shape as inputs.


The self-attention output attention_output is meant to capture the relationships between words in the input. By adding inputs and attention_output, we retain information of the original input tokens, as well as incorporate the learnt relationships.

The feedforward layer (dense_proj) is designed to learn non-linear combinations of the input features, which is essential for capturing more complex patterns in data. Adding the input to the feedforward layer output is another residual connection, allowing for gradient flow.


In [None]:
class TransformerEncoder(layers.Layer):
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads
        self.attention = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim)
        self.dense_proj = keras.Sequential(
            [layers.Dense(dense_dim, activation="relu"),
             layers.Dense(embed_dim),]
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()

    def call(self, inputs, mask=None):
        if mask is not None:
            mask = mask[:, tf.newaxis, :]
        attention_output = self.attention(
            inputs, inputs, attention_mask=mask)
        proj_input = self.layernorm_1(inputs + attention_output)
        proj_output = self.dense_proj(proj_input)
        return self.layernorm_2(proj_input + proj_output)

    def get_config(self):
        config = super().get_config()
        config.update({
            "embed_dim": self.embed_dim,
            "num_heads": self.num_heads,
            "dense_dim": self.dense_dim,
        })
        return config

**Using the Transformer encoder for text classification**

In [None]:
vocab_size = 10000
embed_dim = 256
num_heads = 2
dense_dim = 32

inputs = keras.Input(shape=(None,), dtype="int64")
x = layers.Embedding(vocab_size, embed_dim)(inputs)
x = TransformerEncoder(embed_dim, dense_dim, num_heads)(x)
x = layers.GlobalMaxPooling1D()(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(inputs, outputs)
model.compile(optimizer="rmsprop",
              loss="binary_crossentropy",
              metrics=["accuracy"])
model.summary()

**Training and evaluating the Transformer encoder based model**

In [None]:
callbacks = [
    keras.callbacks.ModelCheckpoint("transformer_encoder.keras",
                                    save_best_only=True)
]
model.fit(int_train_ds, validation_data=int_val_ds, epochs=3, callbacks=callbacks)
model = keras.models.load_model(
    "transformer_encoder.keras",
    custom_objects={"TransformerEncoder": TransformerEncoder})
print(f"Test acc: {model.evaluate(int_test_ds)[1]:.3f}")

Epoch 1/3
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 156ms/step - accuracy: 0.4990 - loss: 1.8905 - val_accuracy: 0.4850 - val_loss: 0.8459
Epoch 2/3
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 42ms/step - accuracy: 0.5492 - loss: 0.9205 - val_accuracy: 0.5250 - val_loss: 0.6610
Epoch 3/3
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 36ms/step - accuracy: 0.5854 - loss: 0.7447 - val_accuracy: 0.5450 - val_loss: 0.6723




[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 29ms/step - accuracy: 0.5789 - loss: 0.6517
Test acc: 0.558


#### Using positional encoding to re-inject order information

**Implementing positional embedding as a subclassed layer**

The idea is to mix in an `index vector` into word embeddings, and those "position_embeddings" should represent the position of token in a sequence.

The original positions are non-mixing (orthogonal) due to obvious reasons, hence its dimensionality = sequence_length and may be transformed into embedding vectors, since some positions could correlated with each other (e.g. the positions 0 and 2 could have some relation in question sequences, for example, "do you love me?" )

In [None]:
class PositionalEmbedding(layers.Layer):
    def __init__(self, sequence_length, input_dim, output_dim, **kwargs):
        super().__init__(**kwargs)
        self.token_embeddings = layers.Embedding(
            input_dim=input_dim, output_dim=output_dim, mask_zero=True)
        self.position_embeddings = layers.Embedding(
            input_dim=sequence_length, output_dim=output_dim)
        self.sequence_length = sequence_length
        self.input_dim = input_dim
        self.output_dim = output_dim

    def call(self, inputs):
        length = tf.shape(inputs)[-1]
        positions = tf.range(start=0, limit=length, delta=1)
        embedded_tokens = self.token_embeddings(inputs)
        embedded_positions = self.position_embeddings(positions)
        return embedded_tokens + embedded_positions # both representations are mixed together

    #def compute_mask(self, inputs, mask=None):
    #    return tf.keras.backend.not_equal(inputs, 0)

    def get_config(self):
        config = super().get_config()
        config.update({
            "output_dim": self.output_dim,
            "sequence_length": self.sequence_length,
            "input_dim": self.input_dim,
        })
        return config

**Combining the Transformer encoder with positional embedding**

In [None]:
vocab_size = 10000
sequence_length = 300
embed_dim = 256
num_heads = 2
dense_dim = 32

inputs = keras.Input(shape=(None,), dtype="int64")
x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(inputs)
x = TransformerEncoder(embed_dim, dense_dim, num_heads)(x)
x = layers.GlobalMaxPooling1D()(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(inputs, outputs)
model.compile(optimizer="rmsprop",
              loss="binary_crossentropy",
              metrics=["accuracy"])
model.summary()

callbacks = [
    keras.callbacks.ModelCheckpoint("full_transformer_encoder.keras",
                                    save_best_only=True)
]
model.fit(int_train_ds, validation_data=int_val_ds, epochs=5, callbacks=callbacks)
model = keras.models.load_model(
    "full_transformer_encoder.keras",
    custom_objects={"TransformerEncoder": TransformerEncoder,
                    "PositionalEmbedding": PositionalEmbedding})
print(f"Test acc: {model.evaluate(int_test_ds)[1]:.3f}")

Epoch 1/5
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 161ms/step - accuracy: 0.5105 - loss: 1.7603 - val_accuracy: 0.5000 - val_loss: 0.7329
Epoch 2/5
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 45ms/step - accuracy: 0.5702 - loss: 0.8116 - val_accuracy: 0.5000 - val_loss: 0.6786
Epoch 3/5
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 39ms/step - accuracy: 0.5742 - loss: 0.7452 - val_accuracy: 0.6750 - val_loss: 0.6027
Epoch 4/5
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 39ms/step - accuracy: 0.7209 - loss: 0.5239 - val_accuracy: 0.7100 - val_loss: 0.5636
Epoch 5/5
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 38ms/step - accuracy: 0.8222 - loss: 0.4068 - val_accuracy: 0.7550 - val_loss: 0.5115




[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 27ms/step - accuracy: 0.7271 - loss: 0.5273
Test acc: 0.733
