# DL Lab 3.3 - Text Classification with Transformers

Welcome to the DL Lab! Discussing sequence processing, you heard a lot about Recurrent Neural Networks (RNNs) and their different architectures. At last, one of the most breakthrough architectural mechanisms was discussed: the **self-attention** mechanism, which represents the key concept for the so called **Transformer** architecture.

In this homework, you will implement an **encoder** based on **multi-head attention** and use it for **text classification**.

***

**After completing this homework you will be able to**

- Use TF's **subclassing** API to efficiently create **custom layers** and **blocks**.
- Implement and use **scaled dot product attention** in **multiple attention heads**.

***

# 1 - Data Preparation

you will use the same text data as for the DL Lab 3.1, i.e., the **[AG_NEWS](http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html)** dataset containing news articles of 4 different categories.

Execute the cell below for downloading and preprocessing the text data.

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

BATCHSIZE = 32

dataset = tfds.load('ag_news_subset')
train_ds = dataset['train']
val_ds = dataset['test']

classes = ['World', 'Sports', 'Business', 'Sci/Tech']
num_classes = len(classes)

def extract_text(x):
    return x['title'] + ' ' + x['description']

We also need a word-level vectorizer to convert the text into a numerical representation:

In [None]:
from tensorflow.keras import layers, Model

max_vocab_size = 10000
max_sequence_length = 100

vectorizer = layers.TextVectorization(
    max_tokens = max_vocab_size,
    output_sequence_length = max_sequence_length
)
vectorizer.adapt(train_ds.take(1000).map(extract_text))

In [None]:
# Get the unique words in the vocabulary
vocab = vectorizer.get_vocabulary()

# Length of the vocabulary
vocab_size = len(vocab)
print(f"Number of words in vocab: {vocab_size}")

# most common tokens (notice the [UNK] token for "unknown" words)
top_5_words = vocab[:5]
print(f"Top 5 most common words: {top_5_words}")

# least common tokens
bottom_5_words = vocab[-5:]
print(f"Bottom 5 least common words: {bottom_5_words}")

Let's tokenize our text and optimize it for performance:

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

def tupelize(x):
    return (vectorizer(extract_text(x)), x['label'])

train_ds_opt = train_ds.map(tupelize)
train_ds_opt = train_ds_opt.cache().shuffle(1000).batch(BATCHSIZE).prefetch(AUTOTUNE)

val_ds_opt = val_ds.map(tupelize)
val_ds_opt = val_ds_opt.cache().batch(BATCHSIZE).prefetch(AUTOTUNE)

# 2 - Transformers

Before the rise of Transformer models, the state of the art in any sequence related tasks (translation, image captioning, ...) was basically a RNN with added attention mechanism (compare with Lecture 3.2 - Attention).

Transformer architectures also rely on attention mechanisms, but without any recurrent processing. The impactful [*Attention is all you need*](https://arxiv.org/abs/1706.03762) paper highlights that the attention mechanism alone is as powerful as RNNs with added attention. However, recurrent processing requires that sequences are processed strictly in order both during forward and backward propagation. In Transformers, a sequence can be processed *at once*, which significantly reduces training times and allows for much larger datasets to be efficiently processed.

## 2.1 - Define Multi-head Attention Encoder

The heart of each transformer is an encoder-decoder structure. The encoder processes the input sequence to generate encodings of each token in such way that they contain contextual information, i.e., the encoder encodes which tokens of the input sequence are relevant to each other. The decoder part then transforms these token encodings to generate an output sequence.

### Scaled Dot-Product Attention

In order to compute this contextual token relevancy, both the encoder and the decoder are using a specific type of attention: the **scaled dot-product attention**, which allows for dynamic computation of the weighted average of the features of the input tokens. "Dynamic" means that the weights are not pre-defined but depend on the actual values of the tokens.

Therefore, for each token embedding $x_i \in X$ of the input sequence, a set of vectors is computed: a **query** vector $q_i$, a **key** vector $k_i$, and a **value** vector $v_i$. These vectors are computed by linear projection using trainable weight matrices:
$$
q_i = x_i W_q\\
k_i = x_i W_k\\
v_i = x_i W_v.
$$

As you can see, computing these vectors is exactly the same as computing the output of a fully connected layer. You will also implement it as such. ;-)

The attention weight

$$a_{ij} = q_i \cdot k_j$$

from token $i$ to token $j$ is based on the similarity between the query vector $q_i$ of token $i$ and the key vector $k_j$ of token $j$, and the dot product is used as similarity metric. Hence, the attention weights for all tokens can be computed by matrix-matrix multiplication $Q K^T$, where Q (K) contains the query (key) vectors for all input tokens of the sequence. In addition, we normalize the attention weights so that they are between $0$ and $1$ and sum up to $1$. As the average value of the dot product increases with increasing embedding dimension $d_k$, we also divide by $\sqrt{d_k}$ to mitigate the amount by which the increase in dimension increases the dot product.

Ultimately, the hidden representations of the tokens with attention is computed by multiplying the scaled attention weights with the value vectors $V$:

$$H(Q, K, V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}}\right)V$$.

<img src='https://3.bp.blogspot.com/-mnlTQLXKuiU/XfgOSZ2eBsI/AAAAAAAAB0w/6jjXEtzO6_M1IlPkNVzR_wcmP62u0jI0ACLcBGAsYHQ/s1600/attention.png' width="760">

### Multi-head Attention

Having multiple attention heads, each computing the scaled dot-product attention as described above, the model can learn different relevance relations between the input tokens. The output of all heads is concatenated and propagated through a linear layer reducing the dimension to the hidden size of the model. Each head will process a different chunk of the token embedding. To speed up computation, we first compute the $Q, K, V$ matrices and then split and distribute them across the different heads. You can think of it as stacking all the $W_q^{(h)}, W_k^{(h)}, W_v^{(h)}$ weight matrices of the $h$ different heads into combined weight matrices $W_q, W_k, W_v$.


**Task**: Complete the `MultiHeadSelfAttention` class defined below.
- Initialize the [dense layers](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense) `query_linear`, `key_linear`, `value_linear` for computing the linear transformations of the inputs. Each layer shall have `hidden_size` neurons, use linear activation and no bias.
- Compute the dot product attention in the `attention` method using [`tf.matmul()`](https://www.tensorflow.org/api_docs/python/tf/linalg/matmul).

In [None]:
class MultiHeadSelfAttention(layers.Layer):
  ''' Multi-head self attention layer consisting of four parts:
    - linear layers and separation into multiple heads
    - scaled dot product attention
    - concatenation of multiple heads
    - final linear layer
  '''

  def __init__(self, hidden_size, num_heads=4):
    assert hidden_size % num_heads == 0, \
      'The hidden size (={}) has to be divisble by number of heads (={})'.format(
          hidden_size, num_heads)
    super(MultiHeadSelfAttention, self).__init__()
    self.hidden_size = hidden_size
    self.num_heads = num_heads
    self.projection_dim = hidden_size // num_heads

    # Layers for linearly projecting to queries, keys and values
    ### START YOUR CODE HERE ###  (3 LOC)
    self.query_linear =
    self.key_linear =
    self.value_linear =
    ### END YOUR CODE HERE ###
    self.output_linear = layers.Dense(hidden_size, name='transform_output')


  def attention(self, query, key, value):
    ''' Apply scaled dot product attention

    Args:
      query : Tensor with shape (batch_size, num_heads, sequence_length, hidden_size/num_heads)
      key : Tensor with shape (batch_size, num_heads, sequence_length, hidden_size/num_heads)
      value : Tensor with shape (batch_size, num_heads, sequence_length, hidden_size/num_heads)

    Returns:
      Tensor with shape (batch_size, num_heads, sequence_length, hidden_size/num_heads)
    '''

    # compute dot product attention
    ### START YOUR CODE HERE ###  (1 LOC)
    logits =
    ### END YOUR CODE HERE ###

    # scale logits
    dim_key = tf.cast( tf.shape(key)[-1], tf.float32 )
    logits = logits / tf.math.sqrt(dim_key)

    # apply softmax
    attention_weights = tf.nn.softmax(logits, axis=-1, name='attention_weights')

    # multiply attention weights with value
    output = tf.matmul(attention_weights, value)

    return output


  def separate_heads(self, x, batch_size):
    ''' Separate x into different heads and transpose resulting value.

    Args:
      x : Tensor with shape (batch_size, sequence_length, hidden_size)

    Returns:
      Tensor with shape (batch_size, num_heads, sequence_length, hidden_size/num_heads)
    '''
    x = tf.reshape(x, (batch_size, -1, self.num_heads, self.projection_dim))
    return tf.transpose(x, perm=[0,2,1,3])


  def combine_heads(self, x, batch_size):
    ''' Combine splitted tensor.

    Args:
      x : Tensor with shape (batch_size, num_heads, sequence_length, hidden_size/num_heads)

    Returns:
      Tensor with shape (batch_size, sequence_length, hidden_size)
    '''
    x = tf.transpose(x, perm=[0,2,1,3])
    return tf.reshape(x, (batch_size, -1, self.hidden_size))


  def call(self, inputs):
    ''' Apply self-attention mechanism to inputs.

    Args:
      inputs : Tensor with shape (batch_size, sequence_length, hidden_size)

    Returns:
      Tensor with shape (batch_size, sequence_length, hidden_size)
    '''
    # store batch_size
    batch_size = tf.shape(inputs)[0]

    # transform q, k, v by linear projection
    query = self.query_linear(inputs)
    key = self.key_linear(inputs)
    value = self.value_linear(inputs)

    # separate q, k, v into heads
    query = self.separate_heads(query, batch_size)
    key = self.separate_heads(key, batch_size)
    value = self.separate_heads(value, batch_size)

    # apply self-attention
    attention = self.attention(query, key, value)

    # combine heads
    attention = self.combine_heads(attention, batch_size)

    # (batch_size, sequence_length, hidden_size)
    attention_output = self.output_linear(attention)

    return attention_output

Each encoder layer of the transformer consists of the multi-head attention layer, and two dense layers with dropout.

In [None]:
class EncoderLayer(layers.Layer):
  def __init__(self, embedding_dim, ff_dim, num_heads=4, rate=.1):
    super(EncoderLayer, self).__init__()
    self.attention = MultiHeadSelfAttention(embedding_dim, num_heads)
    self.ffn = tf.keras.Sequential( [layers.Dense(ff_dim, activation='relu'),
                                     layers.Dense(embedding_dim)] )
    self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
    self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
    self.dropout1 = layers.Dropout(rate)
    self.dropout2 = layers.Dropout(rate)

  def call(self, inputs, training):
    attn_output = self.attention(inputs)
    attn_output = self.dropout1(attn_output, training=training)
    out = self.layernorm1(inputs + attn_output)
    ffn_output = self.ffn(out)
    ffn_output = self.dropout2(ffn_output, training=training)
    return self.layernorm2(out + ffn_output)

Unlike RNNs, the transformer model has no meaning of sequences of operations for a set of inputs. Hence it is important to keep track of the order of inputs. You can either train an *positional embedding* using the `EmbeddingWithPositionalEmbedding` layer below (the model then learns to encode positions), or use the *positional encoding* layer `EmbeddingWithPositionalEncoding`. Both should perform equally well.

In [None]:
class EmbeddingWithPositionalEmbedding(layers.Layer):
  def __init__(self, max_length, vocabulary_size, embedding_dim):
    super(EmbeddingWithPositionalEmbedding, self).__init__()
    self.token_emb = layers.Embedding(vocabulary_size, embedding_dim)
    self.position_emb = layers.Embedding(max_length, embedding_dim)

  def call(self, x):
    max_length = tf.shape(x)[-1]
    positions = tf.range(start=0, limit=max_length, delta=1)
    positions = self.position_emb(positions)
    x = self.token_emb(x)
    return x + positions

In [None]:
class EmbeddingWithPositionalEncoding(layers.Layer):
  def __init__(self, max_length, vocabulary_size, embedding_dim):
    super(EmbeddingWithPositionalEncoding, self).__init__()
    self.embedding_dim = embedding_dim
    self.vocabulary_size = vocabulary_size
    self.token_emb = layers.Embedding(vocabulary_size, embedding_dim)
    self.pos_encoding = self.positional_encoding(vocabulary_size)

  def _get_angles(self, position, i):
    angles = 1 / tf.pow(10000, (2 * (i // 2)) / tf.cast(self.embedding_dim, tf.float32))
    return position * angles

  def positional_encoding(self, position):
    angle_rads = self._get_angles(
        position=tf.range(position, dtype=tf.float32)[:, tf.newaxis],
        i=tf.range(self.embedding_dim, dtype=tf.float32)[tf.newaxis, :]
        )
    # apply sin to even index in the array
    sines = tf.math.sin(angle_rads[:, 0::2])
    # apply cos to odd index in the array
    cosines = tf.math.cos(angle_rads[:, 1::2])

    pos_encoding = tf.concat([sines, cosines], axis=-1)
    pos_encoding = pos_encoding[tf.newaxis, ...]
    return tf.cast(pos_encoding, tf.float32)

  def call(self, x):
    x = self.token_emb(x)
    positions = self.pos_encoding[:, :tf.shape(x)[1], :]
    return x + positions

## 2.2 - Create Model

Now you can use the custom layers to build your own transformer encoder:

In [None]:
from tensorflow.keras import layers, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.metrics import SparseCategoricalAccuracy

def build_transformer_encoder(
    embedding_dim,
    num_heads,
    ff_dim,
    max_length,
    vocabulary_size,
    learning_rate=1e-3,
    summary=True
    ):

  input_ = layers.Input( shape=(max_length,) )

  # choose one
  #embedding_layer = EmbeddingWithPositionalEmbedding(max_length, vocabulary_size, embedding_dim)
  embedding_layer = EmbeddingWithPositionalEncoding(max_length, vocabulary_size, embedding_dim)
  x = embedding_layer(input_)

  transformer_block = EncoderLayer(embedding_dim, ff_dim, num_heads)

  x = transformer_block(x)
  x = layers.GlobalAveragePooling1D()(x)
  x = layers.Dropout(.1)(x)
  x = layers.Dense(x.shape[1], activation='relu')(x)
  x = layers.Dropout(.1)(x)
  output_ = layers.Dense(4, activation='softmax')(x)

  model = Model(input_, output_)

  model.compile(
      loss=SparseCategoricalCrossentropy(),
      optimizer=Adam(
          learning_rate=learning_rate,
          beta_2=0.98,
          epsilon=1e-9
      ),
      metrics=SparseCategoricalAccuracy()
  )

  if summary:
    print(model.summary())

  return model

Training transformers is typically done using learning rate warm-up. Instead of making large updates from the very beginning, we start to learn slowly by gradually increasing the learning rate from 0 to a certain value. After this warm-up, the learning rate decays again.

The [*Attention is all you need* paper](https://arxiv.org/abs/1706.03762) provides following formula, which we implement as custom learning rate scheduler:

$$\Large{lrate = d_{model}^{-0.5} * min(step{\_}num^{-0.5}, step{\_}num * warmup{\_}steps^{-1.5})}$$.

In [None]:
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):

  def __init__(self, d_model, warmup_steps=200):
    super(CustomSchedule, self).__init__()
    self.d_model = d_model
    self.d_model = tf.cast(self.d_model, tf.float32)
    self.warmup_steps = warmup_steps

  def __call__(self, step):
    step = tf.cast(step, tf.float32)
    arg1 = tf.math.rsqrt(step)
    arg2 = step * (self.warmup_steps ** -1.5)

    return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

Check the scheduled learning rate:

In [None]:
from matplotlib import pyplot as plt

EMBEDDING_DIM = 32

LearningRateSchedule = CustomSchedule(EMBEDDING_DIM)

plt.plot(LearningRateSchedule(tf.range(2000, dtype=tf.float32)))
plt.ylabel("Learning Rate")
plt.xlabel("Train Step")
plt.show()

## 2.3 - Train and Evaluate Model

Finally, you can train your transformer! If everything works well, the model should immediately converge and start overfitting after only few epochs.

In [None]:
# @title define `plot_history()`
from matplotlib import pyplot as plt
import numpy as np

def plot_history(history):
  fig, (ax1, ax2) = plt.subplots(2,1, sharex=True, dpi=150)
  ax1.plot(history.history['loss'], label='training')
  ax1.plot(history.history['val_loss'], label='validation')
  ax1.set_ylabel('Loss')
  ax1.set_yscale('log')
  if history.history.__contains__('lr'):
    ax1b = ax1.twinx()
    ax1b.plot(history.history['lr'], 'g-', linewidth=1)
    ax1b.set_yscale('log')
    ax1b.set_ylabel('Learning Rate', color='g')
  ax1.legend()

  key = None
  for k in sorted(history.history.keys()):
    if 'acc' in k and not 'val_' in k:
      key = k
      break
  if key:
    ax2.plot(history.history[key], label='training')
    ax2.plot(history.history['val_'+key], label='validation')
    ax2.set_ylabel('Accuracy')
    ax2.set_xlabel('Epochs')
  plt.show()

In [None]:
num_heads = 4 # Number of attention heads
ff_dim = 16 # Hidden layer size in feed forward network inside transformer

BATCHSIZE = 32

LearningRateSchedule = CustomSchedule(EMBEDDING_DIM)

EarlyStoppingCallback = tf.keras.callbacks.EarlyStopping(monitor='val_loss',
                                                         patience=3,
                                                         restore_best_weights=True)

my_transformer = build_transformer_encoder(
    EMBEDDING_DIM, num_heads, ff_dim,
    max_sequence_length,
    vocab_size,
    learning_rate=LearningRateSchedule
)

history = my_transformer.fit(
    train_ds_opt,
    batch_size=BATCHSIZE,
    epochs=10,
    validation_data=val_ds_opt,
    callbacks=[EarlyStoppingCallback]
)

plot_history(history)

In [None]:
plot_history(history)