In [1]:
import datetime
import dill as pickle
import numpy as np
import torch
import time
import tqdm

In [2]:
from functools import partial
from mytransformers.data import SimpleTranslationDataset
from mytransformers.data import pad_to_seq_len
from mytransformers.models import TransformerModel

## training config

In [3]:
MAX_SEQ_LEN = 256
VOCAB_SIZE = 8000
BATCH_SIZE = 16
MAX_EPOCHS = 10

## make datasets and dataloaders

this is being trained for English > Korean translation on the `bible` dataset from https://github.com/jungyeul/korean-parallel-corpora 

no preprocessing is done except to remove the initial verse information, and then the (en, kr) pairs were written to tsv.

every tenth sample, starting from the second (index 1), is used for test set; the rest are used for training.


In [4]:
train_dataset = SimpleTranslationDataset("data/translation/train_pairs.tsv", vocab_size=VOCAB_SIZE)

fitting tokenizer...
fitting source tokenizer...
fitting target tokenizer...


sentencepiece_trainer.cc(77) LOG(INFO) Starts training with : 
trainer_spec {
  input_format: 
  model_prefix: 
  model_type: UNIGRAM
  vocab_size: 8000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  control_symbols: [CLS]
  control_symbols: [SEP]
  control_symbols: [NEW1]
  control_symbols: [NEW2]
  control_symbols: [NEW3]
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 1
  bos_id: 2
  eos_id: 3
  pad_id: 0
  unk_piece: [UNK]
  bos_piece: [BOS]
  eos_piece: [EOS]
  pad_piece: [PAD]
  unk_surface:  ⁇ 

In [5]:
src_tokenizer, tgt_tokenizer = train_dataset.get_tokenizers()

In [6]:
pickle.dump(src_tokenizer, open("data/translation/src_tokenizer.pkl", "wb"))
pickle.dump(tgt_tokenizer, open("data/translation/tgt_tokenizer.pkl", "wb"))

In [7]:
test_dataset = SimpleTranslationDataset("data/translation/valid_pairs.tsv", 
                                        src_tokenizer=src_tokenizer, 
                                        tgt_tokenizer=tgt_tokenizer)

In [8]:
print("train samples: {}".format(len(train_dataset)))
print("test  samples: {}".format(len(test_dataset)))

train samples: 27975
test  samples: 3109


In [9]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=partial(pad_to_seq_len, max_seq_len=MAX_SEQ_LEN))
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=partial(pad_to_seq_len, max_seq_len=MAX_SEQ_LEN))

## create model, etc

the model configuration is loosely based on the *Attention is All You Need* base configuration, with the following changes:

- only four attention heads are used
- the token embedding space used is smaller than the transformer input dimension, like ALBERT
- a small amount of dropout is added to the QK.T (`attn_dropout`) and to the first FFNN projection (`ffnn_dropout`)
- the GELU activation is used in the FFNN layer, like BERT and GPT
- the pre-layernorm configuration is used

In [10]:
mytransformer = TransformerModel(
     src_vocab_sz=VOCAB_SIZE,
     tgt_vocab_sz=VOCAB_SIZE,
     enc_layers=6,
     dec_layers=6,
     seq_len=MAX_SEQ_LEN,
     d_vocab=128,
     d_in=512, 
     d_attn=128, 
     d_ffnn=1024, 
     attn_heads=4, 
     dropout=0.1,
     attn_dropout=0.05, 
     ffnn_dropout=0.05,
     shared_vocab=False,
     attn_mask_val=-10e8, 
     ffnn_activation="gelu", 
     pre_ln=True
).cuda()

In [11]:
mytransformer.config

{'src_vocab_sz': 8000,
 'tgt_vocab_sz': 8000,
 'enc_layers': 6,
 'dec_layers': 6,
 'seq_len': 256,
 'd_vocab': 128,
 'd_in': 512,
 'd_attn': 128,
 'd_ffnn': 1024,
 'attn_heads': 4,
 'dropout': 0.1,
 'attn_dropout': 0.05,
 'ffnn_dropout': 0.05,
 'pos_encoding': 'sinusoidal',
 'shared_vocab': False,
 'attn_mask_val': -1000000000.0,
 'attn_q_bias': False,
 'attn_kv_bias': False,
 'attn_out_bias': False,
 'ffnn_activation': 'gelu',
 'pre_ln': True}

### learning rate scheduling

we'll use the `OneCycleLR` to roughly approximate the warmup and annealing by 'warming up' 4000 steps and then decaying for 16000 steps.

In [12]:
criterion = torch.nn.CrossEntropyLoss(ignore_index=0, reduction="sum")

optimizer = torch.optim.Adam(mytransformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-09, weight_decay=0.001, amsgrad=True)

scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, total_steps=20000, pct_start=0.2)


## training loop

In [13]:
mytransformer.train()

global_step = 0

windowed_losses = []

for epoch in range(MAX_EPOCHS):

    for batch in train_dataloader:
        x, y_in, y_true, x_lens, y_lens = batch
        x = x.to("cuda")
        y_in = y_in.to("cuda")
        y_true = y_true.to("cuda")
        x_lens = x_lens.to("cuda")
        y_lens = y_lens.to("cuda")

        y_pred = mytransformer(x, y_in, x_lens, y_lens)

        loss = criterion(torch.transpose(y_pred, 1, 2), y_true)

        loss.backward() 

        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
    
        windowed_losses.append(loss.item())
        windowed_losses = windowed_losses[-200:]
    
        global_step += 1
        
        if global_step <= 2000 and global_step % 100 == 0:
            tme = datetime.datetime.now().isoformat()[11:22]
            print("[{}] epoch {:>03d} global step {:>04d}: loss: {:>8.3f}\tavg: {:>8.3f}".format(
                tme, epoch+1, global_step, loss.item()/BATCH_SIZE, np.mean(windowed_losses)/BATCH_SIZE
            ))
        elif global_step > 2000 and global_step % 200 == 0:
            tme = datetime.datetime.now().isoformat()[11:22]
            print("[{}] epoch {:>03d} global step {:>04d}: loss: {:>8.3f}\tavg: {:>8.3f}".format(
                tme, epoch+1, global_step, loss.item()/BATCH_SIZE, np.mean(windowed_losses)/BATCH_SIZE
            ))
            
    # evaluate
    eval_losses = []
    time.sleep(1)
    tme = datetime.datetime.now().isoformat()[11:22]
    print("\n[{}] evaluating...\n".format(tme))
    time.sleep(1)
    
    mytransformer.eval()
    
    for batch in tqdm.tqdm(test_dataloader):
        x, y_in, y_true, x_lens, y_lens = batch
        x = x.to("cuda")
        y_in = y_in.to("cuda")
        y_true = y_true.to("cuda")
        x_lens = x_lens.to("cuda")
        y_lens = y_lens.to("cuda")
        y_pred = mytransformer(x, y_in, x_lens, y_lens)
        loss = criterion(torch.transpose(y_pred, 1, 2), y_true)
        eval_losses.append(loss.item())
    tme = datetime.datetime.now().isoformat()[11:22]
    print("\n[{}] epoch {:>03d} eval loss: {:>8.3f}".format(tme, epoch+1, np.mean(eval_losses)/BATCH_SIZE))
    
    mytransformer.train()
    
    # save
    torch.save({
            'epoch': epoch+1,
            'global_step': global_step,
            'model_state_dict': mytransformer.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'windowed_losses': windowed_losses,
            'avg_loss': np.mean(windowed_losses)/BATCH_SIZE,
            'eval_loss': np.mean(eval_losses)/BATCH_SIZE,
            'training_config': mytransformer.config,
            'batch_size': BATCH_SIZE
            }, "data/translation/checkpoint.pt")
    
    print("\n[{}] checkpoint saved!".format(tme))
    

[15:06:18.48] epoch 000 global step 0100: loss:  855.400	avg:  885.843
[15:06:33.25] epoch 000 global step 0200: loss:  596.050	avg:  769.033
[15:06:48.11] epoch 000 global step 0300: loss:  589.208	avg:  600.988
[15:07:02.96] epoch 000 global step 0400: loss:  483.039	avg:  514.638
[15:07:17.80] epoch 000 global step 0500: loss:  399.909	avg:  447.875
[15:07:32.65] epoch 000 global step 0600: loss:  371.637	avg:  392.230
[15:07:47.53] epoch 000 global step 0700: loss:  283.004	avg:  353.835
[15:08:02.44] epoch 000 global step 0800: loss:  237.866	avg:  326.004
[15:08:17.38] epoch 000 global step 0900: loss:  256.727	avg:  301.190
[15:08:32.35] epoch 000 global step 1000: loss:  273.828	avg:  290.748
[15:08:47.27] epoch 000 global step 1100: loss:  203.040	avg:  283.346
[15:09:02.19] epoch 000 global step 1200: loss:  300.926	avg:  269.114
[15:09:17.12] epoch 000 global step 1300: loss:  274.900	avg:  263.016
[15:09:32.04] epoch 000 global step 1400: loss:  298.570	avg:  259.900
[15:09

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 195/195 [00:10<00:00, 18.43it/s]



[15:10:36.72] epoch 000 eval loss:  233.199


[15:10:36.72] checkpoint saved!
[15:10:45.94] epoch 001 global step 1800: loss:  274.008	avg:  249.182
[15:11:00.92] epoch 001 global step 1900: loss:  256.068	avg:  243.836
[15:11:15.84] epoch 001 global step 2000: loss:  228.119	avg:  243.304
[15:11:45.85] epoch 001 global step 2200: loss:  203.430	avg:  236.859
[15:12:15.86] epoch 001 global step 2400: loss:  259.414	avg:  234.023
[15:12:45.94] epoch 001 global step 2600: loss:  209.408	avg:  230.394
[15:13:15.83] epoch 001 global step 2800: loss:  246.744	avg:  224.742
[15:13:45.79] epoch 001 global step 3000: loss:  172.266	avg:  219.727
[15:14:15.82] epoch 001 global step 3200: loss:  236.821	avg:  215.811
[15:14:45.76] epoch 001 global step 3400: loss:  209.244	avg:  213.388

[15:15:01.36] evaluating...



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 195/195 [00:10<00:00, 18.52it/s]



[15:15:12.89] epoch 001 eval loss:  205.056


[15:15:12.89] checkpoint saved!
[15:15:29.79] epoch 002 global step 3600: loss:  216.116	avg:  212.932
[15:15:59.68] epoch 002 global step 3800: loss:  233.801	avg:  207.705
[15:16:29.61] epoch 002 global step 4000: loss:  223.026	avg:  204.183
[15:16:59.62] epoch 002 global step 4200: loss:  224.804	avg:  202.028
[15:17:29.45] epoch 002 global step 4400: loss:  197.826	avg:  201.144
[15:17:59.31] epoch 002 global step 4600: loss:  208.375	avg:  197.073
[15:18:29.23] epoch 002 global step 4800: loss:  171.988	avg:  200.314
[15:18:59.12] epoch 002 global step 5000: loss:  159.278	avg:  196.568
[15:19:28.95] epoch 002 global step 5200: loss:  185.857	avg:  194.304

[15:19:36.89] evaluating...



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 195/195 [00:10<00:00, 18.57it/s]



[15:19:48.40] epoch 002 eval loss:  187.747


[15:19:48.40] checkpoint saved!
[15:20:12.87] epoch 003 global step 5400: loss:  201.055	avg:  192.379
[15:20:42.73] epoch 003 global step 5600: loss:  188.163	avg:  191.244
[15:21:12.58] epoch 003 global step 5800: loss:  178.876	avg:  191.146
[15:21:42.42] epoch 003 global step 6000: loss:  184.907	avg:  187.417
[15:22:12.27] epoch 003 global step 6200: loss:  203.344	avg:  188.157
[15:22:42.10] epoch 003 global step 6400: loss:  179.400	avg:  187.000
[15:23:11.94] epoch 003 global step 6600: loss:  161.490	avg:  184.733
[15:23:41.79] epoch 003 global step 6800: loss:  187.059	avg:  187.522

[15:24:11.97] evaluating...



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 195/195 [00:10<00:00, 18.52it/s]



[15:24:23.50] epoch 003 eval loss:  178.593


[15:24:23.50] checkpoint saved!
[15:24:25.73] epoch 004 global step 7000: loss:  170.055	avg:  187.175
[15:24:55.53] epoch 004 global step 7200: loss:  155.328	avg:  182.858
[15:25:25.36] epoch 004 global step 7400: loss:  207.552	avg:  182.483
[15:25:55.18] epoch 004 global step 7600: loss:  185.946	avg:  181.897
[15:26:25.08] epoch 004 global step 7800: loss:  170.926	avg:  182.648
[15:26:54.91] epoch 004 global step 8000: loss:  185.362	avg:  179.754
[15:27:24.73] epoch 004 global step 8200: loss:  180.423	avg:  179.963
[15:27:54.58] epoch 004 global step 8400: loss:  181.306	avg:  179.876
[15:28:24.41] epoch 004 global step 8600: loss:  181.171	avg:  179.560

[15:28:46.96] evaluating...



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 195/195 [00:10<00:00, 18.53it/s]



[15:28:58.49] epoch 004 eval loss:  173.059


[15:28:58.49] checkpoint saved!
[15:29:08.30] epoch 005 global step 8800: loss:  172.429	avg:  177.714
[15:29:38.11] epoch 005 global step 9000: loss:  170.585	avg:  176.464
[15:30:07.92] epoch 005 global step 9200: loss:  174.514	avg:  178.252
[15:30:37.84] epoch 005 global step 9400: loss:  149.142	avg:  174.932
[15:31:07.72] epoch 005 global step 9600: loss:  163.473	avg:  175.860
[15:31:37.56] epoch 005 global step 9800: loss:  155.457	avg:  175.831
[15:32:07.46] epoch 005 global step 10000: loss:  169.501	avg:  173.621
[15:32:37.36] epoch 005 global step 10200: loss:  187.753	avg:  174.350
[15:33:07.23] epoch 005 global step 10400: loss:  166.840	avg:  172.865

[15:33:22.20] evaluating...



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 195/195 [00:10<00:00, 18.51it/s]



[15:33:33.74] epoch 005 eval loss:  168.505


[15:33:33.74] checkpoint saved!
[15:33:51.23] epoch 006 global step 10600: loss:  180.432	avg:  171.914
[15:34:21.16] epoch 006 global step 10800: loss:  169.803	avg:  173.744
[15:34:51.02] epoch 006 global step 11000: loss:  175.237	avg:  171.910
[15:35:20.83] epoch 006 global step 11200: loss:  143.320	avg:  168.460
[15:35:50.71] epoch 006 global step 11400: loss:  146.061	avg:  170.854
[15:36:20.55] epoch 006 global step 11600: loss:  190.681	avg:  168.510
[15:36:50.39] epoch 006 global step 11800: loss:  161.113	avg:  170.824
[15:37:20.20] epoch 006 global step 12000: loss:  167.819	avg:  170.536
[15:37:50.00] epoch 006 global step 12200: loss:  167.414	avg:  170.340

[15:37:57.34] evaluating...



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 195/195 [00:10<00:00, 18.53it/s]



[15:38:08.87] epoch 006 eval loss:  165.626


[15:38:08.87] checkpoint saved!
[15:38:33.93] epoch 007 global step 12400: loss:  170.905	avg:  166.254
[15:39:03.98] epoch 007 global step 12600: loss:  159.244	avg:  169.695
[15:39:33.87] epoch 007 global step 12800: loss:  175.417	avg:  165.688
[15:40:03.73] epoch 007 global step 13000: loss:  208.833	avg:  167.492
[15:40:33.56] epoch 007 global step 13200: loss:  194.835	avg:  165.960
[15:41:03.41] epoch 007 global step 13400: loss:  156.088	avg:  167.134
[15:41:33.23] epoch 007 global step 13600: loss:  150.271	avg:  167.714
[15:42:03.08] epoch 007 global step 13800: loss:  186.349	avg:  168.065

[15:42:32.63] evaluating...



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 195/195 [00:10<00:00, 18.55it/s]



[15:42:44.15] epoch 007 eval loss:  162.255


[15:42:44.15] checkpoint saved!
[15:42:46.87] epoch 008 global step 14000: loss:  194.108	avg:  162.667
[15:43:16.60] epoch 008 global step 14200: loss:  152.825	avg:  162.680
[15:43:46.38] epoch 008 global step 14400: loss:  180.556	avg:  164.472
[15:44:16.21] epoch 008 global step 14600: loss:  172.616	avg:  164.843
[15:44:46.06] epoch 008 global step 14800: loss:  179.300	avg:  162.964
[15:45:15.88] epoch 008 global step 15000: loss:  147.744	avg:  162.579
[15:45:45.73] epoch 008 global step 15200: loss:  166.598	avg:  165.124
[15:46:15.58] epoch 008 global step 15400: loss:  186.272	avg:  163.941
[15:46:45.42] epoch 008 global step 15600: loss:  138.625	avg:  164.168

[15:47:07.39] evaluating...



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 195/195 [00:10<00:00, 18.46it/s]



[15:47:18.95] epoch 008 eval loss:  161.094


[15:47:18.95] checkpoint saved!
[15:47:29.34] epoch 009 global step 15800: loss:  158.638	avg:  161.972
[15:47:59.10] epoch 009 global step 16000: loss:  141.384	avg:  159.934
[15:48:28.87] epoch 009 global step 16200: loss:  193.410	avg:  162.098
[15:48:58.65] epoch 009 global step 16400: loss:  148.191	avg:  163.518
[15:49:28.53] epoch 009 global step 16600: loss:  185.487	avg:  161.614
[15:49:58.41] epoch 009 global step 16800: loss:  160.649	avg:  161.522
[15:50:28.27] epoch 009 global step 17000: loss:  162.731	avg:  161.180
[15:50:58.15] epoch 009 global step 17200: loss:  165.410	avg:  161.454
[15:51:28.04] epoch 009 global step 17400: loss:  161.806	avg:  163.586

[15:51:42.42] evaluating...



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 195/195 [00:10<00:00, 18.53it/s]



[15:51:53.95] epoch 009 eval loss:  160.184


[15:51:53.95] checkpoint saved!
