#### IWSLT German->English translation

This notebook shows a simple example of how to use the transformer provided by this repo for one-direction translation. 

We will use the IWSLT 2016 De-En dataset.

In [1]:
from torchtext import data, datasets
import spacy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

import sys
sys.path.append("..")
from model.transformers import BaseTransformer

from model.utils import device, Batch, BasicIterator
from model.opt import NoamOpt

import time

##### The below does some basic data preprocessing and filtering, in addition to setting special tokens.

In [2]:
de_data = spacy.load('de_core_news_sm')
en_data = spacy.load('en_core_web_sm')

def de_tokenizer(data):
    raw_data = [x.text for x in de_data.tokenizer(data)]
    return raw_data
def en_tokenizer(data):
    raw_data = [x.text for x in en_data.tokenizer(data)]
    return raw_data


BOS = "<s>"
EOS = "</s>"
BLANK = "<blank>"

de = data.Field(tokenize=de_tokenizer, pad_token=BLANK, init_token=BOS, eos_token=EOS)
en = data.Field(tokenize=en_tokenizer, pad_token=BLANK, init_token=BOS, eos_token=EOS)

MAX_LEN = 128

train, val, test = datasets.IWSLT.splits(
    exts=(".de", ".en"), fields=(de, en),
    filter_pred=lambda x: len(vars(x)['src']) <= MAX_LEN and len(vars(x)['trg']) <= MAX_LEN
)

MIN_FREQ = 4

de.build_vocab(train.src, min_freq=MIN_FREQ)
en.build_vocab(train.trg, min_freq=MIN_FREQ)

##### Torchtext required functions. batch_size_fn exists to make sure the batch size stays where it should be.

##### The BasicIterator class helps with dynamic batching, so batches are tightly grouped with minimal padding.

In [4]:
global max_src_in_batch, max_tgt_in_batch
def batch_size_fn(new, count, sofar):
    "Keep augmenting batch and calculate total number of tokens + padding."
    global max_src_in_batch, max_tgt_in_batch
    if count == 1:
        max_src_in_batch = 0
        max_tgt_in_batch = 0
    max_src_in_batch = max(max_src_in_batch,  len(new.src))
    max_tgt_in_batch = max(max_tgt_in_batch,  len(new.trg) + 2)
    src_elements = count * max_src_in_batch
    tgt_elements = count * max_tgt_in_batch
    return max(src_elements, tgt_elements)

train_loader = BasicIterator(train, batch_size=1100, device=torch.device("cuda"),
                   repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
                   batch_size_fn=batch_size_fn, train=True)
val_loader = BasicIterator(val, batch_size=1100, device=torch.device("cuda"),
                   repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
                   batch_size_fn=batch_size_fn, train=False)
test_loader = BasicIterator(test, batch_size=1100, device=torch.device("cuda"),
                   repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
                   batch_size_fn=batch_size_fn, train=False)

##### Single step over entire dataset, with tons of gradient accumulation to get batch sizes big enough for stable training.

In [5]:
def train_step(dataloader):
    i = 0
    loss = 0
    total_loss = 0
    for batch in dataloader:
        source = batch.src
        target = batch.trg
        # Only take a step every 11th batch to simulate bs of ~12k
        if (i + 1) % 11 == 0:
            optimizer.step()
            optimizer.zero_grad()

        loss, _ = transformer.forward_and_return_loss(criterion, source, target)
        loss.backward()
        total_loss += loss.item()
        i += 1

    return total_loss / i

#### Creating the translation model:

Subclassing the Transformer class allows us to implement a forward_and_return_loss_function and generation function, and requires nothing else before being fully functional. 

The Transformer class handles embedding and the transformer layers itself, including an output Linear layer.

The goal of a basic translation model is to recreate the translation given the input (in a different language). We use crossentropy between the target and ground truth.

We use the utils.Batch object to automatically create padding masks, in addition to dec-dec attn. masks.

In [None]:
class TranslationModel(BaseTransformer):
    def __init__(
        self, *args,
    ):
        super(TranslationModel, self).__init__(*args)

    def forward_and_return_loss(self, criterion, sources, targets):
        """
        Pass input through transformer and return loss, handles masking automagically
        Args:
            criterion: torch.nn.functional loss function of choice
            sources: source sequences, [seq_len, bs]
            targets: full target sequence, [seq_len, bs, embedding_dim]

        Returns:
            loss, transformer output
        """

        batch = Batch(sources, targets, self.pad_idx)
        seq_len, batch_size = batch.trg.size()
        out = self.forward(batch.src, batch.trg, batch.src_mask, batch.trg_mask)
        loss = criterion(
            out.contiguous().view(-1, out.size(-1)),
            batch.trg_y.contiguous().view(-1),
            ignore_index=self.pad_idx,
        )

        return loss, out

    def generate(self, source, source_mask, max_len):
        """
        Args:
            source: input sequence indices, [seq_len, bs,
            source_mask: the source mask to prevent attending to <pad> tokens
            max_len: maximum length

        Returns:
            generated translations
        """
        memory = self.encoder(source, source_mask)
        ys = torch.ones(1, source.size(1)).long().fill_(self.sos_idx).to(device)
        # max target length is 1.5x * source + 10 to save compute power
        for _ in range(int(1.5 * source.size(0)) - 1 + 10):
            out = self.decoder(ys, memory, source_mask, Batch(ys, ys, 1).raw_mask)
            out = self.fc1(out[-1].unsqueeze(0))
            prob = F.log_softmax(out, dim=-1)
            next_word = torch.argmax(prob, dim=-1)
            ys = torch.cat([ys, next_word.detach()], dim=0)

        return ys

##### These hyperparameters were set for a GTX980. A bigger GPU, such as a P100 or similar, will be able to handle default transformer hyperparameters and bigger batch sizes.

In [6]:
input_vocab_size = len(de.vocab)
output_vocab_size = len(en.vocab)
embedding_dim = 256
n_layers = 4
hidden_dim = 512
n_heads = 8
dropout_rate = .1
transformer = TranslationModel(input_vocab_size, output_vocab_size, embedding_dim, 
                               n_layers,hidden_dim, n_heads, dropout_rate).to(device)

adamopt = torch.optim.Adam(transformer.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
optimizer = NoamOpt(embedding_dim, 1, 2000, adamopt)
criterion = F.cross_entropy

# optimization is unstable without this step
for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

##### Runs 10 epochs of the entire training dataset.

In [7]:
true_start = time.time()
for i in range(10):
    transformer.train()
    t = time.time()
    
    loss = train_step(train_loader)
    
    print("Epoch {}. Loss: {}, ".format((i+1), str(loss)[:5], int(time.time() - t)))
    print("Total time (s): {}, Last epoch time (s): {}".format(int(time.time()- true_start), int(time.time() - t)))

Epoch 1. Loss: 6.035, 
Total time: 439, Last epoch time (s): 439
Epoch 2. Loss: 4.087, 
Total time: 901, Last epoch time (s): 461
Epoch 3. Loss: 3.289, 
Total time: 1346, Last epoch time (s): 445
Epoch 4. Loss: 2.814, 
Total time: 1786, Last epoch time (s): 439
Epoch 5. Loss: 2.519, 
Total time: 2247, Last epoch time (s): 461
Epoch 6. Loss: 2.295, 
Total time: 2674, Last epoch time (s): 426
Epoch 7. Loss: 2.053, 
Total time: 3102, Last epoch time (s): 427
Epoch 8. Loss: 1.858, 
Total time: 3527, Last epoch time (s): 425
Epoch 9. Loss: 1.698, 
Total time: 3951, Last epoch time (s): 424
Epoch 10. Loss: 1.561, 
Total time: 4376, Last epoch time (s): 424


In [8]:
torch.save(transformer, "basic_translation.pt")

##### Finally, generations. 


The model by default uses greedy decoding for generation, and does not have incremental decoding. Currently, this leads to the transformer generating at about 1/2 the speed of Fairseq for short sequences. 

Implementing incremental decoding, however, makes the code for the attention function much harder to read, and has been left out for now. 

In [9]:
transformer.eval()
new_batch = next(iter(val_loader))
inp = new_batch.src
tra = new_batch.trg

out = transformer.generate(inp, Batch(inp, inp, 1).src_mask, 120)
for i in range(len(inp)):
    print("Input sentence: ", end="")
    for j in range(1, inp.size(0)):
        char = de.vocab.itos[inp[j, i]]
        if char == "</s>": 
            break
        print(char, end =" ")
    print("\nPredicted translation: ", end="")
    for j in range(1, out.size(0)):
        char = en.vocab.itos[out[j, i]]
        if char == "</s>": 
            break
        print(char, end =" ")
    print("\nGround truth translation: ", end="")
    for j in range(1, tra.size(0)):
        char = en.vocab.itos[tra[j, i]]
        if char == "</s>": 
            break
        print(char, end =" ")    
    print("\n")

Input sentence: Und der Garten ist wunderschön . 
Predicted translation: And the garden is beautiful . 
Ground truth translation: And the garden , it was beautiful . 

Input sentence: Die <unk> ist nicht nachhaltig . 
Predicted translation: The <unk> system is n't sustainable . 
Ground truth translation: The internal combustion engine is not sustainable . 

Input sentence: Wir sehen immer dieselben Symptome . 
Predicted translation: We see the same symptoms . 
Ground truth translation: We see all the same symptoms . 

Input sentence: Sie ist keine <unk> . " 
Predicted translation: It 's not a <unk> . " 
Ground truth translation: She 's not North Korean . " 

Input sentence: Weil es so schön klingt . 
Predicted translation: Because it sounds beautiful . 
Ground truth translation: Just because it sounds so good . 

Input sentence: Aber man muss es pflegen . 
Predicted translation: But you have to care about it . 
Ground truth translation: But you have to maintain it . 

Input sentence: D