In [1]:
import time
from typing import Any, MutableMapping, NamedTuple, Tuple

from absl import app
from absl import flags
from absl import logging
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax

import dataset
import model

from datasets import load_dataset
from tqdm.autonotebook import tqdm

In [2]:
IS_TRAINING = True

LEARNING_RATE = 3e-4
SEQ_LENGTH = 512
GRAD_CLIP_VALUE = 1
LOG_EVERY = 50
MAX_STEPS = 10**6
SEED = 42

# dataset & tokenizer

## Peek at our data (English-Spanish)

In [3]:
dataset = load_dataset("avacaondata/europarl_en_es_v2", split='train')
dataset



Dataset({
    features: ['id', 'source_en', 'target_es', '__index_level_0__'],
    num_rows: 275203
})

In [4]:
dataset[0]

{'id': 19512,
 'source_en': "Let the Commission go about its business and produce an amended version of the exceptions to heavy metals in two years' time.\nFinally, I should definitely like to mention the discussion on brominated flame retardants.\nIt has become a discussion between believers, such as myself, convinced of the harmful effects on the environment and health, and non-believers.\nWhat I find important is that, for many, this discussion has led to a greater appreciation of the harmfulness of these products.\nA ban in 2006 is impossible.\nThe amendment tabled by my group asks for producers to demonstrate by 2003 that these products are harmless, and I hope that this can be achieved.\nMr President, more than thirty years ago an organisation called 'Friends of the Earth' was born in my country.\n",
 'target_es': 'Dejemos que la Comisión efectúe su trabajo y presente dentro de dos años una versión adecuada de las excepciones en cuanto a metales pesados.\nPor último, quiero entra

## Train our tokenizers for English and Spanish

We have prepared the following python code for you:

- train_tokenizer_en.py
- train_tokenizer_es.py

Should you choose to use other language, please change the codes to fit the data format and output directory.\
Running them in separate terminals would be a good idea, as each takes some time to finish.\
For me, it took about 8m19s and 12m42s for English and Spanish tokenizers to train on a server at Argonne.

NOTE: non-space-separated languages, such as Chinese, Japanese, Korean, Thai, need further adaptation to the tokenizer pipeline.

Useful readings:
- https://huggingface.co/docs/tokenizers/quicktour
- https://huggingface.co/course/chapter6/8?fw=pt#building-a-tokenizer-block-by-block
- https://huggingface.co/docs/tokenizers/index
- https://www.reddit.com/r/MachineLearning/comments/rprmq3/d_sentencepiece_wordpiece_bpe_which_tokenizer_is/

## Tokenize the text data
suppose we have our English and Spanish tokenizers trained and stored at `./vanilla-NMT/en` and `./vanilla-NMT/es`

In [5]:
from tokenizers import Tokenizer
from transformers import PreTrainedTokenizerFast

None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [6]:
tokenizer_english = Tokenizer.from_file("vanilla-NMT/en/tokenizer.json")

src_tokenizer = PreTrainedTokenizerFast(
    tokenizer_object=tokenizer_english,
    bos_token="<s>",
    eos_token="</s>",
    unk_token="<unk>",
    pad_token="<pad>",
    cls_token="<cls>",
    sep_token="<sep>",
    mask_token="<mask>",
    padding_side="right",
    truncation_side='right',
)

tokenizer_spanish = Tokenizer.from_file("vanilla-NMT/es/tokenizer.json")
tgt_tokenizer = PreTrainedTokenizerFast(
    tokenizer_object=tokenizer_spanish,
    bos_token="<s>",
    eos_token="</s>",
    unk_token="<unk>",
    pad_token="<pad>",
    cls_token="<cls>",
    sep_token="<sep>",
    mask_token="<mask>",
    padding_side="right",
    truncation_side='right',
)

In [7]:
src_tokenizer

PreTrainedTokenizerFast(name_or_path='', vocab_size=25007, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '<sep>', 'pad_token': '<pad>', 'cls_token': '<cls>', 'mask_token': '<mask>'})

In [8]:
tgt_tokenizer

PreTrainedTokenizerFast(name_or_path='', vocab_size=25007, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '<sep>', 'pad_token': '<pad>', 'cls_token': '<cls>', 'mask_token': '<mask>'})

In [9]:
# we use streaming version of dataset
dataset = load_dataset("avacaondata/europarl_en_es_v2", split='train', streaming=True)

# encode function to map on each dataset entry
def encode(examples):
    src_inputs = src_tokenizer(
        examples['source_en'], 
        truncation=True, max_length=SEQ_LENGTH, padding='max_length',
        return_token_type_ids=False,
        return_attention_mask=False,
    )['input_ids']
    tgt_inputs = tgt_tokenizer(
        examples['target_es'], 
        truncation=True, max_length=SEQ_LENGTH, padding='max_length',
        return_token_type_ids=False,
        return_attention_mask=False,
    )['input_ids']
    return {
        'src_inputs': src_inputs,
        'tgt_inputs': tgt_inputs,
    }

# now dataset is a iter object
dataset = iter(dataset.map(encode, batched=True, remove_columns=["id", "source_en", "target_es", "__index_level_0__"]))



In [10]:
print(src_tokenizer.decode(next(dataset)['src_inputs'], skip_special_tokens=True).replace('▁', ''))

Let the Commission go about its business and produce an amended version of the exceptions to heavy metals in two years' time.
Finally, I should definitely like to mention the discussion on brom inated flame retardants.
It has become a discussion between believers, such as myself, convinced of the harmful effects on the environment and health, and non-believers.
What I find important is that, for many, this discussion has led to a greater appreciation of the harmful ness of these products.
A ban in 2006 is impossible.
The amendment tabled by my group asks for producers to demonstrate by 2003 that these products are harmless, and I hope that this can be achieved.
Mr President, more than thirty years ago an organisation called 'F rie nd s of the Earth'was born in my country.



In [11]:
src_tokenizer.decode([3])

'<pad>'

# Training functions and settings

In [12]:
# some training and model parameters:
CONFIG = model.TransformerConfig(
    input_vocab_size=src_tokenizer.vocab_size,
    output_vocab_size=tgt_tokenizer.vocab_size,
    model_size=256,
    num_heads=8,
    num_layers=6,
    hidden_size=512,
    dropout_rate=0.1,
    src_pad_token=src_tokenizer.pad_token_id,
    tgt_pad_token=tgt_tokenizer.pad_token_id,
)

In [13]:
class TrainingState(NamedTuple):
    """Container for the training state."""
    params: hk.Params
    opt_state: optax.OptState
    rng: jnp.DeviceArray
    step: jnp.DeviceArray

In [14]:
# Create the model.
def forward(
    src_inputs: jnp.ndarray,
    tgt_inputs: jnp.ndarray,
    is_training: bool,
) -> jnp.ndarray:

    lm = model.Transformer(
        config=CONFIG,
        is_training=IS_TRAINING
    )
    return lm(src_inputs, tgt_inputs, is_training=is_training)

In [15]:
optimizer = optax.chain(
    optax.clip_by_global_norm(GRAD_CLIP_VALUE),
    optax.adam(LEARNING_RATE, b1=0.9, b2=0.99),
)

In [16]:
@hk.transform
def loss_fn(data) -> jnp.ndarray:
    src_inputs = jnp.asarray(data['src_inputs'], dtype=jnp.int32)[None,:]
    tgt_inputs = jnp.asarray(data['tgt_inputs'], dtype=jnp.int32)[None,:]
       
    logits = forward(src_inputs, tgt_inputs, IS_TRAINING)
    targets = jax.nn.one_hot(tgt_inputs, CONFIG.output_vocab_size)
    assert logits.shape == targets.shape
    
    mask = jnp.greater(tgt_inputs, 0)
    log_likelihood = jnp.sum(targets * jax.nn.log_softmax(logits), axis=-1)
    return -jnp.sum(log_likelihood * mask) / jnp.sum(mask) # NLL per token

In [17]:
_Metrics = MutableMapping[str, Any]

@jax.jit
def update(state: TrainingState, data) -> Tuple[TrainingState, _Metrics]:
    '''
    Does an SGD step and return metrics
    '''
    rng, new_rng = jax.random.split(state.rng)
    loss_and_grad_fn = jax.value_and_grad(loss_fn.apply)
    loss, gradients = loss_and_grad_fn(state.params, rng, data)
    
    updates, new_opt_state = optimizer.update(gradients, state.opt_state)
    new_params = optax.apply_updates(state.params, updates)
    
    new_state = TrainingState(
        params=new_params,
        opt_state=new_opt_state,
        rng=new_rng,
        step=state.step + 1,
    )
    
    metrics = {
        'step': state.step,
        'loss': loss,
    }
    return new_state, metrics

In [18]:
@jax.jit
def init(rng: jnp.ndarray, data) -> TrainingState:
    rng, init_rng = jax.random.split(rng)
    initial_params = loss_fn.init(init_rng, data)
    initial_opt_state = optimizer.init(initial_params)
    return TrainingState(
        params=initial_params,
        opt_state=initial_opt_state,
        rng=rng,
        step=np.array(0),
    )

In [None]:
rng = jax.random.PRNGKey(SEED)
data = next(dataset)
state = init(rng, data)

prev_time = time.time()
for step in range(MAX_STEPS):
    data = next(dataset)
    state, metrics = update(state, data)
    if step % LOG_EVERY == 0:
        step_per_sec = LOG_EVERY / (time.time() - prev_time)
        prev_time = time.time()
        metrics |= {'step_per_sec': step_per_sec}
        logging.info({k: float(v) for k, v in metrics.items()})