In [1]:
import mxnet as mx
import gluonnlp as nlp

import time
import random
import numpy as np
import sacremoses
from tqdm import tqdm_notebook as tqdm

# Local Libraries
import nmt
import dataprocessor
import utils
import nmt.gnmt_hparams

# Seeds for reproducibility
np.random.seed(100)
random.seed(100)
mx.random.seed(10000)

# CPU setup
# ctx = mx.cpu()
# Single GPU setup
ctx = mx.gpu(0)

[nltk_data] Downloading package punkt to /home/andreto/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /home/andreto/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to /home/andreto/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [2]:
# Dataset Parameters
src_lang, tgt_lang = 'en', 'de'
# No limit on sentences length
src_max_len, tgt_max_len = -1, -1

In [3]:
# WMT2016 Dataset (Train and Evaluation)
wmt_train_text_bpe = nlp.data.WMT2016BPE("train", # BPE: cheapest --> cheap@@, est
                                         src_lang=src_lang,
                                         tgt_lang=tgt_lang)

wmt_train_text     = nlp.data.WMT2016("train",
                                      src_lang=src_lang,
                                      tgt_lang=tgt_lang)

wmt_test_text_bpe  = nlp.data.WMT2016BPE("newstest2016", # BPE: cheapest --> cheap@@, est
                                         src_lang=src_lang,
                                         tgt_lang=tgt_lang)

wmt_test_text      = nlp.data.WMT2016("newstest2016",
                                     src_lang=src_lang,
                                     tgt_lang=tgt_lang)

wmt_src_vocab = wmt_train_text_bpe.src_vocab
wmt_tgt_vocab = wmt_train_text_bpe.tgt_vocab

  'Detected a corrupted index in the deserialize vocabulary. '


In [4]:
# Processing datasets
# Filtering training data to a maximum number of samples,
# so that training can be handled in a reasonable time (~8 hrs)
# in single GPU setups
max_samples = int(1e4)
wmt_train_text_bpe = mx.gluon.data.SimpleDataset([wmt_train_text_bpe[i] for i in range(max_samples)])
wmt_train_text     = mx.gluon.data.SimpleDataset([wmt_train_text[i] for i in range(max_samples)])
wmt_test_text_bpe  = mx.gluon.data.SimpleDataset(wmt_test_text_bpe)
wmt_test_text      = mx.gluon.data.SimpleDataset(wmt_test_text)

In [5]:
# Dataset example (human-readable): English and German
print(wmt_test_text[16][0])
print(wmt_test_text[16][1])

By the end of the day, there would be one more death: Lamb took his own life as police closed in on him.
Bis zum Ende des Tages gab es einen weiteren Tod: Lamm nahm sich das Leben, als die Polizei ihn einkesselte.


In [6]:
# Retrieve (split) translated sequences (target)
wmt_train_tgt_sentences = wmt_train_text.transform(lambda src, tgt: tgt)
wmt_test_tgt_sentences  = wmt_test_text.transform(lambda src, tgt: tgt)
print("Sample target sentence:")
print(wmt_test_tgt_sentences[16])

Sample target sentence:
Bis zum Ende des Tages gab es einen weiteren Tod: Lamm nahm sich das Leben, als die Polizei ihn einkesselte.


In [7]:
# Dataset processing: clipping, tokenizing, indexing and adding of EOS (src/tgt) / BOS (tgt)
wmt_transform_fn = dataprocessor.TrainValDataTransform(wmt_src_vocab, wmt_tgt_vocab)

wmt_train_processed = wmt_train_text_bpe.transform(wmt_transform_fn, lazy=False)
wmt_test_processed  = wmt_test_text_bpe.transform(wmt_transform_fn, lazy=False)

wmt_train_text_with_len = wmt_train_processed.transform(nmt.utils.get_length_index_fn(), lazy=False)
wmt_test_text_with_len  = wmt_test_processed.transform(nmt.utils.get_length_index_fn(), lazy=False)

print(wmt_test_text_with_len[16][0])
print(wmt_test_text_with_len[16][1])

[ 2083 28753 16760 23875 28753 15230    28 28783 31223 12931 24017 23247
 15259   569  5971 12813 29083 20097 24348 22312 12290 24829 14439 20585
 24004 20061    62     3]
[    2  1897 31601  3259 15535  9414 18646 17382 16407 30851  9629   569
  5971 22642 23439 27119 15199  6041    28 11681 15681  7670 20454 16394
 21488 26868 28535    62     3]


In [8]:
# Batcher
wmt_batchify_fn = nlp.data.batchify.Tuple(
    nlp.data.batchify.Pad(),                   # Source Token IDs
    nlp.data.batchify.Pad(),                   # Target Token IDs
    nlp.data.batchify.Stack(dtype='float32'),  # Source Sequence Length
    nlp.data.batchify.Stack(dtype='float32'),  # Target Sequence Length
    nlp.data.batchify.Stack())                 # Index

  'Padding value is not given and will be set automatically to 0 '


In [9]:
# Hyperparameters
hparams = nmt.gnmt_hparams

In [10]:
# Samplers
wmt_train_batch_sampler = nlp.data.FixedBucketSampler(
    lengths=wmt_train_text_with_len.transform(lambda src, tgt, src_len, tgt_len, idx: (src_len, tgt_len)),
    num_buckets=hparams.num_buckets,
    batch_size=hparams.batch_size)
print(wmt_train_batch_sampler.stats())

wmt_test_batch_sampler = nlp.data.FixedBucketSampler(
    lengths=wmt_test_text_with_len.transform(lambda src, tgt, src_len, tgt_len, idx: (src_len, tgt_len)),
    num_buckets=hparams.num_buckets,
    batch_size=hparams.test_batch_size)
print(wmt_test_batch_sampler.stats())

FixedBucketSampler:
  sample_num=10000, batch_num=159
  key=[(21, 25), (39, 44), (57, 63), (75, 82), (93, 101)]
  cnt=[3409, 3991, 1797, 622, 181]
  batch_size=[64, 64, 64, 64, 64]
FixedBucketSampler:
  sample_num=2999, batch_num=97
  key=[(23, 26), (43, 48), (63, 70), (83, 92), (103, 114)]
  cnt=[1417, 1191, 329, 56, 6]
  batch_size=[32, 32, 32, 32, 32]


In [11]:
# DataLoaders
wmt_train_data_loader = mx.gluon.data.DataLoader(
    wmt_train_text_with_len,
    batch_sampler=wmt_train_batch_sampler,
    batchify_fn=wmt_batchify_fn,
    num_workers=8)
print('Number of training batches:', len(wmt_train_data_loader))

wmt_test_data_loader = mx.gluon.data.DataLoader(
    wmt_test_text_with_len,
    batch_sampler=wmt_test_batch_sampler,
    batchify_fn=wmt_batchify_fn,
    num_workers=8)
print('Number of testing batches:', len(wmt_test_data_loader))

Number of training batches: 159
Number of testing batches: 97


In [12]:
# Model
encoder, decoder, one_step_ahead_decoder = nmt.gnmt.get_gnmt_encoder_decoder(
    hidden_size=hparams.num_hidden,
    dropout=hparams.dropout,
    num_layers=hparams.num_layers,
    num_bi_layers=hparams.num_bi_layers)

gnmt_model = nlp.model.translation.NMTModel(
    src_vocab=wmt_src_vocab,
    tgt_vocab=wmt_tgt_vocab,
    encoder=encoder,
    decoder=decoder,
    one_step_ahead_decoder=one_step_ahead_decoder,
    embed_size=hparams.num_hidden,
    prefix='gnmt_')

gnmt_model.initialize(init=mx.init.Uniform(0.1), ctx=ctx)
static_alloc = True
gnmt_model.hybridize(static_alloc=static_alloc)

In [13]:
scorer=nlp.model.BeamSearchScorer(
    alpha=hparams.lp_alpha,
    K=hparams.lp_k)

gnmt_translator = nmt.translation.BeamSearchTranslator(
    model=gnmt_model,
    beam_size=hparams.beam_size,
    scorer=scorer,
    max_length=tgt_max_len + 100)

print("Use beam_size={}, alpha={}, K={}".format(hparams.beam_size, hparams.lp_alpha, hparams.lp_k))

Use beam_size=10, alpha=1.0, K=5


In [14]:
# Evaluation (Baseline)
eval_start_time = time.time()
wmt_loss_function = nlp.loss.MaskedSoftmaxCELoss()
wmt_loss_function.hybridize()
wmt_detokenizer = nlp.data.SacreMosesDetokenizer()

gnmt_test_loss, gnmt_test_translation_out = nmt.utils.evaluate(
    gnmt_model,
    wmt_test_data_loader,
    wmt_loss_function,
    gnmt_translator,
    wmt_tgt_vocab,
    wmt_detokenizer,
    ctx)

gnmt_test_bleu_score, _, _, _, _ = nmt.bleu.compute_bleu(
    [wmt_test_tgt_sentences],
    gnmt_test_translation_out,
    tokenized=False,
    tokenizer=hparams.bleu,
    split_compound_word=False,
    bpe=False)

print('WMT16 EN-DE GNMT model test loss: %.2f; test bleu score: %.2f; time cost %.2fs' %(gnmt_test_loss, gnmt_test_bleu_score * 100, (time.time() - eval_start_time)))

  0%|          | 0/97 [00:00<?, ?it/s]

Extension horovod.torch has not been built: /home/ubuntu/anaconda3/envs/mxnet_p37/lib/python3.7/site-packages/horovod/torch/mpi_lib/_mpi_lib.cpython-37m-x86_64-linux-gnu.so not found
If this is not expected, reinstall Horovod with HOROVOD_WITH_PYTORCH=1 to debug the build error.
[2022-06-05 12:17:37.605 ip-172-31-28-47:9936 INFO utils.py:27] RULE_JOB_STOP_SIGNAL_FILENAME: None
[2022-06-05 12:17:37.636 ip-172-31-28-47:9936 INFO profiler_config_parser.py:111] Unable to find config at /opt/ml/input/config/profilerconfig.json. Profiler is disabled.
WMT16 EN-DE SOTA model test loss: 7.75; test bleu score: 0.00; time cost 125.61s


In [15]:
# Training
trainer = mx.gluon.Trainer(gnmt_model.collect_params(), 'adam', {'learning_rate': hparams.lr})

In [19]:
hparams.epochs = 1

nmt.utils.train(
    gnmt_model,
    wmt_train_data_loader,
    wmt_test_data_loader,
    wmt_loss_function,
    trainer,
    gnmt_translator,
    wmt_tgt_vocab,
    wmt_test_tgt_sentences,
    wmt_detokenizer,
    hparams.file_name,
    hparams,
    ctx)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/159 [00:00<?, ?it/s]

[Epoch 0 Batch 100/159] loss=6.3304, ppl=561.3720, gnorm=0.3461, throughput=20.05K wps, wc=486.77K


  0%|          | 0/97 [00:00<?, ?it/s]

[Epoch 0] valid Loss=5.7555, valid ppl=315.9161, valid bleu=0.12
Save best parameters to gnmt_en_de_512.params
Learning rate change to 0.0005


In [None]:
print("Qualitative Evaluation: Translating from English to German:")

sample_src_seq = "I love reading technical books from Packt."
print("[\'" + sample_src_seq + "\']")

sample_tgt_seq = nmt.utils.translate(
    transformer_translator,
    sample_src_seq,
    wmt_src_vocab,
    wmt_tgt_vocab,
    wmt_detokenizer,
    ctx)

print("The German translation is:")
print(sample_tgt_seq)