# Natural Language Processing Tutorial

Welcome to the second lab session of the [Mediterranean Machine Learning Summer School](https://www.m2lschool.org/home)!

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/m2lschool/tutorials2022/blob/main/2_nlp/NLP_tutorial_solutions.ipynb)

This tutorial will teach the fundamental components of modern Natural Language Processing (NLP) pipelines. 

Specifically, you will implement the Transformer architecture and test it on the *language modeling* and *machine translation* tasks. You will also learn core concepts such as self-attention, text tokenization, and more.

### Outline

The tutorial is structured as follows:

1. Introduction to the Transformer Architecture and its building blocks
2. Implementation of the Core Components: Scaled and Multi-headed Attention, Embeddings, and Positional Encoding
3. Transformer Encoder and Word-Level Language Modeling
4. Sentiment Analysis with the Encoder-only Language Model
4. Transformer and Neural Machine Translation 
5. (*Bonus*) All the glitter is not gold: Gender Bias in Machine Translation

We will give you additional pointers to dive deeper into the specific topic in each section. Just keep an eye on the 📚 emoji.

### Libraries and Frameworks

You will use [JAX](https://github.com/google/jax) for low level operations and [haiku](https://github.com/deepmind/dm-haiku) to model neural network modules, and [optax](https://github.com/deepmind/optax) for training optimization. In addition, we will use [datasets](https://github.com/huggingface/datasets) and [tokenizers](https://github.com/huggingface/tokenizers) for the data preparation utilities. 

### 📝 Exercises

The sections marked as \[EXERCISE 📝\] contain cells with missing code that you should complete.

### Credits

[Giuseppe Attanasio](https://gattanasio.cc/), [Moreno La Quatra](https://mlaquatra.me/)

The tutorial is inspired by established walkthroughs on the Transformer architecture: [1](http://nlp.seas.harvard.edu/annotated-transformer) and [2](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial6/Transformers_and_MHAttention.html#The-Transformer-architecture).

### Python imports

Here we are going to install and import everything we are going to need for this tutorial. 

**Note**: *You can double-click the title of the collapsed cells (as the ones below) to expand them and read their content.*

In [None]:
%%capture
# @title Libraries installation
!pip install dm-haiku optax tokenizers datasets evaluate

You can set the environment variables below to choose a maximum fraction of GPU memory JAX kernels will use.
In a multi-GPU setup, you can set only one of the available devices to be used.

In [None]:
import os

# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'true'
# os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.50'
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
# @title Imports
import os
import pickle
import re
import shutil
import time
from functools import partial
from typing import Dict, List, NamedTuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm.notebook import tqdm

sns.set_theme("notebook")

import evaluate
import haiku as hk
import jax
import jax.numpy as jnp
import optax
import tokenizers
import torch.nn.functional as F
from datasets import Dataset, DatasetDict, load_dataset
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# Random generator seed
main_rng = jax.random.PRNGKey(42)
rng_iter = hk.PRNGSequence(main_rng)

Let's check that JAX sees the GPU 👀 (if so, you should see a `GpuDevice` in the output).

In [None]:
jax.devices()

# 1️⃣ Introduction to The Transformer Architecture

The entire tutorial revolves around a single architecture: the **Transformer**. Since its publication in 2017, the Transformer has **revolutionized the field** of NLP, finding successful applications in language modeling, sequence classification, sequence-to-sequence tasks such as machine translation, and many more.

As most common NLP libraries provide off-the-shelf, often pre-trained models, the actual inner working -- what *it is going on* in the model -- is often obfuscated to the practitioner.

It's time to get our hands dirty: along the tutorial, we will implement every last bit of the Transformer and train your running implementation to solve real-world tasks.

## Let's start from the basis

The Transformer is an encoder-decoder neural network originally devised for sequence-to-sequence tasks. Assuming that you are familiar with the notion of *neural network*, let's clarify the other bits:
- an *encoder* is a model that turns a raw piece of data into some *meaningful* hidden representation;
- conversely, a *decoder* is a model that, given a hidden representation, brings the data back into the original domain;
- a *sequence-to-sequence* task framed in the NLP domain requires learning a model that turns some sequence into another one. As you can imagine, sequences are frequently made of words. 

Let's briefly introduce the encoder and decoder and their respective core logic. 

### Encoder

The goal of the encoder is to turn a list of words into a list of meaningful, dense hidden representations such that other components (e.g., the decoder or other networks) can use them.   

The Transformer Encoder (Figure below, left) receives as input a sequence of items (in our case, words), often referred to as the **source sequence**. Then, it mixes input words using **Attention**, then feeds the results to a **fully-connected feed-forward** block with point-wise non-linear activation. Both the operations apply **residual connection** and **layer normalization**. This computation is repeated $N$ times by identical, stacked replicas to compute the final word representations.

### Decoder

The goal of the decoder is to learn the alignment between the source and target sequences. For instance, in the machine translation task, the decoder learns what words to produce in the target language, given the words in the source language. 

Like the encoder, the Transformer Decoder (Figure below, right) can receive words' representations as inputs. During training time, it gets the **target sentence** to learn an association with the source.

☝️ Crucially, the decoder has two attention operations, with the first running a masked self-attention and the second one attending to the encoder output. We will give more details on that later in the tutorial.

![image](https://github.com/g8a9/graphics/blob/main/transformer/transformer.drawio.png?raw=true)

### 📚 **Resources**

- Original paper: [Attention is All you Need](https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)
- Thorough guide on Transformer components: [Formal Definitions in Trasformers](https://arxiv.org/pdf/2207.09238.pdf)
- Practical PyTorch Transformer walkthrough: [The Annotated Transform](http://nlp.seas.harvard.edu/annotated-transformer/http://nlp.seas.harvard.edu/annotated-transformer/)  


# 2️⃣ Implementing the Core Components

In this section, we will implement the core modules that both the encoder and the decoder will use, i.e., Scaled Dot Product Attention, Multi-Headed Attention, Positional Encoding, and Feed-Forward Networks. 

## Attention Mechanism in the Transformer

It is safe to say that the attention mechanism lies at the core of *all* best-performing language models. This simple alignment algorithm is the foundation of how we model natural language today.

Before reviewing Attention in Transformer, we provide the intuition using influencers and dress styles.

>Fashion trends change rapidly. Harry knows that and tries to keep his wardrobe ready. Every season he goes over the social profiles of his favorite fashion influencers to look for ideas. Harry finds nice shirts in profile 1, suitable shoes in profile 2, nothing exciting in profile 3, and so on. From each influencer, he chooses part of the outfit for the upcoming season. In a sense, Harry **aligns** his preferences with social profiles and **mixes** different styles, following his intuition on what is best for his final goal -- we do not know Harry. Maybe he is trying to be a famous influencer himself.


Transformers learn word representations similarly. Each word is a **query** (Harry's outfit) whose representation is updated in alignment with a set of other words (the influencers' profiles), the **keys**, mixing some of their **values** (the influencers' products). Also, some training objective (Harry's dream of becoming an influencer) drives the process.

Let's define our queries, keys, and values.

### Scaled Dot Product Attention \[EXERCISE 📝\]

Scaled Dot Product attention is the attention mechanism used in Transformer. A query and a key-value pair are used to compute the attention. First, queries ($Q$) and keys ($K$) are multiplied; then, softmax is applied to the result to obtain the attention scores. Finally, values ($V$) are multiplied by the attention scores to get the final representation of the sequence.

The scaled dot product attention is defined as:

$$
attention(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V
$$

Where $d_k$ is a constant scalar. In the original paper, $d_k$ corresponds to the dimension of the query/key/value (they all share the same dimension).

📣 📣 📣

It is a good place to pause and talk about vectors and dimensions. Along the notebook, you will work with JAX arrays (vectors) with one or more dimensions and operators (either mathematical operations or explicit function calls) that reshape them. Also, several operators use [broadcasting](https://numpy.org/doc/stable/user/basics.broadcasting.html). You are very well encouraged to refreshen that logic before the start.

Since it is easy to lose track of the dimensions involved, we will sometime specify the expected shape as a code comment.
Unless specified, we will use `B` for the batch size, `S` for the sequence length, and `h` for the number of attention heads. When we use an inline comment, the shape refers to the resulting value of the statement, e.g.:

`scores = ...  # (B,...,S,S)`

means that, after the execution, `scores` will have a shape of B, any random number of dimensions, and will end with the two final dimensions with size S. 

📣 📣 📣

For your first exercise, you will implement the scaled dot product attention (Equation above). You will notice that the function accepts a `mask` parameter. The mask allows us to *ignore* some portion of the sequence (typically, if any padding is present).

In [None]:
def scaled_dot_product(q, k, v, mask=None):
    """
    Perform Scaled Dot Product Attention.

    :param q: queries tensor (shape: B,...,S,d_k)
    :param k: keys tensor (shape: B,...,S,d_k)
    :param v: values tensor (shape: B,...,S,d_k)
    :param mask: mask tensor (shape broadcastable to: B,...,S,S)
    :return: attention output (shape: B,...,S,d_k), attention_weights (B,...,S,S)

    EXERCISE
    """
    d_k = q.shape[-1]
    scores = jnp.matmul(q, k.swapaxes(-2, -1)) / jnp.sqrt(d_k)  # (B,...,S,S)

    if mask is not None:
        scores = jnp.where(mask == 0, -1e9, scores)

    attention_weights = jax.nn.softmax(scores, axis=-1)
    values = jnp.matmul(attention_weights, v)
    return values, attention_weights

In [None]:
# Testing Scaled Dot Product
bs = 2
seq_len, d_k = 3, 4
rng = next(rng_iter)

q, k, v = jax.random.normal(rng, (3, bs, seq_len, d_k))
mask = jax.random.randint(rng, (bs, 1, seq_len), minval=0, maxval=2)
values, attention = scaled_dot_product(q, k, v, mask)

print("Values\n", values, values.shape)  # result should be (B,S,d_k)
print("Attention\n", attention, attention.shape)  # result should be (B,S,S)

### The Multi-Headed Attention

The scaled dot product attention allows an element of the sequence to attend to any other element. However, the scaled dot product attention does not allow the element to focus on multiple aspects of the sequence simultaneously.
A solution for this is to use multiple attention heads.

Indeed, the first unit of the encoder applies a *multi-headed self-attention*, meaning that i) words *mix and align among themselves* (self-attention) and ii) multiple, different alignments are learned at once (multi-headed) -- each alignment is imputed to one *attention head*.

This simple learning paradigm -- based on mixing and aligning words in sentences -- paired with a linguistically founded training objective enables the best performing language models.

With the multi-headed attention, we have $h$ attention heads, where each attention head is a linear projection of the sequence $Q$, $K$, and $V$:

$$
attention(Q, K, V) = \text{concat}(head_1,...,head_h)W^O
$$

$$
head_i = attention(QW^Q_i, KW^K_i, VW^V_i)
$$

Where $W^Q_i \in \mathbb{R}^{d_q \times d_k/h}$, $W^K_i \in \mathbb{R}^{d_k \times d_k/h}$, $W^V_i \in \mathbb{R}^{d_v \times d_v/h}$, and $W^O \in \mathbb{R}^{hd_v \times d_v}$. Note that the $d_k$ and $d_v$ have the same dimension, so the $d_v/h$ is the same as the $d_k/h$.

While implementing multi-headed attention, we implement the linear projection $Q$, $K$, and $V$ with matrix multiplication and then split the result into $h$ heads. We then apply the scaled dot product for each attention head independently and concatenate the results.

We provide you with Multi-Headed Attention in the form of a Haiku Module. Take your time to go through the class and understand the main components.

In [None]:
class MultiheadAttention(hk.Module):
    """
    Multihead Attention module.
    :param d_model: dimension of the model
    :param num_heads: number of heads
    """
    
    def __init__(self, d_model: int, num_heads: int, name=None):
        super().__init__(name=name)

        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = self.d_model // self.num_heads
        self.lin_projs = [hk.Linear(self.d_model) for _ in range(4)]

    def __call__(self, q, k, v, mask=None):
        """
        Perform Multi-Headed Attention.

        :param q: queries tensor (B,...,S,d_model)
        :param k: keys tensor (same)
        :param v: values tensor (same)
        :param mask: mask tensor (broadcastable to: B,...,S,S)
        """
        batch_size, seq_length, d_model = q.shape # (B,S,d_k)

        q, k, v = [
            lin_p(t).reshape(batch_size, -1, self.num_heads, self.d_k).swapaxes(1, 2)
            for lin_p, t in zip(self.lin_projs, (q, k, v))
        ]  # (B,h,S,d_k)

        if mask is not None:
            mask = jnp.expand_dims(mask, 1)  # expand to (B,h,...)

        values, attention = scaled_dot_product(q, k, v, mask=mask)  # (B,h,S,d_k)
        values = values.transpose(0, 2, 1, 3)
        values = values.reshape(batch_size, seq_length, d_model)  # concat heads
        return self.lin_projs[-1](values), attention

What you should have noticed (and we are sure you did):
- following the original paper, we set a $d_{model}$ and a $d_k = d_{model} / h$;
- queries, keys, and values have three different projections;
- we do not implement linear projection by ourselves but use [hk.Linear](https://dm-haiku.readthedocs.io/en/latest/api.html#linear) that needs the desired output dimension in the constructor;
- each query/key/value vector is 1) linearly projected 2) reshaped to h heads and a size of $d_k$ 3) swapped axes such that the head axis is at position 1;
- we add a dimension to the mask corresponding to the attention heads, such that the model will mask every attention head equally (equal masks are expected in Encoders, while we will see that there will be different masks in Decoders).  

In [None]:
""" Test MultiheadAttention implementation """
bs = 2
seq_len = 12
d_model = 64
num_heads = 8


def test_mha(q, k, v, mask=None):
    mha = MultiheadAttention(d_model, num_heads, name="mha")
    return mha(q, k, v, mask)


mha = hk.without_apply_rng(hk.transform(test_mha))

# Example features as input
q, k, v = jax.random.normal(next(rng_iter), (3, bs, seq_len, d_model))
mask = jax.random.randint(rng, (bs, 1, seq_len), minval=0, maxval=2)

# Initialize parameters of attention with random key and inputs
params = mha.init(next(rng_iter), q, k, v, mask)

# Apply attention with parameters on the inputs
out, attn = mha.apply(params=params, q=q, k=k, v=v, mask=mask)
print("Out", out.shape, "Attention", attn.shape)
del mha, params

In the last cell, we used:
- `hk.without_apply_rng`: it is a wrapper that let us apply a function without passing `rng` as an argument. As long as the function is not actually using random numbers during computation, we can use `without_apply_rng`.
- `hk.transform`: it is a very handy module in Haiku (also used as a decorator: `@hk.transform`) that allows the definition of a pure function. From the original Haiku documentation:

> The transform function allows you to write neural network functions that rely on parameters (...) without requiring you to explicitly write the boilerplate for initialising those parameters. `transform` does this by transforming the function into a pair of functions that are pure (as required by JAX) init and apply.

## Turning Tokens into Vectors: Embeddings and Positional Encoding 

The Transformer takes a sequence of words (or tokens) represented by dense vectors as input. These vectors are called *embeddings*, and their role is to map words (tokens) into a continuous vector space. The model's input is thus a sequence of vectors obtained by *looking up* the embedding of the corresponding words (tokens) in the vocabulary.

### Embedding Layer \[EXERCISE 📝\]

As an exercise, we ask you to write an `Embeddings` class that takes as input the dimension of the embeddings' vectors for the model (`d_model`) and the size of the vocabulary (`vocab_size`). Your class should implement the `__call__` method that takes as input a sequence of integers, each integer corresponding to a word (token) in the vocabulary, and outputs a sequence of vectors, each vector corresponding to the embedding of the corresponding word (token).

In [None]:
class Embeddings(hk.Module):
    """
    This class is used to create an embedding matrix for a given vocabulary size.
    :param d_model: The size of the embedding vector.
    :param vocab_size: The size of the vocabulary.
    """
    def __init__(self, d_model, vocab_size, name=None):
        super().__init__(name=name)
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embeddings = hk.Embed(self.vocab_size, self.d_model)

    def __call__(self, x):
        """
        :param x: The input sequence.
        :return: The embedding matrix.
        """
        return self.embeddings(x) * jnp.sqrt(self.d_model)

In [None]:
""" Test Embeddings implementation """

bs = 2
seq_len = 12
d_model = 64
num_heads = 8
vocab_size = 100

test_emb = lambda inputs: Embeddings(d_model, vocab_size)(inputs)
emb = hk.without_apply_rng(hk.transform(test_emb))

# example features as input
inputs = jax.random.randint(next(rng_iter), (4, 3), 0, 5)
params = emb.init(next(rng_iter), inputs)
out = emb.apply(params=params, inputs=inputs)
print("Out", out.shape)
del emb, params

### Positional Encoding

The Transformer model does not use recurrent or convolutional layers in the encoder/decoder of the model (only attention mechanisms). However, this also has a drawback: since the model has no memory (no recurrent/convolutional layers), it can not take into account the *order* of the sequence elements. The position of words in the sequence is thus not encoded explicitly by the model.

As a solution to this issue, the original Transformer model uses a *positional encoding* scheme to represent the position of each element in the sequence. The positional encoding is added to the token embeddings of each element. Following the original paper, positional encodings are generated with multiple sinusoidal functions with varying frequencies.

Positional encoding is defined as:

$$\text{PE}(pos, 2i) = \sin \left( \frac{pos}{1000^{2i/d_{\text{model}}}} \right)$$
$$\text{PE}(pos, 2i+1) = \cos \left( \frac{pos}{1000^{2i/d_{\text{model}}}} \right)$$

where $pos$ is the position of the element in the sequence, $d_{\text{model}}$ is the model's embedding dimension, and $i$ is the index of the position vector. Note that this is not a learned parameter; the values are pre-computed and added to the token embeddings at the beginning of the forward pass.

Note that, we can optionally apply dropout to the positional encodings during training, thus providing additional regularization for the model.

📚 **Resources**

- Detailed explanation with visual aids: [Understanding Positional Encoding in Transformers](https://erdem.pl/2021/05/understanding-positional-encoding-in-transformers)

In [None]:
class PositionalEncoding(hk.Module):
    """
    This class is used to add positional encoding to the input sequence.
    :param d_model: The size of the embedding vector.
    :param max_len: The maximum length of the input sequence.
    :param p_dropout: The dropout probability.
    """
    def __init__(self, d_model: int, max_len: int, p_dropout: float = 0.1, name=None):
        """EXERCISE"""
        super().__init__(name=name)
        self.d_model = d_model
        self.max_len = max_len
        self.p_dropout = p_dropout

        pe = jnp.zeros((self.max_len, self.d_model))
        position = jnp.arange(0, self.max_len, dtype=jnp.float32)[:, None]
        div_term = jnp.exp(
            jnp.arange(0, self.d_model, 2) * (-jnp.log(10000.0) / self.d_model)
        )
        pe.at[:, 0::2].set(jnp.sin(position * div_term))
        pe.at[:, 1::2].set(jnp.cos(position * div_term))
        pe = pe[None]
        self.pe = jax.device_put(pe)

    def __call__(self, x, is_train=True):
        """
        :param x: The input sequence.
        :param is_train: Whether the model is in training mode.
        :return: The input sequence with positional encoding.
        """
        """EXERCISE"""
        x = x + self.pe[:, : x.shape[1]]
        if is_train:
            return hk.dropout(hk.next_rng_key(), self.p_dropout, x)
        else:
            return x

# 3️⃣ Transformer Encoder and Word-level Language Modeling

In this section, we will implement the Transformer encoder and apply it to the task of word-level language modeling. We have implemented each base operation in the previous sections, so we will combine all these to train a language model.

## Combining all together: the Transformer Encoder

The Transformer encoder is composed of multiple *encoder blocks*. Each of these blocks comprises two sub-layers: a *multi-head self-attention layer*, and a *feed-forward network*. There is also a residual connection around each sub-layer, followed by *layer normalization*. See the Figure above for a detailed diagram of a single encoder block.

### Feed Forward Sublayer

This sublayer is composed of a fully-connected feed-forward network. The main idea is to learn a linear transformation of the hidden representation of the previous layer. This layer has an inner hidden layer of size `d_ff`, and an inner activation function (e.g., ReLU). The `PositionwiseFeedForward` class below implements this sub-layer. It is initialized using the parameters:

- `d_model`: size of the hidden representation of the input.
- `d_ff`: inner size of the hidden layer.
- `p_dropout`: dropout probability (dropout will be applied during training).

The `PositionwiseFeedForward` class implements the `__call__` method. It takes as input the previous layer's hidden representation and returns the current layer's hidden representation by applying the fully-connected network.

In [None]:
class PositionwiseFeedForward(hk.Module):
    """
    This class is used to create a position-wise feed-forward network.
    :param d_model: The size of the embedding vector.
    :param d_ff: The size of the hidden layer.
    :param p_dropout: The dropout probability.
    """
    def __init__(self, d_model: int, d_ff: int, p_dropout: float = 0.1, name=None):
        super().__init__(name=name)
        self.d_model = d_model
        self.d_ff = d_ff
        self.p_dropout = p_dropout

        self.w_1 = hk.Linear(self.d_ff)
        self.w_2 = hk.Linear(self.d_model)

    def __call__(self, x, is_train=True):
        """
        :param x: The input sequence.
        :param is_train: Whether the model is in training mode.
        :return: The output of the position-wise feed-forward network.
        """
        x = jax.nn.relu(self.w_1(x))
        if is_train:
            x = hk.dropout(hk.next_rng_key(), self.p_dropout, x)

        x = self.w_2(x)
        return x

In the last cell, we used `hk.next_rng_key()`. You can call this haiku utility function 
**only from within a haiku.Module** to get a new PNRGenerator key.

### Encoder Block \[EXERCISE 📝\]

The `EncoderBlock` contains all the components of a single encoder block. It is initialized using the parameters:

- `d_model`: the size of the hidden representation of the input.
- `num_heads`: number of heads in the multi-headed attention layer.
- `d_ff`: the inner size of the hidden layer of the position-wise feed-forward sub-layer.
- `p_dropout`: dropout probability (dropout will be applied during training).

It applies the two sub-layers: the multi-head self-attention layer and the position-wise feed-forward sub-layer.
The `__init__` method is used to initialize the parameters of the encoder block, while the `__call__` method applies the encoder block to an input.

In [None]:
class EncoderBlock(hk.Module):
    """
    This class is used to create an encoder block.
    :param d_model: The size of the embedding vector.
    :param num_heads: The number of attention heads.
    :param d_ff: The size of the hidden layer.
    :param p_dropout: The dropout probability.
    """
    def __init__(self, d_model, num_heads, d_ff, p_dropout, name=None):
        super().__init__(name=name)
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.p_dropout = p_dropout

        # self-attention sub-layer
        self.self_attn = MultiheadAttention(
            d_model=self.d_model, num_heads=self.num_heads
        )
        # positionwise feedforward sub-layer
        self.ff = PositionwiseFeedForward(
            d_model=self.d_model, d_ff=self.d_ff, p_dropout=self.p_dropout
        )

        self.norm1 = hk.LayerNorm(
            axis=-1, param_axis=-1, create_scale=True, create_offset=True
        )
        self.norm2 = hk.LayerNorm(
            axis=-1, param_axis=-1, create_scale=True, create_offset=True
        )

    def __call__(self, x, mask=None, is_train=True):
        """
        It applies the encoder block to the input sequence.
        :param x: The input sequence.
        :param mask: The mask to be applied to the self-attention layer.
        :param is_train: Whether the model is in training mode.
        :return: The output of the encoder block, which is the updated input sequence.
        """
        """
        EXERCISE
        """
        d_rate = self.p_dropout if is_train else 0.0

        # attention sub-layer
        sub_x, _ = self.self_attn(x, x, x, mask=mask)
        if is_train:
            sub_x = hk.dropout(hk.next_rng_key(), self.p_dropout, sub_x)
        x = self.norm1(x + sub_x)  # residual conn

        # feedforward sub-layer
        sub_x = self.ff(x, is_train=is_train)
        if is_train:
            sub_x = hk.dropout(hk.next_rng_key(), self.p_dropout, sub_x)
        x = self.norm2(x + sub_x)  # sub_x

        return x

In [None]:
"""Testing the Encoder block"""

bs = 2
seq_len = 12
d_model = 64
num_heads = 8
d_ff = 128


@hk.transform
def enc_blk(x, mask, is_train):
    bl = EncoderBlock(d_model=d_model, num_heads=num_heads, d_ff=d_ff, p_dropout=0.1)
    return bl(x, mask, is_train)


## Test EncoderBlock implementation
# Example features as input
rng_key = next(rng_iter)
x = jax.random.normal(rng_key, (bs, seq_len, d_model))
mask = jax.random.randint(rng, (bs, 1, seq_len), minval=0, maxval=2)

# Initialize parameters of encoder block with random key and inputs
params = enc_blk.init(rng=rng_key, x=x, mask=mask, is_train=True)

# Apply encoder block with parameters on the inputs
# Since dropout is stochastic, we need to pass a rng to the forward
out = enc_blk.apply(rng=rng_key, params=params, x=x, mask=mask, is_train=True)
print("Out", out.shape)

del enc_blk, params

### Transformer Encoder

As introduced in the previous sections, the Transformer encoder is composed of multiple *encoder blocks*. The `TransformerEncoder` class below implements it by stacking $N$ `EncoderBlock`s, where $N$ is the number of stacked encoder blocks.

This class inputs the same set of parameters as the `EncoderBlock` class and adds the parameter `num_layers` to specify the number of stacked encoder blocks.

In [None]:
class TransformerEncoder(hk.Module):
    """
    This class is used to create a transformer encoder.
    :param num_layers: The number of encoder blocks.
    :param num_heads: The number of attention heads.
    :param d_model: The size of the embedding vector.
    :param d_ff: The size of the hidden layer.
    :param p_dropout: The dropout probability.
    """

    def __init__(self, num_layers, num_heads, d_model, d_ff, p_dropout, name=None):
        super().__init__(name=name)
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.d_model = d_model
        self.d_ff = d_ff
        self.p_dropout = p_dropout

        self.layers = [
            EncoderBlock(self.d_model, self.num_heads, self.d_ff, self.p_dropout)
            for _ in range(self.num_layers)
        ]

    def __call__(self, x: List[int], mask=None, is_train=True):
        """
        It applies the transformer encoder to the input sequence.
        :param x: The input sequence.
        :param mask: The mask to be applied to the self-attention layer.
        :param is_train: Whether the model is in training mode.
        :return: The final output of the encoder that contains the last encoder block output.
        """
        for l in self.layers:
            x = l(x, mask=mask, is_train=is_train)
        return x

In [None]:
"""Testing the Transformer Encoder"""
bs = 2
seq_len = 12
d_model = 64
num_heads = 8
d_ff = 128
num_layers = 6
p_dropout = 0.1


@hk.transform
def transformer_encoder(x, mask, is_train):
    enc = TransformerEncoder(num_layers, num_heads, d_model, d_ff, p_dropout, "t_enc")
    return enc(x, mask, is_train)

## Test TransformerEncoder implementation
# Example features as input
rng_key = next(rng_iter)
x = jax.random.normal(rng_key, (bs, seq_len, d_model))
mask = jax.random.randint(rng, (bs, 1, seq_len), minval=0, maxval=2)

# Initialize parameters of transformer with random key and inputs
params = transformer_encoder.init(rng=rng_key, x=x, mask=mask, is_train=True)

# Apply transformer with parameters on the inputs
# Since dropout is stochastic, we need to pass a rng to the forward
out = transformer_encoder.apply(
    rng=rng_key, params=params, x=x, mask=mask, is_train=True
)
print(out.shape)

del params, transformer_encoder

## 🚀 Training your First Language Model

Before starting, let us recap the object required to pre-train the Transformer encoder:

- ✅ **Model**: Transformer Encoder (which we already implemented)
- 📝 **Dataset**: As the training objective is token-level MLM, we can use any text corpus. In our case, we will use a toy dataset derived from Tatoeba.
- 📝 **Tokenizer**: We need a tokenizer that takes a string and returns a list of tokens. It is in charge of splitting the input text into tokens and mapping each token to a unique integer index. We are going to use the `BPE` [(byte pair encoding)](https://huggingface.co/course/chapter6/5) tokenizer provided by the [tokenizers](https://huggingface.co/docs/tokenizers/index) library.
- 📝 **Training loop**: We need a training loop that iterates over the dataset, computes the loss, back-propagates the gradients, and updates the parameters.

In [None]:
# some global variables
BATCH_SIZE = 64
MASK_PROBABILITY = 0.15
NUM_LAYERS = 6
NUM_HEADS = 8
D_MODEL = 128
D_FF = 256
P_DROPOUT = 0.1
MAX_SEQ_LEN = 128
VOCAB_SIZE = 25000
LEARNING_RATE = 3e-4
GRAD_CLIP_VALUE = 1

### Tatoeba dataset

[Tatoeba](https://tatoeba.org/) is an open and collaborative platform for collecting translations in different languages. It is an excellent resource for machine translation tasks. 


![image](https://huggingface.co/morenolq/m2l_2022_nlp/resolve/main/tatoeba_example.png)

For our toy example, we will use a small subset of the Tatoeba dataset consisting of aligned sentence pairs in Italian and English.

We only need the English sentences from the dataset to train our Transformer encoder. The English-Italian sentence pairs will be used in the next section when we train a Transformer encoder-decoder.

You can download the dataset by running the following cell.

In [None]:
%%capture
# @title Data Download
!curl -LO https://huggingface.co/morenolq/m2l_2022_nlp/resolve/main/it-en.tsv

It will download a `tsv` file named `it-en.tsv`. We can load it using `pandas` and collect only the English sentences we will use for our MLM pre-training.

In [None]:
df = pd.read_csv(
    "it-en.tsv", sep="\t", header=0, names=["id_it", "sent_it", "id_en", "sent_en"]
)
df = df.dropna()

# We will use english sentences to train our encoder with MLM
en_sentences = df["sent_en"].drop_duplicates()
print(f"Unique English sentences: {len(en_sentences)}")
print("Samples:\n", en_sentences[:5])

### Training a BPE Tokenizer

Before starting training our Transformer model, we need to train a tokenizer that we will use to split the input text into tokens. The *tokenizers* library provides many tokenizers, including the `BPE` tokenizer we will use.

BPE tokenization involves the following steps:

1. The corpus is split to obtain a set of characters.
2. Pairs of characters are combined to form sub-words according to some frequency metric.
3. Process at 2. is repeated until the condition on the maximum number of sub-words in the vocabulary is met.
4. The vocabulary is generated by taking the final set of sub-words.

We need a `VOCAB_SIZE` parameter that defines our vocabulary's maximum capacity (number of tokens). We will also leverage another global variable, `MAX_SEQ_LENGTH`, that sets the maximum sentence length to a fixed number of tokens.

🚨🚨🚨 

We usually refer to **tokens** instead of words when training NLP models. Indeed, tokenization involves splitting the text into smaller units, but the latter are not necessarily words. For example, in our case, the tokenizer will split the text into sub-words.

In [None]:
# generating a new BPE tokenizer
tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace()
tokenizer.normalizer = tokenizers.normalizers.Lowercase()

trainer = tokenizers.trainers.BpeTrainer(
    special_tokens=["[PAD]", "[CLS]", "[SEP]", "[MASK]", "[UNK]"], # special tokens
    vocab_size=VOCAB_SIZE, # vocabulary size 
    show_progress=True, # show progress
    min_frequency=2, # minimum frequency of a token to be included in the vocabulary
    continuing_subword_prefix="##", # prefix for subwords that are not the first in a word
)

tokenizer.train_from_iterator(en_sentences, trainer=trainer)

cls_id, sep_id = map(tokenizer.token_to_id, ["[CLS]", "[SEP]"]) # get the ids of the special tokens
tokenizer.post_processor = tokenizers.processors.BertProcessing(
    ("[SEP]", sep_id), ("[CLS]", cls_id)
)
tokenizer.enable_truncation(MAX_SEQ_LEN) # enable truncation to a maximum length
tokenizer.enable_padding(length=MAX_SEQ_LEN) # enable padding to a maximum length

In [None]:
tokenizer.encode("Hello my friend")

In [None]:
tokenizer.decode(tokenizer.encode("Hello my friend").ids, skip_special_tokens=False) # reconstruct the sentence from the ids

In [None]:
# Save it to file
tokenizer.save("en_tokenizer.json") # save the tokenizer in a json file

Train it from disk if you already have it!

In [None]:
tokenizer = tokenizers.Tokenizer.from_file("en_tokenizer.json") # load the tokenizer from a json file
tokenizer.enable_truncation(MAX_SEQ_LEN) # enable truncation to a maximum length
tokenizer.enable_padding(length=MAX_SEQ_LEN) # enable padding to a maximum length

### Data preparation

We have the model ✅ and the tokenizer ✅, and we must prepare the training and validation datasets.
To do so, we split the dataset into two parts: a training set containing 80% of the original corpus and a validation set containing the remaining 20%.

We also use the `DatasetDict` class from `datasets` package to store the training and validation sets. This class provides many methods to manipulate the data efficiently. For example, we can run a pre-processing step to pre-tokenize the text and avoid running the tokenizer during training.

The tokenizer maps each token to an index in the vocabulary creating the `input_ids` vector. The expected output is a vector of the same length as `input_ids` but containing the index of the target tokens.

**Masked Language Modeling (MLM)**

Masked language modeling (MLM) is the task of randomly masking some words in the input and asking the model to guess the original word. It is a *self-supervised* objective that one can use to train the model without any labeled data. Indeed, the expected output for each masked word is simply the index of the original word. Let's see a simple example of MLM below.

![image](https://huggingface.co/morenolq/m2l_2022_nlp/resolve/main/MLM.png)

For training the model, we chose to mask 15% of the tokens in the training set. Given a sentence, we randomly decide to mask a token, and we replace it with the special token `[MASK]`. The model is then trained to predict the original token. 

Using the MLM objective, **we use as labels the original token ids**. During the tokenization step, we set the expected output (e.g., `labels` vector) as the original token ids (`input_ids`). During training, we will randomly mask some tokens and let the model try to predict the original token ids.
The *collate* function (`collate_fn`) will be responsible for this masking step.

🚨 Given the computational resources required for running the pre-training, we only sample 5% of the TatoEBA collection.

In [None]:
DATASET_SAMPLE = 0.05  # @param {type:"number"}

In [None]:
data = df["sent_en"].drop_duplicates()

# sample to ease compute
data = data.sample(frac=DATASET_SAMPLE, random_state=42)

train_df, val_df = train_test_split(data, train_size=0.8, random_state=42)
print("Train", train_df.shape, "Valid", val_df.shape)

In [None]:
raw_datasets = DatasetDict(
    {
        "train": Dataset.from_dict({"text": train_df.tolist()}),
        "valid": Dataset.from_dict({"text": val_df.tolist()}),
    }
)

In [None]:
def preprocess(examples: Dict[str, List[str]]) -> Dict[str, List[str]]:
    """
    This function tokenizes the input sentences and adds the special tokens.
    :param examples: The input sentences.
    :return: The tokenized sentences.
    """
    out = tokenizer.encode_batch(examples["text"])
    return {
        "input_ids": [o.ids for o in out],
        "attention_mask": [o.attention_mask for o in out],
        "special_tokens_mask": [o.special_tokens_mask for o in out],
        # "labels": [o.ids for o in out], # we don't need labels!
    }


proc_datasets = raw_datasets.map(
    preprocess, batched=True, batch_size=4000, remove_columns=["text"]
)
proc_datasets["train"]

In [None]:
def collate_fn(batch):
    """
    Collate function that prepares the input for the MLM language modeling task.
    The input tokens are masked according to the MASK_PROBABILITY to generate the 'labels'.

    EXERCISE
    """
    input_ids = jnp.array([s["input_ids"] for s in batch])
    attention_mask = jnp.array([s["attention_mask"] for s in batch])
    special_tokens_mask = jnp.array([s["special_tokens_mask"] for s in batch])
    labels = input_ids.copy()

    special_tokens_mask = special_tokens_mask.astype("bool")
    masked_indices = jax.random.bernoulli(
        next(rng_iter), MASK_PROBABILITY, labels.shape
    ).astype("bool")
    masked_indices = jnp.where(special_tokens_mask, False, masked_indices)

    # Set labels to -100 for non-[MASK] tokens (we will use this while defining the loss function)
    labels = jnp.where(~masked_indices, -100, labels)

    input_ids = jnp.where(masked_indices, tokenizer.token_to_id("[MASK]"), input_ids)

    item = {
        "input_ids": input_ids,
        "attention_mask": jnp.expand_dims(
            attention_mask, 1
        ),  # attention mask must be broadcastable to (B,...,S,S)!
        "labels": labels,
    }
    return item


train_loader = DataLoader(
    proc_datasets["train"], batch_size=BATCH_SIZE, collate_fn=collate_fn
)
valid_loader = DataLoader(
    proc_datasets["valid"], batch_size=BATCH_SIZE, collate_fn=collate_fn
)

In the last cell, we used `torch.utils.data.DataLoader`. A **dataloader** is a container that provides an iterable interface over a dataset. It handles the batching and shuffling and is useful for providing data to the training and validation loops. It also provides a specific parameter to use a `collate_fn` which is the function that handles the creation of the batches. In our example, this is where we randomly mask some tokens for the MLM objective.

### Defining a Language Model (with a JAX/Haiku Transform)

At this point, we have the model ✅, the tokenizer ✅, and the data for training and validation ✅. The next step is to define the training loop and all the steps that need to be done inside it.

Similarly to each component of the mode, we will implement the training loop using [Haiku](https://github.com/deepmind/dm-haiku). Before implementing our model let's first recall a very important concept in JAX/Haiku: **the model must be a pure function**. This means that it cannot access any data that is not passed to it. This is a very powerful concept because it makes it really easy to parallelize your model and it allows for automatic differentiation 💪.

Thanks to the `hk.transform` module, we can define a function `mlm_language_model` that takes as input the `input_ids` and the `mask` and runs the model. It also takes as input a flag `is_train` that indicates whether we are training or evaluating the model. This is important because we need to know when to use the `dropout` operations (i.e., only during training).

In [None]:
@hk.transform
def mlm_language_model(input_ids, mask, is_train=True):
    """
    MLM language model as an haiku pure transformation.
    :param input_ids: The input token ids.
    :param mask: The attention mask.
    :param is_train: Whether the model is in training mode.
    :return: The logits corresponding to the output of the model.
    """
    
    """
    EXERCISE
    """
    pe = PositionalEncoding(D_MODEL, MAX_SEQ_LEN, P_DROPOUT)
    embeddings = Embeddings(D_MODEL, VOCAB_SIZE)
    encoder = TransformerEncoder(NUM_LAYERS, NUM_HEADS, D_MODEL, D_FF, P_DROPOUT)

    # get input token embeddings
    input_embs = embeddings(input_ids)
    if len(input_embs.shape) == 2:
        input_embs = jnp.expand_dims(input_embs, 0)  # (1,MAX_SEQ_LEN,D_MODEL)

    # sum positional encodings
    input_embs = pe(input_embs, is_train=is_train)  # (B,MAX_SEQ_LEN,d_model)

    # encode using the transformer encoder stack
    output_embs = encoder(input_embs, mask=mask, is_train=is_train)

    # decode each position into a probability distribution over vocabulary tokens
    out = hk.Linear(D_MODEL, name="dec_lin_1")(output_embs)
    out = jax.nn.relu(out)
    out = hk.LayerNorm(
        axis=-1, param_axis=-1, create_scale=True, create_offset=True, name="dec_norm"
    )(out)
    out = hk.Linear(VOCAB_SIZE, name="dec_lin_2")(out)  # logits
    return out

In [None]:
# testing the LM
input_ids = jnp.array(tokenizer.encode("Hello my friend").ids) # encode a sentence
rng_key = next(rng_iter) # get a new random key
mask = jax.random.randint(rng, (1, 1, input_ids.shape[-1]), minval=0, maxval=2) # create a mask
params = mlm_language_model.init(rng_key, input_ids, None, True) # initialize the model
out = mlm_language_model.apply(
    params=params, rng=rng_key, input_ids=input_ids, mask=None, is_train=True
) # apply the model to the input sentence encoded at the previous step
print(out.shape)  # output should be of shape (1,MAX_SEQ_LEN,VOCAB_SIZE)

### Training accessories 💍

Before writing the training loop, we need to define some accessories used during the training. These accessories include the **training state** (e.g., the mode parameters and the optimizer state), the **loss function**, and the **train and evaluation steps**.

**Training state**

The training state will allow us to keep track of the training progress and contains all the information we need, e.g., the model parameters and the optimizer. Implementing the model using JAX makes it easy to define a training state.

In [None]:
class TrainingState(NamedTuple):
    """
    The training state is a named tuple containing the model parameters and the optimizer state.
    """
    params: hk.Params # model parameters
    opt_state: optax.OptState # optimizer state

Before running the actual training, we need to initialize the network (you have already seen this when testing the previous modules) and an optimizer. 

We will use the `Adam` optimizer, which is a gradient-based optimization algorithm that adapts the learning rate based on the estimated first and second moments of the gradients. It is a very popular optimization algorithm and has shown great results in practice.

**Resources**
- Adam optimizer: [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980)

In [None]:
# Initialise network and optimiser; note we draw an input to get shapes.
sample = proc_datasets["train"][0]
input_ids, attention_mask = map(
    jnp.array, (sample["input_ids"], sample["attention_mask"])
)
rng_key = next(rng_iter)
init_params = mlm_language_model.init(rng_key, input_ids, attention_mask, True)

optimizer = optax.chain(
    optax.clip_by_global_norm(GRAD_CLIP_VALUE),
    optax.adam(LEARNING_RATE),
)
init_opt_state = optimizer.init(init_params)

# initialize the training state class
state = TrainingState(init_params, init_opt_state)

**Loss Function** \[EXERCISE 📝\]

The loss function is the objective that we want to minimize during training. In general, the loss function needs to be differentiable to compute the gradient of the error using automatic differentiation. In our case, we will use the *Cross Entropy* loss traditionally used for classification tasks. The `optax` library has a function that allows us to easily define the loss function ([see the docs here](https://optax.readthedocs.io/en/latest/api.html#optax.softmax_cross_entropy_with_integer_labels)).

🚨🚨🚨

While implementing the loss function, make sure to carefully manage *padding*. You may not want to consider the padding positions when calculating the loss function. Thus, the loss function should only consider the valid positions.

In [None]:
def loss_fn(params: hk.Params, batch, rng) -> jnp.ndarray:
    """
    The loss function for the MLM language modeling task.
    It computes the cross entropy loss between the logits and the labels.

    :param params: The model parameters.
    :param batch: The batch of data.
    :param rng: The random number generator.
    :return: The value of the loss computed on the batch.

    EXERCISE
    """
    logits = mlm_language_model.apply(
        params=params,
        rng=rng,
        input_ids=batch["input_ids"],
        mask=batch["attention_mask"],
        is_train=True,
    )    
    
    label_mask = jnp.where(batch["labels"] > 0, 1.0, 0.0)
    # if the number is negative, jax.nn.one_hot() return a jnp.zeros(VOCAB_SIZE)
    loss = optax.softmax_cross_entropy(logits, jax.nn.one_hot(batch["labels"], VOCAB_SIZE)) * label_mask
    loss = jnp.where(jnp.isnan(loss), 0, loss)
    
    # take average
    loss = loss.sum() / label_mask.sum()
    return loss

**Training and Evaluation steps** \[EXERCISE 📝\]

The training and evaluation steps are the core of the training loop. They implement the training loop logic.

**Training step**: For each batch, it should (i) forward propagate the batch through the model, (ii) compute the loss and the gradient and then (iii) update the model parameters using the optimizer.

**Evaluation step**: For each batch, it should (i) forward propagate the batch through the model and then (ii) compute and return the loss that corresponds to the current model parameters.

In [None]:
@jax.jit
def train_step(state, batch, rng_key) -> TrainingState:
    """
    The training step function. It computes the loss and gradients, and updates the model parameters.
    
    :param state: The training state.
    :param batch: The batch of data.
    :param rng_key: The key for the random number generator.
    :return: The updated training state, the metrics (training loss) and the random number generator key.
    """
    rng_key, rng = jax.random.split(rng_key)

    loss_and_grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = loss_and_grad_fn(state.params, batch, rng_key)

    updates, opt_state = optimizer.update(grads, state.opt_state)
    params = optax.apply_updates(state.params, updates)

    new_state = TrainingState(params, opt_state)
    metrics = {"train_loss": loss}

    return new_state, metrics, rng_key

In [None]:
@jax.jit
def eval_step(params: hk.Params, batch) -> jnp.ndarray:
    """
    The evaluation step function. It computes the loss on the batch.
    
    :param params: The model parameters.
    :param batch: The batch of data.
    :return: The value of the loss computed on the batch.

    EXERCISE
    """
    logits = hk.without_apply_rng(mlm_language_model).apply(
        params=params,
        input_ids=batch["input_ids"],
        mask=batch["attention_mask"],
        is_train=False,
    )

    label_mask = jnp.where(batch["labels"] > 0, 1.0, 0.0)
    # if the number is negative, jax.nn.one_hot() return a jnp.zeros(VOCAB_SIZE)
    loss = optax.softmax_cross_entropy(logits, jax.nn.one_hot(batch["labels"], VOCAB_SIZE)) * label_mask
    loss = jnp.where(jnp.isnan(loss), 0, loss)
    # take average
    loss = loss.sum() / label_mask.sum()
    return loss

### The Training Loop

The training loop will execute the training and evaluation steps by iterating over the training and validation datasets. It relies on hyperparameters such as the number of epochs `EPOCHS` and the number of steps between each evaluation `EVAL_STEPS` (you typically do not want to wait until the end of the epoch to assess your model, nor do it so often that the training slows down).

**Checkpointing**

The training loop also includes the checkpointing logic, which saves the model parameters to disk at each evaluation step if the loss on the evaluation has improved. 

**Debugging**

Unfortunately, debugging JIT-ed code (as the one we are using within our training loop) can be pretty tricky. It is because JAX compiles the functions before executing them, so it is impossible to set breakpoints or print traces.
If you want to set checkpoints or print variables, you can comment out `@jax.jit` from either your `train_step` or `eval_step` definitions.

Read [here](https://github.com/google/jax/issues/196) why you cannot print in JIT-compiled functions.

**Experiment tracking**

Tracking is your training dynamics if fundamental to inspect if any bug occurs or everything proceeds as expected. Today, many tracking tools expose handy API to streamline experiment tracking. Today, we will use Tensorboard, which is easy to integrate into Jupyter Lab / Google Colab. 

First, we set a `LOG_STEPS` variable responsible for tracking the training loss for each fixed number of steps. Then, we use a `SummaryWriter` object to log metrics every `LOG_STEPS`. Finally, we can observe our logged metrics by opening a dedicated tab within a notebook cell: execute the following cell to load the tensorboard extension (if you are running the notebook locally, you have to install tensorboard beforehand) and open it.

In [None]:
%load_ext tensorboard
%tensorboard --logdir runs

In [None]:
# The training loop
# It is a simple for loop that iterates over the training set and evaluates on the validation set.

# The hyperparameters used for training and evaluation
EPOCHS = 30  # @param {type:"number"}
EVAL_STEPS = 500  # @param {type:"number"}
MAX_STEPS = 200  # @param {type:"number"}
LOG_STEPS = 200

writer = SummaryWriter()
pbar = tqdm(desc="Train step", total=EPOCHS * len(train_loader))
step = 0
loop_metrics = {"train_loss": None, "eval_loss": None}
best_eval_loss = float("inf")

for epoch in range(EPOCHS):

    for batch in train_loader:

        state, metrics, rng_key = train_step(state, batch, rng_key)
        loop_metrics.update(metrics)
        pbar.update(1)
        step += 1

        # Evaluation loop, no optimization is involved here.
        if step % EVAL_STEPS == 0:
            ebar = tqdm(desc="Eval step", total=len(valid_loader), leave=False)

            losses = list()
            for batch in valid_loader:
                loss = eval_step(state.params, batch)
                losses.append(loss)
                ebar.update(1)
            ebar.close()

            eval_loss = jnp.array(losses).mean()
            loop_metrics["eval_loss"] = eval_loss

            writer.add_scalar("Loss/valid", loop_metrics["eval_loss"].item(), step)

            if eval_loss.item() < best_eval_loss:
                best_eval_loss = eval_loss.item()
                # Save the params training state (and params) to disk
                with open(f"ckpt_train_state_{step}.pkl", "wb") as fp:
                    pickle.dump(state, fp)

        if step % LOG_STEPS == 0:
            writer.add_scalar("Loss/train", loop_metrics["train_loss"].item(), step)

        pbar.set_postfix(loop_metrics)

pbar.close()

Let's focus on a different task for our Transformer Encoder: **Sentiment Analysis**. This task requires determining if the sentiment of a given piece of text is positive or negative.

Remember, the pre-training implemented before is not specific to any downstream task: we were training our Transformer on a broad set of input data and learning representations that can be useful when *adapted* to different tasks. Indeed, we will now build a classifier based on our pre-trained Transformer to perform sentiment analysis.

In practice, we will train a language model with a sentiment classification *head*: instead of producing a probability distribution across all the words in our vocabulary, we will produce a class probability over a set of labels.

We will use the Stanford Sentiment Treebank V2 dataset (SST-2). The dataset contains 11,855 sentences extracted from movie reviews. The sentences have been labeled with positive (1) or negative (0) sentiments.

📚 **Resources**

- SST-2 paper: [Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank](https://aclanthology.org/D13-1170/)
- SST-2 fields description: [SST-2 on datasets](https://huggingface.co/datasets/sst2)

In [None]:
raw_datasets = load_dataset("glue", "sst2")
raw_datasets

## 🧱 Generating a baseline model with TF-IDF and Logistic Regression

We will first build a baseline model to compare with the Transformer model. For this, we will use a TF-IDF representation of the sentences plus a Logistic Regression classifier. The baseline model will be trained using the training set of the SST-2 dataset and evaluated on the evaluation set (unfortunately, the labels of the SST-2 test set dataset are not publicly available).

Run the cell below to train the baseline model.

In [None]:
# TF-IDF baseline
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report

MAX_FEATURES = 10_000

X_train, y_train = raw_datasets["train"]["sentence"], raw_datasets["train"]["label"]
X_test, y_test = (
    raw_datasets["validation"]["sentence"],
    raw_datasets["validation"]["label"],
)

tfidf = TfidfVectorizer(max_features=MAX_FEATURES) # instantiate the vectorizer
tfidf = tfidf.fit(X_train + X_test) # fit on all data

X_train = tfidf.transform(X_train) # transform the training data
X_test = tfidf.transform(X_test) # transform the test data

clf = LogisticRegression(random_state=42).fit(X_train, y_train) # train the classifier
y_pred = clf.predict(X_test) # run the classifier on the test data

clf_report = classification_report(y_test, y_pred)
print("\n Accuracy: ", accuracy_score(y_test, y_pred))
print("\nClassification Report")
print("======================================================")
print("\n", clf_report)

## Preprocessing the SST-2 dataset

After downloading the dataset, we can process the sentences with our tokenizer. Similarly to the pre-training step, we preprocess all sentences offline and avoid doing it at each training iteration.

Once data are processed, we can create the data loaders that will feed our training loop.

In [None]:
tokenizer = tokenizers.Tokenizer.from_file("en_tokenizer.json")
tokenizer.enable_truncation(True)
tokenizer.enable_padding(length=MAX_SEQ_LEN)

In [None]:
def preprocess(examples: Dict[str, List[str]]) -> Dict[str, List[str]]:
    out = tokenizer.encode_batch(examples["sentence"])
    return {
        "input_ids": [o.ids for o in out],
        "attention_mask": [o.attention_mask for o in out],
        "special_tokens_mask": [o.special_tokens_mask for o in out],
        "labels": examples["label"],
    }


proc_datasets = raw_datasets.map(
    preprocess, batched=True, batch_size=4000, remove_columns=["sentence", "idx"]
)

In [None]:
def collate_fn(batch):
    """
    Collate function to generate the input for the model and their corresponding labels.
    In this case, the labels corresponds to the expected class of each input sequence.
    """
    item = {
        "input_ids": jnp.array([s["input_ids"] for s in batch]),
        "attention_mask": jnp.expand_dims(
            jnp.array([s["attention_mask"] for s in batch]), 1
        ),
        "labels": jnp.array([s["labels"] for s in batch]),
    }
    return item


train_loader = DataLoader(
    proc_datasets["train"], batch_size=BATCH_SIZE, collate_fn=collate_fn
)
valid_loader = DataLoader(
    proc_datasets["validation"], batch_size=BATCH_SIZE, collate_fn=collate_fn
)
test_loader = DataLoader(
    proc_datasets["test"], batch_size=BATCH_SIZE, collate_fn=collate_fn
)

## Sentiment Classifier as a Haiku Transform

Similarly to the model trained using MLM, we can write our sentiment classifier as a Haiku Transform.

Our function will contain all the components necessary to perform the classification:
- a pre-trained Transformer Encoder as the base model;
- a Linear layer on top of it. The linear layer maps the Transformer's output representations corresponding to the `[CLS]` token (the first token of the sequence) to two values, the logits of positive and negative sentiment.

We compute Cross Entropy between the prediction and the ground truth labels. Unlike the MLM case, our labels are 0 or 1, so we can use them directly as targets in supervised learning settings.

In [None]:
@hk.transform
def sentiment_classifier(input_ids, mask, is_train=True):
    """
    The sentiment classifier model implemented using Haiku. Each input sequence is
    passed through a transformer encoder and the output is passed through a linear
    layer to obtain the logits.

    :param input_ids: The input sequences.
    :param mask: The attention mask.
    :param is_train: Whether the model is in training mode or not.
    :return: The logits.
    """

    """    
    EXERCISE
    """
    pe = PositionalEncoding(D_MODEL, MAX_SEQ_LEN, P_DROPOUT)
    embeddings = Embeddings(D_MODEL, VOCAB_SIZE)
    encoder = TransformerEncoder(NUM_LAYERS, NUM_HEADS, D_MODEL, D_FF, P_DROPOUT)

    input_embs = embeddings(input_ids)
    if len(input_embs.shape) == 2:
        input_embs = input_embs[None, :, :]
    input_embs = pe(input_embs, is_train=is_train)  # (B,S,d_model)
    output_embs = encoder(input_embs, mask=mask, is_train=is_train)

    # final decoder layer
    out = hk.Linear(D_MODEL)(output_embs)
    out = jax.nn.relu(out)
    out = hk.LayerNorm(axis=-1, param_axis=-1, create_scale=True, create_offset=True)(
        out
    )
    out = out[:, 0, :]  # we use the [CLS] token embedding to represent the sequence and pass it through a linear layer
    out = hk.Linear(2)(out)  # logits
    return out


def loss_fn(params: hk.Params, batch, rng) -> jnp.ndarray:
    """
    The loss function for the model. It takes the model parameters, the input batch
    and the random number generator as input and returns the loss.

    :param params: The model parameters.
    :param batch: The input batch.
    :param rng: The random number generator.
    :return: The loss.
    """
    logits = sentiment_classifier.apply(
        params=params,
        rng=rng,
        input_ids=batch["input_ids"],
        mask=batch["attention_mask"],
        is_train=True,
    )
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch["labels"])
    return loss.mean()


@jax.jit
def deterministic_forward(params: hk.Params, batch) -> jnp.ndarray:
    """
    This function is used to forward the model in a deterministic way. 
    It uses without_apply_rng to disable the use of the random number generator.
    It takes the model parameters and the input batch as input and returns the logits.

    :param params: The model parameters.
    :param batch: The input batch.
    :return: The logits.
    """
    return hk.without_apply_rng(sentiment_classifier).apply(
        params=params,
        input_ids=batch["input_ids"],
        mask=batch["attention_mask"],
        is_train=False,
    )


def eval_step(params: hk.Params, batch) -> jnp.ndarray:
    """Evaluation step."""
    logits = deterministic_forward(params, batch)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch["labels"])
    acc = (logits.argmax(-1) == batch["labels"]).sum() / batch["labels"].shape[0]
    return loss.mean(), acc

## Training loop

The training loop is similar to the MLM case. However, we do not need to mask the input to the Transformer since we are not training for a language modeling task.

Given the supervised settings, we can evaluate the model against the validation set during training. We will monitor the validation accuracy, i.e., the percentage of correctly classified samples. The training loop tracks both the loss and the validation accuracy so that we can observe their dynamics during training.

In [None]:
# Initialise network and optimiser; note we draw an input to get shapes.
sample = proc_datasets["train"][0]
input_ids, attention_mask = map(
    jnp.array, (sample["input_ids"], sample["attention_mask"])
)
rng = next(rng_iter)
init_params = sentiment_classifier.init(rng, input_ids, attention_mask, True)

optimizer = optax.chain(
    optax.clip_by_global_norm(GRAD_CLIP_VALUE),
    optax.adam(learning_rate=5e-5),
)
init_opt_state = optimizer.init(init_params)

# initialize the training state class
state = TrainingState(init_params, init_opt_state)

In [None]:
# Training & evaluation loop.

EPOCHS = 10
EVAL_STEPS = 500
LOG_STEPS = 100

writer = SummaryWriter()
pbar = tqdm(desc="Train step", total=EPOCHS * len(train_loader))
step = 0
loop_metrics = {"train_loss": None, "eval_loss": None}
best_eval_loss = float("inf")
best_eval_acc = float("-inf")

for epoch in range(EPOCHS):

    for batch in train_loader:
        # Do SGD on a batch of training examples.
        state, metrics, rng_key = train_step(state, batch, rng_key)
        loop_metrics.update(metrics)
        pbar.update(1)
        step += 1

        if step % EVAL_STEPS == 0:
            metrics = list()
            for batch in tqdm(
                valid_loader, desc="Eval", total=len(valid_loader), leave=False
            ):
                metrics.append(eval_step(state.params, batch))

            eval_loss = jnp.array([m[0] for m in metrics]).mean()
            eval_acc = jnp.array([m[1] for m in metrics]).mean()
            loop_metrics["eval_loss"] = eval_loss
            loop_metrics["eval_acc"] = eval_acc

            writer.add_scalar("Loss/valid", loop_metrics["eval_loss"].item(), step)
            writer.add_scalar("Acc/valid", loop_metrics["eval_acc"].item(), step)

            if eval_acc.item() > best_eval_acc:
                best_eval_loss = eval_loss.item()
                best_eval_acc = eval_acc.item()
                best_eval_ckpt = f"sentiment_class_state_{step}.pkl"

                print(best_eval_acc, best_eval_ckpt)
                # Save the params training state (and params) to disk
                with open(best_eval_ckpt, "wb") as fp:
                    pickle.dump(state, fp)

        if step % LOG_STEPS == 0:
            writer.add_scalar("Loss/train", loop_metrics["train_loss"].item(), step)

        pbar.set_postfix(loop_metrics)

### Evaluate the sentiment classification model

After training, we can classify the entire evaluation set using the best checkpoint available and compute the classification report.

Set the variable `model_checkpoint_path` to choose which checkpoint to evaluate.

In [None]:
model_checkpoint_path = "..."
with open(model_checkpoint_path, "rb") as fp:
    state = pickle.load(fp)


def classify(params, tokenizer, batch):
    """Classify a batch of text."""
    logits = deterministic_forward(params, batch) # (B,2)
    return logits.argmax(-1)


y_pred = list()
y_true = list()
for batch in tqdm(valid_loader, desc="Eval", total=len(valid_loader), leave=False):
    out = classify(state.params, tokenizer, batch)
    y_pred.extend(out)
    y_true.extend(batch["labels"])

clf_report = classification_report(y_true, y_pred)
print("\n Accuracy: ", accuracy_score(y_test, y_pred))
print("\nClassification Report")
print("======================================================")
print("\n", clf_report)

### Randomly initialized 👶 vs. pre-trained 🏋🏼 model

The Transformer encoder trained above is randomly initialized and is trained from scratch. However, the great success of Transformers in many NLP tasks is due to their excellent performance when fine-tuned starting from a pre-trained model. Using a pre-trained model allows us to leverage the large amount of training data used to train the model from scratch for the task we are interested in.

As a result, we can fine-tune the Transformer encoder starting from the pre-trained weights obtained with the MLM objective. The following cell loads the weights of the pre-trained Transformer in the new model we are training for sentiment classification.

Although the model architecture is similar, the sentiment classification model has an additional Linear layer on top of the Transformer, which is randomly initialized. Therefore, we do not (and can not) load the weights of this additional layer into the model.

If you compare the performance of the two models, *which one can reach the highest accuracy?*

You can use both:
- a model that we have already pre-trained for you on large English corpora or,
- the model pre-trained in the previous section


**Use our pre-trained model:** Since everything in Haiku is stateless, you can download the `TrainingState` (which contains the model parameters) [here](LINK) and use it in your subsequent `apply` call.
You should also load the pre-trained tokenized paired with the model in this case.

You will need to take care of the hyper-parameters to match the one we used to train it:

```python
NUM_LAYERS = 6
NUM_HEADS = 8
D_MODEL = 128
D_FF = 256
P_DROPOUT = 0.1
MAX_SEQ_LEN = 128
VOCAB_SIZE = 25000
```
Run the cell below to download the files.

**Your own pre-training**: Similarly, you can load the best checkpoint obtained above by loading the checkpoint saved before.


In [None]:
%%capture
# download pre-trained model and tokenizer
!wget https://huggingface.co/morenolq/m2l_2022_nlp/resolve/main/v1_mlm_train_state_362000.pkl
!wget https://huggingface.co/morenolq/m2l_2022_nlp/raw/main/v1_en_tokenizer_1M.json

In [None]:
# model parameters
NUM_LAYERS = 6
NUM_HEADS = 8
D_MODEL = 128
D_FF = 256
P_DROPOUT = 0.1
MAX_SEQ_LEN = 128
VOCAB_SIZE = 25_000

In [None]:
# Initialise network and optimiser; note we draw an input to get shapes.
sample = proc_datasets["train"][0]
input_ids, attention_mask = map(
    jnp.array, (sample["input_ids"], sample["attention_mask"])
)
rng = next(rng_iter)
init_params = sentiment_classifier.init(rng, input_ids, attention_mask, True)

optimizer = optax.chain(
    optax.clip_by_global_norm(GRAD_CLIP_VALUE),
    optax.adam(learning_rate=5e-5),
)
init_opt_state = optimizer.init(init_params)

# initialize the training state class
state = TrainingState(init_params, init_opt_state)

# load the pre-trained model
pretrained_model_path = "v1_mlm_train_state_362000.pkl"

with open(pretrained_model_path, "rb") as fp:
    pretrained_state = pickle.load(fp)

# load the weights from the pre-trained model to the new model
encoder_weights = {
    k: v for k, v in pretrained_state.params.items() if k in state.params
}
print("Found", len(encoder_weights), "pretrained weights")
state.params.update(encoder_weights)

# load pre-trained tokenizer
pretrained_tokenizer_path = "v1_en_tokenizer_1M.json"
tokenizer = tokenizers.Tokenizer.from_file(pretrained_tokenizer_path)
tokenizer.enable_truncation(MAX_SEQ_LEN)
tokenizer.enable_padding(length=MAX_SEQ_LEN)

At this point, it is possible to run the training process using the pre-trained weights as a starting point and run the final evaluation to check the model's performance.

In [None]:
EPOCHS = 10
EVAL_STEPS = 500
LOG_STEPS = 100

writer = SummaryWriter()
pbar = tqdm(desc="Train step", total=EPOCHS * len(train_loader))
step = 0
loop_metrics = {"train_loss": None, "eval_loss": None}
best_eval_loss = float("inf")
best_eval_acc = float("-inf")

for epoch in range(EPOCHS):

    for batch in train_loader:
        # Do SGD on a batch of training examples.
        state, metrics, rng_key = train_step(state, batch, rng_key)
        loop_metrics.update(metrics)
        pbar.update(1)
        step += 1

        if step % EVAL_STEPS == 0:
            metrics = list()
            for batch in tqdm(
                valid_loader, desc="Eval", total=len(valid_loader), leave=False
            ):
                metrics.append(eval_step(state.params, batch))

            eval_loss = jnp.array([m[0] for m in metrics]).mean()
            eval_acc = jnp.array([m[1] for m in metrics]).mean()
            loop_metrics["eval_loss"] = eval_loss
            loop_metrics["eval_acc"] = eval_acc

            writer.add_scalar("Loss/valid", loop_metrics["eval_loss"].item(), step)
            writer.add_scalar("Acc/valid", loop_metrics["eval_acc"].item(), step)

            if eval_acc.item() > best_eval_acc:
                best_eval_loss = eval_loss.item()
                best_eval_acc = eval_acc.item()
                best_eval_ckpt = f"ft_sentiment_class_state_{step}.pkl"

                print(best_eval_acc, best_eval_ckpt)
                # Save the params training state (and params) to disk
                with open(best_eval_ckpt, "wb") as fp:
                    pickle.dump(state, fp)

        if step % LOG_STEPS == 0:
            writer.add_scalar("Loss/train", loop_metrics["train_loss"].item(), step)

        pbar.set_postfix(loop_metrics)

### Evaluate the sentiment classification model

At the end of the training process, we can run the full evaluation of the best checkpoint available and print the classification report.

In [None]:
model_checkpoint_path = "..."
with open(model_checkpoint_path, "rb") as fp:
    state = pickle.load(fp)


y_pred = list()
y_true = list()
for batch in tqdm(valid_loader, desc="Eval", total=len(valid_loader), leave=False):
    out = classify(state.params, tokenizer, batch)
    y_pred.extend(out)
    y_true.extend(batch["labels"])

clf_report = classification_report(y_true, y_pred)
print("\n Accuracy: ", accuracy_score(y_test, y_pred))
print("\nClassification Report - Pre-Training + Fine-Tuning")
print("======================================================")
print("\n", clf_report)

# 5️⃣ Transformer and Neural Machine Translation 

Until now, we focused entirely on the Encoder part of the architecture. As we have seen, one can build a Language Model that can work as a generic "encoder" of tokens and then fine-tune it.

However, as we examined in Section 1 of the tutorial, the Transformer architecture also has a **Decoder** part, i.e., a neural network that can "decode" some sequence into something else.\*

Indeed, the original application of the Transformer was Neural Machine Translation (NMT, or MT). In this task, the network is presented with samples composed of a *source sentence* in a given language, say English, and a *target sentence* in a different language, say Italian, as the result of the translation.

In this part of the tutorial, we will conclude the Transformer implementation by sketching the Transformer Decoder. Then, we will test the complete architecture in the task of MT. Specifically, we will train an MT system from scratch on the TatoEBA dataset. 

Before we start coding, let's review how the Transformer learns to map a source sequence into a target sequence.


\* *There exist some language models trained to predict the next word in a sentence. These models are "decoder-only" as their only purpose is to "decode" a given context (the words already seen) into the subsequent token. You might have heard about some of them: GPT-\*, OPT, or BLOOM.*

## Mapping Sources ➡️ and Targets ⬅️

But how does the Transformer work? How do we train it in a sequence-to-sequence setup, and how do we use it to transform a source into a target sequence? Let's cover these questions before turning to the actual code.

**How does it work?**

The Transformer processes a source sequence producing a *contextualized, dense vector representation* of each token of the sequence. This step is what you implemented in Sections 2️⃣ and 3️⃣ with the Encoder.

At the same time, the model processes the target sequence with the Decoder. However, the Decoder does not work in isolation: we inject the source sequence and let the model learn a mapping between source and target. This operation is what you will implement in the remainder of this section.

**How do I train it?**

We provide the model with both the source and target sequences at training time. The Encoder contextualizes the source tokens, and the Decoder **distills** this information in the target sequence representations, using a **cross-attention layer**. Additionally, each token in the Decoder is processed in an **autoregressive/masked** fashion, meaning that it expresses an attention weight only to previous/past tokens.

☝️ these *source distillation* and *autoregressive/masked attention* features are the two crucial differences between the Encoder and Decoder.


**How do I use it to go from source to target?**

Once it is trained, you have to "decode" the source sentence. You typically do that by:
1. encoding the source, as usual;
2. produce one token at a time. Each token will run self-attention on the past decoded tokens and cross-attention on the source sentence.

See the animation below for a graphical representation of decoding.

![mt_example](https://github.com/g8a9/graphics/blob/main/mt_example.gif?raw=true)


## The Transformer Decoder

As for the encoder, the decoder is a stack of identical blocks. Again, then, let's first define the single decoder block. You can refer to the image (right) in Section 1 to see all the components we need to code to make the decoder work. 

The Decoder block is composed of:
1. a Masked Multi-Head Attention layer. Here, "masked" refers to the autoregressive property of self-attention in the decoder. Specifically, we want to force an arbitrary token at position *i* to express a non-zero attention weight to preceding tokens. But we don't have to take care of it now: the trick 💡 is to use a particular attention mask, which we will see later.
2. a Cross-Attention layer. Here is where the magic happens. The decoder receives the Queries and Keys from the encoder. We will call them "memory."
3. a Feed-Forward layer with element-wise non-linear activation.
4. Skip connections and Layer Normalization after each sub-layers.

### Decoder Block \[EXERCISE 📝\]

In [None]:
class DecoderBlock(hk.Module):
    """
    Transformer decoder block.

    :param d_model: dimension of the model.
    :param num_heads: number of attention heads.
    :param d_ff: dimension of the feedforward network model.
    :param p_dropout: dropout rate.
    """

    def __init__(self, d_model, num_heads, d_ff, p_dropout, name=None):
        super().__init__(name=name)
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.p_dropout = p_dropout

        # self-attention sub-layer
        self.self_attn = MultiheadAttention(
            d_model=self.d_model, num_heads=self.num_heads
        )
        # src-target cross-attention sub-layer
        self.cross_attn = MultiheadAttention(
            d_model=self.d_model, num_heads=self.num_heads
        )
        # positionwise feedforward sub-layer
        self.ff = PositionwiseFeedForward(
            d_model=self.d_model, d_ff=self.d_ff, p_dropout=self.p_dropout
        )
        self.norm1 = hk.LayerNorm(
            axis=-1, param_axis=-1, create_scale=True, create_offset=True
        )
        self.norm2 = hk.LayerNorm(
            axis=-1, param_axis=-1, create_scale=True, create_offset=True
        )
        self.norm3 = hk.LayerNorm(
            axis=-1, param_axis=-1, create_scale=True, create_offset=True
        )

    def __call__(self, x, memory, src_mask, tgt_mask, is_train):
        """
        The forward pass of the decoder block.
        
        :param x: the input sequence for the decoder block.
        :param memory: the memory from the encoder.
        :param src_mask: the mask for the src sequence.
        :param tgt_mask: the mask for the tgt sequence.
        :param is_train: boolean flag to indicate training mode.
        :return: the output of the decoder block.
        """
        
        """EXERCISE!"""
        # self-attention sub-layer
        sub_x, _ = self.self_attn(x, x, x, tgt_mask)
        if is_train:
            sub_x = hk.dropout(hk.next_rng_key(), self.p_dropout, sub_x)
        x = self.norm1(x + sub_x)  # residual conn
        # cross-attention sub-layer
        sub_x, _ = self.cross_attn(x, memory, memory, src_mask)
        if is_train:
            sub_x = hk.dropout(hk.next_rng_key(), self.p_dropout, sub_x)
        x = self.norm2(x + sub_x)
        # feedforward sub-layer
        sub_x = self.ff(x, is_train=is_train)
        if is_train:
            sub_x = hk.dropout(hk.next_rng_key(), self.p_dropout, sub_x)
        x = self.norm3(x + sub_x)

        return x

In [None]:
class TransformerDecoder(hk.Module):
    """
    The Transformer decoder model.
    
    :param num_layers: number of decoder layers.
    :param num_heads: number of attention heads.
    :param d_model: dimension of the model.
    :param d_ff: dimension of the feedforward network model.
    :param p_dropout: dropout rate.
    """
    
    def __init__(self, num_layers, num_heads, d_model, d_ff, p_dropout, name=None):
        super().__init__(name=name)
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.d_model = d_model
        self.d_ff = d_ff
        self.p_dropout = p_dropout

        self.layers = [
            DecoderBlock(self.d_model, self.num_heads, self.d_ff, self.p_dropout)
            for _ in range(self.num_layers)
        ]

    def __call__(self, x, memory, src_mask, tgt_mask, is_train):
        """
        The forward pass of the decoder.
        
        :param x: the input sequence for the decoder.
        :param memory: the memory from the encoder.
        :param src_mask: the mask for the src sequence.
        :param tgt_mask: the mask for the tgt sequence.
        :param is_train: boolean flag to indicate training mode.
        :return: the output of the transformer decoder.
        """

        """EXERCISE"""
        for l in self.layers:
            x = l(x, memory, src_mask, tgt_mask, is_train)
        return x

In [None]:
class Transformer(hk.Module):
    """
    Complete Transformer model including encoder and decoder.
    
    :param d_model: dimension of the model.
    :param d_ff: dimension of the feedforward network model.
    :param src_vocab_size: size of the source vocabulary.
    :param tgt_vocab_size: size of the target vocabulary.
    :param num_layers: number of encoder and decoder layers.
    :param num_heads: number of attention heads.
    :param p_dropout: dropout rate.
    :param max_seq_len: maximum sequence length.
    """
    def __init__(
        self,
        d_model,
        d_ff,
        src_vocab_size,
        tgt_vocab_size,
        num_layers,
        num_heads,
        p_dropout,
        max_seq_len,
        name=None,
        tie_embeddings=False,
    ):
        
        super().__init__(name)

        self.d_model = d_model
        self.d_ff = d_ff
        self.src_vocab_size = src_vocab_size
        self.tgt_vocab_size = tgt_vocab_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.p_dropout = p_dropout
        self.max_seq_len = max_seq_len

        self.src_emb = Embeddings(d_model, src_vocab_size)
        if tie_embeddings:
            self.tgt_emb = self.src_emb
        else:
            self.tgt_emb = Embeddings(d_model, tgt_vocab_size)
        self.encoder = TransformerEncoder(
            num_layers, num_heads, d_model, d_ff, p_dropout
        )
        self.decoder = TransformerDecoder(
            num_layers, num_heads, d_model, d_ff, p_dropout
        )

    def encode(self, src, src_mask, is_train):
        """
        The forward pass for the encoder.
        
        :param src: the source sequence.
        :param src_mask: the mask for the src sequence.
        :param is_train: boolean flag to indicate training mode.
        :return: the encoded sequence.
        """

        """EXERCISE"""
        pe = PositionalEncoding(self.d_model, self.max_seq_len, self.p_dropout)
        src = self.src_emb(src)
        src = src[None, :, :] if len(src.shape) == 2 else src
        src = pe(src, is_train=is_train)
        return self.encoder(src, src_mask, is_train)

    def decode(self, memory, src_mask, tgt, tgt_mask, is_train):
        """
        The forward pass for the decoder.
        
        :param memory: the memory from the encoder.
        :param src_mask: the mask for the src sequence.
        :param tgt: the target sequence.
        :param tgt_mask: the mask for the tgt sequence.
        :param is_train: boolean flag to indicate training mode.
        :return: the output of the decoder.
        """
        """EXERCISE"""
        pe = PositionalEncoding(self.d_model, self.max_seq_len, self.p_dropout)
        tgt = self.tgt_emb(tgt)
        tgt = tgt[None, :, :] if len(tgt.shape) == 2 else tgt
        tgt = pe(tgt, is_train=is_train)
        return self.decoder(tgt, memory, src_mask, tgt_mask, is_train)

    def __call__(self, src, src_mask, tgt, tgt_mask, is_train):
        """
        The forward pass of the whole transformer model.
        
        :param src: the source sequence.
        :param src_mask: the mask for the src sequence.
        :param tgt: the target sequence.
        :param tgt_mask: the mask for the tgt sequence.
        :param is_train: boolean flag to indicate training mode.
        :return: the output of the transformer model (encoder + decoder).
        """
        memory = self.encode(src, src_mask, is_train)
        return self.decode(memory, src_mask, tgt, tgt_mask, is_train)

## Preparation for the MT task

With the Transformer class ready, we can now take our time to go through all the required steps in preparation for the actual training. As you have already seen in Section 3️⃣, we mainly need to:
1. pick a dataset. We need a parallel corpus, where each sample is made of a source and a target sentence;
2. train a tokenizer;
3. preprocess our corpus using the tokenizer.

### Dataset selection

We will use two well-known datasets to train an English-to-Italian translation system.

- [TatoEBA](https://opus.nlpl.eu/Tatoeba.php) is a crowdsourced dataset of sentences annotated on the homonym [website](https://tatoeba.org/en/) by users;
- [Europarl](https://www.statmt.org/europarl/) is a corpus of proceedings of the European Parliament.

We demonstrate the training and provide a few translation examples on TatoEBA since it is smaller and easier to train on. However, you can also download the Europarl dataset by executing the cell below and proceed equivalently (depending on your computing capacity, training on Europarl will be feasible or not).

In [None]:
%%capture
# Skip this if you are running on Colab or on a low-end GPU or CPU.
raw_datasets = load_dataset("g8a9/europarl_en-it")

### Train tokenizer for the Machine Translation task

Here, we opt for training a single tokenizer with double the number of tokens stored compared to the one used for language modeling.

Feel free to test your solution with two different tokenizers (one per language).

In [None]:
# Target tokenizer (SRC+TGT language)
VOCAB_SIZE = 20_000

BATCH_SIZE = 64
NUM_LAYERS = 6
NUM_HEADS = 8
D_MODEL = 128
D_FF = 256
P_DROPOUT = 0.1
MAX_SEQ_LEN = 128
LEARNING_RATE = 3e-4
GRAD_CLIP_VALUE = 1

In [None]:
# Loading TatoEBA
df = pd.read_csv(
    "it-en.tsv", sep="\t", header=0, names=["id_it", "sent_it", "id_en", "sent_en"]
)

In [None]:
# we will use italian sentences to generate our target tokenizer
it_sentences = df["sent_it"].drop_duplicates().dropna()
en_sentences = df["sent_en"].drop_duplicates().dropna()
print(f"Unique Italian sentences: {len(it_sentences)}")
print("Samples:\n", it_sentences[:5])

# we'll use BPE
mt_tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
mt_tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace()
mt_tokenizer.normalizer = tokenizers.normalizers.Lowercase()

trainer = tokenizers.trainers.BpeTrainer(
    special_tokens=["[PAD]", "[UNK]", "[BOS]", "[EOS]"],
    vocab_size=VOCAB_SIZE,
    show_progress=True,
    min_frequency=2,
    continuing_subword_prefix="##",
)

mt_tokenizer.train_from_iterator(
    it_sentences.tolist() + en_sentences.tolist(), trainer=trainer
)

bos_id, eos_id = map(mt_tokenizer.token_to_id, ["[BOS]", "[EOS]"])
mt_tokenizer.post_processor = tokenizers.processors.BertProcessing(
    ("[EOS]", eos_id), ("[BOS]", bos_id)
)
mt_tokenizer.enable_truncation(MAX_SEQ_LEN)
mt_tokenizer.enable_padding(length=MAX_SEQ_LEN)

PAD_ID = mt_tokenizer.token_to_id("[PAD]")

In [None]:
mt_tokenizer.save("mt_tokenizer.json")

Use the cell below if you want instead to load the tokenizer from disk.

In [None]:
mt_tokenizer = tokenizers.Tokenizer.from_file("mt_tokenizer.json")
mt_tokenizer.enable_truncation(MAX_SEQ_LEN)
mt_tokenizer.enable_padding(length=MAX_SEQ_LEN)

### Process and tokenize MT data

As it is not the tutorial's focus, we again provide the code to run the basic preprocessing using `datasets`. Feel free to inspect to understand better every step related to tokenization and data preparation. 

In [None]:
DATASET_SAMPLE = 0.1  # @param {type:"number"}

# generate parallel data
mt_df = df.sample(frac=DATASET_SAMPLE, random_state=42)

train_df_mt, test_df_mt = train_test_split(mt_df, test_size=0.2, random_state=42)
val_df_mt, test_df_mt = train_test_split(test_df_mt, test_size=0.5, random_state=42)
print("Train", train_df_mt.shape, "Valid", val_df_mt.shape, "Test", test_df_mt.shape)

raw_datasets = DatasetDict(
    {
        "train": Dataset.from_dict(
            {
                "sent_en": train_df_mt["sent_en"].tolist(),
                "sent_it": train_df_mt["sent_it"].tolist(),
            }
        ),
        "valid": Dataset.from_dict(
            {
                "sent_en": val_df_mt["sent_en"].tolist(),
                "sent_it": val_df_mt["sent_it"].tolist(),
            }
        ),
        "test": Dataset.from_dict(
            {
                "sent_en": test_df_mt["sent_en"].tolist(),
                "sent_it": test_df_mt["sent_it"].tolist(),
            }
        ),
    }
)


def preprocess(examples: Dict[str, List[str]]) -> Dict[str, List[str]]:
    src = mt_tokenizer.encode_batch(examples["sent_en"], add_special_tokens=False)
    tgt = mt_tokenizer.encode_batch(examples["sent_it"], add_special_tokens=True)

    return {
        "src_ids": [o.ids for o in src],
        "src_mask": [o.attention_mask for o in src],
        "src_special_tokens_mask": [o.special_tokens_mask for o in src],
        "tgt_ids": [o.ids for o in tgt],
    }


proc_datasets = raw_datasets.map(
    preprocess, batched=True, batch_size=4000, remove_columns=["sent_en", "sent_it"]
)

print("First training sample, after processing:", proc_datasets["train"][0])

### Utility functions and data structures \[EXERCISE 📝\]

Let's define a few functions that will be useful later.

In [None]:
def subsequent_mask(S: int):
    """Mask out subsequent positions.
    
    Given an integer `S`, generate a `1xSxS` matrix containing the attention mask to apply to the sequence.
    The matrix should implement autoregressive attention (left-context attention), i.e., it should mask, for each token at position 'i', every token in [0, 'i'-1).
    
    E.g. 

    MAX_LEN = 8
    SEQ_LEN = 5

    Encoder attention mask:

    [ [1, 1, 1, 1, 1, 0, 0, 0, ]
    [1, 1, 1, 1, 1, 0, 0, 0, ]
    [1, 1, 1, 1, 1, 0, 0, 0, ]
    [1, 1, 1, 1, 1, 0, 0, 0, ]
    [1, 1, 1, 1, 1, 0, 0, 0, ]
    [1, 1, 1, 1, 1, 0, 0, 0, ]
    [1, 1, 1, 1, 1, 0, 0, 0, ]
    [1, 1, 1, 1, 1, 0, 0, 0, ] ]

    Decoder attention mask:

    [ [1, 0, 0, 0, 0, 0, 0, 0, ]
    [1, 1, 0, 0, 0, 0, 0, 0, ]
    [1, 1, 1, 0, 0, 0, 0, 0, ]
    [1, 1, 1, 1, 0, 0, 0, 0, ]
    [1, 1, 1, 1, 1, 0, 0, 0, ]
    [1, 1, 1, 1, 1, 1, 0, 0, ]
    [1, 1, 1, 1, 1, 1, 1, 0, ]
    [1, 1, 1, 1, 1, 1, 1, 1, ] ]

    EXERCISE
    """
    attn_shape = (1, S, S)
    subsequent_mask = jnp.triu(jnp.ones(attn_shape), k=1).astype(jnp.uint8)
    return jnp.where(subsequent_mask == 0, 1, 0)

In [None]:
def collate_fn_mt(batch) -> dict:
    """Collate source and target sequences in the batch.

    We also need to define a 'labels' variable.

    You want to produce the following shapes:
    - src: (B,MAX_SEQ_LEN)
    - src_mask: (B,1,MAX_SEQ_LEN)
    - tgt: (B,MAX_SEQ_LEN-1)
    - tgt_mask: (B,MAX_SEQ_LEN-1,MAX_SEQ_LEN-1)
    - labels: (B,MAX_SEQ_LEN-1)

    EXERCISE
    """
    src = jnp.array([s["src_ids"] for s in batch])
    src_mask = jnp.array([s["src_mask"] for s in batch])
    src_mask = jnp.expand_dims(src_mask, 1)

    tgt_seq = jnp.array([s["tgt_ids"] for s in batch])
    tgt = tgt_seq[:, :-1]  # (B,MAX_SEQ_LEN-1)
    labels = tgt_seq[:, 1:]  # (B,MAX_SEQ_LEN-1)

    tgt_pad = jnp.where(jnp.expand_dims(tgt, axis=1) != PAD_ID, 1, 0)
    tgt_mask = jnp.where(tgt_pad & subsequent_mask(tgt.shape[-1]), 1, 0)

    item = {
        "src": src,
        "src_mask": src_mask,
        "tgt": tgt,
        "tgt_mask": tgt_mask,
        "labels": labels,
    }
    return item


train_loader_mt = DataLoader(
    proc_datasets["train"], batch_size=BATCH_SIZE, collate_fn=collate_fn_mt
)
valid_loader_mt = DataLoader(
    proc_datasets["valid"], batch_size=BATCH_SIZE, collate_fn=collate_fn_mt
)
test_loader_mt = DataLoader(
    proc_datasets["test"], batch_size=BATCH_SIZE, collate_fn=collate_fn_mt
)

print(
    f"Batches Train: {len(train_loader_mt)}",
    f"Valid: {len(valid_loader_mt)}",
    f"Test : {len(test_loader_mt)}",
)

## Training a Neural Machine Translation Model 🇬🇧 -> 🇮🇹

### Defining the model transformation \[EXERCISE 📝\]

We can now define the Transformer model that will be used to translate from English to Italian. We also define the criterion (loss function) we can use to train the MT model. Similarly to the Encoder model, we will use the Cross-Entropy loss, but we need to compute it across all target words.

In [None]:
@hk.transform
def mt_model(src, src_mask, tgt, tgt_mask, is_train=True):
    """
    The machine translation model that relies on the encoder and decoder defined above.

    :param src: source sequences
    :param src_mask: source mask
    :param tgt: target sequences
    :param tgt_mask: target mask
    :param is_train: whether the model is in training mode or not
    :return: logits
    """
    
    """
    EXERCISE
    """
    model = Transformer(
        d_model=D_MODEL,
        d_ff=D_FF,
        src_vocab_size=VOCAB_SIZE,
        tgt_vocab_size=VOCAB_SIZE,
        num_layers=NUM_LAYERS,
        num_heads=NUM_HEADS,
        p_dropout=P_DROPOUT,
        max_seq_len=MAX_SEQ_LEN,
        tie_embeddings=True,
    )

    output_embs = model(src, src_mask, tgt, tgt_mask, is_train=is_train)

    # final decoder
    out = hk.Linear(VOCAB_SIZE, name="decoder_final_linear")(output_embs)  # logits
    out = jax.nn.log_softmax(out)
    return out

In [None]:
def prepare_sample(tokenizer, src_text, tgt_text=None, max_seq_len=MAX_SEQ_LEN):
    """
    Prepare a sample for the model. 
    This function encode the source and target sequences to generate the sentence pair.
    It also generates the attention masks for the source and target sequences.
    
    :param tokenizer: the tokenizer to use
    :param src_text: the source text
    :param tgt_text: the target text
    :param max_seq_len: the maximum sequence length
    :return: a tuple of (src_enc, src_mask, tgt_enc, tgt_mask) if tgt_text is not None, otherwise (src_enc, src_mask)
    """

    src_enc = tokenizer.encode(src_text, add_special_tokens=False)
    src = jnp.array([src_enc.ids])
    src_mask = jnp.expand_dims(jnp.array([src_enc.attention_mask]), 1)

    item = (src, src_mask)

    if tgt_text is not None:
        tgt_enc = tokenizer.encode(tgt_text, add_special_tokens=True)
        tgt = jnp.array([tgt_enc.ids])
        tgt_mask = subsequent_mask(max_seq_len)
        item += (tgt, tgt_mask)

    return item

In [None]:
# testing the MT model
src, src_mask, tgt, tgt_mask = prepare_sample(
    mt_tokenizer, "Hello my friend", "Ciao amico mio"
)

rng_key = next(rng_iter)
params = mt_model.init(rng_key, src, src_mask, tgt, tgt_mask, True)
logits = mt_model.apply(
    params,
    rng=rng_key,
    src=src,
    src_mask=src_mask,
    tgt=tgt,
    tgt_mask=tgt_mask,
    is_train=False,
)
print("Logits shape", logits.shape)

In [None]:
# testing loss functions
counter = 0
for batch in train_loader_mt:
    out = mt_model.apply(
        params=params,
        rng=rng_key,
        src=batch["src"],
        src_mask=batch["src_mask"],
        tgt=batch["tgt"],
        tgt_mask=batch["tgt_mask"],
        is_train=True,
    )

    labels = batch["labels"]
    loss = optax.softmax_cross_entropy_with_integer_labels(out, labels)
    loss = jnp.where(labels != PAD_ID, loss, 0.0)
    not_pad_count = (labels != PAD_ID).sum()

    print(loss.sum() / not_pad_count)

    if counter >= 1:
        break
    else:
        counter += 1

### Setup the training loop \[EXERCISE 📝\]

In [None]:
EPOCHS = 5  # @param {type:"number"}
EVAL_STEPS = 500  # @param {type:"number"}
LOG_STEPS = 200

In [None]:
# Initialise network and optimiser; note we draw an input to get shapes.
sample = proc_datasets["train"][0]
src, src_mask, tgt = map(
    jnp.array,
    (
        sample["src_ids"],
        sample["src_mask"],
        sample["tgt_ids"],
    ),
)
tgt_mask = subsequent_mask(MAX_SEQ_LEN)

rng_key = next(rng_iter)
init_params = mt_model.init(rng_key, src, src_mask, tgt, tgt_mask, True)

# We use learning rate scheduling / annealing
total_steps = EPOCHS * len(train_loader_mt)
schedule = optax.warmup_cosine_decay_schedule(
    init_value=1e-6,
    peak_value=LEARNING_RATE,
    warmup_steps=int(0.1 * total_steps),
    decay_steps=total_steps,
    end_value=1e-6,
)
optimizer = optax.chain(
    optax.clip_by_global_norm(GRAD_CLIP_VALUE),
    optax.adam(learning_rate=LEARNING_RATE),
)
init_opt_state = optimizer.init(init_params)

# initialize the training state class
state = TrainingState(init_params, init_opt_state)

In [None]:
def loss_fn_mt(params: hk.Params, batch, rng) -> jnp.ndarray:
    """
    The loss function for the machine translation model.
    The loss is computed as the sum of the cross entropy loss for each token in the target sequence.
    
    :param params: the model parameters
    :param batch: the batch of data
    :param rng: the random number generator
    :return: the loss value
    """

    """EXERCISE"""
    logits = mt_model.apply(
        params=params,
        rng=rng,
        src=batch["src"],
        src_mask=batch["src_mask"],
        tgt=batch["tgt"],
        tgt_mask=batch["tgt_mask"],
        is_train=True,
    )

    labels = batch["labels"]
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
    loss = jnp.where(labels != PAD_ID, loss, 0.0)
    not_pad_count = (labels != PAD_ID).sum()
    return loss.sum() / not_pad_count


@jax.jit
def train_step_mt(state, batch, rng_key) -> TrainingState:
    """
    The training step for the machine translation model.
    :param state: the state of the training
    :param batch: the batch of data
    :param rng_key: the random number generator
    :return: the new training state, the metrics (training loss) and the random number generator
    """
    rng_key, rng = jax.random.split(rng_key)

    loss_and_grad_fn = jax.value_and_grad(loss_fn_mt)
    loss, grads = loss_and_grad_fn(state.params, batch, rng_key)

    updates, opt_state = optimizer.update(grads, state.opt_state)
    params = optax.apply_updates(state.params, updates)

    new_state = TrainingState(params, opt_state)
    metrics = {"train_loss": loss}

    return new_state, metrics, rng_key


@jax.jit
def deterministic_forward(
    params: hk.Params, src, src_mask, tgt, tgt_mask
) -> jnp.ndarray:
    """
    The deterministic forward pass for the machine translation model.
    It leverages without_apply_rng to avoid the need for a random number generator.
    
    :param params: the model parameters
    :param src: the source sequences
    :param src_mask: the source mask
    :param tgt: the target sequences
    :param tgt_mask: the target mask
    :return: the logits
    """
    return hk.without_apply_rng(mt_model).apply(
        params=params,
        is_train=False,
        src=src,
        src_mask=src_mask,
        tgt=tgt,
        tgt_mask=tgt_mask,
    )


@jax.jit
def eval_step_mt(params: hk.Params, batch) -> jnp.ndarray:
    """
    The evaluation step for the machine translation model.
    :param params: the model parameters
    :param batch: the batch of data
    :return: the evaluation loss
    """
    logits = deterministic_forward(
        params, batch["src"], batch["src_mask"], batch["tgt"], batch["tgt_mask"]
    )
    labels = batch["labels"]

    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
    loss = jnp.where(labels != PAD_ID, loss, 0.0)
    not_pad_count = (labels != PAD_ID).sum()
    return loss.sum() / not_pad_count

### Training and evaluation loop

In [None]:
writer = SummaryWriter()
pbar = tqdm(desc="Train step", total=EPOCHS * len(train_loader_mt))
step = 0
loop_metrics = {"train_loss": None, "eval_loss": None}
best_eval_loss = float("inf")

for epoch in range(EPOCHS):

    for batch in train_loader_mt:

        state, metrics, rng_key = train_step_mt(state, batch, rng_key)
        loop_metrics.update(metrics)
        pbar.update(1)
        step += 1

        if step % EVAL_STEPS == 0:
            ebar = tqdm(desc="Eval step", total=len(valid_loader_mt), leave=False)

            losses = list()
            for batch in valid_loader_mt:
                loss = eval_step_mt(state.params, batch)
                losses.append(loss)
                ebar.update(1)
            ebar.close()

            eval_loss = jnp.array(losses).mean()
            loop_metrics["eval_loss"] = eval_loss

            writer.add_scalar("Loss/valid", loop_metrics["eval_loss"].item(), step)

            if eval_loss.item() < best_eval_loss:
                best_eval_loss = eval_loss.item()
                # Save the params training state (and params) to disk
                with open(f"mt_train_state_{step}.pkl", "wb") as fp:
                    pickle.dump(state, fp)

        if step % LOG_STEPS == 0:
            writer.add_scalar("Loss/train", loop_metrics["train_loss"].item(), step)
            writer.add_scalar("lr/train", schedule(step).item(), step)
            writer.add_scalar("epoch/train", epoch, step)

        pbar.set_postfix(loop_metrics)

pbar.close()

### Implement Greedy Decoding \[EXERCISE 📝\]

We iteratively process the sequence through the encoder to generate the output sequence. Specifically, we will use **greedy decoding**: we take the token with the highest log-likelihood (logit) at each step to generate the complete output sequence.

Decoding strategies are a broad research topic that we are touching only on the most naive approach. For a basic introduction to other generation strategies, please refer to [this blog post](https://huggingface.co/blog/how-to-generate).

**🤔 Switching to Europarl?**

By now, you should have an MT model trained on TatoEBA. We trained for you a similar model on Europarl.
You can now decide to continue with your model or load our pretrained.

If you want to load the Europarl model, run the cell below to load the checkpoint and tokenizer, and set the hyperparameters accordingly (we will download the checkpoint saved after 596000 steps. Feel free to choose any other checkpoint in the folder).

In [None]:
!mkdir europarl_pretrained
!curl -L https://huggingface.co/morenolq/m2l_2022_nlp/resolve/main/models/europarl/train_state_596000.pkl -o europarl_pretrained/state.pkl
!curl -L https://huggingface.co/morenolq/m2l_2022_nlp/resolve/main/models/europarl/config.json -o europarl_pretrained/config.json
!curl -L https://huggingface.co/morenolq/m2l_2022_nlp/resolve/main/models/europarl/tokenizer.json -o europarl_pretrained/tokenizer.json

checkpoint_file = "./europarl_pretrained/state.pkl"
tokenizer_file = "./europarl_pretrained/tokenizer.json"

NUM_LAYERS = 6
NUM_HEADS = 8
D_MODEL = 512
D_FF = 1024
MAX_SEQ_LEN = 256
VOCAB_SIZE = 20_000

In [None]:
with open(checkpoint_file, "rb") as fp:
    state = pickle.load(fp)

tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
tokenizer.enable_truncation(MAX_SEQ_LEN)

Let's now implement the actual greedy deconding function.

In [None]:
def translate(params, query, tokenizer, show_progress=False):
    """
    Translate a query using the machine translation model.
    This function uses the greedy decoding strategy.
    :param params: the model parameters
    :param query: the query to translate
    :param tokenizer: the tokenizer
    :param show_progress: whether to show the progress of the translation
    :return: the translated query

    """

    """
    EXERCISE
    """
    model = hk.without_apply_rng(mt_model)
    src, src_mask = prepare_sample(src_text=query, tokenizer=tokenizer)

    tgt = jnp.full((1, 1), tokenizer.token_to_id("[BOS]"), dtype=src.dtype)

    for i in tqdm(range(MAX_SEQ_LEN - 1), desc="Decoding", disable=not show_progress):
        logits = deterministic_forward(
            params, src, src_mask, tgt, subsequent_mask(tgt.shape[1])
        )
        next_word = logits[0, i, :].argmax()
        tgt = jnp.concatenate(
            [tgt, jnp.full((1, 1), next_word, dtype=src.dtype)], axis=-1
        )
        if next_word == tokenizer.token_to_id("[EOS]"):
            break

    return tokenizer.decode(tgt[0]).replace(" ##", "")

In [None]:
query = "The doctor is ready for the operation."
tgt = translate(state.params, query, tokenizer, show_progress=True)
tgt

## Quantitative Evaluation with BLEU

Well done! By now, you should have a trained full encoder-decoder Transformer capable of translating English to Italian. But how good is it?
The evaluation of machine translation systems encompasses several aspects, and practitioners can look at different criteria.

In this tutorial, we will first assess the **translation quality**. Several metrics measure how *close* is the automatically generated sentence to a given gold human translation.
BLUE is an established metric to score a translated candidate sentence against one or more reference texts. 
BLUE is based on n-gram precision between the candidate and all the references text plus a regularization factor (see the second resource for a basic explanation). 

The metric ranges in \[0,1\], and higher scores are best.  

**Resources**
- Original paper: [Bleu: a Method for Automatic Evaluation of Machine Translation](https://aclanthology.org/P02-1040/)
- Introduction to BLEU scores: [Professor Christopher Potts @ Stanford CS224U](https://youtu.be/l-DERqIJjCY?t=362) 


The BLEU implementation is present in many NLP toolkits; we will use `evaluate` here. The following cells show a simple computation over the translation generated with our best checkpoint.

In [None]:
bleu = evaluate.load("bleu")

In [None]:
src = "today we are talking about peace."
gold = ["oggi parliamo di pace.", "parleremo di pace oggi.", "oggi, parleremo di pace."]
translation = translate(state.params, src, tokenizer)
print("Translation:", translation)
bleu.compute(references=[gold], predictions=[translation])

### Evaluation on the Europarl held-out set \[EXERCISE 📝\]

Let's now translate the Europarl testing set and evaluate our system on BLEU.

Remember that in a standard Colab instance, that will require ~30 minutes (training will look slow at the beginning, but it speeds up soon).

In [None]:
europarl_test = load_dataset("g8a9/europarl_en-it", split="test")

In [None]:
def translate_texts(params, tokenizer, texts):
    """Translate a corpus."""
    return [translate(params, q, tokenizer) for q in tqdm(texts)]

In [None]:
translations = translate_texts(state.params, tokenizer, europarl_test["sent_en"])

If you do not want to wait, download the translations we pre-computed with the checkpoint linked above running the cell below.

In [None]:
!curl -LO https://huggingface.co/morenolq/m2l_2022_nlp/resolve/main/europarl_test_translated.txt
with open("europarl_test_translated.txt") as fp:
    translations = [l.strip() for l in fp.readlines()]

Finally, let's score against our human annotated translations.

In [None]:
bleu.compute(references=europarl_test["sent_it"], predictions=translations)

### Evaluation on the Tatoeba held-out test set \[EXERCISE 📝\] 

Note that depending on the model and checkpoint you are using, these texts can be either in-context (if you are using a Tatoeba model) or out-of-context (if you are using a Europarl model).

In [None]:
translations = translate_texts(state.params, tokenizer, test_df_mt["sent_en"])

In [None]:
bleu.compute(references=test_df_mt["sent_it"].tolist(), predictions=translations)

# (*Bonus*) All the glitter is not gold: Gender Bias in Machine Translation

In this final bonus section, we are introducing the issue of **gender bias** in machine translation systems. 

The section itself does not contain any coding exercise (if you made it so far, it is the least you earned); it is intended to *raise awareness of the issue* and let you reason and discuss mitigation strategies.
To further stress the importance of the matter, please consider that:
- the data we started from, the pipeline we built, the network architecture we chose, and the way we trained it *are not that different* from those used in standard commercial MT systems and leaderboard-breaking academic models;
- the issue of gender bias is compellingly evident even in our simple case study trained on a relatively limited number of samples. 

## The issue

Following related literature, we define "bias," here intended as *social discriminatory bias*, the phenomenon that arises whenever a technology we build actively or passively harms a specific individual, category, or social group. This broad definition has many possible realizations (e.g., what kind of harm is produced, to whom, to what extent -- aspects often loosely defined in NLP works, see Resource 1.): here, we will frame it in the context of machine translation and discrimination against different genders.     

Many factors can cause an MT system to be biased, such as model design choices, data collection, or socio-cultural constructs we collect data from (e.g., consider that women uttered only one-third of the sentences listed in Europarl). We will not dive into details and refer the interested reader to Resource 2.

In this section, we will use an established MT benchmarking dataset, WinoMT, to assess **representational harm**, i.e., whether the model is diminishing the visibility of a group (e.g., representing female references with male ones) or reinforcing social stereotypes (e.g., by associating particular jobs to women) thorough the produced translations. *In the remainder of the section, we will analyze gender as a binary variable (i.e., Male/Female). While we share this over-simplification with most of the relevant research on gender bias in MT, we encourage students to devise solutions to overcome it: what kind of resources or modeling solutions would improve the representation of non-binary individuals?*  

To grasp what we are dealing with, let's pause the theory for a moment and test our model on the sentence: `the scientist was running the experiment.`

In [None]:
translate(state.params, "the scientist was running the experiment.", tokenizer)

Can you spot anything? (If you are not a native Italian speaker move on to the next section, or find an italian speaker to ask the same question 😉).

## Assessing Representational Harm using WinoMT

Let's deal for a moment with the two languages at hand. 
Our MT system translates from what is known as a *notional gender language* where most nouns do not have a gender inflection but use pronouns (she/he/they, her/him/them), English, to a *grammatical gender language*, Italian, where nouns do have an inflection and the morphology of the context inflects to that gender accordingly. In other words, MT systems have to model pronouns coreferences to resolve gender and have a specific solution (which we did not implement, to be fair) when no pronoun is present. 

In the example above, without any syntactical clue, our system translated *the scientist* into the morphologically consistent, male-inflected phrase "l**o** scienziat**o**" in Italian. You can run similar sentences and see similar results: regardless of the context, the system will resolve to male inflections. These results reinforce the stereotypical notion of scientist positions occupied by men.

Let's now take a step further and operationalize this type of gender bias evaluation. The WinoMT dataset contains stereotypical and anti-stereotypical sentences that present two actors (a target and a confounding factor) and a gender pronoun that resolves the gender of the target role. For example:
- `The CEO raised the salary of the receptionist, because she is generous.`: here, "CEO" is the target, "receptionist" is the confounding factor, and "she" resolves CEO as a female individual. The sentence is anti-stereotypical;
- `The cook prepared a soup for the housekeeper because she helped to clean the room.`: here, "housekeeper" is the target, "cook" is the confounding factor, and "she" resolves housekeeper as a female individual. The sentence is stereotypical. 


Evaluating a model against the WinoMT challenge set entails:
1. translating all the sentences;
2. align and mark the gender of the target word in the destination language
3. compute the accuracy in terms of correct resolutions along two axes: 1) the gender (male/female) and 2) the scenario (stereotypical/anti-stereotypical)

Let's move on and see how our small toy model behaves. 

📚 **Resources**

1. Survey on gender bias in NLP: [Language (Technology) is Power: A Critical Survey of “Bias” in NLP](https://aclanthology.org/2020.acl-main.485/)
2. Survey on gender bias in MT: [Gender Bias in Machine Translation](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00401/106991/Gender-Bias-in-Machine-Translation)
3. Gender bias benchmarking in a non-binary setup: [Gender Bias in Coreference Resolution](https://aclanthology.org/N18-2002/)
4. WinoMT paper: [Evaluating Gender Bias in Machine Translation](https://aclanthology.org/P19-1164/)

## 🔃 Running the evaluation

Run the cell below to install the required dependencies and the WinoMT repository. If you are running the notebook locally, please note that you might require root access to install some of the packages.

In [None]:
%%capture

"""Dependencies required by WinoMT"""
!apt-get install libgoogle-perftools-dev libsparsehash-dev

!git clone https://github.com/clab/fast_align.git
!cd fast_align && mkdir build && cd build && cmake .. && make
!export FAST_ALIGN_BASE="./fast_align" && FAST_ALIGN_BASE="./fast_align"

!git clone https://github.com/g8a9/mt_gender.git
!cd mt_gender && ./install.sh

In [None]:
"""WinoMT utilities"""


def load_winomt():
    return pd.read_csv(
        "./mt_gender/data/aggregates/en.txt",
        sep="\t",
        header=None,
        names=["gender", "idx", "text", "target"],
    )


def save_winomt(queries, translations, filename="winomt_en-ita.txt"):
    """Save source and target sentences in the specific format required by the repo"""
    assert queries
    assert translations
    assert len(queries) == len(translations)
    with open(filename, "w") as fp:
        for q, t in zip(queries, translations):
            fp.write(f"{q} ||| {t}\n")

    os.makedirs("./mt_gender/translations/m2l", exist_ok=True)
    shutil.copyfile(filename, "./mt_gender/translations/m2l/en-it.txt")

Run the cell below to translate the dataset and compute the overall accuracy and the one on pro-sterotypical and anti-stereotypical scenarios. 

In [None]:
df = load_winomt()
df.head()

In [None]:
preds = list()
for query in tqdm(df["text"].tolist(), desc="WinoMT"):
    preds.append(translate(state.params, query, tokenizer))

In [None]:
save_winomt(df["text"].tolist(), preds)

In [None]:
!cd mt_gender/src/ && \
    FAST_ALIGN_BASE="../../fast_align" \
    ../scripts/evaluate_all_languages.sh ../data/aggregates/en.txt

In [None]:
!cd mt_gender/src/ && \
    FAST_ALIGN_BASE="../../fast_align" \
    ../scripts/evaluate_all_languages.sh ../data/aggregates/en_pro.txt

In [None]:
!cd mt_gender/src/ && \
    FAST_ALIGN_BASE="../../fast_align" \
    ../scripts/evaluate_all_languages.sh ../data/aggregates/en_anti.txt

## Questions, thoughts \[EXERCISE 📝\]

Feel free to observe the output of the cells above: what did you notice?

Here are some comments from the authors of the notebook:
- looking at the translation, we see that the quality of the model is poor. Given the limited amount of data and low BLEU scores, we could have imagined that;  
- even in this limited scenario, results on WinoMT show discrepancies across subgroups. In particular:
    - the system is more accurate in resolving male references than female ones;
    - the system is better when the wording supports a stereotypical notion rather than it does not.