In [None]:
import nltk
import json
import torch
import torch.utils.data as data


class Dataset(data.Dataset):
    """Custom data.Dataset compatible with data.DataLoader."""
    def __init__(self, src_path, trg_path, src_word2id, trg_word2id):
        """Reads source and target sequences from txt files."""
        self.src_seqs = open(src_path).readlines()
        self.trg_seqs = open(trg_path).readlines()
        self.num_total_seqs = len(self.src_seqs)
        self.src_word2id = src_word2id
        self.trg_word2id = trg_word2id

    def __getitem__(self, index):
        """Returns one data pair (source and target)."""
        src_seq = self.src_seqs[index]
        trg_seq = self.trg_seqs[index]
        src_seq = self.preprocess(src_seq, self.src_word2id, trg=False)
        trg_seq = self.preprocess(trg_seq, self.trg_word2id)
        return src_seq, trg_seq

    def __len__(self):
        return self.num_total_seqs

    def preprocess(self, sequence, word2id, trg=True):
        """Converts words to ids."""
        tokens = nltk.tokenize.word_tokenize(sequence.lower())
        sequence = []
        sequence.append(word2id['<start>'])
        sequence.extend([word2id[token] for token in tokens if token in word2id])
        sequence.append(word2id['<end>'])
        sequence = torch.Tensor(sequence)
        return sequence


def collate_fn(data):
    """Creates mini-batch tensors from the list of tuples (src_seq, trg_seq).
    We should build a custom collate_fn rather than using default collate_fn,
    because merging sequences (including padding) is not supported in default.
    Seqeuences are padded to the maximum length of mini-batch sequences (dynamic padding).
    Args:
        data: list of tuple (src_seq, trg_seq).
            - src_seq: torch tensor of shape (?); variable length.
            - trg_seq: torch tensor of shape (?); variable length.
    Returns:
        src_seqs: torch tensor of shape (batch_size, padded_length).
        src_lengths: list of length (batch_size); valid length for each padded source sequence.
        trg_seqs: torch tensor of shape (batch_size, padded_length).
        trg_lengths: list of length (batch_size); valid length for each padded target sequence.
    """
    def merge(sequences):
        lengths = [len(seq) for seq in sequences]
        padded_seqs = torch.zeros(len(sequences), max(lengths)).long()
        for i, seq in enumerate(sequences):
            end = lengths[i]
            padded_seqs[i, :end] = seq[:end]
        return padded_seqs, lengths

    # sort a list by sequence length (descending order) to use pack_padded_sequence
    data.sort(key=lambda x: len(x[0]), reverse=True)

    # seperate source and target sequences
    src_seqs, trg_seqs = zip(*data)

    # merge sequences (from tuple of 1D tensor to 2D tensor)
    src_seqs, src_lengths = merge(src_seqs)
    trg_seqs, trg_lengths = merge(trg_seqs)
    
    batch_handler = Batch(src_seqs.transpose(0,1), trg_seqs.transpose(0,1), 1)

    return batch_handler.src.transpose(0, 1), batch_handler.trg.transpose(0, 1), batch_handler.src_mask, batch_handler.trg_mask.unsqueeze(0).repeat(2, 1, 1), batch_handler.trg_y.transpose(0, 1)


def get_loader(src_path, trg_path, src_word2id, trg_word2id, batch_size=32):
    """Returns data loader for custom dataset.
    Args:
        src_path: txt file path for source domain.
        trg_path: txt file path for target domain.
        src_word2id: word-to-id dictionary (source domain).
        trg_word2id: word-to-id dictionary (target domain).
        batch_size: mini-batch size.
    Returns:
        data_loader: data loader for custom dataset.
    """
    # build a custom dataset
    dataset = Dataset(src_path, trg_path, src_word2id, trg_word2id)

    # data loader for custome dataset
    # this will return (src_seqs, src_lengths, trg_seqs, trg_lengths) for each iteration
    # please see collate_fn for details
    data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              collate_fn=collate_fn)

    return data_loader

import nltk
import json
import argparse
from collections import Counter


def build_word2id(seq_path, min_word_count):
    """Creates word2id dictionary.
    
    Args:
        seq_path: String; text file path
        min_word_count: Integer; minimum word count threshold
        
    Returns:
        word2id: Dictionary; word-to-id dictionary
    """
    sequences = open(seq_path).readlines()
    num_seqs = len(sequences)
    counter = Counter()
    
    for i, sequence in enumerate(sequences):
        tokens = nltk.tokenize.word_tokenize(sequence.lower())
        counter.update(tokens)

        if i % 1000 == 0:
            print("[{}/{}] Tokenized the sequences.".format(i, num_seqs))

    # create a dictionary and add special tokens
    word2id = {}
    word2id['<pad>'] = 0
    word2id['<start>'] = 1
    word2id['<end>'] = 2
    word2id['<unk>'] = 3
    
    # if word frequency is less than 'min_word_count', then the word is discarded
    words = [word for word, count in counter.items() if count >= min_word_count]
    
    # add the words to the word2id dictionary
    for i, word in enumerate(words):
        word2id[word] = i + 4
    
    return word2id


def b_vocab(source_path, target_pad, min_word_count, src_out_path, trg_out_path):
    
    # build word2id dictionaries for source and target sequences
    src_word2id = build_word2id(source_path, min_word_count)
    trg_word2id = build_word2id(target_pad, min_word_count)
    
    # save word2id dictionaries
    with open(src_out_path, 'w') as f:
        json.dump(src_word2id, f)
    with open(trg_out_path, 'w') as f:
        json.dump(trg_word2id, f)


In [None]:
import json

In [None]:
with open("src.vocab", "r") as f:
    src_v = json.load(f)

In [None]:
with open("trg.vocab", "r") as f:
    trg_v = json.load(f)

In [None]:
# b_vocab("en-de/train.tags.en-de.de", "en-de/train.tags.en-de.en", 4, "src.vocab", "trg.vocab")

In [None]:
dloader = get_loader("en-de/train.tags.en-de.de", "en-de/train.tags.en-de.en", src_v, trg_v)

#### IWSLT German->English translation

This notebook shows a simple example of how to use the transformer provided by this repo for one-direction translation. 

We will use the IWSLT 2016 De-En dataset.

In [None]:
from torchtext import data, datasets
import spacy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

import sys
from transformers import BaseTransformer

from utils import Batch, BasicIterator
from opt import NoamOpt

import time

In [None]:
import torch.multiprocessing as mp


In [None]:
import torch.distributed as dist

##### The below does some basic data preprocessing and filtering, in addition to setting special tokens.

In [None]:
import pytorch_lightning as pl


##### Single step over entire dataset, with tons of gradient accumulation to get batch sizes big enough for stable training.

In [None]:
# del transformer
# torch.cuda.empty_cache()

In [None]:
class TranslationModel(BaseTransformer):
    def __init__(
        self, *args,
    ):
        super(TranslationModel, self).__init__(*args)

    def forward_and_return_loss(self, criterion, sources, targets):
        """
        Pass input through transformer and return loss, handles masking automagically
        Args:
            criterion: torch.nn.functional loss function of choice
            sources: source sequences, [seq_len, bs]
            targets: full target sequence, [seq_len, bs, embedding_dim]

        Returns:
            loss, transformer output
        """

        batch = Batch(sources, targets, self.pad_idx)
        seq_len, batch_size = batch.trg.size()
        out = self.forward(batch.src, batch.trg, batch.src_mask, batch.trg_mask)
        loss = criterion(
            out.contiguous().view(-1, out.size(-1)),
            batch.trg_y.contiguous().view(-1),
            ignore_index=self.pad_idx,
        )

        return loss, out

    def generate(self, source, source_mask, max_len):
        """
        Args:
            source: input sequence indices, [seq_len, bs,
            source_mask: the source mask to prevent attending to <pad> tokens
            max_len: maximum length

        Returns:
            generated translations
        """
        memory = self.encoder(source, source_mask)
        ys = torch.ones(1, source.size(1)).long().fill_(self.sos_idx).to(device)
        # max target length is 1.5x * source + 10 to save compute power
        for _ in range(int(1.5 * source.size(0)) - 1 + 10):
            out = self.decoder(ys, memory, source_mask, Batch(ys, ys, 1).raw_mask)
            out = self.fc1(out[-1].unsqueeze(0))
            prob = F.log_softmax(out, dim=-1)
            next_word = torch.argmax(prob, dim=-1)
            ys = torch.cat([ys, next_word.detach()], dim=0)

        return ys

In [None]:
class TranslationTransformer(pl.LightningModule):
    def __init__(self, transformer):
        super().__init__()
        self.transformer = transformer
        
    def forward(self, x, y, src_mask, trg_mask):
        return self.transformer(x, y, src_mask, trg_mask)
    
    def training_step(self, batch, batch_idx):
        
        src, trg, src_mask, trg_mask, trg_y = batch
        print(src.device, trg.device, src_mask.device, trg_mask.device, trg_y.device)
        print(self.device)
        #print("bs: ", src.size(0), src_mask.size(), trg_mask.size())
        #print(src.device, self.device)
        src_mask, trg_mask = src_mask.squeeze(), trg_mask.squeeze()

            
        out = transformer(src, 
                          trg, 
                          src_mask, 
                          trg_mask)
        
        loss = criterion(
            out.contiguous().view(-1, out.size(-1)),
            trg_y.transpose(0, 1).to(self.device).contiguous().view(-1)
        )
        
        result = pl.TrainResult(loss)
        result.log("train_loss", loss, on_epoch=True)
        return result
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)


In [None]:
input_vocab_size = len(src_v)
output_vocab_size = len(trg_v)

#input_vocab_size = 2000
#output_vocab_size = 2000
embedding_dim = 256
n_layers = 4
hidden_dim = 512
n_heads = 8
dropout_rate = .1
transformer = TranslationModel(input_vocab_size, output_vocab_size, embedding_dim, 
                               n_layers,hidden_dim, n_heads, dropout_rate)

# optimization is unstable without this step
for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

In [None]:
m = TranslationTransformer(transformer)
criterion = nn.CrossEntropyLoss(ignore_index = 1)

trainer = pl.Trainer(max_epochs=1, gpus=[0, 1], distributed_backend='dp')

trainer.fit(m, dloader, )

##### Runs 10 epochs of the entire training dataset.

In [None]:
true_start = time.time()
world_size = 0
transformer = nn.DataParallel(transformer, device_ids=[0, 1, 2, 3])
# torch.cuda.set_device(0)
transformer = transformer.cuda()
for i in range(10):
    transformer.train()
    t = time.time()
    
    loss = train_step(train_loader)
    
    print("Epoch {}. Loss: {}, ".format((i+1), str(loss)[:5], int(time.time() - t)))
    print("Total time (s): {}, Last epoch time (s): {}".format(int(time.time()- true_start), int(time.time() - t)))

In [None]:
torch.save(transformer, "basic_translation.pt")

##### Finally, generations. 


The model by default uses greedy decoding for generation, and does not have incremental decoding. Currently, this leads to the transformer generating at about 1/2 the speed of Fairseq for short sequences. 

Implementing incremental decoding, however, makes the code for the attention function much harder to read, and has been left out for now. 

In [None]:
transformer.eval()
new_batch = next(iter(val_loader))
inp = new_batch.src
tra = new_batch.trg

out = transformer.generate(inp, Batch(inp, inp, 1).src_mask, 120)
for i in range(len(inp)):
    print("Input sentence: ", end="")
    for j in range(1, inp.size(0)):
        char = de.vocab.itos[inp[j, i]]
        if char == "</s>": 
            break
        print(char, end =" ")
    print("\nPredicted translation: ", end="")
    for j in range(1, out.size(0)):
        char = en.vocab.itos[out[j, i]]
        if char == "</s>": 
            break
        print(char, end =" ")
    print("\nGround truth translation: ", end="")
    for j in range(1, tra.size(0)):
        char = en.vocab.itos[tra[j, i]]
        if char == "</s>": 
            break
        print(char, end =" ")    
    print("\n")