In [1]:
import mxnet as mx
import gluonnlp as nlp

import matplotlib.pyplot as plt
import numpy as np
import os
import random
import sacremoses
import time
from tqdm.notebook import tqdm
import io
from importlib import reload

# Local Libraries
import nmt
import dataprocessor
import utils
import nmt.transformer_hparams
import transformer_model

# Hyperparameters for Dataloaders and Training
hparams = nmt.transformer_hparams

# Seeds for reproducibility
np.random.seed(100)
random.seed(100)
mx.random.seed(100)

# CPU setup
# ctx = mx.cpu()
# Single GPU setup
ctx = mx.gpu(0)

In [2]:
# WMT2016 Dataset
# Train/Validation: Pre-processed
# Test: Raw (for evaluation)

# Dataset Parameters
# src_lang, tgt_lang = "de", "en"
src_lang, tgt_lang = "en", "de"
src_max_len, tgt_max_len = 50, 50

# wmt_train_text = nlp.data.WMT2016(
#     "train",
#     src_lang=src_lang,
#     tgt_lang=tgt_lang)

# wmt_test_text  = nlp.data.WMT2016(
#     "newstest2016",
#     src_lang=src_lang,
#     tgt_lang=tgt_lang)

wmt_train_text = nlp.data.WMT2016BPE(
    "train",
    src_lang=src_lang,
    tgt_lang=tgt_lang)

wmt_test_text  = nlp.data.WMT2016BPE(
    "newstest2016",
    src_lang=src_lang,
    tgt_lang=tgt_lang)

In [3]:
print("Length of train set:", len(wmt_train_text))
# print("Length of val set  :", len(wmt_val_text)) XXXX NOT CREATED YET
print("Length of test set :", len(wmt_test_text))

Length of train set: 4500966
Length of test set : 2999


In [4]:
# wmt_val_text = nlp.data.WMT2016(
#     "train",
#     src_lang=src_lang,
#     tgt_lang=tgt_lang)

wmt_val_text = nlp.data.WMT2016BPE(
    "train",
    src_lang=src_lang,
    tgt_lang=tgt_lang)

In [5]:
# Validation dataset generation (from training dataset)
val_length = 3000

wmt_val_text._data[0] = wmt_train_text._data[0][-val_length:]
wmt_val_text._data[1] = wmt_train_text._data[1][-val_length:]
wmt_val_text._length = val_length

# Modify Training dataset to remove validation dataset
# Mini Training set
train_length = int(100e3)
wmt_train_text._data[0] = wmt_train_text._data[0][:train_length]
wmt_train_text._data[1] = wmt_train_text._data[1][:train_length]
wmt_train_text._length = train_length

In [6]:
print("Length of train set:", len(wmt_train_text))
print("Length of val set  :", len(wmt_val_text))
print("Length of test set :", len(wmt_test_text))

Length of train set: 100000
Length of val set  : 3000
Length of test set : 2999


In [8]:
wmt2016_src_vocab, wmt2016_tgt_vocab = wmt_train_text.src_vocab, wmt_train_text.tgt_vocab
wmt2016_src_vocab, wmt2016_tgt_vocab

(Vocab(size=36548, unk="<unk>", reserved="['<eos>', '<bos>']"),
 Vocab(size=36548, unk="<unk>", reserved="['<eos>', '<bos>']"))

In [9]:
# Target Sequences (Val, Test)
fetch_tgt_sentence = lambda src, tgt: tgt.split()
val_tgt_sentences = list(wmt_val_text.transform(fetch_tgt_sentence))
test_tgt_sentences = list(wmt_test_text.transform(fetch_tgt_sentence))

In [11]:
# Dataset processing: clipping, tokenizing, indexing and adding of EOS (src/tgt) / BOS (tgt)
wmt_train_processed = wmt_train_text.transform(
    dataprocessor.TrainValDataTransform(
        wmt2016_src_vocab,
        wmt2016_tgt_vocab,
        src_max_len,
        tgt_max_len),
    lazy=False)

wmt_val_processed   = wmt_val_text.transform(
    dataprocessor.TrainValDataTransform(
        wmt2016_src_vocab,
        wmt2016_tgt_vocab,
        src_max_len,
        tgt_max_len),
    lazy=False)

wmt_test_processed  = wmt_test_text.transform(
    dataprocessor.TrainValDataTransform(
        wmt2016_src_vocab,
        wmt2016_tgt_vocab,
        src_max_len,
        tgt_max_len),
    lazy=False)

In [13]:
# # Create Gluon Datasets
# # Not needed for training, as training data will be sharded later
# wmt_train_transformed = wmt_train_processed.transform(
#     lambda src, tgt: (src, tgt, len(src), len(tgt)),
#     lazy=False)

# wmt_val_dataset = mx.gluon.data.SimpleDataset(
#     [(ele[0], ele[1], len(ele[0]), len(ele[1]),i) for i, ele in enumerate(wmt_val_processed)])

# wmt_test_dataset = mx.gluon.data.SimpleDataset(
#     [(ele[0], ele[1], len(ele[0]), len(ele[1]), i) for i, ele in enumerate(wmt_test_processed)])

def get_length_index_fn():
    global idx
    idx = 0
    def transform(src, tgt):
        global idx
        result = (src, tgt, len(src), len(tgt), idx)
        idx += 1
        return result
    return transform

wmt_data_test_with_len = wmt_test_processed.transform(get_length_index_fn(), lazy=False)

In [14]:
# Updates for this recipe
hparams.num_hidden = 512
hparams.num_layers = 4
hparams.dropout = 0.2
hparams.num_buckets = 5
hparams.lr = 0.001
#hparams.lr = 0.0001
#hparams.lr = 0.0003 achieves 21.44 test_bleu: qualitative evaluation didn't work
#hparams.lr = 0.0001 achieves 19.66 test_bleu: qualitative evaluation worked
hparams.clip = 5
hparams.epochs = 12
hparams.beam_size = 10
hparams.lp_alpha = 1.0
hparams.lp_k = 5

hparams.max_length = 150
hparams.batch_size = 256

In [15]:
# Create Gluon Samplers and DataLoaders

# Helper function for lengths
def get_data_lengths(dataset):
    get_lengths = lambda *args: (args[2], args[3])
    return list(dataset.transform(get_lengths))

# Bucket scheme
bucket_scheme = nlp.data.ExpWidthBucket(bucket_len_step=1.2)

wmt_train_lengths = get_data_lengths(wmt_train_transformed)
wmt_val_lengths = get_data_lengths(wmt_val_dataset)
wmt_test_lengths = get_data_lengths(wmt_test_dataset)

train_batchify_fn = nlp.data.batchify.Tuple(
    nlp.data.batchify.Pad(),
    nlp.data.batchify.Pad(),
    nlp.data.batchify.Stack(dtype='float32'),
    nlp.data.batchify.Stack(dtype='float32'))

test_batchify_fn = nlp.data.batchify.Tuple(
    nlp.data.batchify.Pad(),
    nlp.data.batchify.Pad(),
    nlp.data.batchify.Stack(dtype='float32'),
    nlp.data.batchify.Stack(dtype='float32'),
    nlp.data.batchify.Stack())

target_val_lengths = list(map(lambda x: x[-1], wmt_val_lengths))
target_test_lengths = list(map(lambda x: x[-1], wmt_test_lengths))

train_batch_sampler = nlp.data.FixedBucketSampler(
    lengths=wmt_train_lengths,
    batch_size=hparams.batch_size,
    num_buckets=hparams.num_buckets,
    ratio=0,
    shuffle=True,
    use_average_length=False,
    num_shards=0,
    bucket_scheme=bucket_scheme)
    
train_data_loader = nlp.data.ShardedDataLoader(
    wmt_train_transformed,
    batch_sampler=train_batch_sampler,
    batchify_fn=train_batchify_fn,
    num_workers=8)

val_batch_sampler = nlp.data.FixedBucketSampler(
    lengths=wmt_val_lengths,
    batch_size=hparams.test_batch_size,
    num_buckets=hparams.num_buckets,
    ratio=0,
    shuffle=False,
    use_average_length=False,
    bucket_scheme=bucket_scheme)

val_data_loader = mx.gluon.data.DataLoader(
    wmt_val_dataset,
    batch_sampler=val_batch_sampler,
    batchify_fn=test_batchify_fn,
    num_workers=8)

test_batch_sampler = nlp.data.FixedBucketSampler(
    lengths=wmt_test_lengths,
    batch_size=hparams.test_batch_size,
    num_buckets=hparams.num_buckets,
    ratio=0,
    shuffle=False,
    use_average_length=False,
    bucket_scheme=bucket_scheme)

test_data_loader = mx.gluon.data.DataLoader(
    wmt_test_dataset,
    batch_sampler=test_batch_sampler,
    batchify_fn=test_batchify_fn,
    num_workers=8)

NameError: name 'wmt_train_transformed' is not defined

## Training from scratch

In [14]:
# Transformer Model
transformer_encoder, transformer_decoder, transformer_one_step_ahead_decoder = nlp.model.transformer.get_transformer_encoder_decoder(
    hidden_size=hparams.num_hidden,
    dropout=hparams.dropout,
    num_layers=hparams.num_layers)

transformer_model_ts = nlp.model.translation.NMTModel(
    src_vocab=wmt_src_vocab,
    tgt_vocab=wmt_tgt_vocab,
    encoder=transformer_encoder,
    decoder=transformer_decoder,
    one_step_ahead_decoder=transformer_one_step_ahead_decoder,
    #embed_size=hparams.num_hidden,
    embed_size=hparams.num_units,
    prefix='transformer_')

transformer_model_ts.initialize(init=mx.init.Xavier(magnitude=1.0), ctx=ctx)
transformer_model_ts.hybridize(static_alloc=True)



In [15]:
# Translator (using model defined above)
transformer_ts_translator = nmt.translation.BeamSearchTranslator(
    model=transformer_model_ts,
    beam_size=hparams.beam_size,
    scorer=nlp.model.BeamSearchScorer(
        alpha=hparams.lp_alpha,
        K=hparams.lp_k),
    # max_length=150)
    max_length=50)

In [16]:
# Loss function
loss_function = nlp.loss.MaskedSoftmaxCELoss()
loss_function.hybridize(static_alloc=True)

In [17]:
# Let's train
trainer = mx.gluon.Trainer(transformer_model_ts.collect_params(), hparams.optimizer, {'learning_rate': hparams.lr})

best_valid_bleu = 0.0

train_losses = []
valid_losses = []
valid_bleus  = []
valid_perplexities = []

for epoch_id in tqdm(range(hparams.epochs)):

    log_loss = 0
    log_denom = 0
    log_avg_gnorm = 0
    log_wc = 0
    log_start_time = time.time()
    
    for batch_id, (src_seq, tgt_seq, src_valid_length, tgt_valid_length) in enumerate(tqdm(train_data_loader)):
        
        src_seq = src_seq.as_in_context(ctx)
        tgt_seq = tgt_seq.as_in_context(ctx)
        src_valid_length = src_valid_length.as_in_context(ctx)
        tgt_valid_length = tgt_valid_length.as_in_context(ctx)
        
        with mx.autograd.record():
            out, _ = transformer_model(
                src_seq,
                tgt_seq[:, :-1],
                src_valid_length,
                tgt_valid_length - 1)

            loss = loss_function(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean()
            loss = loss * (tgt_seq.shape[1] - 1)
            log_loss += loss * tgt_seq.shape[0]
            log_denom += (tgt_valid_length - 1).sum()
            loss = loss / (tgt_valid_length - 1).mean()
            loss.backward()
        
        grads = [p.grad(ctx) for p in transformer_model.collect_params().values() if p.grad_req != 'null']
        gnorm = mx.gluon.utils.clip_global_norm(grads, hparams.clip)
        trainer.step(1)
        
        src_wc = src_valid_length.sum().asscalar()
        tgt_wc = (tgt_valid_length - 1).sum().asscalar()
        log_loss = log_loss.asscalar()
        log_denom = log_denom.asscalar()
        log_avg_gnorm += gnorm
        log_wc += src_wc + tgt_wc
        
        train_loss = log_loss / log_denom
        
        if (batch_id + 1) % hparams.log_interval == 0:
            wps = log_wc / (time.time() - log_start_time)
            print("[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, gnorm={:.4f}, "
                         "throughput={:.2f}K wps, wc={:.2f}K"
                         .format(epoch_id, batch_id + 1, len(train_data_loader),
                                 train_loss,
                                 np.exp(log_loss / log_denom),
                                 log_avg_gnorm / hparams.log_interval,
                                 wps / 1000, log_wc / 1000))
            
            log_start_time = time.time()
            log_loss = 0
            log_denom = 0
            log_avg_gnorm = 0
            log_wc = 0
            
    train_losses.append(train_loss)
    
    valid_loss, valid_translation_out = nmt.utils.evaluate(
        val_data_loader,
        transformer_model,
        transformer_translator,
        loss_function,
        wmt_tgt_vocab,
        ctx)

    valid_perplexity = np.exp(valid_loss)
    valid_perplexities.append(valid_perplexity)
    
    valid_bleu_score, _, _, _, _ = nmt.bleu.compute_bleu([val_tgt_sentences], valid_translation_out)
    print("[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}"
          .format(epoch_id, valid_loss, valid_perplexity, valid_bleu_score * 100))
    
    valid_losses.append(valid_loss)
    valid_bleus.append(valid_bleu_score * 100)

    if valid_bleu_score > best_valid_bleu:
        best_valid_bleu = valid_bleu_score
        print("Save best parameters to {}".format(hparams.file_name))
        transformer_model.save_parameters(hparams.file_name)
    
    if epoch_id + 1 >= (hparams.epochs * 2) // 3:
        new_lr = trainer.learning_rate * hparams.lr_update_factor
        print("Learning rate change to {}".format(new_lr))
        trainer.set_learning_rate(new_lr)
        
    print("Qualitative Evaluation: Translating from Vietnamese to English")

    expected_tgt_seq = "I like to read books."
    print("Expected translation:")
    print(expected_tgt_seq)
    # From Google Translate
    src_seq = "Tôi thích đọc sách kỹ thuật."
    print("In Vietnamese (from Google Translate):")
    print(src_seq)

    translation_out = nmt.utils.translate_with_unk(
    # translation_out = nmt.utils.translate(
        transformer_translator,
        src_seq,
        wmt_src_vocab,
        wmt_tgt_vocab,
        ctx)

    print("The English translation is:")
    print(" ".join(translation_out[0]))

if os.path.exists(hparams.file_name):
    transformer_model.load_parameters(hparams.file_name)

valid_loss, valid_translation_out = nmt.utils.evaluate(
    val_data_loader,
    transformer_model,
    transformer_translator,
    loss_function,
    wmt_tgt_vocab,
    ctx)

valid_bleu_score, _, _, _, _ = nmt.bleu.compute_bleu([val_tgt_sentences], valid_translation_out)
print("Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}"
      .format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100))

test_loss, test_translation_out = nmt.utils.evaluate(
    test_data_loader,
    transformer_model,
    transformer_translator,
    loss_function,
    wmt_tgt_vocab,
    ctx)

test_bleu_score, _, _, _, _ = nmt.bleu.compute_bleu([test_tgt_sentences], test_translation_out)
print("Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'"
      .format(test_loss, np.exp(test_loss), test_bleu_score * 100))

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/785 [00:00<?, ?it/s]

[Epoch 0 Batch 100/785] loss=7.2770, ppl=1446.6195, gnorm=1.0628, throughput=39.78K wps, wc=640.02K
[Epoch 0 Batch 200/785] loss=6.9850, ppl=1080.2670, gnorm=0.8397, throughput=44.56K wps, wc=682.64K
[Epoch 0 Batch 300/785] loss=6.9788, ppl=1073.5958, gnorm=0.7445, throughput=42.69K wps, wc=601.62K
[Epoch 0 Batch 400/785] loss=6.9852, ppl=1080.5719, gnorm=0.7487, throughput=43.33K wps, wc=622.22K
[Epoch 0 Batch 500/785] loss=6.9809, ppl=1075.8973, gnorm=0.6450, throughput=43.51K wps, wc=615.55K
[Epoch 0 Batch 600/785] loss=6.9887, ppl=1084.2783, gnorm=0.6042, throughput=41.96K wps, wc=572.96K
[Epoch 0 Batch 700/785] loss=6.9689, ppl=1063.0504, gnorm=0.4530, throughput=43.56K wps, wc=644.42K


  0%|          | 0/50 [00:00<?, ?it/s]

[Epoch 0] valid Loss=8.8308, valid ppl=6842.0041, valid bleu=0.00
Qualitative Evaluation: Translating from Vietnamese to English
Expected translation:
I like to read books.
In Vietnamese (from Google Translate):
Tôi thích đọc sách kỹ thuật.
The English translation is:



  0%|          | 0/785 [00:00<?, ?it/s]

[Epoch 1 Batch 100/785] loss=6.9312, ppl=1023.6773, gnorm=0.4437, throughput=40.27K wps, wc=627.10K
[Epoch 1 Batch 200/785] loss=6.9542, ppl=1047.5575, gnorm=0.5122, throughput=40.71K wps, wc=603.48K
[Epoch 1 Batch 300/785] loss=6.9507, ppl=1043.8606, gnorm=0.4681, throughput=42.08K wps, wc=646.65K
[Epoch 1 Batch 400/785] loss=6.9542, ppl=1047.4990, gnorm=0.4005, throughput=41.86K wps, wc=651.48K
[Epoch 1 Batch 500/785] loss=6.9468, ppl=1039.8311, gnorm=0.4278, throughput=41.91K wps, wc=653.35K
[Epoch 1 Batch 600/785] loss=6.9363, ppl=1028.9366, gnorm=0.3741, throughput=41.50K wps, wc=636.45K
[Epoch 1 Batch 700/785] loss=6.9520, ppl=1045.1875, gnorm=0.3871, throughput=40.51K wps, wc=596.07K


  0%|          | 0/50 [00:00<?, ?it/s]

[Epoch 1] valid Loss=8.8647, valid ppl=7077.7704, valid bleu=0.00
Qualitative Evaluation: Translating from Vietnamese to English
Expected translation:
I like to read books.
In Vietnamese (from Google Translate):
Tôi thích đọc sách kỹ thuật.
The English translation is:



  0%|          | 0/785 [00:00<?, ?it/s]

[Epoch 2 Batch 100/785] loss=6.9357, ppl=1028.3254, gnorm=0.4197, throughput=40.47K wps, wc=608.10K
[Epoch 2 Batch 200/785] loss=6.9451, ppl=1038.0244, gnorm=0.4487, throughput=41.61K wps, wc=636.38K
[Epoch 2 Batch 300/785] loss=6.9261, ppl=1018.5643, gnorm=0.4701, throughput=41.03K wps, wc=628.10K


KeyboardInterrupt: 

In [None]:
# plot losses and validation accuracy
plt.style.use("ggplot")
plt.figure()
plt.plot(range(0, hparams.epochs), valid_losses, label="Validation Loss")
plt.plot(range(0, hparams.epochs), train_losses, label="Training Loss")
plt.plot(range(0, hparams.epochs), valid_perplexities, label="Validation Perplexity")
plt.plot(range(0, hparams.epochs), valid_bleus, label="Validation BLEU")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(loc="upper right")
plt.title("Losses / Perplexity / BLEU")
plt.show()

In [None]:
print("Qualitative Evaluation: Translating from Vietnamese to English")

expected_tgt_seq = "I like to read books."
print("Expected translation:")
print(expected_tgt_seq)
# From Google Translate
src_seq = "Tôi thích đọc sách kỹ thuật."
print("In Vietnamese (from Google Translate):")
print(src_seq)

translation_out = nmt.utils.translate_with_unk(
# translation_out = nmt.utils.translate(
    transformer_translator,
    src_seq,
    wmt_src_vocab,
    wmt_tgt_vocab,
    ctx)

print("The English translation is:")
print(" ".join(translation_out[0]))

## Pre-Trained Model

In [4]:
wmt_model_name = 'transformer_en_de_512'
wmt_transformer_model_pt, wmt_src_vocab, wmt_tgt_vocab = nlp.model.get_model(
    wmt_model_name,
    dataset_name='WMT2014',
    pretrained=True,
    ctx=ctx)

wmt_transformer_model_pt.hybridize(static_alloc=True)

model_filename_pt = "transformer_en_de_512_pt.params"



In [5]:
# Reload data with model vocab
src_lang, tgt_lang = "en", "de"

wmt_data_test = nlp.data.WMT2016BPE(
    'newstest2016',
    src_lang=src_lang,
    tgt_lang=tgt_lang)
print('Sample BPE tokens: "{}"'.format(wmt_data_test[0]))

wmt_test_text = nlp.data.WMT2016(
    'newstest2016',
    src_lang=src_lang,
    tgt_lang=tgt_lang)
print('Sample raw text: "{}"'.format(wmt_test_text[0]))

wmt_test_tgt_sentences = wmt_test_text.transform(lambda src, tgt: tgt)
print('Sample target sentence: "{}"'.format(wmt_test_tgt_sentences[0]))

# wmt_src_vocab, wmt_tgt_vocab = nmt.utils.create_vocab(wmt_test_text)
src_max_len, tgt_max_len = 50, 50

Sample BPE tokens: "('Obama receives Net@@ any@@ ah@@ u', 'Obama empfängt Net@@ any@@ ah@@ u')"
Sample raw text: "('Obama receives Netanyahu', 'Obama empfängt Netanyahu')"
Sample target sentence: "Obama empfängt Netanyahu"


In [6]:
# Pre-processing WMT2016 with WMT2014 model vocab
wmt_dataset_processed = wmt_data_test.transform(
    dataprocessor.TrainValDataTransform(
        wmt_src_vocab,
        wmt_tgt_vocab,
        src_max_len,
        tgt_max_len),
    lazy=False)

def get_length_index_fn():
    global idx
    idx = 0
    def transform(src, tgt):
        global idx
        result = (src, tgt, len(src), len(tgt), idx)
        idx += 1
        return result
    return transform

wmt_data_test_with_len = wmt_dataset_processed.transform(get_length_index_fn(), lazy=False)

In [7]:
wmt_test_batchify_fn = nlp.data.batchify.Tuple(
    nlp.data.batchify.Pad(pad_val=0),
    nlp.data.batchify.Pad(pad_val=0),
    nlp.data.batchify.Stack(dtype='float32'),
    nlp.data.batchify.Stack(dtype='float32'),
    nlp.data.batchify.Stack())

In [8]:
wmt_bucket_scheme = nlp.data.ExpWidthBucket(bucket_len_step=1.2)
wmt_test_batch_sampler = nlp.data.FixedBucketSampler(
    lengths=wmt_data_test_with_len.transform(lambda src, tgt, src_len, tgt_len, idx: tgt_len), # target length
    use_average_length=True, # control the element lengths (i.e. number of tokens) to be about the same
    bucket_scheme=wmt_bucket_scheme,
    batch_size=hparams.batch_size)
print(wmt_test_batch_sampler.stats())

FixedBucketSampler:
  sample_num=2999, batch_num=390
  key=[7, 9, 12, 15, 18, 23, 28, 35, 43, 52]
  cnt=[26, 67, 169, 275, 302, 435, 421, 442, 355, 507]
  batch_size=[36, 28, 21, 17, 14, 11, 9, 7, 6, 4]


In [9]:
num_workers=0

wmt_test_data_loader = mx.gluon.data.DataLoader(
    wmt_data_test_with_len,
    batch_sampler=wmt_test_batch_sampler,
    batchify_fn=wmt_test_batchify_fn,
    num_workers=num_workers)
len(wmt_test_data_loader)

390

In [10]:
transformer_pt_translator = nmt.translation.BeamSearchTranslator(
    model=wmt_transformer_model_pt,
    beam_size=hparams.beam_size,
    scorer=nlp.model.BeamSearchScorer(alpha=hparams.lp_alpha, K=hparams.lp_k),
    max_length=200)

wmt_test_loss_function = nlp.loss.MaskedSoftmaxCELoss()
wmt_test_loss_function.hybridize()

wmt_detokenizer = nlp.data.SacreMosesDetokenizer()

In [11]:
reload(transformer_model)

wmt_test_loss, wmt_test_translation_out = transformer_model.evaluate(
    wmt_transformer_model_pt,
    wmt_test_data_loader,
    wmt_test_loss_function,
    transformer_pt_translator,
    wmt_tgt_vocab,
    wmt_detokenizer,
    ctx)

wmt_test_bleu_score, _, _, _, _ = nmt.bleu.compute_bleu(
    [wmt_test_tgt_sentences],
    wmt_test_translation_out,
    tokenized=False,
    tokenizer="13a",
    split_compound_word=False,
    bpe=False)

print('WMT16 test loss: %.2f; test bleu score: %.2f'
      %(wmt_test_loss, wmt_test_bleu_score * 100))

  0%|          | 0/390 [00:00<?, ?it/s]

WMT16 test loss: 1.59; test bleu score: 29.76


## Transfer Learning

In [12]:
## TEMP XXXXXXX #####
#### LOAD TRAIN VAL DATASETS FOR TRAINING XXXXXXX

In [13]:
# Reload data with model vocab
src_lang, tgt_lang = "en", "de"

wmt2016_train_data = nlp.data.WMT2016BPE(
    'train',
    src_lang=src_lang,
    tgt_lang=tgt_lang)

wmt2016_val_data = nlp.data.WMT2016BPE(
    'train',
    src_lang=src_lang,
    tgt_lang=tgt_lang)

wmt2016_test_data = nlp.data.WMT2016BPE(
    'newstest2016',
    src_lang=src_lang,
    tgt_lang=tgt_lang)

# Text samples not required for train/val
# wmt2016_train_text = nlp.data.WMT2016(
#     'train',
#     src_lang=src_lang,
#     tgt_lang=tgt_lang)

# wmt2016_val_text = nlp.data.WMT2016(
#     'train',
#     src_lang=src_lang,
#     tgt_lang=tgt_lang)

# wmt2016_src_vocab, wmt2016_tgt_vocab = nmt.utils.create_vocab(wmt2016_test_text)
src_max_len, tgt_max_len = 50, 50

In [14]:
# Validation dataset generation (from training dataset)
val_length = 3000

wmt2016_val_data._data[0] = wmt2016_train_data._data[0][-val_length:]
wmt2016_val_data._data[1] = wmt2016_train_data._data[1][-val_length:]
wmt2016_val_data._length = val_length

# Modify Training dataset to remove validation dataset
# Mini Training set
train_length = int(3000)
wmt2016_train_data._data[0] = wmt2016_train_data._data[0][:train_length]
wmt2016_train_data._data[1] = wmt2016_train_data._data[1][:train_length]
wmt2016_train_data._length = train_length

In [31]:
# Target sentences for validation and test
fetch_tgt_sentence = lambda src, tgt: tgt.split()
wmt2016_val_tgt_sentences = wmt2016_val_data.transform(fetch_tgt_sentence)
wmt2016_test_tgt_sentences = wmt2016_test_data.transform(fetch_tgt_sentence)

In [16]:
reload(transformer_model)

# Pre-processing WMT2016 with WMT2014 model vocab
wmt_transform_fn = dataprocessor.TrainValDataTransform(
    wmt_src_vocab,
    wmt_tgt_vocab,
    src_max_len,
    tgt_max_len)

wmt2016_train_data_processed = wmt2016_train_data.transform(
    wmt_transform_fn,
    lazy=False)

wmt2016_val_data_processed = wmt2016_val_data.transform(
    wmt_transform_fn,
    lazy=False)

wmt2016_test_data_processed = wmt2016_test_data.transform(
    wmt_transform_fn,
    lazy=False)

wmt2016_train_data_lengths = transformer_model.get_data_lengths(wmt2016_train_data_processed)
wmt2016_val_data_lengths = transformer_model.get_data_lengths(wmt2016_val_data_processed)
wmt2016_test_data_lengths = transformer_model.get_data_lengths(wmt2016_test_data_processed)

In [17]:
# Add Lengths to the datasets and indexes for validation and test
wmt2016_train_data_len_processed = wmt2016_train_data_processed.transform(lambda src, tgt: (src, tgt, len(src), len(tgt)), lazy=False)
wmt2016_val_data_len_processed = wmt2016_val_data_processed.transform(transformer_model.get_length_index_fn(), lazy=False)
wmt2016_test_data_len_processed = wmt2016_test_data_processed.transform(transformer_model.get_length_index_fn(), lazy=False)

In [18]:
train_batchify_fn = nlp.data.batchify.Tuple(
    nlp.data.batchify.Pad(pad_val=0),
    nlp.data.batchify.Pad(pad_val=0),
    nlp.data.batchify.Stack(dtype='float32'),
    nlp.data.batchify.Stack(dtype='float32'))

val_batchify_fn = nlp.data.batchify.Tuple(
    nlp.data.batchify.Pad(pad_val=0),
    nlp.data.batchify.Pad(pad_val=0),
    nlp.data.batchify.Stack(dtype='float32'),
    nlp.data.batchify.Stack(dtype='float32'),
    nlp.data.batchify.Stack())

In [19]:
reload(transformer_model)

bucket_scheme = nlp.data.ExpWidthBucket(bucket_len_step=1.2)

wmt2016_train_batch_sampler = nlp.data.FixedBucketSampler(
    lengths=wmt2016_train_data_lengths,
    use_average_length=True, # control the element lengths (i.e. number of tokens) to be about the same
    num_buckets=hparams.num_buckets,
    bucket_scheme=bucket_scheme,
    batch_size=hparams.batch_size,
    shuffle=True)
print(wmt2016_train_batch_sampler.stats())

wmt2016_val_batch_sampler = nlp.data.FixedBucketSampler(
    lengths=wmt2016_val_data_lengths,
    use_average_length=True, # control the element lengths (i.e. number of tokens) to be about the same
    num_buckets=hparams.num_buckets,
    bucket_scheme=bucket_scheme,
    batch_size=hparams.batch_size,
    shuffle=False)
print(wmt2016_val_batch_sampler.stats())

wmt2016_test_batch_sampler = nlp.data.FixedBucketSampler(
    lengths=wmt2016_test_data_lengths,
    use_average_length=True, # control the element lengths (i.e. number of tokens) to be about the same
    num_buckets=hparams.num_buckets,
    bucket_scheme=bucket_scheme,
    batch_size=hparams.batch_size,
    shuffle=False)
print(wmt2016_test_batch_sampler.stats())

FixedBucketSampler:
  sample_num=3000, batch_num=445
  key=[(9, 10), (16, 17), (26, 27), (37, 38), (51, 52)]
  cnt=[139, 468, 775, 697, 921]
  batch_size=[27, 16, 10, 7, 4]
FixedBucketSampler:
  sample_num=3000, batch_num=484
  key=[(9, 10), (17, 18), (26, 27), (38, 39), (51, 52)]
  cnt=[82, 357, 615, 853, 1093]
  batch_size=[28, 16, 10, 7, 4]
FixedBucketSampler:
  sample_num=2999, batch_num=369
  key=[(10, 11), (18, 19), (27, 28), (38, 39), (51, 52)]
  cnt=[155, 694, 782, 663, 705]
  batch_size=[25, 15, 10, 7, 5]


In [20]:
num_workers = 0

wmt2016_train_data_loader = nlp.data.ShardedDataLoader(
    wmt2016_train_data_len_processed,
    batch_sampler=wmt2016_train_batch_sampler,
    batchify_fn=train_batchify_fn,
    num_workers=num_workers)

wmt2016_val_data_loader = nlp.data.ShardedDataLoader(
    wmt2016_val_data_len_processed,
    batch_sampler=wmt2016_val_batch_sampler,
    batchify_fn=val_batchify_fn,
    num_workers=num_workers)

wmt2016_test_data_loader = nlp.data.ShardedDataLoader(
    wmt2016_test_data_len_processed,
    batch_sampler=wmt2016_test_batch_sampler,
    batchify_fn=val_batchify_fn,
    num_workers=num_workers)

In [21]:
wmt_transformer_model_tl = wmt_transformer_model_pt

# # Freeze Layers (keeping track of the updated parameters)
# updated_params = []
# for param in wmt_transformer_model_tl.collect_params().values():
#     if param.grad_req == "write":
#         param.grad_req = "null"
#         updated_params += [param.name]

# # Re-enable gradients for last layer
# for param in wmt_transformer_model_tl.tgt_proj.collect_params().values():
#     if param in updated_params:
#         param.grad_req = "write"

# What if we don't overwrite this?
# wmt_transformer_model_tl.tgt_proj = mx.gluon.nn.Dense(units=len(wmt_tgt_vocab), flatten=False, prefix='tgt_proj_')
# wmt_transformer_model_tl.tgt_proj.initialize(ctx=ctx)

# wmt_transformer_model_tl.hybridize(static_alloc=True)

model_filename_tl = "transformer_en_de_512_tl.params"

In [22]:
wmt_translator_tl = nmt.translation.BeamSearchTranslator(
    model=wmt_transformer_model_tl,
    beam_size=hparams.beam_size,
    scorer=nlp.model.BeamSearchScorer(alpha=hparams.lp_alpha, K=hparams.lp_k),
    max_length=200)

In [26]:
reload(transformer_model)

# Let's train
trainer = mx.gluon.Trainer(wmt_transformer_model_tl.collect_params(), hparams.optimizer, {'learning_rate': hparams.lr})

hparams.epochs = 3

loss_function = nlp.loss.MaskedSoftmaxCELoss()
loss_function.hybridize(static_alloc=True)

wmt_detokenizer = nlp.data.SacreMosesDetokenizer()

best_valid_bleu = 0.0

wmt2016_train_losses = []
wmt2016_valid_losses = []
wmt2016_valid_bleus  = []
wmt2016_valid_perplexities = []

for epoch_id in tqdm(range(hparams.epochs)):

    log_loss = 0
    log_denom = 0
    log_avg_gnorm = 0
    log_wc = 0
    log_start_time = time.time()

    # Iterate through each batch
    for batch_id, (src_seq, tgt_seq, src_valid_length, tgt_valid_length) in enumerate(tqdm(wmt2016_train_data_loader)):
        
        # print("XXXXXX TEST XXXXXX:", src_seq,
        #         tgt_seq[:, :-1],
        #         src_valid_length,
        #         tgt_valid_length - 1)
        
        src_seq = src_seq.as_in_context(ctx)
        tgt_seq = tgt_seq.as_in_context(ctx)
        src_valid_length = src_valid_length.as_in_context(ctx)
        tgt_valid_length = tgt_valid_length.as_in_context(ctx)
        
        with mx.autograd.record():
            out, _ = wmt_transformer_model_tl(
                src_seq,
                tgt_seq[:, :-1],
                src_valid_length,
                tgt_valid_length - 1)

            loss = loss_function(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean()
            loss = loss * (tgt_seq.shape[1] - 1)
            log_loss += loss * tgt_seq.shape[0]
            log_denom += (tgt_valid_length - 1).sum()
            loss = loss / (tgt_valid_length - 1).mean()
            loss.backward()

        grads = [p.grad(ctx) for p in wmt_transformer_model_tl.collect_params().values() if p.grad_req != 'null']
        gnorm = mx.gluon.utils.clip_global_norm(grads, hparams.clip)
        trainer.step(1)
        
        src_wc = src_valid_length.sum().asscalar()
        tgt_wc = (tgt_valid_length - 1).sum().asscalar()
        log_loss = log_loss.asscalar()
        log_denom = log_denom.asscalar()
        log_avg_gnorm += gnorm
        log_wc += src_wc + tgt_wc
        
        wmt2016_train_loss = log_loss / log_denom
        wmt2016_train_losses.append(wmt2016_train_loss)
        
        if (batch_id + 1) % hparams.log_interval == 0:
            wps = log_wc / (time.time() - log_start_time)
            print("[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, gnorm={:.4f}, "
                         "throughput={:.2f}K wps, wc={:.2f}K"
                         .format(epoch_id, batch_id + 1, len(wmt2016_train_data_loader),
                                 wmt2016_train_loss,
                                 np.exp(log_loss / log_denom),
                                 log_avg_gnorm / hparams.log_interval,
                                 wps / 1000, log_wc / 1000))
            
            log_start_time = time.time()
            log_loss = 0
            log_denom = 0
            log_avg_gnorm = 0
            log_wc = 0

    # Validation step
    wmt2016_valid_loss, wmt2016_valid_translation_out = transformer_model.evaluate(
        wmt_transformer_model_tl,
        wmt2016_val_data_loader,
        loss_function,
        wmt_translator_tl,
        wmt_tgt_vocab,
        wmt_detokenizer,
        ctx)
    
    wmt2016_valid_bleu_score, _, _, _, _ = nmt.bleu.compute_bleu(
        [wmt2016_val_tgt_sentences],
        wmt2016_valid_translation_out,
        tokenized=False,
        tokenizer="13a",
        split_compound_word=False,
        bpe=False)

    wmt2016_valid_perplexity = np.exp(wmt2016_valid_loss)
    wmt_2016_valid_perplexities.append(wmt2016_valid_perplexity)
    wmt2016_valid_losses.append(wmt2016_valid_loss)
    wmt2016_valid_bleus.append(wmt2016_valid_bleu_score * 100)
    
    print("[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}"
          .format(epoch_id, wmt2016_valid_loss, wmt2016_valid_perplexity, wmt2016_valid_bleu_score * 100))
    
    if wmt2016_valid_bleu_score > best_valid_bleu:
        best_valid_bleu = valid_bleu_score
        print("Save best parameters to {}".format(model_filename_tl))
        wmt_transformer_model_tl.save_parameters(model_filename_tl)
    
    # if epoch_id + 1 >= (hparams.epochs * 2) // 3:
    #     new_lr = trainer.learning_rate * hparams.lr_update_factor
    #     print("Learning rate change to {}".format(new_lr))
    #     trainer.set_learning_rate(new_lr)

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/445 [00:00<?, ?it/s]

[Epoch 0 Batch 100/445] loss=0.9405, ppl=2.5614, gnorm=5.0395, throughput=6.79K wps, wc=37.98K
[Epoch 0 Batch 200/445] loss=0.9812, ppl=2.6677, gnorm=5.0645, throughput=6.54K wps, wc=38.19K
[Epoch 0 Batch 300/445] loss=1.0089, ppl=2.7426, gnorm=5.2692, throughput=6.61K wps, wc=38.13K
[Epoch 0 Batch 400/445] loss=1.0695, ppl=2.9138, gnorm=5.6947, throughput=6.76K wps, wc=37.63K


  0%|          | 0/484 [00:00<?, ?it/s]

AssertionError: references and translation should have format of list(list(str)) and list(str), respectively, when tokenized is False.

In [33]:
wmt2016_val_tgt_sentences

<mxnet.gluon.data.dataset._LazyTransformDataset at 0x7f29d2cd6e80>

In [29]:
wmt_test_tgt_sentences

<mxnet.gluon.data.dataset._LazyTransformDataset at 0x7f2b53b47730>

In [None]:
wmt2016_test_loss, wmt2016_test_translation_out = transformer_model.evaluate(
    wmt_transformer_model_tl,
    wmt2016_test_data_loader,
    loss_function,
    wmt_translator_tl,
    wmt_tgt_vocab,
    wmt_detokenizer,
    ctx)

wmt2016_valid_bleu_score, _, _, _, _ = nmt.bleu.compute_bleu(
    [wmt2016_test_tgt_sentences],
    wmt2016_test_translation_out,
    tokenized=False,
    tokenizer="13a",
    split_compound_word=False,
    bpe=False)

print('WMT16 test loss: %.2f; test bleu score: %.2f'
      %(wmt_test_loss, wmt_test_bleu_score * 100))

  0%|          | 0/369 [00:00<?, ?it/s]

## Fine-Tuning (after Transfer Learning)

In [25]:
# Model built on top of Transfer Learning model
# wmt_transformer_model_tl.load_parameters(model_filename_tl)
wmt_transformer_model_ft = wmt_transformer_model_tl

# Un-freeze weights
for param in wmt_transformer_model_ft.collect_params().values(): 
    if param.name in updated_params:
        param.grad_req = 'write'

wmt_transformer_model_ft.hybridize()

model_filename_ft = "transformer_en_de_512_ft.params"

In [26]:
wmt_translator_ft = nmt.translation.BeamSearchTranslator(
    model=wmt_transformer_model_ft,
    beam_size=hparams.beam_size,
    scorer=nlp.model.BeamSearchScorer(alpha=hparams.lp_alpha, K=hparams.lp_k),
    max_length=200)

In [27]:
reload(transformer_model)

trainer = mx.gluon.Trainer(wmt_transformer_model_ft.collect_params(), 'adam', {'learning_rate': hparams.lr})

loss_function = nlp.loss.MaskedSoftmaxCELoss()
loss_function.hybridize()

wmt_detokenizer = nlp.data.SacreMosesDetokenizer()

best_valid_bleu = 0.0

# Run through each epoch
for epoch_id in tqdm(range(hparams.epochs)):
    log_avg_loss = 0
    log_avg_gnorm = 0
    log_wc = 0
    log_start_time = time.time()

    # Iterate through each batch
    for batch_id, (src_seq, tgt_seq, src_valid_length, tgt_valid_length)\
            in enumerate(tqdm(train_data_loader)):

        src_seq = src_seq.as_in_context(ctx)
        tgt_seq = tgt_seq.as_in_context(ctx)
        src_valid_length = src_valid_length.as_in_context(ctx)
        tgt_valid_length = tgt_valid_length.as_in_context(ctx)

        # Compute gradients and losses
        with mx.autograd.record():
            out, _ = wmt_transformer_model_ft(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1)
            loss = loss_function(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean()
            loss = loss * (tgt_seq.shape[1] - 1) / (tgt_valid_length - 1).mean()
            loss.backward()

        grads = [p.grad(ctx) for p in wmt_transformer_model_ft.collect_params().values() if p.grad_req != "null"]
        gnorm = mx.gluon.utils.clip_global_norm(grads, hparams.clip)
        trainer.step(1)
        src_wc = src_valid_length.sum().asscalar()
        tgt_wc = (tgt_valid_length - 1).sum().asscalar()
        step_loss = loss.asscalar()
        log_avg_loss += step_loss
        log_avg_gnorm += gnorm
        log_wc += src_wc + tgt_wc
        if (batch_id + 1) % hparams.log_interval == 0:
            wps = log_wc / (time.time() - log_start_time)
            print('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, gnorm={:.4f}, '
                         'throughput={:.2f}K wps, wc={:.2f}K'
                         .format(epoch_id, batch_id + 1, len(train_data_loader),
                                 log_avg_loss / hparams.log_interval,
                                 np.exp(log_avg_loss / hparams.log_interval),
                                 log_avg_gnorm / hparams.log_interval,
                                 wps / 1000, log_wc / 1000))
            log_start_time = time.time()
            log_avg_loss = 0
            log_avg_gnorm = 0
            log_wc = 0


  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/785 [00:00<?, ?it/s]

[Epoch 0 Batch 100/785] loss=4.8289, ppl=125.0683, gnorm=3.4245, throughput=30.42K wps, wc=655.63K
[Epoch 0 Batch 200/785] loss=4.1356, ppl=62.5249, gnorm=2.5131, throughput=29.60K wps, wc=582.26K
[Epoch 0 Batch 300/785] loss=4.1546, ppl=63.7279, gnorm=2.7742, throughput=30.41K wps, wc=626.88K
[Epoch 0 Batch 400/785] loss=4.1127, ppl=61.1099, gnorm=3.4123, throughput=29.14K wps, wc=635.07K
[Epoch 0 Batch 500/785] loss=4.0549, ppl=57.6773, gnorm=2.4709, throughput=29.68K wps, wc=662.89K
[Epoch 0 Batch 600/785] loss=3.7495, ppl=42.4993, gnorm=2.5455, throughput=30.36K wps, wc=594.40K
[Epoch 0 Batch 700/785] loss=3.7386, ppl=42.0393, gnorm=2.0752, throughput=31.23K wps, wc=644.82K


  0%|          | 0/785 [00:00<?, ?it/s]

[Epoch 1 Batch 100/785] loss=3.1350, ppl=22.9897, gnorm=1.9700, throughput=29.35K wps, wc=614.72K
[Epoch 1 Batch 200/785] loss=3.2306, ppl=25.2949, gnorm=2.2025, throughput=31.86K wps, wc=661.64K
[Epoch 1 Batch 300/785] loss=3.3254, ppl=27.8103, gnorm=5.0830, throughput=31.12K wps, wc=665.94K
[Epoch 1 Batch 400/785] loss=3.3146, ppl=27.5108, gnorm=13.7271, throughput=28.07K wps, wc=597.31K
[Epoch 1 Batch 500/785] loss=3.3871, ppl=29.5803, gnorm=3.6264, throughput=32.07K wps, wc=676.88K
[Epoch 1 Batch 600/785] loss=3.2095, ppl=24.7671, gnorm=2.5420, throughput=28.17K wps, wc=615.45K
[Epoch 1 Batch 700/785] loss=3.0889, ppl=21.9533, gnorm=2.3815, throughput=29.50K wps, wc=589.80K


  0%|          | 0/785 [00:00<?, ?it/s]

[Epoch 2 Batch 100/785] loss=2.6497, ppl=14.1493, gnorm=2.4708, throughput=28.68K wps, wc=589.19K
[Epoch 2 Batch 200/785] loss=2.8581, ppl=17.4278, gnorm=4.8855, throughput=30.29K wps, wc=655.39K
[Epoch 2 Batch 300/785] loss=2.7027, ppl=14.9204, gnorm=3.2143, throughput=31.67K wps, wc=594.79K
[Epoch 2 Batch 400/785] loss=2.8658, ppl=17.5637, gnorm=6.6410, throughput=31.48K wps, wc=651.27K
[Epoch 2 Batch 500/785] loss=2.8147, ppl=16.6881, gnorm=3.2738, throughput=31.63K wps, wc=622.22K
[Epoch 2 Batch 600/785] loss=2.9138, ppl=18.4259, gnorm=4.6055, throughput=30.76K wps, wc=648.76K
[Epoch 2 Batch 700/785] loss=2.7949, ppl=16.3616, gnorm=3.4296, throughput=31.14K wps, wc=630.86K


  0%|          | 0/785 [00:00<?, ?it/s]

KeyboardInterrupt: 

Process ForkPoolWorker-37:
Process ForkPoolWorker-34:
Process ForkPoolWorker-39:
Process ForkPoolWorker-38:
Process ForkPoolWorker-35:
Process ForkPoolWorker-36:
Process ForkPoolWorker-40:
Process ForkPoolWorker-33:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/usr/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/usr/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/usr/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/usr/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/usr/lib/python3.8/multiprocessing/process.py", line 108

## Fine-Tuning (directly)

In [16]:
ctx = mx.gpu()

# Build on top of Pre-Trained model
wmt_model_name = 'transformer_en_de_512'
wmt_transformer_model_ft_direct, _, _ = nlp.model.get_model(
    wmt_model_name,
    dataset_name='WMT2014',
    pretrained=True,
    ctx=ctx)

# # What if we don't overwrite this?
# wmt_transformer_model_ft_direct.tgt_proj = mx.gluon.nn.Dense(units=len(wmt_tgt_vocab), flatten=False, prefix='tgt_proj_')
# wmt_transformer_model_ft_direct.tgt_proj.initialize(ctx=ctx)

wmt_transformer_model_ft_direct.hybridize()

model_filename_ft_direct = "transformer_en_de_512_ft_direct.params"



In [17]:
wmt_translator_ft_direct = nmt.translation.BeamSearchTranslator(
    model=wmt_transformer_model_ft_direct,
    beam_size=hparams.beam_size,
    scorer=nlp.model.BeamSearchScorer(alpha=hparams.lp_alpha, K=hparams.lp_k),
    max_length=200)

In [18]:
reload(transformer_model)

trainer = mx.gluon.Trainer(wmt_transformer_model_ft_direct.collect_params(), 'adam', {'learning_rate': hparams.lr})

loss_function = nlp.loss.MaskedSoftmaxCELoss()
loss_function.hybridize()

wmt_detokenizer = nlp.data.SacreMosesDetokenizer()

best_valid_bleu = 0.0

# Run through each epoch
for epoch_id in tqdm(range(hparams.epochs)):
    log_avg_loss = 0
    log_avg_gnorm = 0
    log_wc = 0
    log_start_time = time.time()

    # Iterate through each batch
    for batch_id, (src_seq, tgt_seq, src_valid_length, tgt_valid_length)\
            in enumerate(tqdm(train_data_loader)):

        src_seq = src_seq.as_in_context(ctx)
        tgt_seq = tgt_seq.as_in_context(ctx)
        src_valid_length = src_valid_length.as_in_context(ctx)
        tgt_valid_length = tgt_valid_length.as_in_context(ctx)

        # Compute gradients and losses
        with mx.autograd.record():
            out, _ = wmt_transformer_model_ft_direct(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1)
            loss = loss_function(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean()
            loss = loss * (tgt_seq.shape[1] - 1) / (tgt_valid_length - 1).mean()
            loss.backward()

        grads = [p.grad(ctx) for p in wmt_transformer_model_ft_direct.collect_params().values() if p.grad_req != "null"]
        gnorm = mx.gluon.utils.clip_global_norm(grads, hparams.clip)
        trainer.step(1)
        src_wc = src_valid_length.sum().asscalar()
        tgt_wc = (tgt_valid_length - 1).sum().asscalar()
        step_loss = loss.asscalar()
        log_avg_loss += step_loss
        log_avg_gnorm += gnorm
        log_wc += src_wc + tgt_wc
        if (batch_id + 1) % hparams.log_interval == 0:
            wps = log_wc / (time.time() - log_start_time)
            print('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, gnorm={:.4f}, '
                         'throughput={:.2f}K wps, wc={:.2f}K'
                         .format(epoch_id, batch_id + 1, len(train_data_loader),
                                 log_avg_loss / hparams.log_interval,
                                 np.exp(log_avg_loss / hparams.log_interval),
                                 log_avg_gnorm / hparams.log_interval,
                                 wps / 1000, log_wc / 1000))
            log_start_time = time.time()
            log_avg_loss = 0
            log_avg_gnorm = 0
            log_wc = 0


  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/785 [00:00<?, ?it/s]

KeyboardInterrupt: 