# ![](https://img.shields.io/badge/language-cb2888) Train Transformer

In this guide,  a simple transformer model is trained from scratch to perform machine translation as detailed in [Attention is all you need](https://arxiv.org/abs/1706.03762) from Arabic to English. But you are free to choose any language pair from the configuration.

## Imports

In [None]:
!pip install git+https://github.com/ASEM000/serket --quiet
!pip install tokenizers --quiet
!pip install datasets --quiet
!pip install ml_collections --quiet
!pip install tqdm --quiet
!pip install more-itertools --quiet

In [1]:
# numerics related
import jax
import jax.random as jr
import jax.numpy as jnp
import serket as sk
import optax
import numpy as np

# dataset related
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import WordLevelTrainer
from more_itertools import chunked

# config related
from ml_collections import ConfigDict

# typing related
from typing_extensions import Annotated, TypedDict
from typing import TypeVar, Generic

# other
import functools as ft
from tqdm.notebook import tqdm

T = TypeVar("T")


class Batched(Generic[T]):
    # a type marker to indicate batch dimension
    ...

## Configs

In [2]:
config = ConfigDict()

# dataset
config.dataset = ConfigDict()
config.dataset.name = "opus100"
config.dataset.src_lang = "ar"
config.dataset.tgt_lang = "en"


config.tokenizer = ConfigDict()
config.tokenizer.unk_token = "[UNK]"
config.tokenizer.special_tokens = ["[UNK]", "[PAD]", "[SOS]", "[EOS]"]
config.tokenizer.min_frequency = 2

config.model = ConfigDict()
config.model.d_model = 512
config.model.seq_len = 100
config.model.num_heads = 8
config.model.num_blocks = 2
config.model.drop_rate = 0.1
config.model.seed = 0

config.train = ConfigDict()
config.train.epochs = 1
config.train.seed = 1
config.train.batch_size = 32

config.optim = ConfigDict()
config.optim.lr = 1e-4

## Layers

In [3]:
embed_init = jax.nn.initializers.normal()


class InputEmbedding(sk.TreeClass):
    """Lookup table for input tokens"""

    def __init__(self, vocab_size: int, d_model: int, *, key: jax.Array):
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.embed = embed_init(key, (vocab_size, d_model))

    def __call__(self, input: jax.Array):
        return jnp.take(self.embed, input, axis=0)


class PositionalEmbedding(sk.TreeClass):
    def __init__(self, d_model: int, seq_len: int):
        self.d_model = d_model
        self.seq_len = seq_len

        i = jnp.arange(0, seq_len)[:, None]
        j = jnp.arange(0, d_model, 2)[None, :]
        angle = i * 1e-4 ** (j / d_model)

        # interleave sin and cos
        self.pos_embed = (
            jnp.zeros([seq_len, d_model])
            .at[:, ::2]
            .set(jnp.sin(angle))
            .at[:, 1::2]
            .set(jnp.cos(angle))
        )

    def __call__(self, input: jax.Array) -> jax.Array:
        assert len(input) <= self.seq_len, f"{len(input)=} > {self.seq_len=}"
        input *= jnp.sqrt(self.d_model)
        return input + jax.lax.stop_gradient(self.pos_embed[: len(input)])


class LayerNorm(sk.TreeClass):
    """Normalize over the last dimension"""

    def __init__(self, in_features: int):
        self.in_features = in_features
        self.weight = jnp.ones(in_features)
        self.bias = jnp.zeros(in_features)

    def __call__(self, input: jax.Array):
        mean = jnp.mean(input, axis=-1, keepdims=True)
        var = jnp.var(input, axis=-1, keepdims=True)
        return self.weight * ((input - mean) / jnp.sqrt(var + 1e-5)) + self.bias


class MHA(sk.TreeClass):
    def __init__(self, d_model: int, num_heads: int, *, key: jax.Array):
        self.num_heads = num_heads
        assert d_model % num_heads == 0, f"{d_model=} not divisible by {num_heads=}"
        k1, k2, k3, k4 = jr.split(key, 4)
        self.q_projection = sk.nn.Linear(d_model, d_model, bias_init=None, key=k1)
        self.k_projection = sk.nn.Linear(d_model, d_model, bias_init=None, key=k2)
        self.v_projection = sk.nn.Linear(d_model, d_model, bias_init=None, key=k3)
        self.o_projection = sk.nn.Linear(d_model, d_model, bias_init=None, key=k4)

    @staticmethod
    def split_heads(input: jax.Array, num_heads: int):
        return input.reshape(*input.shape[:-1], num_heads, input.shape[-1] // num_heads)

    @staticmethod
    def merge_heads(input: jax.Array):
        return input.reshape(*input.shape[:-2], -1)

    def __call__(
        self,
        q_input: jax.Array,
        k_input: jax.Array,
        v_input: jax.Array,
        *,
        mask: jax.Array | None = None,
    ) -> jax.Array:
        q_heads = self.split_heads(self.q_projection(q_input), self.num_heads)
        k_heads = self.split_heads(self.k_projection(k_input), self.num_heads)
        v_heads = self.split_heads(self.v_projection(v_input), self.num_heads)

        logits = jnp.einsum("...qhd,...khd->...hqk", q_heads, k_heads)
        logits = logits / jnp.sqrt(q_heads.shape[-1])
        min_num = jnp.finfo(q_input.dtype).min
        logits = logits if mask is None else jnp.where(mask, logits, min_num)

        attention_weight = jax.nn.softmax(logits)
        attention = jnp.einsum("...hqk,...khd->...qhd", attention_weight, v_heads)
        attention = self.merge_heads(attention)

        return self.o_projection(attention)


class FeedForward(sk.TreeClass):
    def __init__(self, d_model: int, *, key: jax.Array):
        k1, k2 = jr.split(key)
        self.linear1 = sk.nn.Linear(d_model, 4 * d_model, key=k1)
        self.linear2 = sk.nn.Linear(4 * d_model, d_model, key=k2)

    def __call__(self, input: jax.Array) -> jax.Array:
        input = self.linear1(input)
        input = jax.nn.relu(input)
        input = self.linear2(input)
        return input


class EncoderBlock(sk.TreeClass):
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        drop_rate: float,
        *,
        key: jax.Array,
    ):
        self.d_model = d_model
        self.num_heads = num_heads
        self.drop_rate = drop_rate

        sa_key, ff_key = jr.split(key)

        self.sa = MHA(d_model, num_heads, key=sa_key)
        self.sa_dropout = sk.nn.Dropout(drop_rate)
        self.sa_norm = LayerNorm(d_model)

        self.ff = FeedForward(d_model, key=ff_key)
        self.ff_dropout = sk.nn.Dropout(drop_rate)
        self.ff_norm = LayerNorm(d_model)

    def __call__(
        self,
        input: jax.Array,
        *,
        key: jax.Array,
        mask: jax.Array | None = None,
    ):
        sa_dropout_key, ff_dropout_key = jr.split(key)
        # add and norm as done in the original paper
        sa_out = self.sa(input, input, input, mask=mask)
        sa_out = self.sa_dropout(sa_out, key=sa_dropout_key)
        input = sa_out + input
        input = self.sa_norm(input)

        ff_out = self.ff(input)
        ff_out = self.ff_dropout(ff_out, key=ff_dropout_key)
        input = ff_out + input
        input = self.ff_norm(input)

        return input


class DecoderBlock(sk.TreeClass):
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        drop_rate: float,
        *,
        key: jax.Array,
    ):
        sa_key, ca_key, ff_key = jr.split(key, 3)

        self.sa = MHA(d_model, num_heads, key=sa_key)
        self.sa_dropout = sk.nn.Dropout(drop_rate)
        self.sa_norm = LayerNorm(d_model)

        self.ca = MHA(d_model, num_heads, key=ca_key)
        self.ca_dropout = sk.nn.Dropout(drop_rate)
        self.ca_norm = LayerNorm(d_model)

        self.ff = FeedForward(d_model, key=ff_key)
        self.ff_dropout = sk.nn.Dropout(drop_rate)
        self.ff_norm = LayerNorm(d_model)

    def __call__(
        self,
        dec_input: jax.Array,
        enc_output: jax.Array,
        *,
        key: jax.Array,
        dec_mask: jax.Array | None = None,
        enc_mask: jax.Array | None = None,
    ):
        sa_dropout_key, ca_dropout_key, ff_dropout_key = jr.split(key, 3)

        sa_out = self.sa(dec_input, dec_input, dec_input, mask=dec_mask)
        sa_out = self.sa_dropout(sa_out, key=sa_dropout_key)
        dec_input = dec_input + sa_out
        dec_input = self.sa_norm(dec_input)

        ca_out = self.ca(dec_input, enc_output, enc_output, mask=enc_mask)
        ca_out = self.ca_dropout(ca_out, key=ca_dropout_key)
        dec_input = dec_input + ca_out
        dec_input = self.ca_dropout(dec_input, key=ca_dropout_key)

        ff_out = self.ff(dec_input)
        ff_out = self.ff_dropout(ff_out, key=ff_dropout_key)
        dec_input = dec_input + ff_out
        dec_input = self.ff_dropout(dec_input, key=ff_dropout_key)

        return dec_input


class Transformer(sk.TreeClass):
    def __init__(
        self,
        src_vocab_size: int,
        tgt_vocab_size: int,
        d_model: int,
        num_heads: int,
        num_blocks: int,
        drop_rate: float,
        seq_len: int,
        *,
        key: jax.Array,
    ):
        embed_key, encoder_key, decoder_key, fc_key = jr.split(key, 4)

        # embedding
        enc_embed_key, dec_embed_key = jr.split(embed_key)
        self.enc_embed = InputEmbedding(src_vocab_size, d_model, key=enc_embed_key)
        self.dec_embed = InputEmbedding(tgt_vocab_size, d_model, key=dec_embed_key)
        self.pos_embed = PositionalEmbedding(d_model, seq_len)
        self.embed_dropout = sk.nn.Dropout(drop_rate)

        # encoder
        self.encoders = tuple(
            EncoderBlock(d_model, num_heads, drop_rate, key=ki)
            for ki in jr.split(encoder_key, num_blocks)
        )

        # decoder
        self.decoders = tuple(
            DecoderBlock(d_model, num_heads, drop_rate, key=ki)
            for ki in jr.split(decoder_key, num_blocks)
        )

        # out projection
        self.fc = sk.nn.Linear(d_model, tgt_vocab_size, key=fc_key)

    def encode(
        self,
        enc_input: jax.Array,
        mask: jax.Array | None = None,
        *,
        key: jax.Array,
    ):
        embed_dropout_key, *enc_keys = jr.split(key, len(self.encoders) + 1)
        enc_input = self.enc_embed(enc_input)
        enc_input = self.pos_embed(enc_input)
        enc_input = self.embed_dropout(enc_input, key=embed_dropout_key)

        for encoder, key in zip(self.encoders, enc_keys):
            enc_input = encoder(enc_input, key=key, mask=mask)
        return enc_input

    def decode(
        self,
        enc_output: jax.Array,
        dec_input: jax.Array,
        enc_mask: jax.Array | None = None,
        dec_mask: jax.Array | None = None,
        *,
        key: jax.Array,
    ):
        embed_dropout_key, *dec_keys = jr.split(key, len(self.decoders) + 1)

        dec_input = self.dec_embed(dec_input)
        dec_input = self.pos_embed(dec_input)
        dec_input = self.embed_dropout(dec_input, key=embed_dropout_key)

        for decoder, key in zip(self.decoders, dec_keys):
            dec_input = decoder(
                dec_input,
                enc_output,
                enc_mask=enc_mask,
                dec_mask=dec_mask,
                key=key,
            )

        return dec_input

    def project(self, input: jax.Array) -> jax.Array:
        return self.fc(input)

## Data preparation

### Load the dataset

In [4]:
train_dataset, test_dataset = load_dataset(
    config.dataset.name,
    name=f"{config.dataset.src_lang}-{config.dataset.tgt_lang}",
    split=["train", "test"],
    trust_remote_code=True,
)

### Tokenizer

In [5]:
def build_tokenizer(config: ConfigDict, dataset, lang: str) -> Tokenizer:
    tokenizer = Tokenizer(WordLevel(unk_token=config.unk_token))
    tokenizer.pre_tokenizer = Whitespace()
    trainer = WordLevelTrainer(
        special_tokens=config.special_tokens,
        min_frequency=config.min_frequency,
    )

    def iterator():
        for item in dataset:
            yield item["translation"][lang]

    tokenizer.train_from_iterator(iterator(), trainer=trainer)
    return tokenizer


src_tokenizer = build_tokenizer(
    config.tokenizer,
    dataset=train_dataset,
    lang=config.dataset.src_lang,
)
tgt_tokenizer = build_tokenizer(
    config.tokenizer,
    dataset=train_dataset,
    lang=config.dataset.tgt_lang,
)

### Dataloader

In [6]:
class DataItem(TypedDict):
    # the source language token ids
    enc_input: Annotated[jax.Array, "seq_len"]
    # the target language token ids starting with sos
    dec_input: Annotated[jax.Array, "seq_len"]
    # the target language token ids ending with eos
    dec_output: Annotated[jax.Array, "seq_len"]
    # pad skipping mask for the encoder input
    enc_mask: Annotated[jax.Array, "1, 1, seq_len"]
    # pad skipping + causal mask for the decoder input
    dec_mask: Annotated[jax.Array, "1, seq_len, seq_len"]
    # the source language text
    src_text: Annotated[np.ndarray, ""]
    # the target language text
    tgt_text: Annotated[np.ndarray, ""]


def generate_dataloader(
    dataset,
    *,
    src_tokenizer: Tokenizer,
    tgt_tokenizer: Tokenizer,
    src_lang: str,
    tgt_lang: str,
    batch_size: int,
    seq_len: int,
):
    sos_id = np.array([tgt_tokenizer.token_to_id("[SOS]")])
    eos_id = np.array([tgt_tokenizer.token_to_id("[EOS]")])
    pad_id = np.array([tgt_tokenizer.token_to_id("[PAD]")])

    def get_dataitem(index: int) -> DataItem:
        src_tgt_item = dataset[index]

        # encoder
        src_text = src_tgt_item["translation"][src_lang]
        # text -> jax array of integer indices
        # subtract 2 for sos and eos at the beginning and end of the sequence
        src_ids = jnp.array(src_tokenizer.encode(src_text).ids)[: seq_len - 2]
        # fill the rest with pad tokens if the sequence is shorter than seq_len
        enc_pad = pad_id.repeat(max(0, seq_len - len(src_ids) - 2))
        # add sos and eos tokens to the beginning and end of the sequence
        enc_input = jnp.concatenate([sos_id, src_ids, eos_id, enc_pad])
        # mask out the pad tokens for the encoder input
        enc_mask = jnp.where(enc_input == pad_id, False, True)[None, None, :]

        # decoder
        tgt_text = src_tgt_item["translation"][tgt_lang]
        # text -> jax array of integer indices
        # subtract 1 for sos at the beginning of the sequence for the decoder input
        # or eos at the end of the sequence for the decoder output
        tgt_ids = jnp.array(tgt_tokenizer.encode(tgt_text).ids)[: seq_len - 1]
        # fill the rest with pad tokens
        dec_pad = pad_id.repeat(max(0, seq_len - len(tgt_ids) - 1))
        # [<SOS> ...] decoder input
        dec_input = jnp.concatenate([sos_id, tgt_ids, dec_pad])
        # [... <EOS>] decoder output
        dec_output = jnp.concatenate([tgt_ids, eos_id, dec_pad])
        # causal mask for the decoder self-attention layer
        dec_mask = jnp.tril(jnp.ones([1, seq_len, seq_len])).astype(bool)
        # also mask out the pad tokens for the decoder input
        dec_mask &= jnp.where(dec_input == pad_id, False, True)[None, None, :]

        return dict(
            enc_input=enc_input,
            dec_input=dec_input,
            dec_output=dec_output,
            enc_mask=enc_mask,
            dec_mask=dec_mask,
            src_text=src_text,
            tgt_text=tgt_text,
        )

    def get_batch(indices: list[int]) -> Batched[DataItem]:
        batch = [get_dataitem(i) for i in indices]
        return jax.tree_map(lambda *args: np.stack(args), *batch)

    indices = jnp.arange(len(dataset))

    def _dataloader(key: jax.Array):
        indices_: list[int] = jax.random.permutation(key, indices).tolist()
        for batch_indices in chunked(indices_, batch_size):
            yield get_batch(batch_indices)

    return _dataloader


train_dl = generate_dataloader(
    train_dataset,
    src_tokenizer=src_tokenizer,
    tgt_tokenizer=tgt_tokenizer,
    src_lang=config.dataset.src_lang,
    tgt_lang=config.dataset.tgt_lang,
    batch_size=config.train.batch_size,
    seq_len=config.model.seq_len,
)

## Train

### Model

In [7]:
optim = optax.adam(config.optim.lr)
train_key = jr.key(config.train.seed)
PAD_ID = jnp.array([tgt_tokenizer.token_to_id("[PAD]")])
net = Transformer(
    src_vocab_size=src_tokenizer.get_vocab_size(),
    tgt_vocab_size=tgt_tokenizer.get_vocab_size(),
    seq_len=config.model.seq_len,
    d_model=config.model.d_model,
    num_heads=config.model.num_heads,
    num_blocks=config.model.num_blocks,
    drop_rate=config.model.drop_rate,
    key=jr.key(config.model.seed),
)
net = sk.tree_mask(net)
optim_state = optim.init(net)

### Load the pretrained weights (optional)

In [8]:
import pickle

# download from google drive
# https://drive.google.com/file/d/1h9fV9xlLZFr2XDWzccreelB9QneJNad_/view?usp=sharing

# d_model = 512
# seq_len = 100
# num_heads = 8
# num_blocks = 2
# drop_rate = 0.1
# seed = 0

treedef = jax.tree_util.tree_structure(net)
with open("transformer_weights.pickle", "rb") as file:
    leaves = pickle.load(file)
net = jax.tree_util.tree_unflatten(treedef, leaves)

### Train step

In [9]:
def softmax_cross_entropy(logits, idx):
    logprop = jax.nn.log_softmax(logits)
    logprop = jnp.take_along_axis(logprop, idx[..., None], axis=-1)[..., 0]
    # ignore pad tokens
    logprop = jnp.where(idx == PAD_ID, 0.0, logprop)
    return jnp.sum(-logprop)


@jax.jit
def train_step(
    net: Transformer,
    optim_state: optax.OptState,
    enc_input: Batched[Annotated[jax.Array, "seq_len"]],
    dec_input: Batched[Annotated[jax.Array, "seq_len"]],
    dec_output: Batched[Annotated[jax.Array, "seq_len"]],
    enc_mask: Batched[Annotated[jax.Array, "1,1,seq_len"]],
    dec_mask: Batched[Annotated[jax.Array, "1,seq_len,seq_len"]],
    key: jax.Array,
):
    k1, k2 = jr.split(key)

    @ft.partial(jax.grad, has_aux=True)
    def loss_func(net: Transformer):
        net = sk.tree_unmask(net)
        encode = jax.vmap(ft.partial(net.encode, key=k1))
        decode = jax.vmap(ft.partial(net.decode, key=k2))
        project = jax.vmap(net.project)
        enc_pred: Batched[jax.Array] = encode(enc_input, enc_mask)
        dec_pred: Batched[jax.Array] = decode(enc_pred, dec_input, enc_mask, dec_mask)
        logits: Batched[jax.Array] = project(dec_pred)
        loss: Batched[jax.Array] = jax.vmap(softmax_cross_entropy)(logits, dec_output)
        # similar to torch cross_entropy, with ignore_index=PAD_ID and reduction="mean"
        loss = jnp.sum(loss) / jnp.sum(jnp.where(dec_output == PAD_ID, 0.0, 1.0))
        return loss, loss

    grads, loss = loss_func(net)
    updates, optim_state = optim.update(grads, optim_state)
    net = optax.apply_updates(net, updates)
    return net, optim_state, loss

### Train loop

In [None]:
train_key = jr.key(config.train.seed)
batches = len(train_dataset) // config.train.batch_size
for i in (pbar := tqdm(range(config.train.epochs))):
    train_key = jr.fold_in(train_key, i)

    dl_key, train_step_key = jr.split(train_key)

    epoch_loss = []
    for j, batch in tqdm(enumerate(train_dl(key=dl_key)), total=batches):
        train_step_key = jr.fold_in(train_step_key, j)

        net, optim_state, loss = train_step(
            net,
            optim_state,
            batch["enc_input"],
            batch["dec_input"],
            batch["dec_output"],
            batch["enc_mask"],
            batch["dec_mask"],
            train_step_key,
        )
        epoch_loss += [loss]
        pbar.set_description(f"loss: {loss:.4f}")

    pbar.set_description(f"loss: {np.mean(epoch_loss):.4f}")

## Evaluation

In [10]:
def predict(
    text: str,
    *,
    net: Transformer,
    src_tokenizer,
    tgt_tokenizer,
    max_len: int,
    key: jax.Array,
) -> str:
    sos_id = np.array([tgt_tokenizer.token_to_id("[SOS]")])
    eos_id = np.array([tgt_tokenizer.token_to_id("[EOS]")])
    pad_id = np.array([tgt_tokenizer.token_to_id("[PAD]")])

    # text -> jax array of integer indices
    # subtract 2 for sos and eos at the beginning and end of the sequence
    src_ids = jnp.array(src_tokenizer.encode(text).ids)[: max_len - 2]
    # fill the rest with pad tokens if the sequence is shorter than seq_len
    enc_pad = pad_id.repeat(max(0, max_len - len(src_ids) - 2))
    # add sos and eos tokens to the beginning and end of the sequence
    enc_input = jnp.concatenate([sos_id, src_ids, eos_id, enc_pad])

    # 1, 1, seq_len
    enc_mask = jnp.where(enc_input == pad_id, False, True)[None, None, :]
    enc_key, dec_key = jr.split(key)
    enc_output = net.encode(enc_input, enc_mask, key=enc_key)
    dec_input = jnp.array([tgt_tokenizer.token_to_id("[SOS]")])

    while True:
        dec_key = jr.fold_in(dec_key, len(dec_input))
        # causal mask
        dec_mask = jnp.tril(jnp.ones([1, len(dec_input), len(dec_input)])).astype(bool)
        # skip pad tokens
        dec_mask &= jnp.where(dec_input == pad_id, False, True)[None, None, :]

        # seq_len, d_model
        dec_output = net.decode(enc_output, dec_input, enc_mask, dec_mask, key=dec_key)
        logits = net.project(dec_output)[..., -1, :]
        next_token = jnp.argmax(logits, axis=-1)[None]

        if next_token == eos_id:
            # reached end of sequence
            break

        dec_input = jnp.concatenate([dec_input, next_token], axis=-1)

    return tgt_tokenizer.decode(dec_input.tolist())


translate_from_arabic_to_english = ft.partial(
    predict,
    net=sk.tree_eval(sk.tree_unmask(net)),
    src_tokenizer=src_tokenizer,
    tgt_tokenizer=tgt_tokenizer,
    max_len=config.model.seq_len,
)

In [11]:
for index in range(0, 50, 2):
    text_ar = test_dataset[index]["translation"]["ar"]
    text_en = test_dataset[index]["translation"]["en"]
    text_en_pred = translate_from_arabic_to_english(text_ar, key=jr.key(0))

    print(
        f"input arabic: {text_ar}\n"
        f"true english: {text_en}\n"
        f"pred english: {text_en_pred}\n"
    )

input arabic: حسناً ، تبعاً للتقرير هناك ثلاث عبوات ماء مفقودة
true english: Well, according to the report, there were three water bottles missing.
pred english: Well , according to the report is three of the missing water .

input arabic: هل سترجع للعيش معنا هنا
true english: (Maya) So Are You Gonna Move Back In,
pred english: You ' re gonna get him in here .

input arabic: مما أنت خائف؟
true english: What are you scared of?
pred english: Which you ' re afraid ?

input arabic: هيّا يا مهدّئات العضلات.
true english: Come on, muscle relaxers.
pred english: Come on , .

input arabic: كما انك طردتها.
true english: As you rejected it.
pred english: As you ' re a .

input arabic: ثم أنه ليس له.
true english: Then it's not him.
pred english: Then he ' s not for him .

input arabic: وينبغي ألا تقدم التصويبات إلا للنص باللغات الأصلية.
true english: Corrections should be submitted to the original languages only.
pred english: Corrections should not be submitted only in the original languages .
