In [1]:
import datetime
import dill as pickle
import numpy as np
import torch
import time
import tqdm

In [2]:
from functools import partial
from mytransformers.data import SimpleTranslationDataset
from mytransformers.data import pad_to_seq_len
from mytransformers.models import TransformerModel

## training config

In [3]:
MAX_SEQ_LEN = 256
VOCAB_SIZE = 8000
BATCH_SIZE = 32
WARM_EPOCHS = 5
COOL_EPOCHS = 35
MAX_EPOCHS = 50
INIT_LR = 0.0001
MAX_LR = 0.001
GRAD_CLIP = 5.0

## make datasets and dataloaders

this is being trained for English > Korean translation on the `korean-english-news-v1` dataset from https://github.com/jungyeul/korean-parallel-corpora 

no preprocessing is done except to read in the data and write the (en, kr) pairs to tsv.


In [4]:
train_file = "data/translation_news/en-ko-train.tsv"
valid_file = "data/translation_news/en-ko-dev.tsv"
train_tokenizer = "data/translation_news/src_tokenizer.pkl"
valid_tokenizer = "data/translation_news/tgt_tokenizer.pkl"
checkpoint_file = "data/translation_news/checkpoint.pt"
sample_sentences = [
    "After keeping the world's most powerful supercomputer to themselves for a year, government researchers showed off the $110 million wonder and said it might help save the world from nuclear war.",
    "Most of the people involved in the discussion agree that there is a legitimate area in which the government needs to retain the right to intercept communications.",
    "Several Texan transmission companies announced Monday they were forming a consortium to invest in the $5 billion cost of building new power lines to take advantage of the state's vast wind power."
]

In [5]:
# # for training with the alternate 'bible' dataset.
# # here, the verse notes are removed, and every tenth item starting from the second (index = 1) is set to validation
# train_file = "data/translation_bible/train_pairs.tsv"
# valid_file = "data/translation_bible/valid_pairs.tsv"
# train_tokenizer = "data/translation_bible/src_tokenizer.pkl"
# valid_tokenizer = "data/translation_bible/tgt_tokenizer.pkl"
# checkpoint_file = "data/translation_bible/checkpoint.pt"
# sample_sentences = [
#     "The weapons we fight with are not the weapons of the world. On the contrary, they have divine power to demolish strongholds.",
#     "Make it your ambition to lead a quiet life, to mind your own business and to work with your hands, just as we told you,",
#     "It had a great, high wall with twelve gates, and with twelve angels at the gates. On the gates were written the names of the twelve tribes of Israel."
# ]

In [6]:
train_dataset = SimpleTranslationDataset(train_file, vocab_size=VOCAB_SIZE)

fitting tokenizer...
fitting source tokenizer...
fitting target tokenizer...


sentencepiece_trainer.cc(77) LOG(INFO) Starts training with : 
trainer_spec {
  input_format: 
  model_prefix: 
  model_type: UNIGRAM
  vocab_size: 8000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  control_symbols: [CLS]
  control_symbols: [SEP]
  control_symbols: [MASK]
  control_symbols: [NEW1]
  control_symbols: [NEW2]
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 1
  bos_id: 2
  eos_id: 3
  pad_id: 0
  unk_piece: [UNK]
  bos_piece: [BOS]
  eos_piece: [EOS]
  pad_piece: [PAD]
  unk_surface:  ⁇ 

In [7]:
src_tokenizer, tgt_tokenizer = train_dataset.get_tokenizers()

In [8]:
pickle.dump(src_tokenizer, open(train_tokenizer, "wb"))
pickle.dump(tgt_tokenizer, open(valid_tokenizer, "wb"))

In [9]:
valid_dataset = SimpleTranslationDataset(valid_file, 
                                         src_tokenizer=src_tokenizer, 
                                         tgt_tokenizer=tgt_tokenizer)

In [10]:
print("train samples: {}".format(len(train_dataset)))
print("valid samples: {}".format(len(valid_dataset)))

train samples: 94123
valid samples: 1000


In [11]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                                               collate_fn=partial(pad_to_seq_len, max_seq_len=MAX_SEQ_LEN))
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                                               collate_fn=partial(pad_to_seq_len, max_seq_len=MAX_SEQ_LEN))

In [12]:
data_example = next(iter(train_dataloader))

In [13]:
for t in data_example:
    print(t.shape)

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32])
torch.Size([32])


## create model, etc

the model configuration is loosely based on the *Attention is All You Need* base configuration, with the following changes:

- the token embedding space used is smaller than the transformer input dimension, like ALBERT
- a small amount of dropout is added to the first FFNN projection (`ffnn_dropout`)
- the GELU activation is used in the FFNN layer, like BERT and GPT
- the pre-layernorm configuration is used

In [14]:
mytransformer = TransformerModel(
     src_vocab_sz=VOCAB_SIZE,
     tgt_vocab_sz=VOCAB_SIZE,
     enc_layers=6,
     dec_layers=6,
     seq_len=MAX_SEQ_LEN,
     d_vocab=256,
     d_in=512, 
     d_attn=64, 
     d_ffnn=1024, 
     attn_heads=8, 
     dropout=0.1,
     attn_dropout=0.0, 
     ffnn_dropout=0.05,
     pos_encoding="sinusoidal",
     shared_vocab=False,
     attn_mask_val=-1e08, 
     ffnn_activation="gelu", 
     pre_ln=True
).cuda()

In [15]:
mytransformer.config

{'src_vocab_sz': 8000,
 'tgt_vocab_sz': 8000,
 'enc_layers': 6,
 'dec_layers': 6,
 'seq_len': 256,
 'd_vocab': 256,
 'd_in': 512,
 'd_attn': 64,
 'd_ffnn': 1024,
 'attn_heads': 8,
 'dropout': 0.1,
 'attn_dropout': 0.0,
 'ffnn_dropout': 0.05,
 'pos_encoding': 'sinusoidal',
 'shared_vocab': False,
 'attn_mask_val': -100000000.0,
 'attn_q_bias': False,
 'attn_kv_bias': False,
 'attn_out_bias': False,
 'ffnn_activation': 'gelu',
 'pre_ln': True}

### learning rate scheduling

we'll use the `OneCycleLR` to roughly approximate the warmup and annealing by 'warming up' for 1 epoch and then decaying for 49 epochs (until 50th epoch).

In [16]:
criterion = torch.nn.CrossEntropyLoss(ignore_index=0, reduction="sum")

optimizer = torch.optim.Adam(mytransformer.parameters(), lr=INIT_LR, betas=(0.9, 0.98), eps=1e-09, weight_decay=0.0001, amsgrad=False)

scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, 
                                                max_lr=MAX_LR, 
                                                total_steps=len(train_dataloader)*(WARM_EPOCHS+COOL_EPOCHS), 
                                                pct_start=WARM_EPOCHS/(WARM_EPOCHS+COOL_EPOCHS))


## training loop

In [None]:
mytransformer.train()

global_step = 0

windowed_losses = []

for epoch in range(MAX_EPOCHS):

    print("starting epoch {} of {}".format(epoch+1, MAX_EPOCHS))
    time.sleep(1)
    
    with tqdm.tqdm(train_dataloader, desc="minibatch", total=len(train_dataloader)) as b:

        for batch in train_dataloader:

            x, y_in, y_true, x_lens, y_lens = batch
            x = x.to("cuda")
            y_in = y_in.to("cuda")
            y_true = y_true.to("cuda")
            x_lens = x_lens.to("cuda")
            y_lens = y_lens.to("cuda")

            y_pred = mytransformer(x, y_in, x_lens, y_lens)

            loss = criterion(torch.transpose(y_pred, 1, 2), y_true)
            loss /= torch.sum(y_lens)  # scale by all non-zero elements

            loss.backward() 
            torch.nn.utils.clip_grad_norm_(mytransformer.parameters(), GRAD_CLIP)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            windowed_losses.append(loss.item())
            windowed_losses = windowed_losses[-16:]

            global_step += 1
            
            b.set_postfix(loss="{:.3f}".format(np.mean(windowed_losses)), global_step=global_step)
            b.update(1)
            
    # end of epoch loss
    tme = datetime.datetime.now().isoformat()[11:22]
    print("[{}] epoch {:>03d} global step {:>04d}: loss: {:>8.3f}\tavg: {:>8.3f} (end of epoch)".format(
        tme, epoch+1, global_step, loss.item(), np.mean(windowed_losses)
    ))
        
    # evaluate
    eval_losses = []
    time.sleep(1)
    tme = datetime.datetime.now().isoformat()[11:22]
    print("\n[{}] evaluating...\n".format(tme))
    time.sleep(1)
    
    mytransformer.eval()
    
    with torch.no_grad():
        for batch in tqdm.tqdm(valid_dataloader):
            x, y_in, y_true, x_lens, y_lens = batch
            x = x.to("cuda")
            y_in = y_in.to("cuda")
            y_true = y_true.to("cuda")
            x_lens = x_lens.to("cuda")
            y_lens = y_lens.to("cuda")
            y_pred = mytransformer(x, y_in, x_lens, y_lens)
            loss = criterion(torch.transpose(y_pred, 1, 2), y_true)
            loss /= torch.sum(y_lens)  # scale by all non-zero elements
            eval_losses.append(loss.item())
            
    tme = datetime.datetime.now().isoformat()[11:22]
    print("\n[{}] epoch {:>03d} eval loss: {:>8.3f}".format(tme, epoch+1, np.mean(eval_losses)))
    
    # infer some results
    if (epoch + 1) % 5 == 0 or (epoch + 1) == MAX_EPOCHS:
        time.sleep(1)
        print("\n sample greedy outputs:\n")
        with torch.no_grad():
            for sample in sample_sentences:
                x, x_len = src_tokenizer.transform(sample, as_array=True, bos=True, eos=True, max_len=MAX_SEQ_LEN)
                x = torch.from_numpy(x).long().to("cuda")
                x_len = torch.from_numpy(x_len).long().to("cuda")
                y_hat = mytransformer.infer_one_greedy(x, x_len, bos=2, eos=3)
                y_hat = tgt_tokenizer.inverse_transform([y_hat], as_tokens=False)[0]
                print("\tsrc: {}".format(sample))
                print("\tprd: {}\n".format(y_hat))
    
    
    # save
    torch.save({
            'epoch': epoch+1,
            'global_step': global_step,
            'model_state_dict': mytransformer.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'windowed_losses': windowed_losses,
            'avg_loss': np.mean(windowed_losses),
            'eval_loss': np.mean(eval_losses),
            'training_config': mytransformer.config,
            'batch_size': BATCH_SIZE
            }, checkpoint_file)
    
    print("\n[{}] checkpoint saved!".format(tme))
    
    
    mytransformer.train()
    

starting epoch 1 of 50


minibatch: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2942/2942 [15:44<00:00,  3.12it/s, global_step=2942, loss=7.660]


[19:57:19.60] epoch 001 global step 2942: loss:    7.641	avg:    7.660 (end of epoch)

[19:57:20.60] evaluating...



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:03<00:00,  8.97it/s]



[19:57:25.18] epoch 001 eval loss:    7.643

[19:57:25.18] checkpoint saved!
starting epoch 2 of 50


minibatch: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2942/2942 [15:45<00:00,  3.11it/s, global_step=5884, loss=6.213]


[20:13:13.61] epoch 002 global step 5884: loss:    6.576	avg:    6.213 (end of epoch)

[20:13:14.61] evaluating...



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:03<00:00,  8.99it/s]



[20:13:19.18] epoch 002 eval loss:    6.322

[20:13:19.18] checkpoint saved!
starting epoch 3 of 50


minibatch: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2942/2942 [15:41<00:00,  3.13it/s, global_step=8826, loss=5.538]


[20:29:02.79] epoch 003 global step 8826: loss:    5.752	avg:    5.538 (end of epoch)

[20:29:03.80] evaluating...



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:03<00:00,  8.96it/s]



[20:29:08.37] epoch 003 eval loss:    5.521

[20:29:08.37] checkpoint saved!
starting epoch 4 of 50


minibatch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2942/2942 [15:40<00:00,  3.13it/s, global_step=11768, loss=5.189]


[20:44:51.55] epoch 004 global step 11768: loss:    5.173	avg:    5.189 (end of epoch)

[20:44:52.55] evaluating...



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:03<00:00,  8.98it/s]



[20:44:57.12] epoch 004 eval loss:    5.193

[20:44:57.12] checkpoint saved!
starting epoch 5 of 50


minibatch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2942/2942 [15:39<00:00,  3.13it/s, global_step=14710, loss=4.891]


[21:00:39.34] epoch 005 global step 14710: loss:    4.541	avg:    4.891 (end of epoch)

[21:00:40.34] evaluating...



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:03<00:00,  8.91it/s]



[21:00:44.94] epoch 005 eval loss:    4.923

 sample greedy outputs:

	src: After keeping the world's most powerful supercomputer to themselves for a year, government researchers showed off the $110 million wonder and said it might help save the world from nuclear war.
	prd: 세계 최대 규모의 가장 높은 및 세계 최대 및 세계 및 및 세계 및 및 및 및 세계 및 및 및 세계 및 및 및 및 및 및 및 세계 및 및 및 및 및 세계적인 활동이 있다고 주장했다.

	src: Most of the people involved in the discussion agree that there is a legitimate area in which the government needs to retain the right to intercept communications.
	prd: 대부분의 최대 규모의 최대 규모는 정부의 주장이 있다고 주장했다.

	src: Several Texan transmission companies announced Monday they were forming a consortium to invest in the $5 billion cost of building new power lines to take advantage of the state's vast wind power.
	prd: Condordordordordordordordordordordordordordordordordordordordordordordordordordord in the theeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee

minibatch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2942/2942 [15:40<00:00,  3.13it/s, global_step=17652, loss=4.644]


[21:16:30.92] epoch 006 global step 17652: loss:    4.496	avg:    4.644 (end of epoch)

[21:16:31.93] evaluating...



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:03<00:00,  8.95it/s]



[21:16:36.50] epoch 006 eval loss:    4.685

[21:16:36.50] checkpoint saved!
starting epoch 7 of 50


minibatch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2942/2942 [15:39<00:00,  3.13it/s, global_step=20594, loss=4.473]


[21:32:18.75] epoch 007 global step 20594: loss:    4.404	avg:    4.473 (end of epoch)

[21:32:19.75] evaluating...



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:03<00:00,  8.94it/s]



[21:32:24.34] epoch 007 eval loss:    4.534

[21:32:24.34] checkpoint saved!
starting epoch 8 of 50


minibatch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2942/2942 [15:40<00:00,  3.13it/s, global_step=23536, loss=4.362]


[21:48:07.79] epoch 008 global step 23536: loss:    4.471	avg:    4.362 (end of epoch)

[21:48:08.79] evaluating...



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:03<00:00,  8.94it/s]



[21:48:13.38] epoch 008 eval loss:    4.412

[21:48:13.38] checkpoint saved!
starting epoch 9 of 50


minibatch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2942/2942 [15:40<00:00,  3.13it/s, global_step=26478, loss=4.285]


[22:03:56.89] epoch 009 global step 26478: loss:    4.180	avg:    4.285 (end of epoch)

[22:03:57.90] evaluating...



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:03<00:00,  8.93it/s]



[22:04:02.48] epoch 009 eval loss:    4.326

[22:04:02.48] checkpoint saved!
starting epoch 10 of 50


minibatch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2942/2942 [15:40<00:00,  3.13it/s, global_step=29420, loss=4.251]


[22:19:45.81] epoch 010 global step 29420: loss:    4.394	avg:    4.251 (end of epoch)

[22:19:46.81] evaluating...



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:03<00:00,  8.95it/s]



[22:19:51.39] epoch 010 eval loss:    4.293

 sample greedy outputs:

	src: After keeping the world's most powerful supercomputer to themselves for a year, government researchers showed off the $110 million wonder and said it might help save the world from nuclear war.
	prd: 세계 최대의 세계인 세계 최대의 세계인 세계인 세계 에서 가장 많은 수가 있는 세계인 세계인 세계인의 한 지역은 세계가 세계를 지원할 수 있다고 밝혔다.

	src: Most of the people involved in the discussion agree that there is a legitimate area in which the government needs to retain the right to intercept communications.
	prd: 대부분의 사람들은 정부의 입장을 밝히지 않고 있는 정부를 구성하는 것이 합의를 위해 정부를 구성할 것이라고 주장했다.

	src: Several Texan transmission companies announced Monday they were forming a consortium to invest in the $5 billion cost of building new power lines to take advantage of the state's vast wind power.
	prd: 일부 기업들은 4일(현지시간) 투자은행의 투자은행을 투자하는 주택 투자은행으로 투자하는 주택 투자은행을 투자하는 데 도움이 될 것이라고 밝혔다.


[22:19:51.39] checkpoint saved!
starting epoch 11 of 50


minibatch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2942/2942 [15:40<00:00,  3.13it/s, global_step=32362, loss=4.246]


[22:35:36.03] epoch 011 global step 32362: loss:    4.391	avg:    4.246 (end of epoch)

[22:35:37.03] evaluating...



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:03<00:00,  8.93it/s]



[22:35:41.62] epoch 011 eval loss:    4.251

[22:35:41.62] checkpoint saved!
starting epoch 12 of 50


minibatch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2942/2942 [15:39<00:00,  3.13it/s, global_step=35304, loss=4.187]


[22:51:23.86] epoch 012 global step 35304: loss:    4.104	avg:    4.187 (end of epoch)

[22:51:24.86] evaluating...



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:03<00:00,  8.94it/s]



[22:51:29.45] epoch 012 eval loss:    4.215

[22:51:29.45] checkpoint saved!
starting epoch 13 of 50


minibatch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2942/2942 [15:39<00:00,  3.13it/s, global_step=38246, loss=4.073]


[23:07:11.27] epoch 013 global step 38246: loss:    4.144	avg:    4.073 (end of epoch)

[23:07:12.27] evaluating...



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:03<00:00,  8.95it/s]



[23:07:16.85] epoch 013 eval loss:    4.205

[23:07:16.85] checkpoint saved!
starting epoch 14 of 50


minibatch:  54%|██████████████████████████████████████████████████████████████████████▌                                                           | 1598/2942 [08:29<07:10,  3.12it/s, global_step=39844, loss=4.068]