In [1]:
import time
from typing import Any, MutableMapping, NamedTuple, Tuple

from absl import app
from absl import flags
from absl import logging
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax

import dataset
import model

from datasets import load_dataset
from tqdm.autonotebook import tqdm

In [2]:
IS_TRAINING = True

LEARNING_RATE = 3e-4
SEQ_LENGTH = 512
GRAD_CLIP_VALUE = 1
LOG_EVERY = 50
MAX_STEPS = 2000
SEED = 42

# dataset & tokenizer

## Peek at our data (English-Spanish)

In [3]:
dataset = load_dataset("avacaondata/europarl_en_es_v2", split='train')
dataset



Dataset({
    features: ['id', 'source_en', 'target_es', '__index_level_0__'],
    num_rows: 275203
})

In [4]:
dataset[0]

{'id': 19512,
 'source_en': "Let the Commission go about its business and produce an amended version of the exceptions to heavy metals in two years' time.\nFinally, I should definitely like to mention the discussion on brominated flame retardants.\nIt has become a discussion between believers, such as myself, convinced of the harmful effects on the environment and health, and non-believers.\nWhat I find important is that, for many, this discussion has led to a greater appreciation of the harmfulness of these products.\nA ban in 2006 is impossible.\nThe amendment tabled by my group asks for producers to demonstrate by 2003 that these products are harmless, and I hope that this can be achieved.\nMr President, more than thirty years ago an organisation called 'Friends of the Earth' was born in my country.\n",
 'target_es': 'Dejemos que la Comisión efectúe su trabajo y presente dentro de dos años una versión adecuada de las excepciones en cuanto a metales pesados.\nPor último, quiero entra

## Train our tokenizers for English and Spanish

We have prepared the following python code for you:

- train_tokenizer_en.py
- train_tokenizer_es.py

Should you choose to use other language, please change the codes to fit the data format and output directory.\
Running them in separate terminals would be a good idea, as each takes some time to finish.\
For me, it took about 8m19s and 12m42s for English and Spanish tokenizers to train on a server at Argonne.

NOTE: non-space-separated languages, such as Chinese, Japanese, Korean, Thai, need further adaptation to the tokenizer pipeline.

Useful readings:
- https://huggingface.co/docs/tokenizers/quicktour
- https://huggingface.co/course/chapter6/8?fw=pt#building-a-tokenizer-block-by-block
- https://huggingface.co/docs/tokenizers/index
- https://www.reddit.com/r/MachineLearning/comments/rprmq3/d_sentencepiece_wordpiece_bpe_which_tokenizer_is/

## Tokenize the text data
suppose we have our English and Spanish tokenizers trained and stored at `./vanilla-NMT/en` and `./vanilla-NMT/es`

In [5]:
from tokenizers import Tokenizer
from transformers import PreTrainedTokenizerFast

None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [6]:
tokenizer_english = Tokenizer.from_file("vanilla-NMT/en/tokenizer.json")

src_tokenizer = PreTrainedTokenizerFast(
    tokenizer_object=tokenizer_english,
    bos_token="<s>",
    eos_token="</s>",
    unk_token="<unk>",
    pad_token="<pad>",
    cls_token="<cls>",
    sep_token="<sep>",
    mask_token="<mask>",
    padding_side="right",
    truncation_side='right',
)

tokenizer_spanish = Tokenizer.from_file("vanilla-NMT/es/tokenizer.json")
tgt_tokenizer = PreTrainedTokenizerFast(
    tokenizer_object=tokenizer_spanish,
    bos_token="<s>",
    eos_token="</s>",
    unk_token="<unk>",
    pad_token="<pad>",
    cls_token="<cls>",
    sep_token="<sep>",
    mask_token="<mask>",
    padding_side="right",
    truncation_side='right',
)

In [7]:
src_tokenizer

PreTrainedTokenizerFast(name_or_path='', vocab_size=25007, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '<sep>', 'pad_token': '<pad>', 'cls_token': '<cls>', 'mask_token': '<mask>'})

In [8]:
tgt_tokenizer

PreTrainedTokenizerFast(name_or_path='', vocab_size=25007, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '<sep>', 'pad_token': '<pad>', 'cls_token': '<cls>', 'mask_token': '<mask>'})

In [9]:
# we use streaming version of dataset
dataset = load_dataset("avacaondata/europarl_en_es_v2", split='train', streaming=True)

# encode function to map on each dataset entry
def encode(examples):
    def decorate(text, tokenizer):
        decorated = f"{tokenizer.bos_token} {text} {tokenizer.eos_token}"
        decorated = decorated.replace('\n', tokenizer.sep_token)
        return decorated
    
    src_inputs = src_tokenizer(
        decorate(examples['source_en'], src_tokenizer), 
        truncation=True, max_length=SEQ_LENGTH, padding='max_length',
        return_token_type_ids=False,
        return_attention_mask=False,
    )['input_ids']
    tgt_inputs = tgt_tokenizer(
        decorate(examples['target_es'], tgt_tokenizer),
        truncation=True, max_length=SEQ_LENGTH, padding='max_length',
        return_token_type_ids=False,
        return_attention_mask=False,
    )['input_ids']
    return {
        'src_inputs': src_inputs,
        'tgt_inputs': tgt_inputs,
    }

# now dataset is a iter object
dataset = iter(dataset.map(encode, remove_columns=["id", "source_en", "target_es", "__index_level_0__"]))



In [10]:
# next(dataset)

In [11]:
print(src_tokenizer.decode(next(dataset)['src_inputs'], skip_special_tokens=True).replace('▁', ''))

Let the Commission go about its business and produce an amended version of the exceptions to heavy metals in two years' time.  Finally, I should definitely like to mention the discussion on brom inated flame retardants. It has become a discussion between believers, such as myself, convinced of the harmful effects on the environment and health, and non-believers.  What I find important is that, for many, this discussion has led to a greater appreciation of the harmful ness of these products. A ban in 2006 is impossible. The amendment tabled by my group asks for producers to demonstrate by 2003 that these products are harmless, and I hope that this can be achieved. Mr President, more than thirty years ago an organisation called 'F rie nd s of the Earth'was born in my country. 


In [12]:
src_tokenizer.decode([3])

'<pad>'

# Training

`train.py` script is provided. Please run the code in terminal and find your ckpt at `./ckpt` folder.

# Use

In [13]:
from train import TrainingState, main

# some training and model parameters:
CONFIG = model.TransformerConfig(
    input_vocab_size=src_tokenizer.vocab_size,
    output_vocab_size=tgt_tokenizer.vocab_size,
    model_size=256,
    num_heads=8,
    num_layers=6,
    hidden_size=512,
    dropout_rate=0.1,
    src_pad_token=src_tokenizer.pad_token_id,
    tgt_pad_token=tgt_tokenizer.pad_token_id,
)
IS_TRAINING = False

# Create the model.
def forward(
    src_inputs: jnp.ndarray,
    tgt_inputs: jnp.ndarray,
    is_training: bool,
) -> jnp.ndarray:

    lm = model.Transformer(
        config=CONFIG,
        is_training=is_training
    )
    return lm(src_inputs, tgt_inputs, is_training=is_training)

In [14]:
import pickle
from datetime import datetime

ckpt_file = 'ckpt/state_02-Oct-2022 (01:54:46).pickle'

# Load data (deserialize)
with open(ckpt_file, 'rb') as handle:
    state = pickle.load(handle)

In [15]:
# predict_fn = hk.transform(forward)

def predict(src_inputs, tgt_max_len):

    src_inputs = jnp.asarray(src_inputs, dtype=jnp.int32)[None,:]
    tgt_inputs = jnp.asarray([tgt_tokenizer.bos_token_id], dtype=jnp.int32)[None,:]
    
    @hk.transform
    def one_step(src_inputs, tgt_inputs, is_training=False):
        predictions = forward(src_inputs=src_inputs, tgt_inputs=tgt_inputs, is_training=False)
        predictions = predictions[:, -1, :]
        predicted_id = jnp.argmax(predictions, axis=-1)
        return predicted_id
        
    for i in range(tgt_max_len):
        predicted_id = one_step.apply(state.params, state.rng, src_inputs, tgt_inputs, is_training=False)
        if predicted_id == tgt_tokenizer.eos_token_id:
            return output
        tgt_inputs = np.concatenate([tgt_inputs, [predicted_id]], axis=-1)
        print(i, tgt_inputs)
    return jnp.squeeze(tgt_inputs, axis=0)

In [20]:
def translate(sentence):
    # Load data (deserialize)
    ckpt_file = 'ckpt/state_02-Oct-2022 (01:54:46).pickle'
    with open(ckpt_file, 'rb') as handle:
        state = pickle.load(handle)
    tgt_max_len = 64
    src_sentence = f"{src_tokenizer.bos_token} {sentence} {src_tokenizer.eos_token}"
    src_inputs = src_tokenizer(
            src_sentence, 
            truncation=True, max_length=SEQ_LENGTH, padding='max_length',
            return_token_type_ids=False,
            return_attention_mask=False,
        )['input_ids']
    
    output = predict(src_inputs, tgt_max_len)
    print(output)

In [21]:
translate("This is a sentence in English.")

0 [[5 8]]
1 [[   5    8 1529]]
2 [[   5    8 1529    9]]
3 [[   5    8 1529    9   13]]
4 [[   5    8 1529    9   13   17]]
5 [[   5    8 1529    9   13   17    9]]
6 [[   5    8 1529    9   13   17    9    8]]
7 [[   5    8 1529    9   13   17    9    8   14]]
8 [[   5    8 1529    9   13   17    9    8   14    9]]
9 [[   5    8 1529    9   13   17    9    8   14    9   13]]
10 [[   5    8 1529    9   13   17    9    8   14    9   13   15]]
11 [[   5    8 1529    9   13   17    9    8   14    9   13   15  109]]
12 [[   5    8 1529    9   13   17    9    8   14    9   13   15  109    9]]
13 [[   5    8 1529    9   13   17    9    8   14    9   13   15  109    9
    13]]
14 [[   5    8 1529    9   13   17    9    8   14    9   13   15  109    9
    13   15]]
15 [[   5    8 1529    9   13   17    9    8   14    9   13   15  109    9
    13   15    9]]
16 [[   5    8 1529    9   13   17    9    8   14    9   13   15  109    9
    13   15    9    7]]
17 [[   5    8 1529    9   13   17    9