Natural Language Processing Tutorial
======

This is the tutorial of the 2023 [Mediterranean Machine Learning Summer School](https://www.m2lschool.org/) on Natural Language Processing!

This tutorial will explore the fundamental aspects of Natural Language Processing (NLP). Whether you are new to NLP or just beginning your journey, there's no need to worry, as the tutorial assumes minimal prior knowledge. Our focus will be on implementing everything from scratch to ensure clarity and understanding. To facilitate this, we will be using [JAX](https://jax.readthedocs.io/en/latest/), a library that offers an API similar to Numpy (and often identical) with the added benefit of automatic differentiation.

## Outline

- <span style="color:blue">0 Refresher on JAX, Haiku and Optax </span>
- <span style="color:blue">1 Introduction to NLP </span>
  - <span style="color:blue">1.1 The NLP pipeline </span>
  - <span style="color:blue">1.2 Classification pipeline: Multi-hot encoding + MLP model </span>
  - <span style="color:blue">1.3 Classification pipeline: Embeddings + Sequential Model </span>
- 2 Introduction to the Transformers architecture
  - 2.1 Transformer architecture
  - 2.2 Implementing the core components
  - 2.3 Transformer for classification pipeline
- 3 Advanced: Transformers for language translation
  - 3.1 The Transformer Decoder
  - 3.2 Transformer Decoder for character-based Language Modelling
  - 3.3 The full Transformer
  - 3.4 Transformer for Neural Machine Translation

## Emojis

Sections marked as [ 📝 ] contain cells with missing code that you should complete,[ &#x1F4C4; ] is used for links to interesting external resources. When we use the words of an external resource we will cite it with &#x1F449; resource &#x1F448;.

## Libraries

We will keep our promise that (almost &#x1F60B; ) everything will be built from scratch. Indeed, all the vital and challenging components will be developed from zero. In fact the whole tutorial could be done only based on JAX. However, we recognize that certain minor technical details can become distracting if much time spent for them. For these tiny bits, we will ask help from: [haiku](https://dm-haiku.readthedocs.io/en/latest/) to code neural network architectures, [optax](https://optax.readthedocs.io/en/latest/) to find the optimal parameters and [tokenizers](https://huggingface.co/docs/tokenizers/index) to learn the vocabulary in each dataset. And the ubiquitous [numpy](https://numpy.org/) and [pandas](https://pandas.pydata.org/) for tensor handling. That's all!

It would be also nice, if you have access to GPU! And hopefully due to google colab you can immediately access a cuda(s)-enabled environment pressing the button below.

## Credits

The tutorial is created by [Vasilis Gkolemis](https://givasile.github.io/) and [Matko Bošnjak](https://matko.info). It is mostly inspired by [Deep Learning with Python (DLP)](https://www.manning.com/books/deep-learning-with-python) the famous book of Francois Chollet, last year's [M2Lschool](https://github.com/M2Lschool/tutorials2022) NLP tutorial (especially for the transformers part) and the [annotated transformer](http://nlp.seas.harvard.edu/annotated-transformer/) presentation.

## Note for Colab users

To grab a GPU (if available), make sure you go to `Edit -> Notebook settings` and choose a GPU under `Hardware accelerator`

In [1]:
!pip install dm-haiku optax tokenizers



In [2]:
import jax
import jax.numpy as jnp
import haiku as hk
import optax
import numpy as np
import pandas as pd
import tokenizers
import os
import typing

In [3]:
# forces JAX to allocate memory as needed
# see https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

In [4]:
# haiku provides a nice helper for returning a seed generator
init_seed = 21
key_iter = hk.PRNGSequence(jax.random.PRNGKey(init_seed))
key = jax.random.PRNGKey(init_seed)

In [5]:
def jax_has_gpu():
    try:
        _ = jax.device_put(jax.numpy.ones(1), device=jax.devices('gpu')[0])
        return True
    except:
        return False

gpu = jax_has_gpu()   # automatically checks for gpus, override if needed.
print("Is my instance gpu-ready?", gpu)

Is my instance gpu-ready? True


# Practical 0: Refresher on JAX, Haiku and Optax

In Part 0, we will (a) briefly introduce the libraries of the tutorial (JAX, Haiku, Optax) and (b) implement a linear classifier. Apart from getting familiarized with the libraries, we will also implement some crucial accessories (training, evaluation loops) that we will reused throughout the tutorial.

## 0.1 Create a linear dataset

To remember the key concepts of JAX, Haiku and Optax, we will solve a very simple linear classification problem:

$$
y = \begin{cases}
1, & \text{if } wx^T + b > 0 \\
0, & \text{otherwise}
\end{cases}
$$

where $ w \in \mathbb{R}^D $ and $ b \in \mathbb{R} $.

In [6]:
def generate_dataset(key, n_samples: int, dim: int):
    key, skey = jax.random.split(key)
    x = jax.random.normal(skey, (n_samples, dim))

    # generate ground truth
    key, skey1, skey2 = jax.random.split(key, num=3)
    w_gt = jax.random.normal(skey1, (dim, 1))
    b_gt = jax.random.normal(skey2, (1, 1))
    y_gt = (jnp.dot(x, w_gt) + b_gt).squeeze()

    # convert to 0, 1
    y_gt = (y_gt > 0).astype(jnp.int32)

    return x, y_gt, [w_gt, b_gt]

In [7]:
# generate dataset
n_samples = 10_000  # number of training instances
dim = 20  # feature vector dimension
x, y_gt, params_gt = generate_dataset(key, n_samples, dim)

# split into train and test
n_tr = int(n_samples * 0.8)
x_tr, y_tr = x[:n_tr, :], y_gt[:n_tr]
x_te, y_te = x[n_tr:, :], y_gt[n_tr:]

## 0.2 [ 📝 ] Create a linear classifier with Haiku

[ 📝 ]Fill in the MyLinear module to create a custom linear layer. The module should define two parameters, `w` and `b`, and perform the forward pass:

In [8]:
# create a (custom) linear model
class MyLinear(hk.Module):
    def __init__(self):
        """
        Initializes a custom linear layer module.
        """
        # there is no need to add something here
        # except if you want to add some parameters
        # in the initialization of the module
        super().__init__()

    def __call__(self, x):
        """
        Performs the forward pass for the custom linear layer.

        Args:
            x (jnp.ndarray): Input data with shape (batch_size, num_features).

        Returns:
            jnp.ndarray: The output of the linear layer with shape (batch_size,).
        """
        ##################
        # YOUR CODE HERE #
        D = x.shape[-1]
        w_init = hk.initializers.RandomNormal(stddev=0.1)
        b_init = hk.initializers.RandomNormal(stddev=0.1)
        w = hk.get_parameter("w", shape=[D, 1], dtype=x.dtype, init=w_init)
        b = hk.get_parameter("b", shape=[1], dtype=x.dtype, init=b_init)
        y = (jnp.dot(x, w) + b).squeeze()
        # Hint: https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html#A-first-example-with-hk.transform
        # (i) define the initializers of w and b
        # (ii) define w and b as parameters of the module
        # (iii) perform the linear projection: y = w*x + b
        # (iv) make sure the output y has shape (batch_size,)
        # return y
        ##################
        return y

[ 📝 ] Define the stateful forward function. The arguments `mask` and `is_train` are not needed for this part, but will be used later in the tutorial. For now, just ignore them. The function should return the output of the custom linear layer.

In [9]:
def stateful_forward(x, mask, is_train):
    """
    Perform the forward pass.

    Args:
        x (jnp.ndarray): Input data with shape (batch_size, num_features).
        mask (jnp.ndarray): Mask with shape (batch_size,) representing valid elements in x.
        is_train (bool): Flag indicating whether the model is in training mode.

    Returns:
        jnp.ndarray: The output of the custom linear layer with shape (batch_size,).
    """
    linear = MyLinear()
    y = linear(x)
    return y


In the following cells, we (a) transform the stateful forward function to a stateless one (as JAX demands), (b) initialize the parameters of the model and (c) perform a forward pass to make sure everything works as expected.

In [10]:
def funcs_from_stateful(stateful_forward, jit=True):
    """
    Helping function for getting stateless (pure JAX) functions from the stateful_forward function.

    Args:
        stateful_forward (Callable): The stateful forward function to be transformed.
        jit (bool, optional): Whether to jit-compile the functions. Defaults to True.

    Returns:
        Tuple: A tuple containing:
        - model: the transformed model with rng
        - model_fw: the transformed model without rng
        - predict: the predict function with rng
        - predict_fw: the predict function without rng
        - init_params: the init function for initializing the model parameters
    """
    model = hk.transform(stateful_forward)
    model_fw = hk.without_apply_rng(model)

    if jit:
        predict = jax.jit(model.apply)
        predict_fw = jax.jit(model_fw.apply)
        init_params = jax.jit(model.init)
    else:
        predict = model.apply
        predict_fw = model_fw.apply
        init_params = model.init
    return model, model_fw, predict, predict_fw, init_params

In [11]:
model, model_fw, predict, predict_fw, init_params = funcs_from_stateful(stateful_forward)

In [12]:
def init_and_pred(key, x, mask):
    """
    Tests the underlying model. Initializes and performs inference with 3 ways:
    - using predict (with rng) and `is_train=False`
    - using predict_fw (without rng), `is_train=False`
    - using predict (with rng), `is_train=True`

    Args:
        key (jax.random.PRNGKey): Random number generator key for initialization.
        x (jnp.ndarray): Input data with shape (batch_size, num_features).
        mask (jnp.ndarray): Mask with shape (batch_size,) representing valid elements in xx.

    Returns:
        None: This function does not return any value. It prints the results of the inference.
    """
    key, *skey = jax.random.split(key, 3)
    params = init_params(skey[0], x, mask, False)

    print("Inference with predict (with rng), is_train=False")
    print(predict(params, skey[1], x, mask, False))

    print("Inference with predict_fw (without rng), is_train=False")
    print(predict_fw(params, x, mask, False))

    print("Inference with predict (with rng), is_train=True")
    print(predict(params, skey[1], x, mask, True))


In [13]:
# check the model
xx = x[0:5, :]
mask = None
key, skey = jax.random.split(key)
init_and_pred(key, xx, mask)

Inference with predict (with rng), is_train=False
[-0.03593713 -1.6595669   1.1024079   0.17551488  0.70727295]
Inference with predict_fw (without rng), is_train=False
[-0.03593713 -1.6595669   1.1024079   0.17551488  0.70727295]
Inference with predict (with rng), is_train=True
[-0.03593713 -1.6595669   1.1024079   0.17551488  0.70727295]


###  [ 📝 ] 0.3 Implement training and evaluation Loops

In the following cells we will implement five functions for general use:

- `get_loss`: computes the loss on a specific batch
- `train_step`: implements one training step [ 📝 ]
- `evaluate_on_batch`: evaluates the model on a specific batch
- `evaluate`: evaluates the model on the whole dataset (test set)
- `train`: implements the training loop

In [14]:
# defing the function that computes the loss
@jax.jit
def get_loss(params, skey, x: jnp.ndarray, mask, y_gt: jnp.ndarray):
    """
    Compute the binary cross-entropy loss for the given input and ground truth.

    Args:
        params (List): the parameters of the model.
        skey (jax.random.PRNGKey): the random key that is used for the forward pass
        x (jnp.ndarray): the input
        mask (jnp.ndarray): the mask representing valid elements in x
        y_gt (jnp.ndarray): the labels

    Returns:
        jnp.ndarray: The computed loss value.
    """
    # Predict using skey state and is_train
    y = predict(params, skey, x, mask, is_train=True)

    # Compute the loss value
    loss_value = optax.sigmoid_binary_cross_entropy(y, y_gt).mean(axis=-1)
    return loss_value

[ 📝 ] Implement the `train_step` function. This function should perform a single training step for the given input batch. It should return the updated parameters, the updated optimizer state, the computed loss value and the updated random number generator key.

In [15]:
@jax.jit
def train_step(params, key, opt_state, x, mask, y_gt):
    """
    Perform a single training step for the given input batch.

    Args:
        params (List): the parameters of the model.
        skey (jax.random.PRNGKey): the random key that is used for the forward pass
        opt_state (OptState): The state of the optimizer.
        x (jnp.ndarray): the input
        mask (jnp.ndarray): the mask representing valid elements in x
        y_gt (jnp.ndarray): the labels

    Returns:
        Tuple: A tuple containing:
         - params: the updated parameters
         - opt_state: the updated optimizer state
         - loss: the computed loss value
         - key: the updated random number generator key
    """
    ##################
    # YOUR CODE HERE #
    # Move the random generator
    key, skey = jax.random.split(key)

    # Define gradients with respect to the loss function
    val_grad = jax.value_and_grad(get_loss)

    # Get loss and gradients with respect to the input batch
    loss, grads = val_grad(params, skey, x, mask, y_gt)

    # Get updates and new optimizer state, based on gradients and previous state
    updates, opt_state = optimizer.update(grads, opt_state, params)

    # Get new params based on previous params and updates
    params = optax.apply_updates(params, updates)
    #
    # (i) Move the random generator, hint: use jax.random.split
    # (ii) Define the `gradients function` wrt the loss function, hint: use jax.value_and_grad
    # (iii) Get loss and gradients wrt the input batch, hint: use the gradients function
    # (iv) Get weight updates and new optimizer state, based on gradients and previous state, hint: use optimizer.update
    # (v) Get new params based on previous params and updates, hint: use optax.apply_updates
    # (vi) return params, opt_state, loss, key
    ##################
    return params, opt_state, loss, key


In [16]:
@jax.jit
def evaluate_on_batch(params, x: np.ndarray, mask, y_gt: np.ndarray):
    """
    Evaluate the model on a single input batch.

    Args:
        params (List): the parameters of the model.
        x (jnp.ndarray): the input
        mask (jnp.ndarray): the mask representing valid elements in x
        y_gt (jnp.ndarray): the labels

    Returns:
        Tuple: A tuple containing the accuracy and confusion matrix.
            - accuracy (float): The accuracy of the model's predictions.
            - confusion_matrix (jnp.ndarray): The confusion matrix with shape (2, 2).
    """
    y_pred = predict_fw(params, x, mask, is_train=False)
    y_pred_binary = (y_pred > 0).astype(int)
    accuracy = jnp.mean(y_pred_binary == y_gt)

    # Compute confusion matrix
    true_positive = jnp.sum(jnp.logical_and(y_pred_binary == 1, y_gt == 1))
    false_positive = jnp.sum(jnp.logical_and(y_pred_binary == 1, y_gt == 0))
    true_negative = jnp.sum(jnp.logical_and(y_pred_binary == 0, y_gt == 0))
    false_negative = jnp.sum(jnp.logical_and(y_pred_binary == 0, y_gt == 1))

    confusion_matrix = jnp.array([[true_negative, false_positive],
                                  [false_negative, true_positive]])
    return accuracy, confusion_matrix

In [17]:
def evaluate(params, x: jnp.ndarray, mask, y_gt: jnp.ndarray, batch_encode, batch_size=32):
    """
    Evaluate the model on the full dataset (x) using batches of batch_size.

    Args:
        params (List): the parameters of the model.
        x (jnp.ndarray): the input
        mask (jnp.ndarray): the mask representing valid elements in x
        y_gt (jnp.ndarray): the labels
        batch_encode (Callable): Function to encode the input data into batches.
        batch_size (int, optional): Batch size. Defaults to 32.

    Returns:
        Tuple: A tuple containing the test accuracy and confusion matrix.
            - test_accuracy (float): The accuracy of the model's predictions on the test data.
            - cm (jnp.ndarray): The confusion matrix with shape (2, 2).
    """
    cm = jnp.zeros([2, 2])

    # Evaluate on the test set with batching
    test_accuracy = 0.0
    num_batches = int(len(x) / batch_size)
    for j in range(num_batches):
        start_idx = j * batch_size
        end_idx = start_idx + batch_size
        x_batch = x[start_idx:end_idx]
        y_batch = y_gt[start_idx:end_idx]

        if mask is not None:
            mask_batch = mask[start_idx:end_idx]
        else:
            mask_batch = None

        if batch_encode is not None:
            x_batch = batch_encode(x_batch)

        # Move to GPU
        if gpu:
            x_batch = jnp.array(x_batch)
            y_batch = jnp.array(y_batch)
            x_batch = jax.device_put(x_batch, jax.devices("gpu")[0])  # Assuming you have only one GPU
            y_batch = jax.device_put(y_batch, jax.devices("gpu")[0])  # Assuming you have only one GPU

        batch_accuracy, batch_cm = evaluate_on_batch(params, x_batch, mask_batch, y_batch)
        test_accuracy += batch_accuracy
        cm += batch_cm

    test_accuracy /= num_batches
    return test_accuracy, cm

In [18]:
def train(params: list,
          key,
          x_tr: jnp.ndarray,
          mask_tr: jnp.ndarray,
          y_tr: jnp.ndarray,
          epochs: int,
          batch_size: int,
          x_te: jnp.ndarray,
          mask_te: jnp.ndarray,
          y_te: jnp.ndarray,
          batch_encode: typing.Union[None, typing.Callable] = None,
          eval_every: int = 1,
          loss_every_batch: int = 32,
          gpu=False):
    """
    Train the model using mini-batch stochastic gradient descent.

    Args:
        params (List): the parameters of the model.
        key (jax.random.PRNGKey): the random key that is used for the forward pass
        x_tr (jnp.ndarray): Training data
        mask_tr (jnp.ndarray): Mask on training data
        y_tr (jnp.ndarray): Training set labels
        epochs (int): Number of training epochs.
        batch_size (int): Batch size.
        x_te (jnp.ndarray): Test data
        mask_te (jnp.ndarray): Mask on test data
        y_te (jnp.ndarray): Test set labels
        batch_encode (Union[None, Callable], optional): Function to encode the input data into batches. Defaults to None.
        eval_every (int, optional): Number of epochs between evaluations on train and test sets. Defaults to 1.
        loss_every_batch (int, optional): Number of batches between printing the loss value. Set to false for disabling printing. Defaults to 32.
        gpu (bool, optional): Whether to use GPU acceleration. Defaults to False.

    Returns:
        params: The updated model parameters after training.
    """
    opt_state = optimizer.init(params)
    nof_instances = x_tr.shape[0]
    for e in range(epochs):
        nof_full_batches = nof_instances // batch_size
        for i in range(nof_full_batches):

            # Get batch
            batch_start = i * batch_size
            batch_end = (i + 1) * batch_size
            x_batch = x_tr[batch_start:batch_end]
            y_batch = y_tr[batch_start:batch_end]

            if mask_tr is not None:
                mask_batch = mask_tr[batch_start:batch_end]
            else:
                mask_batch = None

            # Vectorize if needed
            if batch_encode is not None:
                x_batch = batch_encode(x_batch)

            # Move to GPU
            if gpu:
                x_batch = jnp.array(x_batch)
                y_batch = jnp.array(y_batch)
                x_batch = jax.device_put(x_batch, jax.devices("gpu")[0])  # Assuming you have only one GPU
                y_batch = jax.device_put(y_batch, jax.devices("gpu")[0])  # Assuming you have only one GPU

            params, opt_state, loss, key = train_step(params, key, opt_state, x_batch, mask_batch, y_batch)

            if loss_every_batch is not False:
                if i % loss_every_batch == 0:
                    print("Epoch: %d, Step %d/%d, Loss: %.3f" % (e, i, nof_full_batches, loss))

        if eval_every is not False:
            if e % eval_every == 0:

                # Evaluate on the whole training set
                train_accuracy, train_cm = evaluate(params, x_tr, mask_tr, y_tr, batch_encode, batch_size)
                print("Epoch: %d, Train Accuracy: %.4f" % (e, train_accuracy))
                print("Confusion Matrix:\n", train_cm)

                # Evaluate on the test set
                test_accuracy, test_cm = evaluate(params, x_te, mask_te, y_te, batch_encode, batch_size)
                print("Epoch: %d, Test Accuracy: %.4f" % (e, test_accuracy))
                print("Confusion Matrix:\n", test_cm)

                print("\n")
    return params

Now that we have all the components let's train our model!

In [19]:
# init network parameters
key, skey = jax.random.split(key)
params = init_params(skey, x, None, None)

# set-up the parameters
epochs = 501
lr = 0.01
batch_size = 1024
eval_every = 500
loss_every_batch = False
optimizer = optax.sgd(lr)
key, skey = jax.random.split(key)

# train
params = train(params, skey, x_tr, None, y_gt, epochs, batch_size, x_te, None, y_te, None, eval_every, loss_every_batch)

Epoch: 0, Train Accuracy: 0.4699
Confusion Matrix:
 [[1027. 2167.]
 [1633. 2341.]]
Epoch: 0, Test Accuracy: 0.4727
Confusion Matrix:
 [[134. 318.]
 [222. 350.]]


Epoch: 500, Train Accuracy: 0.9909
Confusion Matrix:
 [[3169.   25.]
 [  40. 3934.]]
Epoch: 500, Test Accuracy: 0.9932
Confusion Matrix:
 [[448.   4.]
 [  3. 569.]]




# Practical 1: Introduction to NLP

In the first part of the tutorial we will introduce the basic NLP concepts.

## 1.1 The basic NLP pipeline

Nearly all of NLP solultions follow a common high-level pipeline:

<img src="./../../images/NLP-pipeline.png">

Let's briefly discuss them step by step:

* **Standardization:** Reduces small variations in text. For example, "the cat," "The Cat," and "the cat" become "the cat".
* **Tokenization**: Splits the text in tokens (small individual entities). For example, `"Hello my friend!` will be come `[hello, my, friend, !]`
* **Indexing**: Converts to tokens into integers (indexes). For example, `[hello, my, friend, !]` will become `[1, 2, 3, 4]`.
* **Encoding/Embedding**: Maps indexes into vectors, forming the embedding space. For example, the word "cat" can be mapped to the vector `[0.1, 0.2, 0.3, 0.4]`
* **ML model**: Predicts the output based on the embedding space. For example, the model can predict the sentiment of a movie review.


### 1.1.1 Load the IMDB dataset

For the rest of Practical 1, we will use the [IMDB](https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz) dataset. The IMDB is a dataset for sentiment analysis, containing 50,000 movie reviews from IMDB users that are labeled as either positive (1) or negative (0). The dataset is divided into 25,000 reviews for training and 25,000 reviews for testing, so the training and testing sets are balanced.

In [20]:
!curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
!tar -xf aclImdb_v1.tar.gz

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 80.2M  100 80.2M    0     0  2407k      0  0:00:34  0:00:34 --:--:-- 3204k


In [21]:
def read_all_txts(dir):
    # all files in dir
    files = os.listdir(dir)

    sentences = []
    for file in files:
        filepath = dir + '/' + file
        with open(filepath, 'r') as ff:
            for line in ff:
                line = line.strip()
                sentences.append(line)
    return sentences

def load_dataset(dir):
    pos_dir = dir + '/pos'
    neg_dir = dir + '/neg'

    # read all txts in dir
    pos_sentences = read_all_txts(pos_dir)
    neg_sentences = read_all_txts(neg_dir)

    # to a dataframe
    df_pos = pd.DataFrame({'text': pos_sentences, 'label': 1})
    df_neg = pd.DataFrame({'text': neg_sentences, 'label': 0})

    df = pd.concat([df_pos, df_neg])

    # shuffle
    df = df.sample(frac=1).reset_index(drop=True)
    return df

dir = 'aclImdb/train/'
df_tr = load_dataset(dir)
dir = 'aclImdb/test/'
df_te = load_dataset(dir)

df_tr = df_tr.dropna()
df_te = df_te.dropna()

X_tr = df_tr.iloc[:, 0].to_numpy()
Y_tr = df_tr.iloc[:, 1].to_numpy()

X_te = df_te.iloc[:, 0].to_numpy()
Y_te = df_te.iloc[:, 1].to_numpy()

print("The are %d training examples, where %d are positive and %d are negative reviews" % (df_tr.shape[0], df_tr.loc[df_tr["label"] == 1, :].shape[0], df_tr.loc[df_tr["label"] == 1, :].shape[0]))
print("The are %d testing examples, where %d are positive and %d are negative reviews" % (df_te.shape[0], df_te.loc[df_te["label"] == 1, :].shape[0], df_te.loc[df_te["label"] == 1, :].shape[0]))

The are 25000 training examples, where 12500 are positive and 12500 are negative reviews
The are 25000 testing examples, where 12500 are positive and 12500 are negative reviews


### 1.1.2 Tokenization

We will use the [HuggingFace tokenizers](https://huggingface.co/docs/tokenizers/index), a fast and efficient library for tokenizing text.

In [22]:
# define the tokenizer
tokenizer = tokenizers.Tokenizer(tokenizers.models.Unigram())
tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace()
tokenizer.normalizer = tokenizers.normalizers.Lowercase()

# say how it will be trained, to learn the vocabulary
vocab_size = 20000
special_tokens = ["[PAD]", "[CLS]", "[SEP]", "[MASK]",]
unk_token = "[UNK]"
max_piece_length = 16
trainer = tokenizers.trainers.UnigramTrainer(
    special_tokens=special_tokens, # special tokens
    vocab_size=vocab_size, # vocabulary size
    unk_token=unk_token, # set the unknown token
    show_progress=True, # show progress
)

# train the tokenizer
tokenizer.train_from_iterator(X_tr, trainer=trainer)





Let's check the output of the tokenizer:

In [23]:
# inspect the tokenizer
def inspect_tokenizer(inp):
    tok_example = tokenizer.encode(inp)
    print("tokens:", tok_example.tokens)
    print("ids:", tok_example.ids)
    print("attention mask", tok_example.attention_mask)

In [24]:
inp = "Hello my friend!"
inp = inspect_tokenizer(inp)

tokens: ['hello', 'my', 'friend', '!']
ids: [3870, 81, 447, 84]
attention mask [1, 1, 1, 1]


In [25]:
# now see the output of the tokenizer
tokenizer.decode(tokenizer.encode("Hello my friend!").ids) # reconstruct the sentence from the ids

'hello my friend !'

### 1.1.4 ML Model

In the following cells, we will test two approaches; first, we will treat the input text as a **set of words** and later, as a **sequence of words.** In the set of words case, we will use **multi-hot encoding plus an MLP classifier**. In sequence of words **learnable embedding plus a recurrent neural network (LSTM)**.

## 1.2 Classification treating text as a Set of words

Let's first experiment with the simplest approach: multi-hot encoding + a Fully-Connected NN.


### [ 📝 ]  Multi-hot encode the input

We need to transform the input text into a multi-hot encoding. First, we will use the tokenizer (defined above) to transform the input text into a list of integers, where each integer represents a word in the vocabulary. Then, we will use the list of integers to create a multi-hot encoding of the input text. The multi-hot encoding is a binary vector with length VOCAB_SIZE, where the $i$-th element indicates whether the $i$-th word of the vocabulary exists in the input. If a word is present, its corresponding element in the tensor is set to 1; otherwise, it is set to 0.

In [26]:
def batch_encode(X):
    """
    Encode a batch of text data into a multi-hot encoding.

    Args:
        X (List[str]): A list of input texts.

    Returns:
        np.ndarray: The multi-hot encoding of the input texts with shape (num_samples, vocab_size).
    """
    ##################
    # YOUR CODE HERE #
    enc_list = tokenizer.encode_batch(X)
    vocab_size = tokenizer.get_vocab_size()
    multi_hot_encoding = np.zeros([len(enc_list), vocab_size])
    for i, enc in enumerate(enc_list):
        multi_hot_encoding[i, enc.ids] = 1

    # (i) get a list with encoded vector, hint: use tokenizer.encode_batch
    # (ii) get the vocabulary size, hint: use tokenizer.get_vocab_size
    # (iii) create and return the multi_hot encoded tensor as a numpy array
    ##################
    return multi_hot_encoding

Check the shape of the ouput tensor. It should be (32, 20000).

In [27]:
# inspect the batch
batch_size = 32
inp_list = X_tr[:batch_size]
X_batch = batch_encode(inp_list)
print("The shape of the batch is:", X_batch.shape)

The shape of the batch is: (32, 20000)


### [ 📝 ]  Define the MLP

We will use a simple MLP with one hidden layer of 16 units (relu acivation) followed by a dropout layer. The output layer has 1 unit. Implement the model below.

In [28]:
def stateful_forward(x, mask, is_train):
    """
    A Neural Network.

    Args:
        x (jnp.ndarray): Input data with shape (batch_size, num_features).
        mask (jnp.ndarray): Mask with shape (batch_size,) representing valid elements in x.
        is_train (bool): Whether the model is in training mode or not.

    Returns:
        jnp.ndarray: The output of the model with shape (batch_size,).
    """
    ##################
    # YOUR CODE HERE #
    x = hk.Linear(16)(x)
    x = jax.nn.relu(x)
    if is_train:
        x = hk.dropout(hk.next_rng_key(), 0.5, x)
    y = hk.Linear(1)(x).squeeze()
    # Implement a neural network with
    # (i) a linear layer with 16 units and relu activation
    # (ii) dropout with 0.5 probability if in training mode
    # (iii) final linear layer
    ##################
    return y

In [29]:
model, model_fw, predict, predict_fw, init_params = funcs_from_stateful(stateful_forward, jit=False)

In [30]:
# init parameters and check forward pass
xx = X_batch[:2]
mask = None
key, skey = jax.random.split(key)

init_and_pred(key, xx, mask)

  param = init(shape, dtype)


Inference with predict (with rng), is_train=False
[0.05670299 0.02233897]
Inference with predict_fw (without rng), is_train=False
[0.05670299 0.02233897]
Inference with predict (with rng), is_train=True
[-0.00528321 -0.04502412]


In [31]:
# initialize optimizer
lr = 0.001
optimizer = optax.adam(learning_rate=lr)
opt_state = optimizer.init(params)

### Train the model

In this section we will train the model for 5 epochs, using a batch size of 124. There is no need to create a training loop, as we have already implemented it above. Therefore, we will simply call the `train` function with the appropriate arguments.

In [32]:
# train
epochs = 3
batch_size = 124
is_train=True
key, skey = jax.random.split(key)

# init params
params = init_params(skey, xx, mask, is_train)

# train loop
params = train(params, skey, X_tr, None, Y_tr, epochs, batch_size, X_te, None, Y_te, batch_encode, loss_every_batch=25, gpu=gpu)

Epoch: 0, Step 0/201, Loss: 0.696
Epoch: 0, Step 25/201, Loss: 0.626
Epoch: 0, Step 50/201, Loss: 0.517
Epoch: 0, Step 75/201, Loss: 0.483
Epoch: 0, Step 100/201, Loss: 0.376
Epoch: 0, Step 125/201, Loss: 0.383
Epoch: 0, Step 150/201, Loss: 0.405
Epoch: 0, Step 175/201, Loss: 0.411
Epoch: 0, Step 200/201, Loss: 0.374
Epoch: 0, Train Accuracy: 0.9103
Confusion Matrix:
 [[10999.  1470.]
 [  766. 11689.]]
Epoch: 0, Test Accuracy: 0.8791
Confusion Matrix:
 [[10625.  1830.]
 [ 1183. 11286.]]


Epoch: 1, Step 0/201, Loss: 0.337
Epoch: 1, Step 25/201, Loss: 0.331
Epoch: 1, Step 50/201, Loss: 0.224
Epoch: 1, Step 75/201, Loss: 0.290
Epoch: 1, Step 100/201, Loss: 0.248
Epoch: 1, Step 125/201, Loss: 0.286
Epoch: 1, Step 150/201, Loss: 0.321
Epoch: 1, Step 175/201, Loss: 0.257
Epoch: 1, Step 200/201, Loss: 0.329
Epoch: 1, Train Accuracy: 0.9335
Confusion Matrix:
 [[11310.  1159.]
 [  498. 11957.]]
Epoch: 1, Test Accuracy: 0.8843
Confusion Matrix:
 [[10634.  1821.]
 [ 1062. 11407.]]


Epoch: 2, St

In [33]:
def inspect_an_input(i):
    inp = batch_encode(X_te[ii:ii+1])
    y_pr_score = model_fw.apply(params, inp, None, False)
    y_pr = jnp.greater(y_pr_score, 0).astype(int)

    print("input:", X_te[ii])
    print("\n")
    print("ground truth: %d" % Y_te[ii])
    print("\n")
    print("prediction  : %d" % y_pr)
    print("\n")
    print("prediction score: %f" % y_pr_score)
    print("\n")

In [34]:
ii = 10
inspect_an_input(ii)

ii = 12
inspect_an_input(ii)

input: Rigoletto is Verdi's masterpiece, full of drama, emotion and powerful, memorable music. The maestro must have rolled in his grave when this bawdy travesty of his work was released with its needless frontal nudity and cheap copulating and its portrayal of the naive but principalled Gilda as a horny ditz. Opera certainly can be adapted to cinema --- look at Zeferelli's magnificent La Traviata --- but when a work is as superb as Rigoletto, it doesn't need cheap gimmicks. It might even have been acceptable if the dubbed in music had been good but it is a mediocre rendering of the libretto with second rate sound quality at that.


ground truth: 0


prediction  : 1


prediction score: 0.158785


input: This movie was horrible and the only reason it was even made was because the story appealed to the far-left. I consider my self a moderate, so I was able to see this film as the pile of garbage it was. While I'm not a Bush fan, your dislike for GW is not enough of a reason to see this m

### [ 📝 ] Inspect and Discuss

Play around with the model. Try to find some weaknesses using some inputs either from the training set or using your own inputs. Keep in mind that the model does not treat the input as sequence but as a set of words. Therefore, it does not take into account the order of the words. Can you find a characteristic input that reveals this weakness?

If we were to ask ChatGPT to classify the same inputs, would it have the same weaknesses? Let's try it..

In [35]:
def predict_on_custom_input(test_input):
    test_input = np.array([test_input])
    test_enc = batch_encode(test_input)
    print(model_fw.apply(params, test_enc, None, False))

In [36]:
##################
# YOUR CODE HERE #
# real positive impact
test_input = "I have seen a lot of excellent movies and this one is one of them"
predict_on_custom_input(test_input)

# the word not inverts the comment, but the model cannot understand, it simply understands not as something that makes the comment more negative
test_input = "I have seen a lot of excellent movies and this one is not one of them"
predict_on_custom_input(test_input)

# same effect
test_input = "This is, for sure, one of the bad movies."
predict_on_custom_input(test_input)

# same effect
test_input = "This was a worth-seeing movie but, for sure, not one of the best."
predict_on_custom_input(test_input)
##################

2.0342286
1.8297132
-0.44754705
0.88967776


## 1.3 Classification treating the input text as a Sequence

In this part, we will treat the input text as a sequence of words, using a learnable embedding plus an LSTM network.
The network will have (a) an embedding layer (b) an LSTM and a (c) Dense layer
Before we code the models, let's set up appropriatelly the lenght of the sequence returned by the tokenizer.

### [ 📝 ]  Define Tokenizer

Since we will experiment with sequential models, we want the tokenizer's output to be a tensor of shape `BS, SEQUENCE_LENGTH)`

In [37]:
# inspect what happens if setting the min and max sequence lengths to 10
sequence_length = 10
truncation_length = 10
tokenizer.enable_padding(length=sequence_length)
tokenizer.enable_truncation(max_length=truncation_length)

In [38]:
# inspect the tokenizer
inp = "Hello my friend!"
inspect_tokenizer(inp)

inp = "Hello my dear dear dear dear dear dear dear dearest friend!"
inspect_tokenizer(inp)

tokens: ['hello', 'my', 'friend', '!', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
ids: [3870, 81, 447, 84, 0, 0, 0, 0, 0, 0]
attention mask [1, 1, 1, 1, 0, 0, 0, 0, 0, 0]
tokens: ['hello', 'my', 'dear', 'dear', 'dear', 'dear', 'dear', 'dear', 'dear', 'd']
ids: [3870, 81, 2841, 2841, 2841, 2841, 2841, 2841, 2841, 23]
attention mask [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


In [39]:
# now let's set it to a normal value
sequence_length = 600
truncation_length = 600
tokenizer.enable_padding(length=sequence_length)
tokenizer.enable_truncation(max_length=truncation_length)

In [40]:
def batch_encode(X):
    enc_list = tokenizer.encode_batch(X)
    enc_list_2 = []
    mask_list = []
    for i, enc in enumerate(enc_list):
        enc_list_2.append(enc.ids)
        mask_list.append(enc.attention_mask)
    return np.array(enc_list_2), np.array(mask_list)

In [41]:
# but since it is not memory-consuming, we can use the whole dataset
X_tr_enc, X_tr_mask = batch_encode(X_tr)
X_te_enc, X_te_mask = batch_encode(X_te)

print(X_tr_enc.shape)
print(X_te_enc.shape)

(25000, 600)
(25000, 600)


 Please define the network in the next cell. It must have:

 - an embedding of shape `(BS, S, H)`
 - an LSTM to capture the sequential information
 - use all hidden states of the LSTM for obtaining the final prediction

In [42]:
def stateful_forward(x, mask, is_train):
    """
    Implement the forward pass of the stateful model using a DeepRNN with LSTM layers.

    Args:
        x (jnp.ndarray): Input data with shape (batch_size, sequence_length).
        mask (jnp.ndarray): Mask with shape (batch_size, sequence_length) representing valid elements in x.
        is_train (bool): Whether the model is in training mode or not.

    Returns:
        jnp.ndarray: The final predictions with shape (batch_size,).
    """
    ##################
    # YOUR CODE HERE #
    # Initialize classes and dimensions
    batch_size = x.shape[0]
    embed_dim = 128
    embed = hk.Embed(vocab_size, embed_dim)
    core = hk.DeepRNN([hk.LSTM(64), hk.Linear(1)])
    linear = hk.Linear(1)

    # Embed the input
    initial_state = core.initial_state(batch_size=batch_size)
    x = embed(x)

    # Unroll the LSTM across time steps using dynamic_unroll
    outs, states = hk.dynamic_unroll(core, x, initial_state, time_major=False)

    # Apply dropout if in training mode
    if is_train:
        outs = hk.dropout(hk.next_rng_key(), 0.2, outs)

    # Compute final predictions
    outs = jnp.squeeze(outs, axis=-1)
    outs = linear(outs)
    y = jnp.squeeze(outs, axis=-1)
    #
    # (i) Embed the input to a tensor (BS, S, 128)
    # (ii) Unroll an LSTM over the embedded sequence, use a linear layer to map each LSTM output to a single value
    # (iii) use the single values of all intermediate steps of the LSTM as input to a linear layer to get y
    # (iv) return y
    ##################
    return y


In [43]:
model, model_fw, predict, predict_fw, init_params = funcs_from_stateful(stateful_forward, jit=False)

In [44]:
# init parameters and check forward pass
xx = X_tr_enc[:2]
mask = X_tr_mask[:2]
key, skey = jax.random.split(key)

# init_and_pred(key, xx, mask)
params = init_params(key, xx, mask, is_train=False)
out = predict(params, key, xx, mask, False)

In [45]:
# initialize optimizer
lr = 0.01
optimizer = optax.adam(learning_rate=lr)
opt_state = optimizer.init(params)

In [46]:
# train
epochs = 3
batch_size = 128
key, skey = jax.random.split(key)

# init params
params = init_params(skey, xx, mask, is_train)

# train loop
params = train(params,
               skey,
               X_tr_enc,
               X_tr_mask,
               Y_tr,
               epochs=epochs,
               batch_size=batch_size,
               x_te=X_te_enc,
               mask_te=X_te_mask,
               y_te=Y_te,
               batch_encode=None)

Epoch: 0, Step 0/195, Loss: 0.700
Epoch: 0, Step 32/195, Loss: 0.708
Epoch: 0, Step 64/195, Loss: 0.709
Epoch: 0, Step 96/195, Loss: 0.640
Epoch: 0, Step 128/195, Loss: 0.435
Epoch: 0, Step 160/195, Loss: 0.320
Epoch: 0, Step 192/195, Loss: 0.327
Epoch: 0, Train Accuracy: 0.9154
Confusion Matrix:
 [[11175.  1308.]
 [  803. 11674.]]
Epoch: 0, Test Accuracy: 0.8696
Confusion Matrix:
 [[10639.  1840.]
 [ 1415. 11066.]]


Epoch: 1, Step 0/195, Loss: 0.391
Epoch: 1, Step 32/195, Loss: 0.287
Epoch: 1, Step 64/195, Loss: 0.253
Epoch: 1, Step 96/195, Loss: 0.336
Epoch: 1, Step 128/195, Loss: 0.265
Epoch: 1, Step 160/195, Loss: 0.188
Epoch: 1, Step 192/195, Loss: 0.162
Epoch: 1, Train Accuracy: 0.9460
Confusion Matrix:
 [[11260.  1223.]
 [  124. 12353.]]
Epoch: 1, Test Accuracy: 0.8628
Confusion Matrix:
 [[ 9758.  2721.]
 [  704. 11777.]]


Epoch: 2, Step 0/195, Loss: 0.325
Epoch: 2, Step 32/195, Loss: 0.152
Epoch: 2, Step 64/195, Loss: 0.106
Epoch: 2, Step 96/195, Loss: 0.186
Epoch: 2, Step 12

### [ 📝 ] Inspect and Discuss

(You may use the same inputs as in the previous Ispection and Discussion section)

Play around with the model. Try to find some weaknesses using some inputs either from the training set or using your own inputs. Keep in mind that the model does not treat the input as sequence but as a set of words. Therefore, it does not take into account the order of the words. Can you find a characteristic input that reveals this weakness?

In [59]:
def predict_on_custom_input(test_input):
    test_input = np.array([test_input])
    test_enc, mask = batch_encode(test_input)
    print(model_fw.apply(params, test_enc, None, False))

In [60]:
##################
# YOUR CODE HERE #
# real positive impact
test_input = "I have seen a lot of excellent movies and this one is one of them"
predict_on_custom_input(test_input)

# the word not inverts the comment, but the model cannot understand, it simply understands not as something that makes the comment more negative
test_input = "I have seen a lot of excellent movies and this one is not one of them"
predict_on_custom_input(test_input)

# same effect
test_input = "This is, for sure, one of the bad movies."
predict_on_custom_input(test_input)

# same effect
test_input = "This was a worth-seeing movie but, for sure, not one of the best."
predict_on_custom_input(test_input)
##################

[3.3750224]
[2.9429648]
[-0.5702515]
[2.429214]


## 1.4 Discussion

- Which model had the best performance? How do you explain that?
- Among the sequence-based models which one had the best performance? How would you explain that?
- Take the bag of words model and the best sequence-based model, and try one or two custom inputs. Do you see any differences in the results? Could you explain that somehow?

Hidden answer:

Although we tried so much to code quite involved sequential models, first with LSTM and then with transformers, you may have noticed that the best accuracy is achieved with a simple Fully Connected Model after multi-hot encoding. In reality, the best sequential models have almost similar accuracy. Why is this the case? Our dear &#x1F449; [Deep Learning with Python (DLP)](https://www.manning.com/books/deep-learning-with-python) &#x1F448; has a rule of thumb:


> You may sometimes hear that bag-of-words methods are outdated, and that Transformer-
based sequence models are the way to go, no matter what task or dataset you’re look-
ing at. This is definitely not the case: a small stack of Dense layers on top of a bag-of-
bigrams remains a perfectly valid and relevant approach in many cases. In fact, among
the various techniques that we’ve tried on the IMDB dataset throughout this chapter,
the best performing so far was the bag-of-bigrams!
So, when should you prefer one approach over the other?
In 2017, my team and I ran a systematic analysis of the performance of various text-
classification techniques across many different types of text datasets, and we discov-
ered a remarkable and surprising rule of thumb for deciding whether to go with a
bag-of-words model or a sequence model (http://mng.bz/AOzK)—a golden constant
of sorts.
It turns out that when approaching a new text-classification task, you should pay
close attention to the ratio between the number of samples in your training data and
the mean number of words per sample (see figure 11.11). If that ratio is small—less
than 1,500—then the bag-of-bigrams model will perform better (and as a bonus, it will
be much faster to train and to iterate on too). If that ratio is higher than 1,500, then
you should go with a sequence model. In other words, sequence models work best
when lots of training data is available and when each sample is relatively short.


### [ 📝 ] Advanced: Open-end Exercise

Try to design a better model. You can go with a bag-of-words based approach or a sequence-based model. Your idea can be an incremental improvement on the previous models, e.g. stack more LSTM layers, increase the dimension of the hidden states, or something completely new, i.e. try GRU layer instead of the LSTM.

In [61]:
# def stateful_forward(x, mask, is_train):
#     ##################
#     # YOUR CODE HERE #
#     ##################
#     return y

In [62]:
# model, model_fw, predict, predict_fw, init_params = funcs_from_stateful(stateful_forward, jit=False)

In [63]:
# # init parameters and check forward pass
# xx = X_tr_enc[:2]
# mask = X_tr_mask[:2]
# key, skey = jax.random.split(key)

# # init_and_pred(key, xx, mask)
# params = init_params(key, xx, mask, is_train=False)
# out = predict(params, key, xx, mask, False)

In [64]:
# # initialize optimizer
# optimizer = # complete optimizer
# opt_state = optimizer.init(params)

In [65]:
# # train
# epochs = #
# batch_size = #
# key, skey = jax.random.split(key)

# # init params
# params = init_params(skey, xx, mask, is_train)

# # train loop
# params = train(params,
#                skey,
#                X_tr_enc,
#                X_tr_mask,
#                Y_tr,
#                epochs=3,
#                batch_size=128,
#                x_te=X_te_enc,
#                mask_te=X_te_mask,
#                y_te=Y_te,
#                batch_encode=None)