Natural Language Processing Tutorial
======

This is the tutorial of the 2023 [Mediterranean Machine Learning Summer School](https://www.m2lschool.org/) on Natural Language Processing!

This tutorial will explore the fundamental aspects of Natural Language Processing (NLP). Whether you are new to NLP or just beginning your journey, there's no need to worry, as the tutorial assumes minimal prior knowledge. Our focus will be on implementing everything from scratch to ensure clarity and understanding. To facilitate this, we will be using [JAX](https://jax.readthedocs.io/en/latest/), a library that offers an API similar to Numpy (and often identical) with the added benefit of automatic differentiation.

## Outline

- 0 Refresher on JAX, Haiku and Optax
- 1 Introduction to NLP
   - 1.1 The NLP pipeline
   - 1.2 Classification pipeline: Multi-hot encoding + MLP model
   - 1.3 Classification pipeline: Embeddings + Sequential Model
- <span style="color:blue">2 Introduction to the Transformers architecture </span>
   - <span style="color:blue">2.1 Transformer architecture </span>
   - <span style="color:blue">2.2 Implementing the core components </span>
   - <span style="color:blue">2.3 Transformer for classification pipeline </span>
- 3 Advanced: Transformers for language translation
  - 3.1 The Transformer Decoder
  - 3.2 Transformer Decoder for character-based Language Modelling
  - 3.3 The full Transformer
  - 3.4 Transformer for Neural Machine Translation

## Emojis

Sections marked as [ 📝 ] contain cells with missing code that you should complete [ &#x1F4C4; ] is used for links to interesting external resources. When we use the words of an external resource we will cite it with &#x1F449; resource &#x1F448;.

## Libraries

We will keep our promise that (almost &#x1F60B; ) everything will be built from scratch. Indeed, all the vital and challenging components will be developed from zero. In fact the whole tutorial could be done only based on JAX. However, we recognize that certain minor technical details can become distracting if much time spent for them. For these tiny bits, we will ask help from [haiku](https://dm-haiku.readthedocs.io/en/latest/) to code neural network architectures, [optax](https://optax.readthedocs.io/en/latest/) to bring us the optimal parameters and [tokenizers](https://huggingface.co/docs/tokenizers/index) to quickly learning the vocabulary in each dataset. And the ubiquitous [numpy](https://numpy.org/) and [pandas](https://pandas.pydata.org/) for tensor handling. That's all!

It would be also nice, if you have access to GPU! And hopefully due to google colab you can immediately access a cudas-enabled environment pressing the the button below.

## Credits

The tutorial is created by [Vasilis Gkolemis](https://givasile.github.io/) and [Matko Bošnjak](https://matko.info). It is highly inspired by [Deep Learning with Python (DLP)](https://www.manning.com/books/deep-learning-with-python) the famous book of Francois Chollet, last year's [M2Lschool](https://github.com/M2Lschool/tutorials2022) NLP tutorial especially for the transformers part, the [annotated transformer](http://nlp.seas.harvard.edu/annotated-transformer/) presentation

## Note for Colab users

To grab a GPU (if available), make sure you go to `Edit -> Notebook settings` and choose a GPU under `Hardware accelerator`

## Practical 2: Introduction to the Transformers architecture

Welcome to Practical 2, where we will delve into the world of the Transformers architecture. Transformers were introduced in 2017 through the seminal work of Vaswani et. al, titled ["Attention is All You Need"](https://arxiv.org/abs/1706.03762). Since their inception, Transformers have garnered immense recognition, dominating the realm of Natural Language Processing (NLP) and serving as a wellspring of inspiration across various domains of Machine Learning, including computer vision and more.

This Practical will focus on the Transformers' encoder component (Transformers consist of two parts: the encoder and the decoder). In the end, we will use the Transformers' encoder to tackle the IMDB sentiment analysis task, which was introduced in Practical 1. The key aim of this session is to attain a comprehensive understanding of the core elements that constitute Transformers, with a significant emphasis on the foundational building-block: the self-attention layer. For this reason, we will code every from scratch.

In Practical 3, we will also code the decoder part (they share common building blocks) and use the complete Transformers architecture to solve natural language generation tasks; these are the tasks where the transformers really shine! In particular, we will attempt to train a English-Greek translation machine!


#### The following cells are identical to the ones implemented in Practical 1. Let's just run them to ensure that all crucial functions are in RAM.


In [101]:
# assert that all important packages are installed
!pip install dm-haiku optax tokenizers

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [102]:
import jax
import jax.numpy as jnp
import haiku as hk
import optax
import numpy as np
import pandas as pd
import tokenizers
import os
import typing

In [103]:
# forces JAX to allocate memory as needed
# see https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

In [104]:
def jax_has_gpu():
    try:
        _ = jax.device_put(jax.numpy.ones(1), device=jax.devices('gpu')[0])
        return True
    except:
        return False

gpu = jax_has_gpu()   # automatically checks for gpus, override if needed.

In [105]:
# initialize a key and a key iteration
init_seed = 21
key_iter = hk.PRNGSequence(jax.random.PRNGKey(init_seed))
key = jax.random.PRNGKey(init_seed)

In [106]:
def funcs_from_stateful(stateful_forward, jit=True):
    """
    Create and return stateless (pure JAX) functions from the stateful_forward function.

    Args:
        stateful_forward (Callable): The stateful forward function to be transformed.
        jit (bool, optional): Whether to jit-compile the functions. Defaults to True.

    Returns:
        Tuple: A tuple containing the transformed model with full PRNG sequence support (model),
        the transformed model with PRNG sequence support removed (model_fw), the predict function
        with or without jit compilation (predict), the predict function with PRNG sequence support
        removed and with or without jit compilation (predict_fw), and the init function for initializing
        the model parameters (init_params).
    """
    model = hk.transform(stateful_forward)
    model_fw = hk.without_apply_rng(model)

    if jit:
        predict = jax.jit(model.apply)
        predict_fw = jax.jit(model_fw.apply)
        init_params = jax.jit(model.init)
    else:
        predict = model.apply
        predict_fw = model_fw.apply
        init_params = model.init
    return model, model_fw, predict, predict_fw, init_params


In [107]:
def init_and_pred(key, x, mask):
    """
    Perform initialization and inference using the provided parameters.

    Args:
        key (jax.random.PRNGKey): Random number generator key for initialization.
        x (jnp.ndarray): Input data with shape (batch_size, num_features).
        mask (jnp.ndarray): Mask with shape (batch_size,) representing valid elements in xx.

    Returns:
        None: This function does not return any value. It prints the results of the inference.
    """
    key, *skey = jax.random.split(key, 3)
    params = init_params(skey[0], x, mask, False)

    print("Inference with predict, is_train=False")
    print(predict(params, skey[1], x, mask, False))

    print("Inference with predict_fw, is_train=False")
    print(predict_fw(params, x, mask, False))

    print("Inference with predict, is_train=True")
    print(predict(params, skey[1], x, mask, True))


In [108]:
@jax.jit
def get_loss(params, skey, x: jnp.ndarray, mask, y_gt: jnp.ndarray):
    """
    Compute the binary cross-entropy loss for the given input and ground truth.

    Args:
        params (List): A list containing the weight vector (W) and bias (b).
        skey (jax.random.PRNGKey): Random number generator key for prediction.
        x (jnp.ndarray): Input data with shape (batch_size, num_features).
        mask (jnp.ndarray): Mask with shape (batch_size,) representing valid elements in x.
        y_gt (jnp.ndarray): Ground truth labels with shape (batch_size,).

    Returns:
        jnp.ndarray: The computed loss value.
    """

    # Predict using skey state and is_train
    y = predict(params, skey, x, mask, is_train=True)

    # Compute the loss value
    loss_value = optax.sigmoid_binary_cross_entropy(y, y_gt).mean(axis=-1)

    return loss_value


In [109]:
@jax.jit
def train_step(params, key, opt_state, x, mask, y_gt):
    """
    Perform a single training step for the given input batch.

    Args:
        params (List): A list containing the weight vector (W) and bias (b).
        key (jax.random.PRNGKey): Random number generator key for randomness.
        opt_state (OptState): The state of the optimizer.
        x (jnp.ndarray): Input data with shape (batch_size, num_features).
        mask (jnp.ndarray): Mask with shape (batch_size,) representing valid elements in x.
        y_gt (jnp.ndarray): Ground truth labels with shape (batch_size,).

    Returns:
        Tuple: A tuple containing the updated parameters (params), the new optimizer state (opt_state),
        the computed loss value (loss), and the updated random number generator key (key).
    """

    # Move the random generator
    key, skey = jax.random.split(key)

    # Define gradients with respect to the loss function
    val_grad = jax.value_and_grad(get_loss)

    # Get loss and gradients with respect to the input batch
    loss, grads = val_grad(params, skey, x, mask, y_gt)

    # Get updates and new optimizer state, based on gradients and previous state
    updates, opt_state = optimizer.update(grads, opt_state, params)

    # Get new params based on previous params and updates
    params = optax.apply_updates(params, updates)

    return params, opt_state, loss, key


In [110]:
@jax.jit
def evaluate_on_batch(params, x: np.ndarray, mask, y_gt: np.ndarray):
    """
    Evaluate the model on a single input batch.

    Args:
        params (List): A list containing the weight vector (W) and bias (b).
        x (jnp.ndarray): Input data with shape (batch_size, num_features).
        mask (jnp.ndarray): Mask with shape (batch_size,) representing valid elements in x.
        y_gt (jnp.ndarray): Ground truth labels with shape (batch_size,).

    Returns:
        Tuple: A tuple containing the accuracy and confusion matrix.
            - accuracy (float): The accuracy of the model's predictions.
            - confusion_matrix (jnp.ndarray): The confusion matrix with shape (2, 2).
    """
    y_pred = predict_fw(params, x, mask, is_train=False)
    y_pred_binary = (y_pred > 0).astype(int)
    accuracy = jnp.mean(y_pred_binary == y_gt)

    # Compute confusion matrix
    true_positive = jnp.sum(jnp.logical_and(y_pred_binary == 1, y_gt == 1))
    false_positive = jnp.sum(jnp.logical_and(y_pred_binary == 1, y_gt == 0))
    true_negative = jnp.sum(jnp.logical_and(y_pred_binary == 0, y_gt == 0))
    false_negative = jnp.sum(jnp.logical_and(y_pred_binary == 0, y_gt == 1))

    confusion_matrix = jnp.array([[true_negative, false_positive],
                                  [false_negative, true_positive]])
    return accuracy, confusion_matrix

In [111]:
def evaluate(params, x: jnp.ndarray, mask, y_gt: jnp.ndarray, batch_encode, batch_size=32):
    """
    Evaluate the model on the given input data using batching.

    Args:
        params (List): A list containing the weight vector (W) and bias (b).
        x (jnp.ndarray): Input data with shape (num_samples, num_features).
        mask (jnp.ndarray): Mask with shape (num_samples,) representing valid elements in x.
        y_gt (jnp.ndarray): Ground truth labels with shape (num_samples,).
        batch_encode (Callable): Function to encode the input data into batches.
        batch_size (int, optional): Batch size. Defaults to 32.

    Returns:
        Tuple: A tuple containing the test accuracy and confusion matrix.
            - test_accuracy (float): The accuracy of the model's predictions on the test data.
            - cm (jnp.ndarray): The confusion matrix with shape (2, 2).
    """
    cm = jnp.zeros([2, 2])

    # Evaluate on the test set with batching
    test_accuracy = 0.0
    num_batches = int(len(x) / batch_size)
    for j in range(num_batches):
        start_idx = j * batch_size
        end_idx = start_idx + batch_size
        x_batch = x[start_idx:end_idx]
        y_batch = y_gt[start_idx:end_idx]

        if mask is not None:
            mask_batch = mask[start_idx:end_idx]
        else:
            mask_batch = None

        if batch_encode is not None:
            x_batch = batch_encode(x_batch)

        # Move to GPU
        if gpu:
            x_batch = jnp.array(x_batch)
            y_batch = jnp.array(y_batch)
            x_batch = jax.device_put(x_batch, jax.devices("gpu")[0])  # Assuming you have only one GPU
            y_batch = jax.device_put(y_batch, jax.devices("gpu")[0])  # Assuming you have only one GPU

        batch_accuracy, batch_cm = evaluate_on_batch(params, x_batch, mask_batch, y_batch)
        test_accuracy += batch_accuracy
        cm += batch_cm

    test_accuracy /= num_batches
    return test_accuracy, cm

In [112]:
def train(params: list,
          key,
          x_tr: jnp.ndarray,
          mask_tr: jnp.ndarray,
          y_tr: jnp.ndarray,
          epochs: int,
          batch_size: int,
          x_te: jnp.ndarray,
          mask_te: jnp.ndarray,
          y_te: jnp.ndarray,
          batch_encode: typing.Union[None, typing.Callable] = None,
          eval_every: int = 1,
          loss_every_batch: int = 32,
          gpu=False):
    """
    Train the model using mini-batch stochastic gradient descent.

    Args:
        params (List): A list containing the weight vector (W) and bias (b).
        key (jax.random.PRNGKey): Random number generator key for randomness.
        x_tr (jnp.ndarray): Training input data with shape (num_train_samples, num_features).
        mask_tr (jnp.ndarray): Mask with shape (num_train_samples,) representing valid elements in x_tr.
        y_tr (jnp.ndarray): Training ground truth labels with shape (num_train_samples,).
        epochs (int): Number of training epochs.
        batch_size (int): Batch size.
        x_te (jnp.ndarray): Test input data with shape (num_test_samples, num_features).
        mask_te (jnp.ndarray): Mask with shape (num_test_samples,) representing valid elements in x_te.
        y_te (jnp.ndarray): Test ground truth labels with shape (num_test_samples,).
        batch_encode (Union[None, Callable], optional): Function to encode the input data into batches. Defaults to None.
        eval_every (int, optional): Number of epochs between evaluations on train and test sets. Defaults to 1.
        loss_every_batch (int, optional): Number of batches between printing the loss value. Set to False to disable printing. Defaults to 32.
        gpu (bool, optional): Whether to use GPU acceleration. Defaults to False.

    Returns:
        List: The updated model parameters after training.
    """
    opt_state = optimizer.init(params)
    nof_instances = x_tr.shape[0]
    for e in range(epochs):
        nof_full_batches = nof_instances // batch_size
        for i in range(nof_full_batches):

            # Get batch
            batch_start = i * batch_size
            batch_end = (i + 1) * batch_size
            x_batch = x_tr[batch_start:batch_end]
            y_batch = y_tr[batch_start:batch_end]

            if mask_tr is not None:
                mask_batch = mask_tr[batch_start:batch_end]
            else:
                mask_batch = None

            # Vectorize if needed
            if batch_encode is not None:
                x_batch = batch_encode(x_batch)

            # Move to GPU
            if gpu:
                x_batch = jnp.array(x_batch)
                y_batch = jnp.array(y_batch)
                x_batch = jax.device_put(x_batch, jax.devices("gpu")[0])  # Assuming you have only one GPU
                y_batch = jax.device_put(y_batch, jax.devices("gpu")[0])  # Assuming you have only one GPU

            params, opt_state, loss, key = train_step(params, key, opt_state, x_batch, mask_batch, y_batch)

            if loss_every_batch is not False:
                if i % loss_every_batch == 0:
                    print("Epoch: %d, Step %d/%d, Loss: %.3f" % (e, i, nof_full_batches, loss))

        if eval_every is not False:
            if e % eval_every == 0:

                # Evaluate on the whole training set
                train_accuracy, train_cm = evaluate(params, x_tr, mask_tr, y_tr, batch_encode, batch_size)
                print("Epoch: %d, Train Accuracy: %.4f" % (e, train_accuracy))
                print("Confusion Matrix:\n", train_cm)

                # Evaluate on the test set
                test_accuracy, test_cm = evaluate(params, x_te, mask_te, y_te, batch_encode, batch_size)
                print("Epoch: %d, Test Accuracy: %.4f" % (e, test_accuracy))
                print("Confusion Matrix:\n", test_cm)

                print("\n")
    return params

In [113]:
# load dataset if needed
!curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
!tar -xf aclImdb_v1.tar.gz

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 80.2M  100 80.2M    0     0  3481k      0  0:00:23  0:00:23 --:--:-- 5295k
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [114]:
def read_all_txts(dir):
    # all files in dir
    files = os.listdir(dir)

    sentences = []
    for file in files:
        filepath = dir + '/' + file
        with open(filepath, 'r') as ff:
            for line in ff:
                line = line.strip()
                sentences.append(line)
    return sentences


def load_dataset(dir):
    pos_dir = dir + '/pos'
    neg_dir = dir + '/neg'

    # read all txts in dir
    pos_sentences = read_all_txts(pos_dir)
    neg_sentences = read_all_txts(neg_dir)

    # to a dataframe
    df_pos = pd.DataFrame({'text': pos_sentences, 'label': 1})
    df_neg = pd.DataFrame({'text': neg_sentences, 'label': 0})

    df = pd.concat([df_pos, df_neg])

    # shuffle
    df = df.sample(frac=1).reset_index(drop=True)
    return df

dir = 'aclImdb/train/'
df_tr = load_dataset(dir)
dir = 'aclImdb/test/'
df_te = load_dataset(dir)

df_tr = df_tr.dropna()
df_te = df_te.dropna()

X_tr = df_tr.iloc[:, 0].to_numpy()
Y_tr = df_tr.iloc[:, 1].to_numpy()

X_te = df_te.iloc[:, 0].to_numpy()
Y_te = df_te.iloc[:, 1].to_numpy()

print("The are %d training examples, where %d are positive and %d are negative reviews" % (df_tr.shape[0], df_tr.loc[df_tr["label"] == 1, :].shape[0], df_tr.loc[df_tr["label"] == 1, :].shape[0]))
print("The are %d testing examples, where %d are positive and %d are negative reviews" % (df_te.shape[0], df_te.loc[df_te["label"] == 1, :].shape[0], df_te.loc[df_te["label"] == 1, :].shape[0]))

The are 25000 training examples, where 12500 are positive and 12500 are negative reviews
The are 25000 testing examples, where 12500 are positive and 12500 are negative reviews


In [115]:
# define the tokenizer
tokenizer = tokenizers.Tokenizer(tokenizers.models.Unigram())
tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace()
tokenizer.normalizer = tokenizers.normalizers.Lowercase()

# say how it will be trained, to learn the vocabulary
vocab_size = 20000
special_tokens = ["[PAD]", "[CLS]", "[SEP]", "[MASK]",]
unk_token = "[UNK]"
max_piece_length = 16
trainer = tokenizers.trainers.UnigramTrainer(
    special_tokens=special_tokens, # special tokens
    vocab_size=vocab_size, # vocabulary size
    unk_token=unk_token, # set the unknown token
    show_progress=True, # show progress
)

# train the tokenizer
tokenizer.train_from_iterator(X_tr, trainer=trainer)





In [116]:
# now let's set it to a normal value
sequence_length = 600
truncation_length = 600
tokenizer.enable_padding(length=sequence_length)
tokenizer.enable_truncation(max_length=truncation_length)
def batch_encode(X):
    enc_list = tokenizer.encode_batch(X)
    enc_list_2 = []
    mask_list = []
    for i, enc in enumerate(enc_list):
        enc_list_2.append(enc.ids)
        mask_list.append(enc.attention_mask)
    return np.array(enc_list_2), np.array(mask_list)

In [117]:
# but since it is not memory-consuming, we can use the whole dataset
X_tr_enc, X_tr_mask = batch_encode(X_tr)
X_te_enc, X_te_mask = batch_encode(X_te)

print(X_tr_enc.shape)
print(X_te_enc.shape)

(25000, 600)
(25000, 600)


In [118]:
class MaskedEmbedding(hk.Module):
    def __init__(self, vocab_size, embed_size):
        super(MaskedEmbedding, self).__init__()
        self.vocab_size = vocab_size
        self.embed_size = embed_size

    def __call__(self, inputs, mask):
        """
        Apply masked embedding to the input tokens.

        Args:
            inputs (jnp.ndarray): Input data with shape (batch_size, sequence_length).
            mask (jnp.ndarray): Mask with shape (batch_size, sequence_length) representing valid elements in inputs.

        Returns:
            jnp.ndarray: Masked embeddings with shape (batch_size, sequence_length, embed_size).
        """
        # Define an Embedding layer
        embeddings = hk.Embed(self.vocab_size, self.embed_size)(inputs)

        # Expand mask to match the embedding shape
        mask_expanded = jnp.expand_dims(mask, axis=-1)  # (BS, S, 1)

        # Apply mask to zero out embeddings for masked tokens
        masked_embeddings = embeddings * mask_expanded

        return masked_embeddings


# 2 Introduction to the Transformers architecture

The content covered in this section is a summary of the material presented in last year's [M2Lschool](https://github.com/M2Lschool/tutorials2022). For those seeking a deeper dive into Transformers, the original source provides further in-depth information.

## 2.1 The Transformers Architecture

Transformers consist of two parts; the encoder and the decoder. The **encoder** is responsible for transforming raw input data, such as a sequence of words, into meaningful hidden representations and the **decoder** is used for predicting sequences of outputs, such as sequence of words. Therefore, the complete Transformer architecture (encoder and decoder) is appropriate for solving sequence-to-sequence tasks. For simpler classification tasks, such as the one that we will use in this Practical, we can only use the encoder part of the architecture.

Let's briefly introduce the encoder and decoder, along with their core logic:

#### Encoder

The encoder's role is to convert a sequence of words into dense, meaningful hidden representations, making them usable by other components like the decoder or other networks.

The Transformer Encoder (depicted in the left figure below) takes a source sequence, processes it using **Attention**, and passes the results through a **fully-connected feed-forward** block with pointwise non-linear activation. Residual connections and layer normalization are applied to both operations. This process is repeated $N$ times using stacked replicas to compute the final word representations.

#### Decoder

The decoder's task is to learn the alignment between the source and target sequences. For instance, in machine translation, the decoder learns which words to generate in the target language based on the words in the source language.

The Transformer Decoder (shown in the right figure below) can receive word representations as inputs and is given the **target sentence** during training to establish an association with the source.

Notably, the decoder incorporates two attention operations: the first involves masked self-attention, and the second attends to the encoder output. We will delve into more details on these operations later in the tutorial.


![alt text for screen readers](./../../images/transformers.png)
---


- Original paper: [Attention is All you Need](https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)
- In-depth guide on Transformer components: [Formal Definitions in Transformers](https://arxiv.org/pdf/2207.09238.pdf)
- Practical PyTorch Transformer walkthrough: [The Annotated Transformer](http://nlp.seas.harvard.edu/annotated-transformer/http://nlp.seas.harvard.edu/annotated-transformer/)

## 2.2 Implementing the Encoder from scratch

In this section, we will focus on implementing the essential building blocks that are utilized by both the encoder and the decoder. The core components of an encoder are:

- Self-Attention (Scaled Dot Product Attention)
- Multi-Headed Attention
- Feed-Forward Networks
- Positional Encoding


#### [ 📝 ] Self-attention

The intuition behind self-attention is that individual entities (tokens), have different meanings (representations) depending on their context. The self-attention layer, the fundamental building-block of transformers, enriches individual token embeddings to context-aware embeddings &#x1F449; [DLP](https://www.manning.com/books/deep-learning-with-python) &#x1F448;

<img src="./../../images/attention_mechanism_chollet.png">

The scaled dot product attention is defined as:

$$
attention(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V
$$

Where $d_k$ is a constant scalar. In the original paper, $d_k$ corresponds to the dimension of the query/key/value (they all share the same dimension).

For your first exercise, you will implement the scaled dot product attention (Equation above). You will notice that the function accepts a `mask` parameter. The mask allows us to *ignore* some portion of the sequence (typically, if any padding is present).

In [119]:
def scaled_dot_product(q, k, v, mask=None):
    """
    Scaled Dot Product Attention.

    Args:
        q (jnp.ndarray): query, shape (batch_size, ..., hidden_dim)
        k (jnp.ndarray): key, shape (batch_size, ..., hidden_dim)
        v (jnp.ndarray): values, shape (batch_size, ..., hidden_dim)
        mask (jnp.ndarray): values, shape (broadcastable to: B, ..., S, S)

    Returns:
        List(jnp.ndarray, jnp.ndarray):
            - attention output (batch_size, ..., hidden_dim)
            - attention_weights (batch_size, ...)
    """
    ###################
    # YOUR CODE HERE #
    d_k = q.shape[-1]
    scores = jnp.matmul(q, k.swapaxes(-2, -1)) / jnp.sqrt(d_k)  # (B,...,S,S)

    if mask is not None:
        scores = jnp.where(mask == 0, -1e9, scores)

    attention_weights = jax.nn.softmax(scores, axis=-1)
    values = jnp.matmul(attention_weights, v)
    # Steps:
    # (i) tensor multiplication between q, k to get the scores
    # (ii) if mask is not None, set the scores of masked out values a very small value
    # (iii) softmax the scores to get attention weights
    # (iv) tensor multiplication between attention_weights and values to get values
    ###################
    return values, attention_weights

Check that your implementation is correct by running the following cell. Check that the output shape is `(BS, S, DIM)` and `(BS, S, S)` for values and attention, respectively.

In [120]:
# generate a random input for testing
BS = 10
S = 100
DIM = 128
xx_emb = np.random.randn(BS, S, DIM)
mask = np.random.randint(0, 2, size=(BS, S))

# Testing Scaled Dot Product
values, attention = scaled_dot_product(xx_emb, xx_emb, xx_emb, jnp.expand_dims(mask, -1))

In [121]:
print("Values shape:", values.shape)
print("Attention shape:", attention.shape)
print("\n")

Values shape: (10, 100, 128)
Attention shape: (10, 100, 100)




Check that self-attention can generalize to any number of dimensions.
But also check that the self-attention happens only on the `[q|v|k].shape[-2]` dimension (not all the dimensions)
For example, see the ouput of the next cell. The attention shape is `(BS, H, W, W)` which means that the self-attention happens only along the `H` dimension.

In [122]:
# Check that self-attention works for any number of dimensions
# But also check that the self-attention happens only on the [q|v|k].shape[-2] dimension (not all except the first)
# You can check that looking at the attention shape (next cell)
BS = 10
H = 32
W = 64
DIM = 8
xx_image = np.random.randn(BS, H, W, DIM)
mask_image = np.random.randint(0, 2, size=(BS, H, W))

In [123]:
# Testing Scaled Dot Product
values, attention = scaled_dot_product(xx_image, xx_image, xx_image, jnp.expand_dims(mask_image, -1))

print("Values shape:", values.shape)
print("Attention shape:", attention.shape)
print("\n")

Values shape: (10, 32, 64, 8)
Attention shape: (10, 32, 64, 64)




#### [ 📝 ] Multi-Headed Attention

In the scaled dot product attention, an element of the sequence can attend to any other element, but it cannot focus on multiple aspects of the sequence simultaneously. To address this limitation, we introduce the concept of multi-headed attention.

In the encoder unit of the Transformer, the first layer applies a *multi-headed self-attention*, which means that words within the sequence interact and align with each other (self-attention), and multiple different alignments are learned simultaneously (multi-headed) through separate *attention heads*. This learning paradigm, combined with a linguistically founded training objective, has contributed to the success of modern language models.

With multi-headed attention, we have $h$ attention heads, where each attention head is a linear projection of the sequence $Q$, $K$, and $V$:

$$
\text{attention}(Q, K, V) = \text{concat}(head_1, ..., head_h)W^O
$$

$$
\text{head}_i = \text{attention}(QW^Q_i, KW^K_i, VW^V_i)
$$

Here, $W^Q_i \in \mathbb{R}^{d_q \times d_k/h}$, $W^K_i \in \mathbb{R}^{d_k \times d_k/h}$, $W^V_i \in \mathbb{R}^{d_v \times d_v/h}$, and $W^O \in \mathbb{R}^{hd_v \times d_v}$. It's important to note that $d_k$ and $d_v$ have the same dimension, so $d_v/h$ is equal to $d_k/h$.

When implementing multi-headed attention, we first apply linear projections for $Q$, $K$, and $V$ using matrix multiplication, and then split the results into $h$ heads. Next, we apply the scaled dot product attention independently for each attention head and concatenate the results.
To facilitate your implementation, we provide a Haiku Module for Multi-Headed Attention. Take your time to review the class and understand its main components.

In [124]:
class MultiheadAttention(hk.Module):
    def __init__(self, d_model: int, num_heads: int, name=None):
        """
        Multi-Headed Attention Module.

        Args:
            d_model (int): The dimension of the model, i.e. last dimension of q, k, v matrices (should be divisible by num_heads)
            num_heads (int): The number of attention heads.
            name (str): Optional name of the module.
        """
        super().__init__(name=name)
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = self.d_model // self.num_heads
        self.proj_q = hk.Linear(self.d_model)
        self.proj_k = hk.Linear(self.d_model)
        self.proj_v = hk.Linear(self.d_model)
        self.proj_o = hk.Linear(self.d_model)
        # self.lin_projs = [hk.Linear(self.d_model) for _ in range(4)]

    def __call__(self, q, k, v, mask=None):
        """
        Apply Multi-Headed Attention.

        Args:
            q (jnp.ndarray): Query tensor with shape (batch_size, sequence_length, d_model).
            k (jnp.ndarray): Key tensor with shape (batch_size, sequence_length, d_model).
            v (jnp.ndarray): Value tensor with shape (batch_size, sequence_length, d_model).
            mask (jnp.ndarray): Mask tensor with shape (batch_size, sequence_length) representing valid elements in the input sequences.

        Returns:
            jnp.ndarray: Output tensor after multi-headed attention with shape (batch_size, sequence_length, d_model).
            jnp.ndarray: Attention weights tensor with shape (batch_size, num_heads, sequence_length, sequence_length).
        """
        ###################
        # YOUR CODE HERE #
        batch_size, seq_length, d_model = q.shape

        # # Reshape q, k, v
        # q, k, v = [
        #     lin_p(t).reshape(batch_size, -1, self.num_heads, self.d_k).swapaxes(1, 2)
        #     for lin_p, t in zip(self.lin_projs, (q, k, v))
        # ]  # (B,h,S,d_k)

        # for each of q, k, v
        # (a) linear projection from (B, D, d_model) to (B, D, d_model)
        # (b) to (B, h, S, d_k)
        q = self.proj_q(q).reshape(batch_size, -1, self.num_heads, self.d_k).swapaxes(1, 2)
        k = self.proj_k(k).reshape(batch_size, -1, self.num_heads, self.d_k).swapaxes(1, 2)
        v = self.proj_v(v).reshape(batch_size, -1, self.num_heads, self.d_k).swapaxes(1, 2)

        # add a dimension to mask
        if mask is not None:
            mask = jnp.expand_dims(mask, 1)  # expand to (B,h,...)

        # scaled dot product
        values, attention = scaled_dot_product(q, k, v, mask=mask)
        # values shape: (B,h,S,d_k)
        # attention shape: (B, h, S, S)

        # values reshape from (B,h,S,d_k) to (B, S, h*d_k)
        values = values.transpose(0, 2, 1, 3) # first move h, d_k to last dimensions
        values = values.reshape(batch_size, seq_length, d_model)  # concat heads

        # last projection
        values = self.proj_o(values)
        y = values , attention
        ###################
        # Steps:
        # (i) for each of q, k, v
        # (a) linear projection from (B, D, d_model) to (B, D, d_model)
        # (b) reshape from (B, D, d_model) to (B, h, S, d_k)
        # (ii) add a dimension to the mask (from (B, ...) to (B, 1, ...)
        # (iii) apply scaled dot product to (h, q, v)
        # note that input of (B,h,S,d_k) will give us output of (B,h,S,d_k) and attention matrix of (B, h, S, S)
        # which is exactly what we want
        # (iv) reshape output from (B,h,S,d_k) to (B, S, h*d_k)
        # (v) last linear projection from (B, S, h*d_k) to (B, S, h*d_k)
        # (vi) return output, attention
        ###################
        return y


In [125]:
# Test MultiheadAttention implementation
bs = xx_emb.shape[0]
seq_len = xx_emb.shape[1]
d_model = xx_emb.shape[-1]
num_heads = 16
def stateful_forward(q, k, v, mask=None):
    mha = MultiheadAttention(d_model, num_heads, name="mha")
    return mha(q, k, v, mask)

In [126]:
model, model_fw, predict, predict_fw, init_params = funcs_from_stateful(stateful_forward, jit=False)

In [127]:
key, skey = jax.random.split(key)

params = init_params(key, xx_emb, xx_emb, xx_emb, jnp.expand_dims(mask, -1))
out = predict_fw(params, xx_emb, xx_emb, xx_emb, jnp.expand_dims(mask, -1))
output_matrix = out[0]
attention_matrix = out[1]
print(output_matrix.shape)
print(attention_matrix.shape)

(10, 100, 128)
(10, 16, 100, 100)


  param = init(shape, dtype)


#### [ 📝 ] Feed Forward Sublayer

This sublayer is composed of a fully-connected feed-forward network. The main idea is to learn a linear transformation of the hidden representation of the previous layer. This layer has an inner hidden layer of size `d_ff`, and an inner activation function (e.g., ReLU). The `PositionwiseFeedForward` class below implements this sub-layer. It is initialized using the parameters:

- `d_model`: size of the hidden representation of the input.
- `d_ff`: inner size of the hidden layer.
- `p_dropout`: dropout probability (dropout will be applied during training).

The `PositionwiseFeedForward` class implements the `__call__` method. It takes as input the previous layer's hidden representation and returns the current layer's hidden representation by applying the fully-connected network.

In [128]:
class PositionwiseFeedForward(hk.Module):
    """
    This class is used to create a position-wise feed-forward network.

    Args:
        d_model (int): The size of the embedding vector.
        d_ff (int): The size of the hidden layer.
        p_dropout (float, optional): The dropout probability. Default is 0.1.
    """
    ###################
    # YOUR CODE HERE #
    def __init__(self, d_model: int, d_ff: int, p_dropout: float = 0.1, name=None):
        super().__init__(name=name)
        self.d_model = d_model
        self.d_ff = d_ff
        self.p_dropout = p_dropout

        self.w_1 = hk.Linear(self.d_ff)
        self.w_2 = hk.Linear(self.d_model)

    def __call__(self, x, is_train=True):
        """
        Apply the position-wise feed-forward network.

        Args:
            x (jnp.ndarray): The input sequence with shape (batch_size, sequence_length, d_model).
            is_train (bool, optional): Whether the model is in training mode. Default is True.

        Returns:
            jnp.ndarray: The output of the position-wise feed-forward network with shape (batch_size, sequence_length, d_model).
        """
        x = jax.nn.relu(self.w_1(x))
        if is_train:
            x = hk.dropout(hk.next_rng_key(), self.p_dropout, x)

        y = self.w_2(x)
        ###################
        return y

#### [ 📝 ] Encoder Block

And let's pack the Multihead Attention and the PositionwiseFeedForward in a single block.

In [129]:
class EncoderBlock(hk.Module):
    """
    This class is used to create an encoder block.

    Args:
        d_model (int): The size of the embedding vector.
        num_heads (int): The number of attention heads.
        d_ff (int): The size of the hidden layer.
        p_dropout (float): The dropout probability.
    """
    def __init__(self, d_model, num_heads, d_ff, p_dropout, name=None):
        super().__init__(name=name)
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.p_dropout = p_dropout

        # self-attention sub-layer
        self.self_attn = MultiheadAttention(
            d_model=self.d_model, num_heads=self.num_heads
        )
        # positionwise feedforward sub-layer
        self.ff = PositionwiseFeedForward(
            d_model=self.d_model, d_ff=self.d_ff, p_dropout=self.p_dropout
        )

        self.norm1 = hk.LayerNorm(
            axis=-1, param_axis=-1, create_scale=True, create_offset=True
        )
        self.norm2 = hk.LayerNorm(
            axis=-1, param_axis=-1, create_scale=True, create_offset=True
        )

    def __call__(self, x, mask=None, is_train=True):
        """
        Apply the encoder block to the input sequence.

        Args:
            x (jnp.ndarray): The input sequence with shape (batch_size, sequence_length, d_model).
            mask (jnp.ndarray, optional): The mask to be applied to the self-attention layer with shape (batch_size, sequence_length). Default is None.
            is_train (bool, optional): Whether the model is in training mode. Default is True.

        Returns:
            jnp.ndarray: The output of the encoder block, which is the updated input sequence with shape (batch_size, sequence_length, d_model).
        """
        d_rate = self.p_dropout if is_train else 0.0

        # attention sub-layer
        sub_x, _ = self.self_attn(x, x, x, mask=mask)
        if is_train:
            sub_x = hk.dropout(hk.next_rng_key(), self.p_dropout, sub_x)
        x = self.norm1(x + sub_x)  # residual conn

        # feedforward sub-layer
        sub_x = self.ff(x, is_train=is_train)
        if is_train:
            sub_x = hk.dropout(hk.next_rng_key(), self.p_dropout, sub_x)
        x = self.norm2(x + sub_x)  # sub_x

        return x

In [130]:
# Testing the Encoder block"""
d_ff = 128

def stateful_forward(x, mask, is_train):
    bl = EncoderBlock(d_model=d_model, num_heads=num_heads, d_ff=d_ff, p_dropout=0.1)
    return bl(x, mask, is_train)

In [131]:
model, model_fw, predict, predict_fw, init_params = funcs_from_stateful(stateful_forward, jit=False)

In [132]:
key, skey = jax.random.split(key)
is_train = True
params = init_params(key, xx_emb,  jnp.expand_dims(mask, -1), is_train)
out = predict(params, key, xx_emb,  jnp.expand_dims(mask, -1), is_train)
print(out[0].shape)
print(out[1].shape)

(100, 128)
(100, 128)


#### Transformer Encoder - multiple encoder blocks

As introduced in the previous sections, the Transformer encoder is composed of multiple *encoder blocks*. The `TransformerEncoder` class below implements it by stacking $N$ `EncoderBlock`s, where $N$ is the number of stacked encoder blocks.

This class inputs the same set of parameters as the `EncoderBlock` class and adds the parameter `num_layers` to specify the number of stacked encoder blocks.

In [133]:
class TransformerEncoder(hk.Module):
    """
    This class is used to create a transformer encoder.
    :param num_layers: The number of encoder blocks.
    :param num_heads: The number of attention heads.
    :param d_model: The size of the embedding vector.
    :param d_ff: The size of the hidden layer.
    :param p_dropout: The dropout probability.
    """

    def __init__(self, num_layers, num_heads, d_model, d_ff, p_dropout, name=None):
        super().__init__(name=name)
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.d_model = d_model
        self.d_ff = d_ff
        self.p_dropout = p_dropout

        self.layers = [
            EncoderBlock(self.d_model, self.num_heads, self.d_ff, self.p_dropout)
            for _ in range(self.num_layers)
        ]

    def __call__(self, x: typing.List[int], mask=None, is_train=True):
        """
        It applies the transformer encoder to the input sequence.
        :param x: The input sequence.
        :param mask: The mask to be applied to the self-attention layer.
        :param is_train: Whether the model is in training mode.
        :return: The final output of the encoder that contains the last encoder block output.
        """
        for l in self.layers:
            x = l(x, mask=mask, is_train=is_train)
        return x

In [134]:
# Test the encoder
num_layers = 4
p_dropout = 0.1
def stateful_forward(x, mask, is_train):
    enc = TransformerEncoder(num_layers, num_heads, d_model, d_ff, p_dropout, "t_enc")
    return enc(x, mask, is_train)

In [135]:
model, model_fw, predict, predict_fw, init_params = funcs_from_stateful(stateful_forward, jit=False)

In [136]:
key, skey = jax.random.split(key)
is_train = True
params = init_params(key, xx_emb,  jnp.expand_dims(mask, -1), is_train)
out = predict(params, key, xx_emb,  jnp.expand_dims(mask, -1), is_train)
print(out.shape)

(10, 100, 128)


#### Positional Encoding

The Transformer model does not use recurrent or convolutional layers in the encoder/decoder of the model (only attention mechanisms). However, this also has a drawback: since the model has no memory (no recurrent/convolutional layers), it can not take into account the *order* of the sequence elements. The position of words in the sequence is thus not encoded explicitly by the model.

As a solution to this issue, the original Transformer model uses a *positional encoding* scheme to represent the position of each element in the sequence. The positional encoding is added to the token embeddings of each element. Following the original paper, positional encodings are generated with multiple sinusoidal functions with varying frequencies.

Positional encoding is defined as:

$$\text{PE}(pos, 2i) = \sin \left( \frac{pos}{1000^{2i/d_{\text{model}}}} \right)$$
$$\text{PE}(pos, 2i+1) = \cos \left( \frac{pos}{1000^{2i/d_{\text{model}}}} \right)$$

where $pos$ is the position of the element in the sequence, $d_{\text{model}}$ is the model's embedding dimension, and $i$ is the index of the position vector. Note that this is not a learned parameter; the values are pre-computed and added to the token embeddings at the beginning of the forward pass.

Note that, we can optionally apply dropout to the positional encodings during training, thus providing additional regularization for the model.

📚 **Resources**

- Detailed explanation with visual aids: [Understanding Positional Encoding in Transformers](https://erdem.pl/2021/05/understanding-positional-encoding-in-transformers)

In [137]:
class PositionalEncoding(hk.Module):
    """
    This class is used to add positional encoding to the input sequence.

    Args:
        d_model (int): The size of the embedding vector.
        max_len (int): The maximum length of the input sequence.
        p_dropout (float, optional): The dropout probability. Default is 0.1.
    """

    def __init__(self, d_model: int, max_len: int, p_dropout: float = 0.1, name=None):
        """
        Initialize PositionalEncoding module.

        Args:
            d_model (int): The size of the embedding vector.
            max_len (int): The maximum length of the input sequence.
            p_dropout (float, optional): The dropout probability. Default is 0.1.
        """
        super().__init__(name=name)
        self.d_model = d_model
        self.max_len = max_len
        self.p_dropout = p_dropout

        pe = jnp.zeros((self.max_len, self.d_model))
        position = jnp.arange(0, self.max_len, dtype=jnp.float32)[:, None]
        div_term = jnp.exp(jnp.arange(0, self.d_model, 2) * (-jnp.log(10000.0) / self.d_model))
        pe.at[:, 0::2].set(jnp.sin(position * div_term))
        pe.at[:, 1::2].set(jnp.cos(position * div_term))
        pe = pe[None]
        self.pe = jax.device_put(pe)

    def __call__(self, x, is_train=True):
        """
        Apply positional encoding to the input sequence.

        Args:
            x (jnp.ndarray): The input sequence with shape (batch_size, sequence_length, d_model).
            is_train (bool, optional): Whether the model is in training mode. Default is True.

        Returns:
            jnp.ndarray: The input sequence with positional encoding.
        """
        x = x + self.pe[:, : x.shape[1]]
        if is_train:
            return hk.dropout(hk.next_rng_key(), self.p_dropout, x)
        else:
            return x


## 2.3 Lets Build a transformer model

In this section, we will implement the Transformer encoder and apply it to the task of word-level language modeling. We have implemented each base operation in the previous sections, so we will combine all these to train a language model.

###  Combining all together: the Transformer Encoder

The Transformer encoder is composed of multiple *encoder blocks*. Each of these blocks comprises two sub-layers: a *multi-head self-attention layer*, and a *feed-forward network*. There is also a residual connection around each sub-layer, followed by *layer normalization*. See the Figure above for a detailed diagram of a single encoder block.

In [138]:
# some global variables
MASK_PROBABILITY = 0.15
NUM_LAYERS = 2
NUM_HEADS = 2
D_MODEL = xx_emb.shape[2]
D_FF = 124
P_DROPOUT = 0.1
MAX_SEQ_LEN = xx_emb.shape[1]

### [ 📝 ]  Let's build the transformer classifier

In [139]:
def stateful_forward(input_ids, mask, is_train=True):

    embed_dim = 32
    num_layers = 2
    num_heads = 8
    d_ff = 32
    p_dropout = .1
    max_seq_length = 600
    vocab_size = 20000
    pe = PositionalEncoding(embed_dim, max_seq_length, p_dropout)
    embeddings = MaskedEmbedding(vocab_size, embed_dim)
    encoder = TransformerEncoder(num_layers, num_heads, embed_dim, d_ff, p_dropout)

    out = embeddings(input_ids, mask)
    if len(out.shape) == 2:
        out = out[None, :, :]
    out = pe(out, is_train=is_train)  # (B,S,d_model)
    out = encoder(out, mask=jnp.expand_dims(mask, 1), is_train=is_train)

    # Apply the first linear layer to the masked sequence
    out = hk.Linear(1)(out).squeeze()  # You can adjust the output size (128) to your preference

    # Apply global max pooling along the sequence dimension
    max_pooled_x = jnp.max(out, axis=-1)  # (BS, H)

    # Apply the second linear layer to get the final classification logits
    out = hk.Linear(1)(out).squeeze()

    return out

In [140]:
model, model_fw, predict, predict_fw, init_params = funcs_from_stateful(stateful_forward, False)

In [141]:
# check inference
xx = X_tr_enc[:4]
mask = X_tr_mask[:4]
key, skey = jax.random.split(key)

init_and_pred(key, xx, mask)

Inference with predict, is_train=False
[-0.51695037  0.01022601 -0.5831688  -0.52356744]
Inference with predict_fw, is_train=False
[-0.51695037  0.01022601 -0.5831688  -0.52356744]
Inference with predict, is_train=True
[-0.7890152  -0.16649416 -0.2441364   0.26862055]


In [142]:
# init params
key, skey = jax.random.split(key)
is_train = True
batch_size = 8
params = init_params(skey, xx, mask, is_train)

In [143]:
# train loop
# initialize optimizer
lr = .001
optimizer = optax.adam(learning_rate=lr)
opt_state = optimizer.init(params)
epochs = 2
params = train(params,
               skey,
               X_tr_enc,
               X_tr_mask,
               Y_tr,
               epochs=epochs,
               batch_size=batch_size,
               x_te=X_te_enc,
               mask_te=X_te_mask,
               y_te=Y_te,
               batch_encode=None,
               eval_every = 1,
               loss_every_batch=128)

lr = .0001
optimizer = optax.adam(learning_rate=lr)
opt_state = optimizer.init(params)
epochs = 3
params = train(params,
               skey,
               X_tr_enc,
               X_tr_mask,
               Y_tr,
               epochs=epochs,
               batch_size=batch_size,
               x_te=X_te_enc,
               mask_te=X_te_mask,
               y_te=Y_te,
               batch_encode=None,
               eval_every = 1,
               loss_every_batch=128)

Epoch: 0, Step 0/3125, Loss: 0.790
Epoch: 0, Step 128/3125, Loss: 0.717
Epoch: 0, Step 256/3125, Loss: 0.689
Epoch: 0, Step 384/3125, Loss: 0.681
Epoch: 0, Step 512/3125, Loss: 0.656
Epoch: 0, Step 640/3125, Loss: 0.622
Epoch: 0, Step 768/3125, Loss: 0.939
Epoch: 0, Step 896/3125, Loss: 0.736
Epoch: 0, Step 1024/3125, Loss: 0.391
Epoch: 0, Step 1152/3125, Loss: 0.613
Epoch: 0, Step 1280/3125, Loss: 0.777
Epoch: 0, Step 1408/3125, Loss: 0.166
Epoch: 0, Step 1536/3125, Loss: 0.551
Epoch: 0, Step 1664/3125, Loss: 0.353
Epoch: 0, Step 1792/3125, Loss: 1.082
Epoch: 0, Step 1920/3125, Loss: 0.545
Epoch: 0, Step 2048/3125, Loss: 0.510
Epoch: 0, Step 2176/3125, Loss: 0.463
Epoch: 0, Step 2304/3125, Loss: 0.239
Epoch: 0, Step 2432/3125, Loss: 0.276
Epoch: 0, Step 2560/3125, Loss: 0.104
Epoch: 0, Step 2688/3125, Loss: 0.395
Epoch: 0, Step 2816/3125, Loss: 0.446
Epoch: 0, Step 2944/3125, Loss: 0.289
Epoch: 0, Step 3072/3125, Loss: 0.412
Epoch: 0, Train Accuracy: 0.8706
Confusion Matrix:
 [[10499.

KeyboardInterrupt: 

### [ 📝 ] Inspect and Discuss

(You may use the same inputs as in the previous Ispection and Discussion section)

Play around with the model.Try to find some weaknesses using some inputs either from the training set or using your own inputs.Keep in mind that the model does not treat the input as sequence but as a set of words.Therefore, it does not take into account the order of the words.Can you find a characteristic input that reveals this weakness?



In [144]:
def predict_on_custom_input(test_input):
    test_input = np.array([test_input])
    test_enc, mask = batch_encode(test_input)
    print(model_fw.apply(params, test_enc, mask, False))

In [145]:
##################
# YOUR CODE HERE #
# real positive impact
test_input = "I have seen a lot of excellent movies and this one is one of them"
predict_on_custom_input(test_input)

# the word not inverts the comment, but the model cannot understand, it simply understands not as something that makes the comment more negative
test_input = "I have seen a lot of excellent movies and this one is not one of them"
predict_on_custom_input(test_input)

# same effect
test_input = "This is, for sure, one of the bad movies."
predict_on_custom_input(test_input)

# same effect
test_input = "This was a worth-seeing movie but, for sure, not one of the best."
predict_on_custom_input(test_input)
##################

7.471323
7.1996565
-4.1065116
5.3272247


## 2.4 Discussion

So are transformers all we need?

- How does the accuracy of the transformers model compare to the ones from the previous chapter?

Hidden answer:

Although we tried so much to code quite involved sequential models, first with LSTM and then with transformers, you may have noticed that the best accuracy is achieved with a simple Fully Connected Model after multi-hot encoding. In reality, the best sequential models have almost similar accuracy. Why is this the case? Our dear &#x1F449; [Deep Learning with Python (DLP)](https://www.manning.com/books/deep-learning-with-python) &#x1F448; has a rule of thumb:


> You may sometimes hear that bag-of-words methods are outdated, and that Transformer-
based sequence models are the way to go, no matter what task or dataset you’re look-
ing at. This is definitely not the case: a small stack of Dense layers on top of a bag-of-
bigrams remains a perfectly valid and relevant approach in many cases. In fact, among
the various techniques that we’ve tried on the IMDB dataset throughout this chapter,
the best performing so far was the bag-of-bigrams!
So, when should you prefer one approach over the other?
In 2017, my team and I ran a systematic analysis of the performance of various text-
classification techniques across many different types of text datasets, and we discov-
ered a remarkable and surprising rule of thumb for deciding whether to go with a
bag-of-words model or a sequence model (http://mng.bz/AOzK)—a golden constant
of sorts.
It turns out that when approaching a new text-classification task, you should pay
close attention to the ratio between the number of samples in your training data and
the mean number of words per sample (see figure 11.11). If that ratio is small—less
than 1,500—then the bag-of-bigrams model will perform better (and as a bonus, it will
be much faster to train and to iterate on too). If that ratio is higher than 1,500, then
you should go with a sequence model. In other words, sequence models work best
when lots of training data is available and when each sample is relatively short.

- Try one or two custom inputs on the transformer model. Do you see any differences in the results? If yes could you explain why?