In [58]:
import mxnet as mx
import gluonnlp as nlp

import time
import random
import os
import io
import logging
import numpy as np
import sacremoses
from tqdm import tqdm

# Local Libraries
import nmt
import dataprocessor
import utils
import nmt.gnmt_hparams

# Seeds for reproducibility
np.random.seed(100)
random.seed(100)
mx.random.seed(10000)

ctx = mx.gpu(0)

In [2]:
# Dataset Parameters
dataset = 'WMT2014'
src_lang, tgt_lang = 'en', 'de'
# No limit on sentences length
src_max_len, tgt_max_len = -1, -1

In [3]:
# WMT2016 Dataset (Train and Evaluation)
wmt_train_text_bpe = nlp.data.WMT2016BPE("train", # BPE: cheapest --> cheap@@, est
                                         src_lang=src_lang,
                                         tgt_lang=tgt_lang)

wmt_train_text     = nlp.data.WMT2016("train",
                                      src_lang=src_lang,
                                      tgt_lang=tgt_lang)

wmt_test_text_bpe  = nlp.data.WMT2016BPE("newstest2016", # BPE: cheapest --> cheap@@, est
                                         src_lang=src_lang,
                                         tgt_lang=tgt_lang)

wmt_test_text      = nlp.data.WMT2016("newstest2016",
                                     src_lang=src_lang,
                                     tgt_lang=tgt_lang)

wmt_src_vocab = wmt_train_text_bpe.src_vocab
wmt_tgt_vocab = wmt_train_text_bpe.tgt_vocab

  'Detected a corrupted index in the deserialize vocabulary. '


In [4]:
# Processing datasets
wmt_train_text_bpe = mx.gluon.data.SimpleDataset(wmt_train_text_bpe)
wmt_train_text     = mx.gluon.data.SimpleDataset(wmt_train_text)
wmt_test_text_bpe  = mx.gluon.data.SimpleDataset(wmt_test_text_bpe)
wmt_test_text      = mx.gluon.data.SimpleDataset(wmt_test_text)

In [5]:
# Dataset example (human-readable): English and German
print(wmt_test_text[16][0])
print(wmt_test_text[16][1])

By the end of the day, there would be one more death: Lamb took his own life as police closed in on him.
Bis zum Ende des Tages gab es einen weiteren Tod: Lamm nahm sich das Leben, als die Polizei ihn einkesselte.


In [22]:
# Retrieve (split) translated sequences (target)
wmt_train_tgt_sentences = wmt_train_text.transform(lambda src, tgt: tgt)
wmt_test_tgt_sentences  = wmt_test_text.transform(lambda src, tgt: tgt)
print("Sample target sentence:")
print(wmt_test_tgt_sentences[16])

Sample target sentence:
Bis zum Ende des Tages gab es einen weiteren Tod: Lamm nahm sich das Leben, als die Polizei ihn einkesselte.


In [7]:
# Dataset processing: clipping, tokenizing, indexing and adding of EOS (src/tgt) / BOS (tgt)
wmt_transform_fn = dataprocessor.TrainValDataTransform(wmt_src_vocab, wmt_tgt_vocab)

wmt_train_processed = wmt_train_text_bpe.transform(wmt_transform_fn, lazy=False)
wmt_test_processed  = wmt_test_text_bpe.transform(wmt_transform_fn, lazy=False)

wmt_train_text_with_len = wmt_train_processed.transform(nmt.utils.get_length_index_fn(), lazy=False)
wmt_test_text_with_len  = wmt_test_processed.transform(nmt.utils.get_length_index_fn(), lazy=False)

print(wmt_test_text_with_len[16][0])
print(wmt_test_text_with_len[16][1])

[ 2083 28753 16760 23875 28753 15230    28 28783 31223 12931 24017 23247
 15259   569  5971 12813 29083 20097 24348 22312 12290 24829 14439 20585
 24004 20061    62     3]
[    2  1897 31601  3259 15535  9414 18646 17382 16407 30851  9629   569
  5971 22642 23439 27119 15199  6041    28 11681 15681  7670 20454 16394
 21488 26868 28535    62     3]


In [8]:
# Batcher
wmt_batchify_fn = nlp.data.batchify.Tuple(
    nlp.data.batchify.Pad(),                   # Source Token IDs
    nlp.data.batchify.Pad(),                   # Target Token IDs
    nlp.data.batchify.Stack(dtype='float32'),  # Source Sequence Length
    nlp.data.batchify.Stack(dtype='float32'),  # Target Sequence Length
    nlp.data.batchify.Stack())                 # Index

  'Padding value is not given and will be set automatically to 0 '


In [45]:
# Hyperparameters
hparams = nmt.gnmt_hparams

In [9]:
# Samplers
wmt_train_batch_sampler = nlp.data.FixedBucketSampler(
    lengths=wmt_train_text_with_len.transform(lambda src, tgt, src_len, tgt_len, idx: (src_len, tgt_len)),
    num_buckets=hparams.num_buckets,
    batch_size=hparams.batch_size)
print(wmt_train_batch_sampler.stats())

wmt_test_batch_sampler = nlp.data.FixedBucketSampler(
    lengths=wmt_test_text_with_len.transform(lambda src, tgt, src_len, tgt_len, idx: (src_len, tgt_len)),
    num_buckets=hparams.num_buckets,
    batch_size=hparams.test_batch_size)
print(wmt_test_batch_sampler.stats())

FixedBucketSampler:
  sample_num=4500966, batch_num=2250484
  key=[(163, 115), (324, 227), (485, 339)]
  cnt=[4497590, 3303, 73]
  batch_size=[2, 2, 2]
FixedBucketSampler:
  sample_num=2999, batch_num=1501
  key=[(37, 42), (70, 78), (103, 114)]
  cnt=[2353, 619, 27]
  batch_size=[2, 2, 2]


In [10]:
# DataLoaders
wmt_train_data_loader = mx.gluon.data.DataLoader(
    wmt_train_text_with_len,
    batch_sampler=wmt_train_batch_sampler,
    batchify_fn=wmt_batchify_fn,
    num_workers=8)
print('Number of training batches:', len(wmt_train_data_loader))

wmt_test_data_loader = mx.gluon.data.DataLoader(
    wmt_test_text_with_len,
    batch_sampler=wmt_test_batch_sampler,
    batchify_fn=wmt_batchify_fn,
    num_workers=8)
print('Number of testing batches:', len(wmt_test_data_loader))

Number of training batches: 2250484
Number of testing batches: 1501


In [12]:
# Model
encoder, decoder, one_step_ahead_decoder = nmt.gnmt.get_gnmt_encoder_decoder(
    hidden_size=hparams.num_hidden,
    dropout=hparams.dropout,
    num_layers=hparams.num_layers,
    num_bi_layers=hparams.num_bi_layers)

gnmt_model = nlp.model.translation.NMTModel(
    src_vocab=wmt_src_vocab,
    tgt_vocab=wmt_tgt_vocab,
    encoder=encoder,
    decoder=decoder,
    one_step_ahead_decoder=one_step_ahead_decoder,
    embed_size=hparams.num_hidden,
    prefix='gnmt_')

gnmt_model.initialize(init=mx.init.Uniform(0.1), ctx=ctx)
static_alloc = True
gnmt_model.hybridize(static_alloc=static_alloc)

In [13]:
scorer=nlp.model.BeamSearchScorer(
    alpha=hparams.lp_alpha,
    K=hparams.lp_k)

gnmt_translator = nmt.translation.BeamSearchTranslator(
    model=gnmt_model,
    beam_size=hparams.beam_size,
    scorer=scorer,
    max_length=tgt_max_len + 100)

print("Use beam_size={}, alpha={}, K={}".format(hparams.beam_size, hparams.lp_alpha, hparams.lp_k))

Use beam_size=10, alpha=1.0, K=5


In [23]:
# Evaluation
eval_start_time = time.time()
wmt_loss_function = nlp.loss.MaskedSoftmaxCELoss()
wmt_loss_function.hybridize()
wmt_detokenizer = nlp.data.SacreMosesDetokenizer()

gnmt_test_loss, gnmt_test_translation_out = nmt.utils.evaluate(
    gnmt_model,
    wmt_test_data_loader,
    wmt_test_loss_function,
    gnmt_translator,
    wmt_tgt_vocab,
    wmt_detokenizer,
    ctx)

gnmt_test_bleu_score, _, _, _, _ = nmt.bleu.compute_bleu(
    [wmt_test_tgt_sentences],
    gnmt_test_translation_out,
    tokenized=False,
    tokenizer=hparams.bleu,
    split_compound_word=False,
    bpe=False)

print('WMT16 EN-DE SOTA model test loss: %.2f; test bleu score: %.2f; time cost %.2fs' %(gnmt_test_loss, gnmt_test_bleu_score * 100, (time.time() - eval_start_time)))

100%|██████████| 1501/1501 [11:40<00:00,  2.14it/s]


WMT16 EN-DE SOTA model test loss: 9.06; test bleu score: 0.00; time cost 706.79s


In [50]:
nmt.gnmt_hparams.beam_size

10

In [71]:
# Training
trainer = mx.gluon.Trainer(gnmt_model.collect_params(), 'adam', {'learning_rate': hparams.lr})

In [87]:
nmt.utils.train(
    gnmt_model,
    wmt_train_data_loader,
    wmt_test_data_loader,
    wmt_loss_function,
    trainer,
    gnmt_translator,
    wmt_tgt_vocab,
    wmt_detokenizer,
    hparams.save_dir,
    hparams,
    ctx)

  0%|          | 0/1 [00:00<?, ?it/s]
  0%|          | 0/2250484 [00:00<?, ?it/s][A
  0%|          | 1/2250484 [00:00<575:43:48,  1.09it/s][A
  0%|          | 2/2250484 [00:01<557:02:56,  1.12it/s][A
  0%|          | 3/2250484 [00:02<544:16:26,  1.15it/s][A
  0%|          | 4/2250484 [00:03<616:59:25,  1.01it/s][A
  0%|          | 5/2250484 [00:04<589:11:21,  1.06it/s][A
  0%|          | 6/2250484 [00:05<576:37:46,  1.08it/s][A
  0%|          | 7/2250484 [00:06<581:55:56,  1.07it/s][A
  0%|          | 8/2250484 [00:07<558:54:41,  1.12it/s][A
  0%|          | 9/2250484 [00:08<569:11:28,  1.10it/s][A
  0%|          | 10/2250484 [00:09<728:36:31,  1.17s/it][A

[Epoch 0 Batch 10/2250484] loss=6.4054, ppl=605.0831, gnorm=3.6659, throughput=0.79K wps, wc=7.86K



  0%|          | 11/2250484 [00:11<706:14:03,  1.13s/it][A
  0%|          | 12/2250484 [00:11<641:25:10,  1.03s/it][A
  0%|          | 13/2250484 [00:12<627:00:11,  1.00s/it][A
  0%|          | 14/2250484 [00:13<666:49:31,  1.07s/it][A
  0%|          | 15/2250484 [00:14<571:59:33,  1.09it/s][A
  0%|          | 16/2250484 [00:15<565:26:34,  1.11it/s][A
  0%|          | 17/2250484 [00:16<516:20:33,  1.21it/s][A
  0%|          | 18/2250484 [00:16<502:13:12,  1.24it/s][A
  0%|          | 19/2250484 [00:18<656:43:13,  1.05s/it][A
  0%|          | 20/2250484 [00:19<647:04:45,  1.04s/it][A

[Epoch 0 Batch 20/2250484] loss=9.5011, ppl=13375.0654, gnorm=2.5180, throughput=0.81K wps, wc=7.71K



  0%|          | 21/2250484 [00:20<634:29:45,  1.01s/it][A
  0%|          | 22/2250484 [00:21<576:49:35,  1.08it/s][A
  0%|          | 23/2250484 [00:21<548:19:09,  1.14it/s][A
  0%|          | 24/2250484 [00:22<566:49:05,  1.10it/s][A
  0%|          | 25/2250484 [00:23<504:04:58,  1.24it/s][A
  0%|          | 26/2250484 [00:24<490:05:28,  1.28it/s][A
  0%|          | 27/2250484 [00:24<459:24:13,  1.36it/s][A
  0%|          | 28/2250484 [00:25<442:25:32,  1.41it/s][A
  0%|          | 29/2250484 [00:26<436:56:30,  1.43it/s][A
  0%|          | 30/2250484 [00:26<465:20:26,  1.34it/s][A

[Epoch 0 Batch 30/2250484] loss=7.1100, ppl=1224.1965, gnorm=1.7908, throughput=0.93K wps, wc=7.03K



  0%|          | 31/2250484 [00:28<546:15:55,  1.14it/s][A
  0%|          | 32/2250484 [00:28<507:22:06,  1.23it/s][A
  0%|          | 33/2250484 [00:30<596:23:22,  1.05it/s][A
  0%|          | 34/2250484 [00:31<587:52:41,  1.06it/s][A
  0%|          | 35/2250484 [00:31<538:50:01,  1.16it/s][A
  0%|          | 36/2250484 [00:34<898:49:12,  1.44s/it][A
  0%|          | 37/2250484 [00:35<842:09:10,  1.35s/it][A
  0%|          | 38/2250484 [00:36<697:56:01,  1.12s/it][A
  0%|          | 39/2250484 [00:36<587:43:24,  1.06it/s][A
  0%|          | 40/2250484 [00:37<495:15:36,  1.26it/s][A

[Epoch 0 Batch 40/2250484] loss=8.4313, ppl=4588.3653, gnorm=1.8762, throughput=0.67K wps, wc=6.87K



  0%|          | 41/2250484 [00:37<458:12:51,  1.36it/s][A
  0%|          | 42/2250484 [00:38<407:40:15,  1.53it/s][A
  0%|          | 43/2250484 [00:38<369:38:03,  1.69it/s][A
  0%|          | 44/2250484 [00:39<344:43:15,  1.81it/s][A
  0%|          | 45/2250484 [00:39<350:44:33,  1.78it/s][A
  0%|          | 46/2250484 [00:40<354:40:23,  1.76it/s][A
  0%|          | 47/2250484 [00:40<334:42:22,  1.87it/s][A
  0%|          | 48/2250484 [00:41<320:51:15,  1.95it/s][A
  0%|          | 49/2250484 [00:41<299:39:40,  2.09it/s][A
  0%|          | 50/2250484 [00:41<277:30:04,  2.25it/s][A

[Epoch 0 Batch 50/2250484] loss=9.6225, ppl=15101.1743, gnorm=1.5339, throughput=0.93K wps, wc=4.47K



  0%|          | 51/2250484 [00:42<330:36:08,  1.89it/s][A
  0%|          | 52/2250484 [00:43<289:55:08,  2.16it/s][A
  0%|          | 53/2250484 [00:43<290:13:50,  2.15it/s][A
  0%|          | 54/2250484 [00:44<300:15:27,  2.08it/s][A
  0%|          | 55/2250484 [00:44<302:11:35,  2.07it/s][A
  0%|          | 56/2250484 [00:45<325:15:21,  1.92it/s][A
  0%|          | 57/2250484 [00:45<315:23:44,  1.98it/s][A
  0%|          | 58/2250484 [00:45<287:41:21,  2.17it/s][A
  0%|          | 59/2250484 [00:46<285:08:20,  2.19it/s][A
  0%|          | 60/2250484 [00:47<323:48:33,  1.93it/s][A

[Epoch 0 Batch 60/2250484] loss=9.0639, ppl=8637.4898, gnorm=1.3245, throughput=0.82K wps, wc=4.15K



  0%|          | 61/2250484 [00:47<321:18:44,  1.95it/s][A
  0%|          | 62/2250484 [00:48<325:26:37,  1.92it/s][A
  0%|          | 63/2250484 [00:48<336:50:05,  1.86it/s][A
  0%|          | 64/2250484 [00:49<343:41:23,  1.82it/s][A
  0%|          | 65/2250484 [00:49<323:31:35,  1.93it/s][A
  0%|          | 66/2250484 [00:50<291:57:44,  2.14it/s][A
  0%|          | 67/2250484 [00:50<295:26:41,  2.12it/s][A
  0%|          | 68/2250484 [00:50<276:40:01,  2.26it/s][A
  0%|          | 69/2250484 [00:51<272:31:05,  2.29it/s][A
  0%|          | 70/2250484 [00:51<283:12:32,  2.21it/s][A

[Epoch 0 Batch 70/2250484] loss=8.9813, ppl=7953.0270, gnorm=1.1442, throughput=0.88K wps, wc=4.19K



  0%|          | 71/2250484 [00:52<270:00:17,  2.32it/s][A
  0%|          | 72/2250484 [00:52<262:42:49,  2.38it/s][A
  0%|          | 73/2250484 [00:52<252:04:32,  2.48it/s][A
  0%|          | 74/2250484 [00:53<259:30:53,  2.41it/s][A
  0%|          | 75/2250484 [00:53<252:32:30,  2.48it/s][A
  0%|          | 76/2250484 [00:54<276:04:52,  2.26it/s][A
  0%|          | 77/2250484 [00:54<271:37:10,  2.30it/s][A
  0%|          | 78/2250484 [00:55<381:42:05,  1.64it/s][A
  0%|          | 79/2250484 [00:56<369:05:15,  1.69it/s][A
  0%|          | 80/2250484 [00:56<330:56:13,  1.89it/s][A

[Epoch 0 Batch 80/2250484] loss=8.6340, ppl=5619.4209, gnorm=1.0450, throughput=0.86K wps, wc=4.16K



  0%|          | 81/2250484 [00:57<321:34:04,  1.94it/s][A
  0%|          | 82/2250484 [00:57<374:42:31,  1.67it/s][A
  0%|          | 83/2250484 [00:58<347:48:22,  1.80it/s][A
  0%|          | 84/2250484 [00:58<357:28:23,  1.75it/s][A
  0%|          | 85/2250484 [00:59<373:27:36,  1.67it/s][A
  0%|          | 86/2250484 [01:00<380:25:26,  1.64it/s][A
  0%|          | 87/2250484 [01:00<337:34:09,  1.85it/s][A
  0%|          | 88/2250484 [01:01<402:01:40,  1.55it/s][A
  0%|          | 89/2250484 [01:02<395:25:04,  1.58it/s][A
  0%|          | 90/2250484 [01:02<355:34:35,  1.76it/s][A

[Epoch 0 Batch 90/2250484] loss=8.1087, ppl=3323.3713, gnorm=1.0961, throughput=0.87K wps, wc=5.14K



  0%|          | 91/2250484 [01:03<339:24:55,  1.84it/s][A
  0%|          | 92/2250484 [01:03<324:01:00,  1.93it/s][A
  0%|          | 93/2250484 [01:03<302:38:00,  2.07it/s][A
  0%|          | 94/2250484 [01:04<293:21:08,  2.13it/s][A
  0%|          | 95/2250484 [01:04<277:48:16,  2.25it/s][A
  0%|          | 96/2250484 [01:05<260:51:42,  2.40it/s][A
  0%|          | 97/2250484 [01:05<298:33:22,  2.09it/s][A
  0%|          | 98/2250484 [01:06<308:34:28,  2.03it/s][A
  0%|          | 99/2250484 [01:06<283:26:16,  2.21it/s][A
  0%|          | 100/2250484 [01:07<293:39:24,  2.13it/s][A

[Epoch 0 Batch 100/2250484] loss=8.8250, ppl=6802.4975, gnorm=0.9962, throughput=0.87K wps, wc=3.96K



  0%|          | 101/2250484 [01:07<267:01:16,  2.34it/s][A
  0%|          | 102/2250484 [01:07<251:04:05,  2.49it/s][A
  0%|          | 103/2250484 [01:08<292:04:34,  2.14it/s][A
  0%|          | 104/2250484 [01:08<302:29:48,  2.07it/s][A
  0%|          | 105/2250484 [01:09<301:54:10,  2.07it/s][A
  0%|          | 106/2250484 [01:09<296:44:06,  2.11it/s][A
  0%|          | 107/2250484 [01:10<295:13:55,  2.12it/s][A
  0%|          | 108/2250484 [01:10<299:54:27,  2.08it/s][A
  0%|          | 109/2250484 [01:11<308:36:53,  2.03it/s][A
  0%|          | 110/2250484 [01:12<336:01:04,  1.86it/s][A

[Epoch 0 Batch 110/2250484] loss=8.7276, ppl=6171.1156, gnorm=0.8806, throughput=0.85K wps, wc=4.16K



  0%|          | 111/2250484 [01:12<315:05:37,  1.98it/s][A
  0%|          | 112/2250484 [01:12<298:45:43,  2.09it/s][A
  0%|          | 113/2250484 [01:14<480:20:34,  1.30it/s][A
  0%|          | 114/2250484 [01:14<404:59:42,  1.54it/s][A
  0%|          | 115/2250484 [01:14<340:11:38,  1.84it/s][A
  0%|          | 116/2250484 [01:15<324:57:08,  1.92it/s][A
  0%|          | 117/2250484 [01:15<295:25:35,  2.12it/s][A
  0%|          | 118/2250484 [01:16<272:06:23,  2.30it/s][A
  0%|          | 119/2250484 [01:16<268:02:35,  2.33it/s][A
  0%|          | 120/2250484 [01:17<273:27:04,  2.29it/s][A

[Epoch 0 Batch 120/2250484] loss=8.5823, ppl=5336.3750, gnorm=1.0365, throughput=0.73K wps, wc=3.64K



  0%|          | 121/2250484 [01:17<402:44:33,  1.55it/s][A
  0%|          | 0/1 [01:17<?, ?it/s]


KeyboardInterrupt: 

  0%|          | 0/2250484 [00:00<?, ?it/s]

5
(
[[  617 10790   638  9378  6311  9368  1430 10790 10794  7187   617 10790
   7168  9384  3038  7337  5929 10791  3038  7182   616   589  2152  9809
   9368  9393   589  2998   587  5934  9390 10794  9390  2557   615   599
   2986 10408  7188  5927  2553  7173 10409  2117  7939  3595  7187  9809
   2570  6303  2986  9385  5935  4098  5023  4993   640  7168  7350   641
   7173 10408  7188  5927  2553  7173 10409  2117  7939  3595  5014   574
   9987  7177   641  7173  2162  4993  4998   606   618  2141  9987  5924
   5929 10419   618    28  8356  4630  9804  2553  8356  9801  4614   616
    638  9378  6311  9368  1438  6285   595  2986  7187   630  8356  5020
    617  2134  3585  8356  9801  4614   616   638  9374  5432  1438  7373
   9981  4975  3009    28  1445  4960 10815 10815   624  2553  3031  9383
   9373  4998  6285  3033  7945  3039  9981  3041  9385  7937  4981  4637
   9390  8367  2998  5432  2557   615   599  3041  3574  7951   617 10790
   8356  9801  4614  9825  3038  9




NameError: name 'src_seq' is not defined

In [85]:
reload(nmt)
reload(nmt.utils)
reload(nmt.gnmt_hparams)

hparams = nmt.gnmt_hparams

In [86]:
wmt_loss_function = nlp.loss.MaskedSoftmaxCELoss()
wmt_loss_function.hybridize()
wmt_detokenizer = nlp.data.SacreMosesDetokenizer()