In [1]:
import mxnet as mx
import gluonnlp as nlp

import time
import random
import numpy as np
import sacremoses
from tqdm import tqdm_notebook as tqdm

# Local Libraries
import nmt
import dataprocessor
import utils
import nmt.transformer_hparams

# Seeds for reproducibility
np.random.seed(100)
random.seed(100)
mx.random.seed(10000)

# CPU setup
# ctx = mx.cpu()
# Single GPU setup
ctx = mx.gpu(0)

[nltk_data] Downloading package punkt to /home/andreto/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /home/andreto/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to /home/andreto/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [2]:
# WMT2016 Dataset (Train and Evaluation)

# Dataset Parameters
src_lang, tgt_lang = 'en', 'de'

wmt_train_text_bpe = nlp.data.WMT2016BPE("train", # BPE: cheapest --> cheap@@, est
                                         src_lang=src_lang,
                                         tgt_lang=tgt_lang)

wmt_train_text     = nlp.data.WMT2016("train",
                                      src_lang=src_lang,
                                      tgt_lang=tgt_lang)

wmt_test_text_bpe  = nlp.data.WMT2016BPE("newstest2016", # BPE: cheapest --> cheap@@, est
                                         src_lang=src_lang,
                                         tgt_lang=tgt_lang)

wmt_test_text      = nlp.data.WMT2016("newstest2016",
                                     src_lang=src_lang,
                                     tgt_lang=tgt_lang)

In [3]:
# Processing datasets
# Filtering training data to a maximum number of samples,
# so that training can be handled in a reasonable time (~8 hrs)
# in single GPU setups
max_samples = int(1e4)
wmt_train_text_bpe = mx.gluon.data.SimpleDataset([wmt_train_text_bpe[i] for i in range(max_samples)])
wmt_train_text     = mx.gluon.data.SimpleDataset([wmt_train_text[i] for i in range(max_samples)])
wmt_test_text_bpe  = mx.gluon.data.SimpleDataset(wmt_test_text_bpe)
wmt_test_text      = mx.gluon.data.SimpleDataset(wmt_test_text)

In [4]:
# Dataset example (human-readable): English and German
print(wmt_test_text[16][0])
print(wmt_test_text[16][1])

By the end of the day, there would be one more death: Lamb took his own life as police closed in on him.
Bis zum Ende des Tages gab es einen weiteren Tod: Lamm nahm sich das Leben, als die Polizei ihn einkesselte.


In [5]:
# Retrieve (split) translated sequences (target)
wmt_train_tgt_sentences = wmt_train_text.transform(lambda src, tgt: tgt)
wmt_test_tgt_sentences  = wmt_test_text.transform(lambda src, tgt: tgt)
print("Sample target sentence:")
print(wmt_test_tgt_sentences[16])

Sample target sentence:
Bis zum Ende des Tages gab es einen weiteren Tod: Lamm nahm sich das Leben, als die Polizei ihn einkesselte.


In [8]:
# Model
transformer_model, wmt_src_vocab, wmt_tgt_vocab = nlp.model.get_model(
    "transformer_en_de_512",
    dataset_name="WMT2014",
    pretrained=True,
    # pretrained=False,
    ctx=ctx)

transformer_model.initialize(init=mx.init.Uniform(0.1), ctx=ctx)
static_alloc = True
transformer_model.hybridize(static_alloc=static_alloc)

In [7]:
# Dataset processing: clipping, tokenizing, indexing and adding of EOS (src/tgt) / BOS (tgt)
wmt_transform_fn = dataprocessor.TrainValDataTransform(wmt_src_vocab, wmt_tgt_vocab)

wmt_train_processed = wmt_train_text_bpe.transform(wmt_transform_fn, lazy=False)
wmt_test_processed  = wmt_test_text_bpe.transform(wmt_transform_fn, lazy=False)

wmt_train_text_with_len = wmt_train_processed.transform(nmt.utils.get_length_index_fn(), lazy=False)
wmt_test_text_with_len  = wmt_test_processed.transform(nmt.utils.get_length_index_fn(), lazy=False)

print(wmt_test_text_with_len[16][0])
print(wmt_test_text_with_len[16][1])

[ 2105 28768 16772 23915 28768 15253    24 28798 31241 12938 24036 23280
 15283   576  5994 12819 29096 20120 24350 22332 12302 24827 14456 20608
 24023 20084    58     3]
[    2  1914 31623  3287 15561  9451 18640 17378 16427 30867  9672   576
  5994 22652 23470 27121 15222  6064    24 11711 15705  7698 20480 16414
 21512 26873 28556    58     3]


In [8]:
# Batcher
wmt_batchify_fn = nlp.data.batchify.Tuple(
    nlp.data.batchify.Pad(),                   # Source Token IDs
    nlp.data.batchify.Pad(),                   # Target Token IDs
    nlp.data.batchify.Stack(dtype='float32'),  # Source Sequence Length
    nlp.data.batchify.Stack(dtype='float32'),  # Target Sequence Length
    nlp.data.batchify.Stack())                 # Index

  'Padding value is not given and will be set automatically to 0 '


In [9]:
# Hyperparameters
hparams = nmt.transformer_hparams

In [10]:
# Samplers
wmt_train_batch_sampler = nlp.data.FixedBucketSampler(
    lengths=wmt_train_text_with_len.transform(lambda src, tgt, src_len, tgt_len, idx: (src_len, tgt_len)),
    num_buckets=hparams.num_buckets,
    batch_size=hparams.batch_size)
print(wmt_train_batch_sampler.stats())

wmt_test_batch_sampler = nlp.data.FixedBucketSampler(
    lengths=wmt_test_text_with_len.transform(lambda src, tgt, src_len, tgt_len, idx: (src_len, tgt_len)),
    num_buckets=hparams.num_buckets,
    batch_size=hparams.test_batch_size)
print(wmt_test_batch_sampler.stats())

FixedBucketSampler:
  sample_num=10000, batch_num=168
  key=[(17, 25), (21, 29), (25, 33), (29, 37), (33, 41), (37, 45), (41, 49), (45, 53), (49, 57), (53, 61), (57, 65), (61, 69), (65, 73), (69, 77), (73, 81), (77, 85), (81, 89), (85, 93), (89, 97), (93, 101)]
  cnt=[2474, 1053, 1076, 971, 869, 730, 594, 470, 386, 335, 266, 199, 137, 116, 105, 92, 63, 37, 17, 10]
  batch_size=[64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64]
FixedBucketSampler:
  sample_num=2999, batch_num=62
  key=[(8, 19), (13, 24), (18, 29), (23, 34), (28, 39), (33, 44), (38, 49), (43, 54), (48, 59), (53, 64), (58, 69), (63, 74), (68, 79), (73, 84), (78, 89), (83, 94), (88, 99), (93, 104), (98, 109), (103, 114)]
  cnt=[92, 392, 529, 475, 385, 323, 255, 198, 131, 85, 52, 26, 23, 17, 3, 7, 1, 2, 1, 2]
  batch_size=[64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64]


In [11]:
# DataLoaders
wmt_train_data_loader = mx.gluon.data.DataLoader(
    wmt_train_text_with_len,
    batch_sampler=wmt_train_batch_sampler,
    batchify_fn=wmt_batchify_fn,
    num_workers=8)
print('Number of training batches:', len(wmt_train_data_loader))

wmt_test_data_loader = mx.gluon.data.DataLoader(
    wmt_test_text_with_len,
    batch_sampler=wmt_test_batch_sampler,
    batchify_fn=wmt_batchify_fn,
    num_workers=8)
print('Number of testing batches:', len(wmt_test_data_loader))

Number of training batches: 168
Number of testing batches: 62


In [12]:
# For Evaluation
scorer=nlp.model.BeamSearchScorer(
    alpha=hparams.lp_alpha,
    K=hparams.lp_k)

transformer_translator = nmt.translation.BeamSearchTranslator(
    model=transformer_model,
    beam_size=hparams.beam_size,
    scorer=scorer,
    max_length=hparams.max_length)

print("Use beam_size={}, alpha={}, K={}".format(hparams.beam_size, hparams.lp_alpha, hparams.lp_k))

Use beam_size=4, alpha=0.6, K=5


In [13]:
# Evaluation (Baseline)
eval_start_time = time.time()
wmt_loss_function = nlp.loss.MaskedSoftmaxCELoss()
wmt_loss_function.hybridize()
wmt_detokenizer = nlp.data.SacreMosesDetokenizer()

transformer_test_loss, transformer_test_translation_out = nmt.utils.evaluate(
    transformer_model,
    wmt_test_data_loader,
    wmt_loss_function,
    transformer_translator,
    wmt_tgt_vocab,
    wmt_detokenizer,
    ctx)

transformer_test_bleu_score, _, _, _, _ = nmt.bleu.compute_bleu(
    [wmt_test_tgt_sentences],
    transformer_test_translation_out,
    tokenized=False,
    tokenizer=hparams.bleu,
    split_compound_word=False,
    bpe=False)

print('WMT16 EN-DE Transformer model test loss: %.2f; test bleu score: %.2f; time cost %.2fs' %(transformer_test_loss, transformer_test_bleu_score * 100, (time.time() - eval_start_time)))

  0%|          | 0/62 [00:00<?, ?it/s]

Extension horovod.torch has not been built: /home/ubuntu/anaconda3/envs/mxnet_p37/lib/python3.7/site-packages/horovod/torch/mpi_lib/_mpi_lib.cpython-37m-x86_64-linux-gnu.so not found
If this is not expected, reinstall Horovod with HOROVOD_WITH_PYTORCH=1 to debug the build error.
[2022-06-11 11:40:22.772 ip-172-31-28-47:26085 INFO utils.py:27] RULE_JOB_STOP_SIGNAL_FILENAME: None
[2022-06-11 11:40:22.802 ip-172-31-28-47:26085 INFO profiler_config_parser.py:111] Unable to find config at /opt/ml/input/config/profilerconfig.json. Profiler is disabled.
WMT16 EN-DE Transformer model test loss: 1.39; test bleu score: 30.19; time cost 183.00s


In [14]:
# Training
trainer = mx.gluon.Trainer(transformer_model.collect_params(), 'adam', {'learning_rate': hparams.lr})

In [15]:
hparams.epochs = 10

test_loss, test_translation_out = nmt.utils.train(
    transformer_model,
    wmt_train_data_loader,
    wmt_test_data_loader,
    wmt_loss_function,
    trainer,
    # transformer_translator,
    wmt_tgt_vocab,
    wmt_test_tgt_sentences,
    wmt_detokenizer,
    hparams.file_name,
    hparams,
    ctx)

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/168 [00:00<?, ?it/s]

[Epoch 0 Batch 10/168] loss=10.6772, ppl=43355.0385, gnorm=11.6596, throughput=16.98K wps, wc=63.87K
[Epoch 0 Batch 20/168] loss=8.3144, ppl=4082.0349, gnorm=1.6977, throughput=28.21K wps, wc=61.03K
[Epoch 0 Batch 30/168] loss=7.7640, ppl=2354.2756, gnorm=0.6805, throughput=27.07K wps, wc=53.92K
[Epoch 0 Batch 40/168] loss=7.5429, ppl=1887.2953, gnorm=0.4322, throughput=26.81K wps, wc=51.33K
[Epoch 0 Batch 50/168] loss=7.4008, ppl=1637.3582, gnorm=0.2987, throughput=27.64K wps, wc=47.70K
[Epoch 0 Batch 60/168] loss=7.3407, ppl=1541.7434, gnorm=0.2538, throughput=25.07K wps, wc=42.88K
[Epoch 0 Batch 70/168] loss=7.2820, ppl=1453.8531, gnorm=0.2687, throughput=25.32K wps, wc=39.54K
[Epoch 0 Batch 80/168] loss=7.2411, ppl=1395.6796, gnorm=0.2621, throughput=24.59K wps, wc=37.88K
[Epoch 0 Batch 90/168] loss=7.2312, ppl=1381.9313, gnorm=0.2400, throughput=22.62K wps, wc=34.78K
[Epoch 0 Batch 100/168] loss=7.2237, ppl=1371.5189, gnorm=0.2713, throughput=21.69K wps, wc=29.41K
[Epoch 0 Batch 1

  0%|          | 0/62 [00:00<?, ?it/s]

[Epoch 0] valid Loss=7.2806, valid ppl=1451.8367, valid bleu=0.00


  0%|          | 0/168 [00:00<?, ?it/s]

[Epoch 1 Batch 10/168] loss=7.4388, ppl=1700.6920, gnorm=0.6407, throughput=22.08K wps, wc=63.87K
[Epoch 1 Batch 20/168] loss=7.1792, ppl=1311.8772, gnorm=0.4267, throughput=28.18K wps, wc=61.03K
[Epoch 1 Batch 30/168] loss=7.1938, ppl=1331.1514, gnorm=0.3529, throughput=27.35K wps, wc=53.92K
[Epoch 1 Batch 40/168] loss=7.1710, ppl=1301.1846, gnorm=0.3638, throughput=26.87K wps, wc=51.33K
[Epoch 1 Batch 50/168] loss=7.2138, ppl=1358.0271, gnorm=0.2787, throughput=27.52K wps, wc=47.70K
[Epoch 1 Batch 60/168] loss=7.1922, ppl=1328.9878, gnorm=0.2636, throughput=25.05K wps, wc=42.88K
[Epoch 1 Batch 70/168] loss=7.1541, ppl=1279.3798, gnorm=0.2775, throughput=25.46K wps, wc=39.54K
[Epoch 1 Batch 80/168] loss=7.1264, ppl=1244.4135, gnorm=0.2826, throughput=25.02K wps, wc=37.88K
[Epoch 1 Batch 90/168] loss=7.1282, ppl=1246.6253, gnorm=0.2798, throughput=22.97K wps, wc=34.78K
[Epoch 1 Batch 100/168] loss=7.0853, ppl=1194.2458, gnorm=0.3322, throughput=21.77K wps, wc=29.41K
[Epoch 1 Batch 110/

  0%|          | 0/62 [00:00<?, ?it/s]

[Epoch 1] valid Loss=7.6062, valid ppl=2010.7079, valid bleu=0.00


  0%|          | 0/168 [00:00<?, ?it/s]

[Epoch 2 Batch 10/168] loss=7.4394, ppl=1701.7766, gnorm=0.5685, throughput=21.74K wps, wc=63.87K
[Epoch 2 Batch 20/168] loss=7.1701, ppl=1299.9511, gnorm=0.3847, throughput=28.02K wps, wc=61.03K
[Epoch 2 Batch 30/168] loss=7.2013, ppl=1341.2192, gnorm=0.3120, throughput=26.13K wps, wc=53.92K
[Epoch 2 Batch 40/168] loss=7.1982, ppl=1337.0268, gnorm=0.3102, throughput=27.05K wps, wc=51.33K
[Epoch 2 Batch 50/168] loss=7.1619, ppl=1289.3603, gnorm=0.2758, throughput=27.31K wps, wc=47.70K
[Epoch 2 Batch 60/168] loss=7.1761, ppl=1307.7469, gnorm=0.2501, throughput=25.10K wps, wc=42.88K
[Epoch 2 Batch 70/168] loss=7.1083, ppl=1222.1133, gnorm=0.2727, throughput=25.44K wps, wc=39.54K
[Epoch 2 Batch 80/168] loss=7.0957, ppl=1206.8085, gnorm=0.2598, throughput=25.05K wps, wc=37.88K
[Epoch 2 Batch 90/168] loss=7.0851, ppl=1194.0367, gnorm=0.2776, throughput=23.85K wps, wc=34.78K
[Epoch 2 Batch 100/168] loss=7.0661, ppl=1171.5677, gnorm=0.3305, throughput=20.38K wps, wc=29.41K
[Epoch 2 Batch 110/

  0%|          | 0/62 [00:00<?, ?it/s]

[Epoch 2] valid Loss=7.5910, valid ppl=1980.3742, valid bleu=0.00


  0%|          | 0/168 [00:00<?, ?it/s]

[Epoch 3 Batch 10/168] loss=7.4335, ppl=1691.7161, gnorm=0.7444, throughput=21.26K wps, wc=63.87K
[Epoch 3 Batch 20/168] loss=7.1750, ppl=1306.4083, gnorm=0.4588, throughput=25.31K wps, wc=61.03K
[Epoch 3 Batch 30/168] loss=7.1713, ppl=1301.5784, gnorm=0.3152, throughput=26.38K wps, wc=53.92K
[Epoch 3 Batch 40/168] loss=7.1452, ppl=1267.9969, gnorm=0.3600, throughput=26.07K wps, wc=51.33K
[Epoch 3 Batch 50/168] loss=7.1685, ppl=1297.9416, gnorm=0.3108, throughput=26.10K wps, wc=47.70K
[Epoch 3 Batch 60/168] loss=7.1235, ppl=1240.7333, gnorm=0.2645, throughput=24.37K wps, wc=42.88K
[Epoch 3 Batch 70/168] loss=7.0959, ppl=1206.9974, gnorm=0.2830, throughput=25.20K wps, wc=39.54K
[Epoch 3 Batch 80/168] loss=7.0751, ppl=1182.1475, gnorm=0.2838, throughput=23.13K wps, wc=37.88K
[Epoch 3 Batch 90/168] loss=7.0670, ppl=1172.6012, gnorm=0.2836, throughput=23.64K wps, wc=34.78K
[Epoch 3 Batch 100/168] loss=7.0468, ppl=1149.2247, gnorm=0.3261, throughput=20.30K wps, wc=29.41K
[Epoch 3 Batch 110/

  0%|          | 0/62 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
from importlib import reload

reload(nmt)
reload(nmt.utils)
reload(nmt.transformer_hparams)

hparams = nmt.transformer_hparams

In [None]:
print("Qualitative Evaluation: Translating from English to German:")

sample_src_seq = "I love reading technical books."
print("[\'" + sample_src_seq + "\']")

sample_tgt_seq = nmt.utils.translate(
    transformer_translator,
    sample_src_seq,
    wmt_src_vocab,
    wmt_tgt_vocab,
    wmt_detokenizer,
    ctx)

print("The German translation is:")
print(sample_tgt_seq)

In [None]:
print('Sample target sentence: "{}"'.format(wmt_test_tgt_sentences[16]))

In [None]:
len(transformer_test_translation_out)

In [None]:
len(wmt_test_tgt_sentences)

In [None]:
# For Evaluation
scorer=nlp.model.BeamSearchScorer(
    alpha=hparams.lp_alpha,
    K=hparams.lp_k)

transformer_translator = nmt.translation.BeamSearchTranslator(
    model=transformer_model,
    beam_size=hparams.beam_size,
    scorer=scorer,
    max_length=tgt_max_len + 100)

In [None]:
# Evaluation (Baseline)
eval_start_time = time.time()
wmt_loss_function = nlp.loss.MaskedSoftmaxCELoss()
wmt_loss_function.hybridize()
wmt_detokenizer = nlp.data.SacreMosesDetokenizer()

transformer_test_loss, transformer_test_translation_out = nmt.utils.evaluate(
    transformer_model,
    wmt_test_data_loader,
    wmt_loss_function,
    transformer_translator,
    wmt_tgt_vocab,
    wmt_detokenizer,
    ctx)

transformer_test_bleu_score, _, _, _, _ = nmt.bleu.compute_bleu(
    [wmt_test_tgt_sentences],
    transformer_test_translation_out,
    tokenized=False,
    tokenizer=hparams.bleu,
    split_compound_word=False,
    bpe=False)

print('WMT16 EN-DE Transformer model test loss: %.2f; test bleu score: %.2f; time cost %.2fs' %(transformer_test_loss, transformer_test_bleu_score * 100, (time.time() - eval_start_time)))