The default keras version on Google Colab does not support keras ops, so we upgrade it here. If you are running this locally and have the most recent keras version, you can skip this cell.

In [1]:
!pip install keras --upgrade

Collecting keras
  Downloading keras-3.4.1-py3-none-any.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
Collecting namex (from keras)
  Downloading namex-0.0.8-py3-none-any.whl (5.8 kB)
Collecting optree (from keras)
  Downloading optree-0.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (347 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m347.7/347.7 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: namex, optree, keras
  Attempting uninstall: keras
    Found existing installation: keras 2.15.0
    Uninstalling keras-2.15.0:
      Successfully uninstalled keras-2.15.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.4.1 which is incompatible.[0m

The below cell has the Super Attention Layer that has a single W_a matrix managed by the Multi Head Attention Module. Run it if you want to use that version of Super Attention (all other attention layers are the same)

In [2]:
import keras
import tensorflow as tf
from keras import ops
from keras import layers

"""
## Attention Layers (Shared W_a)
"""
class AttentionLayer(keras.layers.Layer):
    def __init__(self,
                 d_model: int,
                 d_q: int,
                 d_k: int,
                 d_v: int,
                 W_a: keras.layers.Dense = None,
                 layer_type: str = 'SDPA',
                 idx: int = 0,
                 max_len: int = 32,
                 **kwargs):
        super().__init__(**kwargs)
        self.d_model = d_model
        self.d_q = d_q
        self.d_k = d_k
        self.d_v = d_v
        self.layer_type = layer_type
        self.idx = idx
        self.max_len = max_len
        self.W_a = W_a

    def build(self, input_shape):
        self.W_q = layers.Dense(self.d_q)
        if self.layer_type in ['SDPA', 'Optimised']:
            self.W_k = layers.Dense(self.d_k)
        if self.layer_type == 'SDPA':
            self.W_v = layers.Dense(self.d_v)
        super().build(input_shape)

    def call(self, inputs):
        inp_q, inp_k, inp_v = inputs
        if self.layer_type == 'Optimised':
            return self._forward_optimised(inp_q, inp_k, inp_v)
        elif self.layer_type == 'Efficient':
            return self._forward_efficient(inp_q, inp_k, inp_v)
        elif self.layer_type == 'Super':
            return self._forward_super(inp_q, inp_k, inp_v)
        else:
            return self._forward_SDPA(inp_q, inp_k, inp_v)

    def _forward_SDPA(self, inp_q, inp_k, inp_v):
        Q = self.W_q(inp_q)
        K = self.W_k(inp_k)
        V = self.W_v(inp_v)
        K_t = tf.transpose(K, perm=[0, 2, 1])
        S = tf.nn.softmax((Q @ K_t) / tf.math.sqrt(tf.cast(self.d_q, tf.float32)), axis=1)
        H = S @ V
        return H

    def _forward_optimised(self, inp_q, inp_k, inp_v):
        Q = self.W_q(inp_q)
        K = self.W_k(inp_k)
        K_t = tf.transpose(K, perm=[0, 2, 1])
        S = tf.nn.softmax((Q @ K_t) / tf.math.sqrt(tf.cast(self.d_q, tf.float32)), axis=1)
        v_lo = self.idx * self.d_v
        v_hi = (self.idx + 1) * self.d_v
        V = inp_v[:, :, v_lo:v_hi]
        H = S @ V
        return H

    def _forward_efficient(self, inp_q, inp_k, inp_v):
        Q = self.W_q(inp_q)
        lo = self.idx * self.d_k
        hi = (self.idx + 1) * self.d_k
        K_t = tf.transpose(inp_k[:, :, lo:hi], perm=[0, 2, 1])
        S = tf.nn.softmax((Q @ K_t) / tf.math.sqrt(tf.cast(self.d_q, tf.float32)), axis=1)
        V = inp_v[:, :, lo:hi]
        H = S @ V
        return H

    def _forward_super(self, inp_q, inp_k, inp_v):
        Q = self.W_q(inp_q)
        lo = self.idx * self.d_k
        hi = (self.idx + 1) * self.d_k
        K_t = tf.transpose(inp_k[:, :, lo:hi], perm=[0, 2, 1])
        S = tf.nn.softmax((Q @ K_t) / tf.math.sqrt(tf.cast(self.d_q, tf.float32)), axis=1)
        V = self.W_a.kernel @ inp_v[:, :, lo:hi]
        H = S @ V
        return H

class MultiHeadAttention(keras.layers.Layer):
    def __init__(self, n_heads, d_model, d_k, d_v, max_len, layer_type, **kwargs):
        super().__init__(**kwargs)
        self.n_heads = n_heads
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.max_len = max_len
        self.layer_type = layer_type

    def build(self,input_shape):
        if self.layer_type == 'Super':
            self.W_a = layers.Dense(self.max_len)
            self.W_a.build((None, self.max_len))
        else:
            self.W_a = None

        self.attention_layers = [
            AttentionLayer(self.d_model, self.d_k, self.d_k, self.d_v,
                           self.W_a, self.layer_type, idx=i, max_len=self.max_len)
            for i in range(self.n_heads)
        ]

        self.W_o = layers.Dense(self.d_model)

        # Build each attention layer
        for layer in self.attention_layers:
            layer.build(input_shape)

        # Build the output dense layer
        self.W_o.build((None, self.n_heads * self.d_v))

        super().build(input_shape)

    def call(self, inputs):
        inp_q, inp_k, inp_v = inputs, inputs, inputs

        H = None
        for i, layer in enumerate(self.attention_layers):
            h_i = layer([inp_q, inp_k, inp_v])
            if i == 0:
                H = h_i
            else:
                H = tf.concat([H, h_i], axis=-1)

        out = self.W_o(H)
        return out

The below cell has the Super Attention Layer that has a single W_a matrix FOR EACH Super Attention Layer, independent of the Multi Head Attention Module. Run it if you want to use that version of Super Attention (all other attention layers are the same)

In [None]:
import keras
import tensorflow as tf
from keras import ops
from keras import layers


"""
## Attention Layers (Individual W_a)
"""
class AttentionLayer(keras.layers.Layer):
    def __init__(self,
                 d_model: int,
                 d_q: int,
                 d_k: int,
                 d_v: int,
                 layer_type: str = 'SDPA',
                 idx: int = 0,
                 max_len: int = 32,
                 **kwargs):
        super().__init__(**kwargs)
        self.d_model = d_model
        self.d_q = d_q
        self.d_k = d_k
        self.d_v = d_v
        self.layer_type = layer_type
        self.idx = idx
        self.max_len = max_len

    def build(self, input_shape):
        self.W_q = layers.Dense(self.d_q)
        self.W_q.build((None, self.d_model))
        if self.layer_type in ['SDPA', 'Optimised']:
            self.W_k = layers.Dense(self.d_k)
            self.W_k.build((None, self.d_model))
        if self.layer_type == 'SDPA':
            self.W_v = layers.Dense(self.d_v)
            self.W_v.build((None, self.d_model))
        if self.layer_type == 'Super':
            self.W_a = layers.Dense(self.max_len)
            self.W_a.build((None, self.max_len))

        super().build(input_shape)

    def call(self, inputs):
        inp_q, inp_k, inp_v = inputs
        if self.layer_type == 'Optimised':
            return self._forward_optimised(inp_q, inp_k, inp_v)
        elif self.layer_type == 'Efficient':
            return self._forward_efficient(inp_q, inp_k, inp_v)
        elif self.layer_type == 'Super':
            return self._forward_super(inp_q, inp_k, inp_v)
        else:
            return self._forward_SDPA(inp_q, inp_k, inp_v)

    def _forward_SDPA(self, inp_q, inp_k, inp_v):
        Q = self.W_q(inp_q)
        K = self.W_k(inp_k)
        V = self.W_v(inp_v)
        K_t = tf.transpose(K, perm=[0, 2, 1])
        S = tf.nn.softmax((Q @ K_t) / tf.math.sqrt(tf.cast(self.d_q, tf.float32)), axis=1)
        H = S @ V
        return H

    def _forward_optimised(self, inp_q, inp_k, inp_v):
        Q = self.W_q(inp_q)
        K = self.W_k(inp_k)
        K_t = tf.transpose(K, perm=[0, 2, 1])
        S = tf.nn.softmax((Q @ K_t) / tf.math.sqrt(tf.cast(self.d_q, tf.float32)), axis=1)
        v_lo = self.idx * self.d_v
        v_hi = (self.idx + 1) * self.d_v
        V = inp_v[:, :, v_lo:v_hi]
        H = S @ V
        return H

    def _forward_efficient(self, inp_q, inp_k, inp_v):
        Q = self.W_q(inp_q)
        lo = self.idx * self.d_k
        hi = (self.idx + 1) * self.d_k
        K_t = tf.transpose(inp_k[:, :, lo:hi], perm=[0, 2, 1])
        S = tf.nn.softmax((Q @ K_t) / tf.math.sqrt(tf.cast(self.d_q, tf.float32)), axis=1)
        V = inp_v[:, :, lo:hi]
        H = S @ V
        return H

    def _forward_super(self, inp_q, inp_k, inp_v):
        Q = self.W_q(inp_q)
        lo = self.idx * self.d_k
        hi = (self.idx + 1) * self.d_k
        K_t = tf.transpose(inp_k[:, :, lo:hi], perm=[0, 2, 1])
        S = tf.nn.softmax((Q @ K_t) / tf.math.sqrt(tf.cast(self.d_q, tf.float32)), axis=1)
        V = self.W_a.kernel @ inp_v[:, :, lo:hi]
        H = S @ V
        return H

class MultiHeadAttention(keras.layers.Layer):
    def __init__(self, n_heads, d_model, d_k, d_v, max_len, layer_type, **kwargs):
        super().__init__(**kwargs)
        self.n_heads = n_heads
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.max_len = max_len
        self.layer_type = layer_type

    def build(self,input_shape):
        self.attention_layers = [
            AttentionLayer(d_model=self.d_model,
                           d_q=self.d_k,
                           d_k=self.d_k,
                           d_v=self.d_v,
                           layer_type=self.layer_type,
                           idx=i,
                           max_len=self.max_len)
            for i in range(self.n_heads)
        ]

        # Build each attention layer
        for layer in self.attention_layers:
            layer.build(input_shape)

        # Build the output dense layer
        self.W_o = layers.Dense(self.d_model)
        self.W_o.build((None, self.n_heads * self.d_v))

        super().build(input_shape)

    def call(self, inputs):
        inp_q, inp_k, inp_v = inputs, inputs, inputs

        H = None
        for i, layer in enumerate(self.attention_layers):
            h_i = layer([inp_q, inp_k, inp_v])
            if i == 0:
                H = h_i
            else:
                H = tf.concat([H, h_i], axis=-1)

        out = self.W_o(H)
        return out

This cell runs the main script!

In [None]:
"""
## "You Need to Pay Better Attention" Keras Transformer Example

## Paper Link: https://arxiv.org/abs/2403.01643

## Author: Nicholas Mesa-Cucalon (https://github.com/NMesaC)
"""
import keras
import os
import tempfile
import time

import tensorflow as tf

from keras import ops
from keras import layers


keras.utils.set_random_seed(1019)

class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, max_len, layer_type = 'SDPA', rate=0.1):
        super().__init__()
        self.att = MultiHeadAttention(n_heads=num_heads,
                                      d_model=embed_dim,
                                      d_k=embed_dim // num_heads,
                                      d_v=embed_dim // num_heads,
                                      max_len=max_len,
                                      layer_type=layer_type)
        self.ffn = keras.Sequential(
            [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),]
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)

    def call(self, inputs):
        attn_output = self.att(inputs)
        attn_output = self.dropout1(attn_output)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output)
        return self.layernorm2(out1 + ffn_output)

class TokenAndPositionEmbedding(layers.Layer):
    def __init__(self, maxlen, vocab_size, embed_dim):
        super().__init__()
        self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
        self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)

    def call(self, x):
        maxlen = ops.shape(x)[-1]
        positions = ops.arange(start=0, stop=maxlen, step=1)
        positions = self.pos_emb(positions)
        x = self.token_emb(x)
        return x + positions

def get_model_size(model):
    # Save the model to a temporary file
    _, keras_file = tempfile.mkstemp('.h5')
    model.save(keras_file, include_optimizer=True)

    # Get the file size
    size_bytes = os.path.getsize(keras_file)

    # Convert to MB
    size_mb = size_bytes / (1024 * 1024)

    # Delete the temporary file
    os.remove(keras_file)

    return size_mb


def main():
    # Setup initial variables
    vocab_size = 20000  # Only consider the top 20k words
    maxlen     = 32     # Only consider the first 32 words of each movie review
    embed_dim  = 32     # Embedding size for each token
    ff_dim     = 32     # Hidden layer size in feed forward network inside transformer
    batch_size = 64     # Batch Size
    epochs     = 10     # Number of epochs
    num_heads  = 4      # Number of attention heads
    # The IMDB dataset is slightly different between Keras and PyTorch, so the results are slightly different
    (x_train, y_train), (x_test, y_test) = keras.datasets.imdb.load_data(num_words=vocab_size)
    print(len(x_train), "Training sequences")
    x_train = keras.utils.pad_sequences(x_train, maxlen=maxlen)
    x_test = keras.utils.pad_sequences(x_test, maxlen=maxlen)

    # Train a Transformer Model with Each Attention Layer Type
    num_runs = 5
    layer_types = ['SDPA','Optimised','Efficient','Super']
    for layer in layer_types:
        avg_test_acc, avg_test_loss, avg_model_size = 0, 0, 0
        run_times = []
        for run in range(num_runs):
            print(f"Training with {layer} training layer")
            # Create inputs
            inputs = layers.Input(shape=(maxlen,))
            # Create model
            embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)
            x = embedding_layer(inputs)
            transformer_block = TransformerBlock(embed_dim, num_heads, ff_dim, maxlen, layer)
            x = transformer_block(x)
            x = layers.GlobalAveragePooling1D()(x)
            x = layers.Dropout(0.1)(x)
            x = layers.Dense(6, activation="relu")(x)
            x = layers.Dropout(0.1)(x)
            outputs = layers.Dense(1, activation="sigmoid")(x)

            # Initialize model
            model = keras.Model(inputs=inputs, outputs=outputs)
            model.compile(
                optimizer="adam", loss="BCE", metrics=["accuracy"]
            )

            # Create callbacks to save best model
            checkpoint_filepath = "./results/imdb/model/" + layer + f"_{num_heads}_heads" + "/run_num_" + str(run) + "/" + "model.weights.h5"
            os.makedirs(os.path.dirname(checkpoint_filepath), exist_ok=True)
            # NOTE: Keras callback will save only the weights of the transformer model when val_accuracy goes up
            checkpoint_callback = keras.callbacks.ModelCheckpoint(checkpoint_filepath,
                                                                  monitor="val_accuracy",
                                                                  save_best_only=True,
                                                                  save_weights_only=True,
                                                                 )
            # Fit model and retrieve history + time taken
            start_time = time.time()
            history = model.fit(x=x_train,
                                y=y_train,
                                batch_size=batch_size,
                                epochs=epochs,
                                validation_split=0.1,
                                callbacks=[checkpoint_callback])
            end_time = time.time()
            duration = end_time - start_time
            run_times.append(duration)
            print(f"Training took {duration:.4f} seconds")

            # Calculate model size
            model_size = get_model_size(model)
            print(f"The size of the model is approximately {model_size:.4f} MB")

            # Compute number of attention parameters
            num_of_attention_params = model.layers[2].att.count_params()
            print(f"Number of attention params: {num_of_attention_params}")

            # Compute and display test loss and accuracy for best model
            model.load_weights(checkpoint_filepath)
            test_loss, test_acc = model.evaluate(x_test, y_test)
            print(f"Test Loss: {test_loss}, Test Accuracy: {test_acc}\n")
            avg_test_loss  += test_loss
            avg_test_acc   += test_acc
            avg_model_size += model_size
        run_times.sort()
        med_run_time = run_times[len(run_times) // 2]
        file_name = f"{layer}_results.txt"
        f = open(file_name,"a")
        f.write(f"Average Test Acc over {num_runs} for {layer}: {avg_test_acc / num_runs} \n")
        f.write(f"Average Test Loss over {num_runs} for {layer}: {avg_test_loss / num_runs} \n")
        f.write(f"Average Model Size over {num_runs} for {layer}: {avg_model_size / num_runs} \n")
        f.write(f"Median Run Time over {num_runs} for {layer}: {med_run_time} \n")
        f.write(f"Number of attention parameters for {layer}: {num_of_attention_params} \n")
        f.close()

if __name__ == '__main__':
    main()






Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz
[1m17464789/17464789[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
25000 Training sequences
Training with SDPA training layer
Epoch 1/10
