#### IWSLT German->English translation

This notebook shows a simple example of how to use the transformer provided by this repo for one-direction translation. 

We will use the IWSLT 2016 De-En dataset.

In [1]:
from torchtext import data, datasets
import spacy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

import sys
sys.path.append("..")
from model.transformers import BaseTransformer

from model.utils import Batch, BasicIterator
from model.opt import NoamOpt

import time

In [2]:
import torch.multiprocessing as mp


In [3]:
import torch.distributed as dist

##### The below does some basic data preprocessing and filtering, in addition to setting special tokens.

In [4]:
import pytorch_lightning as pl


In [5]:
de_data = spacy.load('de_core_news_sm')
en_data = spacy.load('en_core_web_sm')

def de_tokenizer(data):
    raw_data = [x.text for x in de_data.tokenizer(data)]
    return raw_data
def en_tokenizer(data):
    raw_data = [x.text for x in en_data.tokenizer(data)]
    return raw_data


BOS = "<s>"
EOS = "</s>"
BLANK = "<blank>"

de = data.Field(tokenize=de_tokenizer, pad_token=BLANK, init_token=BOS, eos_token=EOS)
en = data.Field(tokenize=en_tokenizer, pad_token=BLANK, init_token=BOS, eos_token=EOS)

MAX_LEN = 128

train, val, test = datasets.IWSLT.splits(
    exts=(".de", ".en"), fields=(de, en),
    filter_pred=lambda x: len(vars(x)['src']) <= MAX_LEN and len(vars(x)['trg']) <= MAX_LEN
)

MIN_FREQ = 4

de.build_vocab(train.src, min_freq=MIN_FREQ)
en.build_vocab(train.trg, min_freq=MIN_FREQ)



In [41]:
global max_src_in_batch, max_tgt_in_batch
def batch_size_fn(new, count, sofar):
    "Keep augmenting batch and calculate total number of tokens + padding."
    global max_src_in_batch, max_tgt_in_batch
    if count == 1:
        max_src_in_batch = 0
        max_tgt_in_batch = 0
    max_src_in_batch = max(max_src_in_batch,  len(new.src))
    max_tgt_in_batch = max(max_tgt_in_batch,  len(new.trg) + 2)
    src_elements = count * max_src_in_batch
    tgt_elements = count * max_tgt_in_batch
    return max(src_elements, tgt_elements)

train_loader = BasicIterator(train, batch_size=6000,
                   repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
                   batch_size_fn=batch_size_fn, train=True)

train_loader = torch.utils.data.DataLoader(train, batch_size=32)

val_loader = BasicIterator(val, batch_size=6000,
                   repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
                   batch_size_fn=batch_size_fn, train=False)
test_loader = BasicIterator(test, batch_size=6000,
                   repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
                   batch_size_fn=batch_size_fn, train=False)

##### Single step over entire dataset, with tons of gradient accumulation to get batch sizes big enough for stable training.

In [7]:
# del transformer
# torch.cuda.empty_cache()

In [8]:
class TranslationModel(BaseTransformer):
    def __init__(
        self, *args,
    ):
        super(TranslationModel, self).__init__(*args)

    def forward_and_return_loss(self, criterion, sources, targets):
        """
        Pass input through transformer and return loss, handles masking automagically
        Args:
            criterion: torch.nn.functional loss function of choice
            sources: source sequences, [seq_len, bs]
            targets: full target sequence, [seq_len, bs, embedding_dim]

        Returns:
            loss, transformer output
        """

        batch = Batch(sources, targets, self.pad_idx)
        seq_len, batch_size = batch.trg.size()
        out = self.forward(batch.src, batch.trg, batch.src_mask.to(self.device), batch.trg_mask.to(self.device))
        loss = criterion(
            out.contiguous().view(-1, out.size(-1)),
            batch.trg_y.contiguous().view(-1),
            ignore_index=self.pad_idx,
        )

        return loss, out

    def generate(self, source, source_mask, max_len):
        """
        Args:
            source: input sequence indices, [seq_len, bs,
            source_mask: the source mask to prevent attending to <pad> tokens
            max_len: maximum length

        Returns:
            generated translations
        """
        memory = self.encoder(source, source_mask)
        ys = torch.ones(1, source.size(1)).long().fill_(self.sos_idx).to(device)
        # max target length is 1.5x * source + 10 to save compute power
        for _ in range(int(1.5 * source.size(0)) - 1 + 10):
            out = self.decoder(ys, memory, source_mask, Batch(ys, ys, 1).raw_mask)
            out = self.fc1(out[-1].unsqueeze(0))
            prob = F.log_softmax(out, dim=-1)
            next_word = torch.argmax(prob, dim=-1)
            ys = torch.cat([ys, next_word.detach()], dim=0)

        return ys

In [38]:
class TranslationTransformer(pl.LightningModule):
    def __init__(self, transformer):
        super().__init__()
        self.transformer = transformer
        
    def forward(self, x, y, src_mask, trg_mask):
        return self.transformer(x, y, src_mask, trg_mask)
    
    def training_step(self, batch, batch_idx):
        print(batch)
        
        batch_handler = Batch(batch.src, batch.trg, 1)

            
        out = transformer(batch_handler.src.transpose(0, 1).to(self.device), 
                          batch_handler.trg.transpose(0, 1).to(self.device), 
                          batch_handler.src_mask.to(self.device), 
                          batch_handler.trg_mask.to(self.device))
        
        loss = criterion(
            out.contiguous().view(-1, out.size(-1)),
            batch_handler.trg_y.to(self.device).contiguous().view(-1)
        )
        
        result = pl.TrainResult(loss)
        result.log("train_loss", loss, on_epoch=True)
        return result
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)


In [42]:
input_vocab_size = len(de.vocab)
output_vocab_size = len(en.vocab)

#input_vocab_size = 2000
#output_vocab_size = 2000
embedding_dim = 256
n_layers = 4
hidden_dim = 512
n_heads = 8
dropout_rate = .1
transformer = TranslationModel(input_vocab_size, output_vocab_size, embedding_dim, 
                               n_layers,hidden_dim, n_heads, dropout_rate)

# optimization is unstable without this step
for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

In [43]:
m = TranslationTransformer(transformer)
criterion = nn.CrossEntropyLoss(ignore_index = 1)

trainer = pl.Trainer(max_epochs=1, gpus=1)#, distributed_backend='dp')

trainer.fit(m, train_loader, )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name        | Type             | Params
-------------------------------------------------
0 | transformer | TranslationModel | 28 M  


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'torchtext.data.example.Example'>

##### Runs 10 epochs of the entire training dataset.

In [None]:
true_start = time.time()
world_size = 0
transformer = nn.DataParallel(transformer, device_ids=[0, 1, 2, 3])
# torch.cuda.set_device(0)
transformer = transformer.cuda()
for i in range(10):
    transformer.train()
    t = time.time()
    
    loss = train_step(train_loader)
    
    print("Epoch {}. Loss: {}, ".format((i+1), str(loss)[:5], int(time.time() - t)))
    print("Total time (s): {}, Last epoch time (s): {}".format(int(time.time()- true_start), int(time.time() - t)))

In [None]:
torch.save(transformer, "basic_translation.pt")

##### Finally, generations. 


The model by default uses greedy decoding for generation, and does not have incremental decoding. Currently, this leads to the transformer generating at about 1/2 the speed of Fairseq for short sequences. 

Implementing incremental decoding, however, makes the code for the attention function much harder to read, and has been left out for now. 

In [None]:
transformer.eval()
new_batch = next(iter(val_loader))
inp = new_batch.src
tra = new_batch.trg

out = transformer.generate(inp, Batch(inp, inp, 1).src_mask, 120)
for i in range(len(inp)):
    print("Input sentence: ", end="")
    for j in range(1, inp.size(0)):
        char = de.vocab.itos[inp[j, i]]
        if char == "</s>": 
            break
        print(char, end =" ")
    print("\nPredicted translation: ", end="")
    for j in range(1, out.size(0)):
        char = en.vocab.itos[out[j, i]]
        if char == "</s>": 
            break
        print(char, end =" ")
    print("\nGround truth translation: ", end="")
    for j in range(1, tra.size(0)):
        char = en.vocab.itos[tra[j, i]]
        if char == "</s>": 
            break
        print(char, end =" ")    
    print("\n")