In [1]:
import datetime
import dill as pickle
import numpy as np
import torch
import time
import tqdm

In [2]:
from functools import partial
from mytransformers.data import SimpleTranslationDataset
from mytransformers.data import pad_to_seq_len
from mytransformers.models import TransformerModel

## training config

In [3]:
MAX_SEQ_LEN = 256
VOCAB_SIZE = 8000
BATCH_SIZE = 16
MAX_EPOCHS = 64
GRAD_CLIP = 10.0

## make datasets and dataloaders

this is being trained for English > Korean translation on the `korean-english-news-v1` dataset from https://github.com/jungyeul/korean-parallel-corpora 

no preprocessing is done except to read in the data and write the (en, kr) pairs to tsv.


In [4]:
train_file = "data/translation_news/en-ko-train.tsv"
valid_file = "data/translation_news/en-ko-dev.tsv"
train_tokenizer = "data/translation_news/src_tokenizer.pkl"
valid_tokenizer = "data/translation_news/tgt_tokenizer.pkl"
checkpoint_file = "data/translation_news/checkpoint.pt"
sample_sentences = [
    "After keeping the world's most powerful supercomputer to themselves for a year, government researchers showed off the $110 million wonder and said it might help save the world from nuclear war.",
    "Most of the people involved in the discussion agree that there is a legitimate area in which the government needs to retain the right to intercept communications.",
    "Several Texan transmission companies announced Monday they were forming a consortium to invest in the $5 billion cost of building new power lines to take advantage of the state's vast wind power."
]

In [5]:
# # for training with the alternate 'bible' dataset.
# # here, the verse notes are removed, and every tenth item starting from the second (index = 1) is set to validation
# train_file = "data/translation_bible/train_pairs.tsv"
# valid_file = "data/translation_bible/valid_pairs.tsv"
# train_tokenizer = "data/translation_bible/src_tokenizer.pkl"
# valid_tokenizer = "data/translation_bible/tgt_tokenizer.pkl"
# checkpoint_file = "data/translation_bible/checkpoint.pt"
# sample_sentences = [
#     "The weapons we fight with are not the weapons of the world. On the contrary, they have divine power to demolish strongholds.",
#     "Make it your ambition to lead a quiet life, to mind your own business and to work with your hands, just as we told you,",
#     "It had a great, high wall with twelve gates, and with twelve angels at the gates. On the gates were written the names of the twelve tribes of Israel."
# ]

In [6]:
train_dataset = SimpleTranslationDataset(train_file, vocab_size=VOCAB_SIZE)

fitting tokenizer...
fitting source tokenizer...


sentencepiece_trainer.cc(77) LOG(INFO) Starts training with : 
trainer_spec {
  input_format: 
  model_prefix: 
  model_type: UNIGRAM
  vocab_size: 8000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  control_symbols: [CLS]
  control_symbols: [SEP]
  control_symbols: [NEW1]
  control_symbols: [NEW2]
  control_symbols: [NEW3]
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 1
  bos_id: 2
  eos_id: 3
  pad_id: 0
  unk_piece: [UNK]
  bos_piece: [BOS]
  eos_piece: [EOS]
  pad_piece: [PAD]
  unk_surface:  ⁇ 

fitting target tokenizer...


 name: nmt_nfkc
  add_dummy_prefix: 1
  remove_extra_whitespaces: 1
  escape_whitespaces: 1
  normalization_rule_tsv: 
}
denormalizer_spec {}
trainer_interface.cc(385) LOG(INFO) Loaded all 94123 sentences
trainer_interface.cc(400) LOG(INFO) Adding meta_piece: [PAD]
trainer_interface.cc(400) LOG(INFO) Adding meta_piece: [UNK]
trainer_interface.cc(400) LOG(INFO) Adding meta_piece: [BOS]
trainer_interface.cc(400) LOG(INFO) Adding meta_piece: [EOS]
trainer_interface.cc(400) LOG(INFO) Adding meta_piece: [CLS]
trainer_interface.cc(400) LOG(INFO) Adding meta_piece: [SEP]
trainer_interface.cc(400) LOG(INFO) Adding meta_piece: [NEW1]
trainer_interface.cc(400) LOG(INFO) Adding meta_piece: [NEW2]
trainer_interface.cc(400) LOG(INFO) Adding meta_piece: [NEW3]
trainer_interface.cc(405) LOG(INFO) Normalizing sentences...
trainer_interface.cc(466) LOG(INFO) all chars count=11967043
trainer_interface.cc(477) LOG(INFO) Done: 99.955% characters are covered.
trainer_interface.cc(487) LOG(INFO) Alphabet si

In [7]:
src_tokenizer, tgt_tokenizer = train_dataset.get_tokenizers()

In [8]:
pickle.dump(src_tokenizer, open(train_tokenizer, "wb"))
pickle.dump(tgt_tokenizer, open(valid_tokenizer, "wb"))

In [9]:
valid_dataset = SimpleTranslationDataset(valid_file, 
                                         src_tokenizer=src_tokenizer, 
                                         tgt_tokenizer=tgt_tokenizer)

In [10]:
print("train samples: {}".format(len(train_dataset)))
print("valid samples: {}".format(len(valid_dataset)))

train samples: 94123
valid samples: 1000


In [11]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                                               collate_fn=partial(pad_to_seq_len, max_seq_len=MAX_SEQ_LEN))
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                                               collate_fn=partial(pad_to_seq_len, max_seq_len=MAX_SEQ_LEN))

## create model, etc

the model configuration is loosely based on the *Attention is All You Need* base configuration, with the following changes:

- the token embedding space used is smaller than the transformer input dimension, like ALBERT
- a small amount of dropout is added to the QK.T (`attn_dropout`) and to the first FFNN projection (`ffnn_dropout`)
- the GELU activation is used in the FFNN layer, like BERT and GPT
- the pre-layernorm configuration is used
- AdamW optimizer is used

In [12]:
mytransformer = TransformerModel(
     src_vocab_sz=VOCAB_SIZE,
     tgt_vocab_sz=VOCAB_SIZE,
     enc_layers=6,
     dec_layers=6,
     seq_len=MAX_SEQ_LEN,
     d_vocab=256,
     d_in=512, 
     d_attn=64, 
     d_ffnn=1024, 
     attn_heads=8, 
     dropout=0.1,
     attn_dropout=0.05, 
     ffnn_dropout=0.05,
     pos_encoding="sinusoidal",
     shared_vocab=False,
     attn_mask_val=-1e08, 
     ffnn_activation="gelu", 
     pre_ln=True
).cuda()

In [13]:
mytransformer.config

{'src_vocab_sz': 8000,
 'tgt_vocab_sz': 8000,
 'enc_layers': 6,
 'dec_layers': 6,
 'seq_len': 256,
 'd_vocab': 256,
 'd_in': 512,
 'd_attn': 64,
 'd_ffnn': 1024,
 'attn_heads': 8,
 'dropout': 0.1,
 'attn_dropout': 0.05,
 'ffnn_dropout': 0.05,
 'pos_encoding': 'sinusoidal',
 'shared_vocab': False,
 'attn_mask_val': -100000000.0,
 'attn_q_bias': False,
 'attn_kv_bias': False,
 'attn_out_bias': False,
 'ffnn_activation': 'gelu',
 'pre_ln': True}

### learning rate scheduling

we'll use the `OneCycleLR` to roughly approximate the warmup and annealing by 'warming up' for 1 epoch and then decaying for 49 epochs (until 50th epoch).

In [14]:
criterion = torch.nn.CrossEntropyLoss(ignore_index=0, reduction="sum")

optimizer = torch.optim.AdamW(mytransformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-09, weight_decay=0.0001, amsgrad=False)

scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, total_steps=len(train_dataloader)*50, pct_start=1./50)


## training loop

In [None]:
mytransformer.train()

global_step = 0

windowed_losses = []

for epoch in range(MAX_EPOCHS):

    print("starting epoch {} of {}".format(epoch+1, MAX_EPOCHS))
    time.sleep(1)
    
    with tqdm.tqdm(train_dataloader, desc="minibatch", total=len(train_dataloader)) as b:

        for batch in train_dataloader:

            x, y_in, y_true, x_lens, y_lens = batch
            x = x.to("cuda")
            y_in = y_in.to("cuda")
            y_true = y_true.to("cuda")
            x_lens = x_lens.to("cuda")
            y_lens = y_lens.to("cuda")

            y_pred = mytransformer(x, y_in, x_lens, y_lens)

            loss = criterion(torch.transpose(y_pred, 1, 2), y_true)
            loss /= torch.sum(y_lens)  # scale by all non-zero elements

            loss.backward() 
            torch.nn.utils.clip_grad_norm_(mytransformer.parameters(), GRAD_CLIP)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            windowed_losses.append(loss.item())
            windowed_losses = windowed_losses[-16:]

            global_step += 1
            
            b.set_postfix(loss="{:.3f}".format(np.mean(windowed_losses)), global_step=global_step)
            b.update(1)
            
    # end of epoch loss
    tme = datetime.datetime.now().isoformat()[11:22]
    print("[{}] epoch {:>03d} global step {:>04d}: loss: {:>8.3f}\tavg: {:>8.3f} (end of epoch)".format(
        tme, epoch+1, global_step, loss.item(), np.mean(windowed_losses)
    ))
        
    # evaluate
    eval_losses = []
    time.sleep(1)
    tme = datetime.datetime.now().isoformat()[11:22]
    print("\n[{}] evaluating...\n".format(tme))
    time.sleep(1)
    
    mytransformer.eval()
    
    with torch.no_grad():
        for batch in tqdm.tqdm(valid_dataloader):
            x, y_in, y_true, x_lens, y_lens = batch
            x = x.to("cuda")
            y_in = y_in.to("cuda")
            y_true = y_true.to("cuda")
            x_lens = x_lens.to("cuda")
            y_lens = y_lens.to("cuda")
            y_pred = mytransformer(x, y_in, x_lens, y_lens)
            loss = criterion(torch.transpose(y_pred, 1, 2), y_true)
            loss /= torch.sum(y_lens)  # scale by all non-zero elements
            eval_losses.append(loss.item())
            
    tme = datetime.datetime.now().isoformat()[11:22]
    print("\n[{}] epoch {:>03d} eval loss: {:>8.3f}".format(tme, epoch+1, np.mean(eval_losses)))
    
    time.sleep(1)
    # infer some results
    print("\n sample greedy outputs:\n")
    with torch.no_grad():
        for sample in sample_sentences:
            x, x_len = src_tokenizer.transform(sample, as_array=True, bos=True, eos=True, max_len=MAX_SEQ_LEN)
            x = torch.from_numpy(x).long().to("cuda")
            x_len = torch.from_numpy(x_len).long().to("cuda")
            y_hat = mytransformer.infer_one_greedy(x, x_len, bos=2, eos=3)
            y_hat = tgt_tokenizer.inverse_transform([y_hat], as_tokens=False)[0]
            print("\tsrc: {}".format(sample))
            print("\tprd: {}\n".format(y_hat))
    
    
    # save
    torch.save({
            'epoch': epoch+1,
            'global_step': global_step,
            'model_state_dict': mytransformer.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'windowed_losses': windowed_losses,
            'avg_loss': np.mean(windowed_losses),
            'eval_loss': np.mean(eval_losses),
            'training_config': mytransformer.config,
            'batch_size': BATCH_SIZE
            }, checkpoint_file)
    
    print("\n[{}] checkpoint saved!".format(tme))
    
    
    mytransformer.train()
    

starting epoch 1 of 64


minibatch:   3%|███▊                                                                                                                               | 174/5883 [00:32<18:10,  5.23it/s, global_step=174, loss=545.209]

In [None]:
-1e08