## Text Classification using an Encoder-Only Transformer

In this exercise we will build an encoder-only transformer and apply it on a text
classification task. To better understand the model architecture, we will build and train the model (almost) from scratch.  

**Note**: in practice, you would use a pretrained language model and possibly finetune it

In [None]:
import keras

# Part 1

**2.1 Copy and Change Classes from Previous Exercise**

An encoder-only transformer is very similar to a decoder-only transformer:
*   It also starts with a positional embedding.  
We are going to use the code for ```PositionalEmbedding``` from the previous exercise.
*   The encoder blocks employ the same ```FeedForward``` layer as the decoder blocks.  
We are going to use the code for ```FeedForward``` from the previous exercise.
*   The encoder block itself is very similar to the decoder block in a decoder-only transformer.  
We are going to use the code for ```GTPDecoderBlock``` from the previous exercise.





In [None]:
@keras.saving.register_keras_serializable()
class FeedForward(keras.layers.Layer):

  def __init__(self, factor=4, **kwargs):
    super().__init__(**kwargs)
    self.factor = factor
    self.relu = keras.activations.relu

  def build(self, batch_input_shape):
    time_steps, embed_size = batch_input_shape[1:]
    self.kernel1 = self.add_weight(shape=(embed_size, self.factor*embed_size))
    self.bias1 = self.add_weight(shape=(self.factor*embed_size, ),
                                 initializer="zeros")
    self.kernel2 = self.add_weight(shape=(self.factor*embed_size, embed_size))
    self.bias2 = self.add_weight(shape=(embed_size, ),
                                 initializer="zeros")

  def call(self, inputs):
    a =  self.relu(keras.ops.matmul(inputs, self.kernel1) + self.bias1)
    return keras.ops.matmul(a, self.kernel2) + self.bias2

  def get_config(self):
    base_config = super().get_config()
    return {
        **base_config,
        "factor": self.factor,
    }

# Embedding with Position
@keras.saving.register_keras_serializable()
class EmbeddingWithPosition(keras.layers.Layer):
  """
  Computes and embedding and also adds a positional embedding.
  This layer does not support masking.
  """

  def __init__(self, num_tokens, max_seq_length, embed_size, **kwargs):
    super().__init__(**kwargs)
    self.num_tokens = num_tokens
    self.max_seq_length = max_seq_length
    self.embed_size = embed_size

  def build(self, batch_input_shape):
    # Shape not needed
    # print(f"Build called with {batch_input_shape} as shape")

    self.kernel = self.add_weight(shape=(self.num_tokens, self.embed_size))
    self.pos_kernel = self.add_weight(shape=(self.max_seq_length, self.embed_size))

  def call(self, inputs):
    _, length = keras.ops.shape(inputs)

    embeddings = keras.ops.take(self.kernel, inputs, axis=0) # (batch, length, embed_size)
    pos_embeddings = self.pos_kernel[:length]

    return embeddings + pos_embeddings # rely on broadcasting. Mask is lost

  def get_config(self):
    base_config = super().get_config()
    return {
        **base_config,
        "num_tokens": self.num_tokens,
        "max_seq_length": self.max_seq_length,
        "embed_size": self.embed_size
    }

Make the following changes to ```GPTDecoderBlock```:
*   Rename the class to ```EncoderBlock```.
*   The multi-headed attention does not need the causal mask when it is called, since each token is allowed to attend to all tokens before Ã¡nd after it. Remove the causal mask.

In [None]:
# A GPT decoder block (without cross attention)
@keras.saving.register_keras_serializable()
class GPTDecoderBlock(keras.layers.Layer):

  def __init__(self, num_heads, embed_size, **kwargs):
    super().__init__(**kwargs)
    self.num_heads = num_heads
    self.embed_size = embed_size

    self.masked_multi_head_attn = keras.layers.MultiHeadAttention(
        num_heads=self.num_heads,
        key_dim = self.embed_size // self.num_heads
    )
    self.layer_norm_1 = keras.layers.LayerNormalization()
    self.feed_forward = FeedForward()
    self.layer_norm_2 = keras.layers.LayerNormalization()

  def build(self):
    pass

  def call(self, inputs):
    skip = inputs
    inputs = self.masked_multi_head_attn(inputs, inputs, use_causal_mask=True)
    inputs = self.layer_norm_1(keras.layers.Add()([inputs, skip]))

    skip = inputs
    inputs = self.feed_forward(inputs)
    inputs = self.layer_norm_2(keras.layers.Add()([skip, inputs]))
    return inputs

  def get_config(self):
    base_config = super().get_config()
    return {
        **base_config,
        "num_heads": self.num_heads,
        "embed_size": self.embed_size
    }

**2.2 Create an Encoder-Only Classification Model**  
The encoder from the original "Attention is all you Need" paper outputs contextual
embeddings of shape: (```batch_size```, ```seq_length```, ```embed_size```).  
**Note**: in principle this model will work with any sequence length, as
long as it is not longer than the maximum sequence length as specified in ```PositionalEmbedding```.
We need a way of attaching a small network for classification on top of these contextual embeddings. We will do the following (other options are certainly possible):

*  First, we will reduce the contextual embeddings of shape
(```batch_size```, ```seq_length```, ```embed_size```) to a tensor of fixed size by
viewing these contextual embeddings as a one-dimensional sequence and applying a one-dimensional global average pooling layer to it. This reduces the
contextual embeddings to a tensor of size (```batch_size```, ```embed_size```).
*  Next, we add a fully connected layer with ```embed_size``` units and the ReLU
activation function.
*  Finally, we add a fully connected layer with the correct number of units and
the correct activation function given the type of classification problem.

In [None]:
def get_classification_encoder_model(
    num_tokens: int, max_seq_length: int,
    embed_size: int, num_heads:int, num_blocks:int,
    num_classes:int,
    use_mask=False,
    scale_embeddings=False):
  """
  num_tokens: the vocabulary size
  max_seq_length: maximum length of any sequence
  embed_size: the embedding dimension
  num_heads: the number of heads in each multi-headed attention
  num_blocks: the number of encoder blocks
  num_classes: the number of classes to classify the text into
  """
  inputs = keras.layers.Input(shape=[max_seq_length], dtype=int) # (B, LEN)

  # Positional Embeddings
  # YOUR CODE HERE

  # Encoder blocks
  # YOUR CODE HERE

  # Simplest classification head
  # Classification network on top of contextual embeddings
  # YOUR CODE HERE

  return keras.Model(inputs=inputs, outputs=output)

## Load Data

**2.3 Apply the Model to the IMDB-Dataset**  
We will now apply this model to the built-in IMDB-dataset. Load the data:

In [None]:
MAX_LEN=500
NUM_TOKENS=10_000

In [None]:
(X_train, y_train), (X_test, y_test) = keras.datasets.imdb.load_data(num_words=NUM_TOKENS)

Not all reviews contain the same number of words. Use a method from
```keras.utils``` to pad the reviews so that they all have the same length. Use arguments so that
*  padding tokens are added at the end of reviews that are shorter than ```MAX_LEN```.
*  reviews that are longer than ```MAX_LEN``` should be truncated so that the start of the review is kept.

In [None]:
# pad sequences
X_train = keras.utils....
X_test = keras.utils....

Set aside the first 10 000 examples of the test data as the validation set:

In [None]:
X_val = ....
y_val = ....
X_test = ....
y_test = ....

Now that the data is ready, we can build a model. Use the following parameters:

In [None]:
NUM_HEADS=4
EMBED_SIZE=32
NUM_BLOCKS=2

In [None]:
model = get_classification_encoder_model(
    num_tokens=NUM_TOKENS,
    max_seq_length=MAX_LEN,
    embed_size=EMBED_SIZE,
    num_heads=NUM_HEADS,
    num_blocks=NUM_BLOCKS,
    num_classes=1)

How many parameters does the model have? You should find that this model has 362 497 parameters, most of which are in the embedding layer.


In [None]:
model.summary()

Compile and fit the model:
*  Use the ```Adam``` optimizer with all the default values.
*  Use early stopping to (try to) prevent (severe) overfitting. Monitor the validation accuracy, using a ```patience``` of 4. Restore the best weights.  

When you train the model, you should find that it reaches a validation accuracy of around 86%, while the accuracy on the training data is much higher, so the model is definitely overfitting.