# Transformer for Translation Task

this notebook demonstrates training a transformer for English to Dutch translation from scratch.

the `SimpleTranslationDataset` is a basic in-memory dataset that uses on-the-fly `sentencepiece` tokenization.

the `TransformerModel` is our model class, which includes the transformer encoder and decoder layers, as well as the input embedding, positional encoding, and output token prediction layers.

details about each component, as well as the training loop, are explained in each section below.

In [1]:
import datetime
import dill as pickle
import numpy as np
import os
import torch
import time
import tqdm
import traceback

In [2]:
from functools import partial
from mytransformers.data import SimpleTranslationDataset
from mytransformers.data import pad_to_seq_len
from mytransformers.models import TransformerModel

## training config

these values are set roughly based on the original *Attention is All You Need* paper.  
vocabulary size is reduced from 37,000 to 16,000 because we are using separate encoder and decoder embedding spaces, and our dataset is smaller (2M pairs vs WMT 20014 en-de's 4.5M pairs).  
warmup steps are slightly increased from 4,000 to 5,000, and gradient clipping is applied with max norm of 5.

In [3]:
MAX_SEQ_LEN =    256
VOCAB_SIZE  =  16000
BATCH_SIZE  =     16
WARM_STEPS =    5000  # loss increase (warmup) steps
COOL_STEPS =  495000  # loss decease steps
MAX_STEPS  = 1000000  # total training steps
EVAL_EVERY =   20000  # run evaluation every n steps
EVAL_STEPS =     500  # only run n steps of eval set
INIT_LR     = 0.0001  # staring learning rate pre-warmup
MAX_LR      = 0.001   # maximum learning rate after warmup
GRAD_CLIP   = 5.0     # gradient norm clip value
LOSS_WIN    = 32      # use the last n losses for rolling loss
RETRAIN_TOKENIZER = False

## make datasets and dataloaders

we are training on europarl english-dutch data, which has approx. 2M pairs: https://www.statmt.org/europarl/

train and valid split 90/10 like so:

```
$ awk 'NR % 10 != 1' europarl-v7.nl-en.en > train.en
$ awk 'NR % 10 == 1' europarl-v7.nl-en.en > valid.en
$ awk 'NR % 10 != 1' europarl-v7.nl-en.nl > train.nl
$ awk 'NR % 10 == 1' europarl-v7.nl-en.nl > valid.nl
```

In [4]:
train_source_file = "data/europarl-en-nl/train.en"
train_target_file = "data/europarl-en-nl/train.nl"

valid_source_file = "data/europarl-en-nl/valid.en"
valid_target_file = "data/europarl-en-nl/valid.nl"

src_tokenizer_path = "data/europarl-en-nl/src_tokenizer_v{}.pkl".format(VOCAB_SIZE)
tgt_tokenizer_path = "data/europarl-en-nl/tgt_tokenizer_v{}.pkl".format(VOCAB_SIZE)
checkpoint_file = "data/translation_news/checkpoint_v{}.pt".format(VOCAB_SIZE)

# three sentences from the validation set
sample_sentences = [
    "The key goal of the structural funds is to strengthen social and economic cohesion between the regions within the European Union.",
    "There is, in fact, a risk of a military coup in the future.",
    "This means that there must be a comprehensive partnership between local authorities and national governments with regard to how these funds are to be spent."
]

In [5]:
%%time
if os.path.exists(src_tokenizer_path) and os.path.exists(tgt_tokenizer_path) and not RETRAIN_TOKENIZER:
    print("loading saved tokenizers...")
    src_tokenizer = pickle.load(open(src_tokenizer_path, "rb"))
    tgt_tokenizer = pickle.load(open(tgt_tokenizer_path, "rb"))
    train_dataset = SimpleTranslationDataset(
        source_file=train_source_file,
        target_file=train_target_file,
        src_tokenizer=src_tokenizer, 
        tgt_tokenizer=tgt_tokenizer
    )
else:
    print("training new tokenizers...")
    
    # first, create the training dataset with only input, output texts
    # this will train new source, target tokenizers (or single tokenizer if share_tokenizer is True)
    train_dataset = SimpleTranslationDataset(
        source_file=train_source_file,
        target_file=train_target_file,
        vocab_size=VOCAB_SIZE
    )
    
    # you can then get the tokenizers, and pickle them
    src_tokenizer, tgt_tokenizer = train_dataset.get_tokenizers()
    pickle.dump(src_tokenizer, open(src_tokenizer_path, "wb"))
    pickle.dump(tgt_tokenizer, open(tgt_tokenizer_path, "wb"))
    
    # you may also export the id : token mapping
    src_tokenizer.export_vocab(src_tokenizer_path.replace(".pkl", ".vocab.txt"))
    src_tokenizer.export_vocab(tgt_tokenizer_path.replace(".pkl", ".vocab.txt"))
    

loading saved tokenizers...
CPU times: user 1.42 s, sys: 1.28 s, total: 2.7 s
Wall time: 2.7 s


In [6]:
# we can then initialize the validation dataset with the pre-fit tokenizers
# this will skip the token-fitting and use the pre-fit tokenizers instead
valid_dataset = SimpleTranslationDataset(
        source_file=valid_source_file,
        target_file=valid_target_file, 
        src_tokenizer=src_tokenizer, 
        tgt_tokenizer=tgt_tokenizer
)

In [7]:
print("train samples: {}".format(len(train_dataset)))
print("valid samples: {}".format(len(valid_dataset)))

train samples: 1797997
valid samples: 199778


In [8]:
# the dataset class has a visual tokenization check of first, middle, and last data (to ensure alignment)
train_dataset.preview()

index 0
	source input : I declare resumed the session of the European Parliament adjourned on Friday 17 December 1999, and I would like once again to wish you a happy new year in the hope that you enjoyed a pleasant festive period.
	target input : Ik verklaar de zitting van het Europees Parlement, die op vrijdag 17 december werd onderbroken, te zijn hervat. Ik wens u allen een gelukkig nieuwjaar en hoop dat u een goede vakantie heeft gehad.
	source tokens: ▁I ▁declare ▁resumed ▁the ▁session ▁of ▁the ▁European ▁Parliament ▁adjourned ▁on ▁Friday ▁17 ▁December ▁1999 , ▁and ▁I ▁would ▁like ▁once ▁again ▁to ▁wish ▁you ▁a ▁happy ▁new ▁year ▁in ▁the ▁hope ▁that ▁you ▁enjoyed ▁a ▁pleasant ▁f est ive ▁period .
	target tokens: ▁Ik ▁verklaar ▁de ▁zitting ▁van ▁het ▁Europees ▁Parlement , ▁die ▁op ▁vrijdag ▁17 ▁december ▁werd ▁onderbroken , ▁te ▁zijn ▁hervat . ▁Ik ▁wens ▁u ▁allen ▁een ▁gelukkig ▁nieuw jaar ▁en ▁hoop ▁dat ▁u ▁een ▁goede ▁vakantie ▁heeft ▁gehad .
-------------------------------------

In [9]:
# the dataloaders need a collate_fn to zero-pad the results
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                                               collate_fn=partial(pad_to_seq_len, max_seq_len=MAX_SEQ_LEN))
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                                               collate_fn=partial(pad_to_seq_len, max_seq_len=MAX_SEQ_LEN))

In [10]:
print(MAX_STEPS)
print(len(train_dataloader))
print(len(valid_dataloader))
print(MAX_STEPS/len(train_dataloader))

1000000
112375
12487
8.898776418242491


In [11]:
data_example = next(iter(train_dataloader))

In [12]:
for t in data_example:
    print(t.shape)

torch.Size([16, 256])
torch.Size([16, 256])
torch.Size([16, 256])
torch.Size([16])
torch.Size([16])


## create model

the model configuration is loosely based on the *Attention is All You Need* base configuration, with the following changes:

- the token embedding space used is smaller than the transformer input dimension, like ALBERT
- the original Transformer paper seems to suggest weight tying in section 3.4, but following other implementations, we disable this with (`weight_tying=False`)
- multi-head attention q, k, v dimension is not necessarily == d_model / heads, following other implementations
- a small amount of dropout is added to the query, key and value attention inputs (`attn_dropout`) and the first FFNN projection (`ffnn_dropout`)
- the GELU activation is used in the FFNN layer, like BERT and GPT (it supports "relu", "selu" or "gelu")
- the pre-layer norm ("pre-LN Transformer") configuration is used

In [13]:
mytransformer = TransformerModel(
     src_vocab_sz=VOCAB_SIZE,
     tgt_vocab_sz=VOCAB_SIZE,
     enc_layers=6,
     dec_layers=6,
     seq_len=MAX_SEQ_LEN,
     d_vocab=128,
     d_model=512, 
     d_attn=128,
     d_ffnn=2048, 
     attn_heads=8, 
     dropout=0.1,
     attn_dropout=0.05, 
     ffnn_dropout=0.05,
     pos_encoding="sinusoidal",
     shared_vocab=False,
     weight_tying=False,
     attn_mask_val=-1e08, 
     ffnn_activation="gelu", 
     pre_ln=True
).cuda()

# this initializes parameters with xavier uniform
mytransformer.initialize()


In [14]:
mytransformer.config

{'src_vocab_sz': 16000,
 'tgt_vocab_sz': 16000,
 'enc_layers': 6,
 'dec_layers': 6,
 'seq_len': 256,
 'd_vocab': 128,
 'd_model': 512,
 'd_attn': 128,
 'd_ffnn': 2048,
 'attn_heads': 8,
 'dropout': 0.1,
 'attn_dropout': 0.05,
 'ffnn_dropout': 0.05,
 'pos_encoding': 'sinusoidal',
 'shared_vocab': False,
 'weight_tying': False,
 'attn_mask_val': -100000000.0,
 'attn_q_bias': False,
 'attn_kv_bias': False,
 'attn_out_bias': False,
 'ffnn_activation': 'gelu',
 'pre_ln': True}

### learning rate scheduling

out of laziness, we'll use the torch default `OneCycleLR` scheduler to roughly approximate the warmup and annealing from *Attention is All You Need*

In [15]:
criterion = torch.nn.CrossEntropyLoss(ignore_index=0, reduction="sum")

optimizer = torch.optim.Adam(mytransformer.parameters(), lr=INIT_LR, betas=(0.9, 0.98), eps=1e-09, weight_decay=0.0001, amsgrad=False)

scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, 
                                                max_lr=MAX_LR, 
                                                total_steps=WARM_STEPS+COOL_STEPS, 
                                                pct_start=WARM_STEPS/(WARM_STEPS+COOL_STEPS))


## training loop

for easier `tqdm` support for arbitrary number of steps between evaluations, we eschew the usual "for epoch in epoch, for batch in dataset" and instead use a while loop and a try-except that will tick up each epoch as we finish it. it's slightly convoluted but it provides a way to view progress for arbitrary step count between evaluations (due to the dataset size, evaluating after every epoch would mean waiting 112,000 steps at a minibatch size of 16.)

In [None]:
mytransformer.train()

train_iterator = iter(train_dataloader)

global_step = 0

epoch = 0

windowed_losses = []

pbar = tqdm.tqdm(total=EVAL_EVERY)

while global_step < MAX_STEPS:

    try:
        batch = next(train_iterator)
    except:
        epoch += 1
        train_iterator = iter(train_dataloader)
        batch = next(train_iterator)

    x, y_in, y_true, x_lens, y_lens = batch
    x = x.to("cuda")
    y_in = y_in.to("cuda")
    y_true = y_true.to("cuda")
    x_lens = x_lens.to("cuda")
    y_lens = y_lens.to("cuda")

    _, y_pred = mytransformer(x, y_in, x_lens, y_lens)

    loss = criterion(y_pred.transpose(1, 2), y_true)
    loss /= torch.sum(y_lens)  # scale by all non-zero elements

    loss.backward() 
    # torch.nn.utils.clip_grad_norm_(mytransformer.parameters(), GRAD_CLIP)
    optimizer.step()
    # don't step the scheduler past its max step
    if global_step <= (WARM_STEPS + COOL_STEPS):
        scheduler.step()
    optimizer.zero_grad()

    windowed_losses.append(loss.item())
    windowed_losses = windowed_losses[-LOSS_WIN:]

    global_step += 1

    pbar.set_postfix(loss="{:.3f}".format(np.mean(windowed_losses)), global_step=global_step)
    pbar.update(1)
    
    if global_step % EVAL_EVERY == 0:
        
        pbar.close()
        time.sleep(1)
            
        # end of epoch loss
        tme = datetime.datetime.now().isoformat()[11:22]
        print("[{}] epoch {:>03d} global step {:>04d}: loss: {:>8.3f}\tavg: {:>8.3f}".format(
            tme, epoch+1, global_step, loss.item(), np.mean(windowed_losses)
        ))

        # evaluate
        eval_losses = []
        time.sleep(1)
        tme = datetime.datetime.now().isoformat()[11:22]
        print("\n[{}] evaluating...\n".format(tme))
        time.sleep(1)

        mytransformer.eval()

        with torch.no_grad():
            for idx, batch in tqdm.tqdm(enumerate(valid_dataloader)):
                x, y_in, y_true, x_lens, y_lens = batch
                x = x.to("cuda")
                y_in = y_in.to("cuda")
                y_true = y_true.to("cuda")
                x_lens = x_lens.to("cuda")
                y_lens = y_lens.to("cuda")
                _, y_pred = mytransformer(x, y_in, x_lens, y_lens)
                loss = criterion(y_pred.transpose(1, 2), y_true)
                loss /= torch.sum(y_lens)  # scale by all non-zero elements
                eval_losses.append(loss.item())
                if idx >= EVAL_STEPS:
                    break
            time.sleep(1)

        tme = datetime.datetime.now().isoformat()[11:22]
        print("\n[{}] epoch {:>03d} eval loss: {:>8.3f}".format(tme, epoch+1, np.mean(eval_losses)))

        # infer some results
        time.sleep(1)
        print("\n sample greedy outputs:\n")
        with torch.no_grad():
            for sample in sample_sentences:
                x, x_len = src_tokenizer.transform(sample, as_array=True, bos=True, eos=True, max_len=MAX_SEQ_LEN)
                x = torch.from_numpy(x).long().to("cuda")
                x_len = torch.from_numpy(x_len).long().to("cuda")
                y_hat = mytransformer.infer_one_greedy(x, x_len, bos=2, eos=3)
                y_hat = tgt_tokenizer.inverse_transform([y_hat], as_tokens=False)[0]
                print("\tsrc: {}".format(sample))
                print("\tprd: {}\n".format(y_hat))

        # save
        torch.save({
                'epoch': epoch+1,
                'global_step': global_step,
                'model_state_dict': mytransformer.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'windowed_losses': windowed_losses,
                'avg_loss': np.mean(windowed_losses),
                'eval_loss': np.mean(eval_losses),
                'training_config': mytransformer.config,
                'batch_size': BATCH_SIZE
                }, checkpoint_file)

        print("\n[{}] checkpoint saved!".format(tme))

        mytransformer.train()
        
        pbar = tqdm.tqdm(total=EVAL_EVERY)


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20000/20000 [1:29:06<00:00,  3.74it/s, global_step=2e+4, loss=4.296]


[01:57:36.73] epoch 024 global step 20000: loss:    4.382	avg:    4.296

[01:57:37.74] evaluating...



500it [00:47, 10.60it/s]



[01:58:26.91] epoch 024 eval loss:    4.068

 sample greedy outputs:

	src: The key goal of the structural funds is to strengthen social and economic cohesion between the regions within the European Union.
	prd: Het belangrijkste doel is de structuurfondsen te versterken om de economische cohesie en economische cohesie tussen de Europese Unie te versterken.

	src: There is, in fact, a risk of a military coup in the future.
	prd: Er is een risico voor een toekomst van een militaire toekomst.

	src: This means that there must be a comprehensive partnership between local authorities and national governments with regard to how these funds are to be spent.
	prd: Dit betekent dat er een akkoord tussen een interinstitutionele en regionale autoriteiten moeten worden ontwikkeld om de lidstaten te maken.



  0%|                                                                                                                                                                                      | 0/20000 [00:00<?, ?it/s]


[01:58:26.91] checkpoint saved!


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20000/20000 [1:29:04<00:00,  3.74it/s, global_step=4e+4, loss=4.223]


[03:27:36.89] epoch 055 global step 40000: loss:    4.657	avg:    4.223

[03:27:37.89] evaluating...



500it [00:47, 10.60it/s]



[03:28:27.05] epoch 055 eval loss:    4.012

 sample greedy outputs:

	src: The key goal of the structural funds is to strengthen social and economic cohesion between the regions within the European Union.
	prd: Het doel van destructuurfondsen is de economische en sociale samenhang tussen de Europese Unie en de Europese Unie.

	src: There is, in fact, a risk of a military coup in the future.
	prd: Er is een risico, een gevaar voor een toekomst van de toekomst.

	src: This means that there must be a comprehensive partnership between local authorities and national governments with regard to how these funds are to be spent.
	prd: Dat betekent dat er een goed partnerschap tussen regeringen en nationale regeringen moeten worden besteed aan de nationale regeringen van deze middelen.



  0%|                                                                                                                                                                                      | 0/20000 [00:00<?, ?it/s]


[03:28:27.05] checkpoint saved!


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20000/20000 [1:29:10<00:00,  3.74it/s, global_step=6e+4, loss=4.077]


[04:57:43.00] epoch 076 global step 60000: loss:    4.556	avg:    4.077

[04:57:44.00] evaluating...



500it [00:47, 10.61it/s]



[04:58:33.14] epoch 076 eval loss:    3.995

 sample greedy outputs:

	src: The key goal of the structural funds is to strengthen social and economic cohesion between the regions within the European Union.
	prd: Het doel van de structuurfondsen is de versterking van de versterking van de economische en cohesie tussen de Europese Unie.

	src: There is, in fact, a risk of a military coup in the future.
	prd: Er is een risico, een risico voor een risico in de toekomst.

	src: This means that there must be a comprehensive partnership between local authorities and national governments with regard to how these funds are to be spent.
	prd: Dit betekent dat er een evenwichtige samenwerking tussen de nationale overheden en de nationale overheden worden uitgevoerd.



  0%|                                                                                                                                                                                      | 0/20000 [00:00<?, ?it/s]


[04:58:33.14] checkpoint saved!


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20000/20000 [1:29:10<00:00,  3.74it/s, global_step=8e+4, loss=4.124]


[06:27:49.10] epoch 096 global step 80000: loss:    3.985	avg:    4.124

[06:27:50.10] evaluating...



500it [00:47, 10.59it/s]



[06:28:39.34] epoch 096 eval loss:    3.959

 sample greedy outputs:

	src: The key goal of the structural funds is to strengthen social and economic cohesion between the regions within the European Union.
	prd: Het belangrijkste doel van de structurele structuurfondsen is een belangrijk doel tussen de sociale en sociale cohesie tussen de Europese Unie en de Europese Unie.

	src: There is, in fact, a risk of a military coup in the future.
	prd: Er is echter een risicobeoordeling van een militaire risicobeoordeling.

	src: This means that there must be a comprehensive partnership between local authorities and national governments with regard to how these funds are to be spent.
	prd: Het betekent dat er een open en ander partnerschap tussen de nationale autoriteiten en de nationale regeringen moet worden aangepakt.



  0%|                                                                                                                                                                                      | 0/20000 [00:00<?, ?it/s]


[06:28:39.34] checkpoint saved!


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20000/20000 [1:29:11<00:00,  3.74it/s, global_step=1e+5, loss=4.071]


[07:57:56.58] epoch 117 global step 100000: loss:    4.202	avg:    4.071

[07:57:57.59] evaluating...



500it [00:47, 10.61it/s]



[07:58:46.72] epoch 117 eval loss:    3.948

 sample greedy outputs:

	src: The key goal of the structural funds is to strengthen social and economic cohesion between the regions within the European Union.
	prd: Het is van cruciaal belang dat de structuurfondsen middelen en de economische samenhang tussen de economische en sociale samenhang tussen de Europese Unie.

	src: There is, in fact, a risk of a military coup in the future.
	prd: Er is een risico dat een risico is een risico op militaire manier in de toekomst.

	src: This means that there must be a comprehensive partnership between local authorities and national governments with regard to how these funds are to be spent.
	prd: Dat betekent dat er een partnerschap tussen de plaatselijke overheden en de nationale overheden moeten worden besteed aan de financiële middelen.



  0%|                                                                                                                                                                                      | 0/20000 [00:00<?, ?it/s]


[07:58:46.72] checkpoint saved!


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20000/20000 [1:29:16<00:00,  3.73it/s, global_step=120000, loss=3.992]


[09:28:09.03] epoch 137 global step 120000: loss:    3.948	avg:    3.992

[09:28:10.03] evaluating...



500it [00:47, 10.58it/s]



[09:28:59.28] epoch 137 eval loss:    3.906

 sample greedy outputs:

	src: The key goal of the structural funds is to strengthen social and economic cohesion between the regions within the European Union.
	prd: Het belangrijkste doel van de structuurfondsen is een betere samenhang tussen de sociale samenhang en de sociale samenhang binnen de Europese Unie.

	src: There is, in fact, a risk of a military coup in the future.
	prd: Er is een risico op een militaire militaire militaire militaire macht.

	src: This means that there must be a comprehensive partnership between local authorities and national governments with regard to how these funds are to be spent.
	prd: Dit betekent dat er een associatie tussen de nationale autoriteiten en nationale autoriteiten worden betrokken.



  0%|                                                                                                                                                                                      | 0/20000 [00:00<?, ?it/s]


[09:28:59.28] checkpoint saved!


 21%|████████████████████████████▎                                                                                                            | 4135/20000 [18:27<1:11:06,  3.72it/s, global_step=124135, loss=3.993]

In [None]:
# with masks: stuck at ~4.6