In [1]:
from torchtext import data, datasets
import spacy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

from model.transformers import MLT

from model.utils import device, Batch
from model.opt import NoamOpt

import time

In [2]:
de_data = spacy.load('de_core_news_sm')
en_data = spacy.load('en_core_web_sm')

def de_tokenizer(data):
    raw_data = [x.text for x in de_data.tokenizer(data)]
    return raw_data
def en_tokenizer(data):
    raw_data = [x.text for x in en_data.tokenizer(data)]
    return raw_data


BOS = "<s>"
EOS = "</s>"
BLANK = "<blank>"

de = data.Field(tokenize=de_tokenizer, pad_token=BLANK, init_token=BOS, eos_token=EOS)
en = data.Field(tokenize=en_tokenizer, pad_token=BLANK, init_token=BOS, eos_token=EOS)

MAX_LEN = 128

train, val, test = datasets.IWSLT.splits(
    exts=(".de", ".en"), fields=(de, en),
    filter_pred=lambda x: len(vars(x)['src']) <= MAX_LEN and len(vars(x)['trg']) <= MAX_LEN
)

MIN_FREQ = 4

de.build_vocab(train.src, min_freq=MIN_FREQ)
en.build_vocab(train.trg, min_freq=MIN_FREQ)

In [3]:
# This class, inspired by "The Annotated Transformer", searches
# over tons of batches to find tight clusters of sentence
# lengths. This is to keep padding minimal.
class BasicIterator(data.Iterator):
    def create_batches(self):
        if self.train:
            def pool_batch(d, random_shuffler):
                for p in data.batch(d, self.batch_size * 100):
                    p_batch = data.batch(
                        sorted(p, key=self.sort_key),
                        self.batch_size, self.batch_size_fn
                    )
                    for b in random_shuffler(list(p_batch)):
                        yield b
            self.batches = pool_batch(self.data(), self.random_shuffler)
        else:
            self.batches = []
            for b in data.batch(self.data(), self.batch_size, self.batch_size_fn):
                self.batches.append(sorted(b, key=self.sort_key))

In [4]:
global max_src_in_batch, max_tgt_in_batch
def batch_size_fn(new, count, sofar):
    "Keep augmenting batch and calculate total number of tokens + padding."
    global max_src_in_batch, max_tgt_in_batch
    if count == 1:
        max_src_in_batch = 0
        max_tgt_in_batch = 0
    max_src_in_batch = max(max_src_in_batch,  len(new.src))
    max_tgt_in_batch = max(max_tgt_in_batch,  len(new.trg) + 2)
    src_elements = count * max_src_in_batch
    tgt_elements = count * max_tgt_in_batch
    return max(src_elements, tgt_elements)

train_loader = BasicIterator(train, batch_size=1100, device=torch.device("cuda"),
                   repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
                   batch_size_fn=batch_size_fn, train=True)
val_loader = BasicIterator(val, batch_size=1100, device=torch.device("cuda"),
                   repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
                   batch_size_fn=batch_size_fn, train=False)
test_loader = BasicIterator(test, batch_size=1100, device=torch.device("cuda"),
                   repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
                   batch_size_fn=batch_size_fn, train=False)

In [5]:
def train_step(dataloader):
    i = 0
    loss = 0
    total_loss = 0
    for batch in dataloader:
        source = batch.src
        target = batch.trg
        # Only take a step every 11th batch to simulate bs of ~12k
        if (i + 1) % 11 == 0:
            optimizer.step()
            optimizer.zero_grad()

        loss, _ = transformer.forward_and_return_loss(criterion, source, target)
        loss.backward()
        total_loss += loss.item()
        i += 1

    return total_loss / i

In [6]:
embedding_dim = 256
n_layers = 4
hidden_dim = 512
n_heads = 8
dropout_rate = .1
transformer = MLT(len(de.vocab), len(en.vocab), embedding_dim, n_layers,
                   hidden_dim, n_heads, dropout_rate).to(device)

adamopt = torch.optim.Adam(transformer.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
optimizer = NoamOpt(embedding_dim, 1, 2000, adamopt)
criterion = F.cross_entropy

# optimization is unstable without this step
for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

In [7]:
true_start = time.time()
for i in range(10):
    transformer.train()
    t = time.time()
    
    loss = train_step(train_loader)
    
    print("Epoch {}. Loss: {}, ".format((i+1), str(loss)[:5], int(time.time() - t)))
    print("Total time: {}, Last epoch time (s): {}".format(int(time.time()- true_start), int(time.time() - t)))

Epoch 1. Loss: 6.035, 
Total time: 439, Last epoch time (s): 439
Epoch 2. Loss: 4.087, 
Total time: 901, Last epoch time (s): 461
Epoch 3. Loss: 3.289, 
Total time: 1346, Last epoch time (s): 445
Epoch 4. Loss: 2.814, 
Total time: 1786, Last epoch time (s): 439
Epoch 5. Loss: 2.519, 
Total time: 2247, Last epoch time (s): 461
Epoch 6. Loss: 2.295, 
Total time: 2674, Last epoch time (s): 426
Epoch 7. Loss: 2.053, 
Total time: 3102, Last epoch time (s): 427
Epoch 8. Loss: 1.858, 
Total time: 3527, Last epoch time (s): 425
Epoch 9. Loss: 1.698, 
Total time: 3951, Last epoch time (s): 424
Epoch 10. Loss: 1.561, 
Total time: 4376, Last epoch time (s): 424


In [8]:
torch.save(transformer, "model_save.pt")

In [9]:
transformer.eval()
new_batch = next(iter(val_loader))
inp = new_batch.src
tra = new_batch.trg

out = transformer.generate(inp, Batch(inp, inp, 1).src_mask, 120)
for i in range(len(inp)):
    print("Input sentence: ", end="")
    for j in range(1, inp.size(0)):
        char = de.vocab.itos[inp[j, i]]
        if char == "</s>": 
            break
        print(char, end =" ")
    print("\nPredicted translation: ", end="")
    for j in range(1, out.size(0)):
        char = en.vocab.itos[out[j, i]]
        if char == "</s>": 
            break
        print(char, end =" ")
    print("\nGround truth translation: ", end="")
    for j in range(1, tra.size(0)):
        char = en.vocab.itos[tra[j, i]]
        if char == "</s>": 
            break
        print(char, end =" ")    
    print("\n")

Input sentence: Und der Garten ist wunderschön . 
Predicted translation: And the garden is beautiful . 
Ground truth translation: And the garden , it was beautiful . 

Input sentence: Die <unk> ist nicht nachhaltig . 
Predicted translation: The <unk> system is n't sustainable . 
Ground truth translation: The internal combustion engine is not sustainable . 

Input sentence: Wir sehen immer dieselben Symptome . 
Predicted translation: We see the same symptoms . 
Ground truth translation: We see all the same symptoms . 

Input sentence: Sie ist keine <unk> . " 
Predicted translation: It 's not a <unk> . " 
Ground truth translation: She 's not North Korean . " 

Input sentence: Weil es so schön klingt . 
Predicted translation: Because it sounds beautiful . 
Ground truth translation: Just because it sounds so good . 

Input sentence: Aber man muss es pflegen . 
Predicted translation: But you have to care about it . 
Ground truth translation: But you have to maintain it . 

Input sentence: D