In [1]:
%matplotlib inline


Language Translation with nn.Transformer and torchtext
======================================================

This tutorial shows, how to train a translation model from scratch using
Transformer. We will be using `Multi30k <http://www.statmt.org/wmt16/multimodal-task.html#task1>`__ 
dataset to train a German to English translation model.



Data Sourcing and Processing
----------------------------

`torchtext library <https://pytorch.org/text/stable/>`__ has utilities for creating datasets that can be easily
iterated through for the purposes of creating a language translation
model. In this example, we show how to use torchtext's inbuilt datasets, 
tokenize a raw text sentence, build vocabulary, and numericalize tokens into tensor. We will use
`Multi30k dataset from torchtext library <https://pytorch.org/text/stable/datasets.html#multi30k>`__
that yields a pair of source-target raw sentences. 





In [12]:
import collections
from pprint import pprint
import torch

from torch.utils.data import random_split, DataLoader

from data import HebrewWords

import data

In [13]:
PAD_IDX = 0
SOS_token = 1
EOS_token = 2

OUTPUT_WORD_TO_IDX = {
        'PAD': PAD_IDX,
        'SOS': SOS_token,
        'EOS': EOS_token,
        ' ': 3,
        '!': 4,
        '&': 5,
        '(': 6,
        '+': 7,
        '-': 8,
        '/': 9,
        ':': 10,
        '<': 11,
        '=': 12,
        '>': 13,
        '[': 14,
        ']': 15,
        '_': 16,
        '~': 17,
        'a': 18,
        'B': 19,
        'c': 20,
        'C': 21,
        'd': 22,
        'D': 23,
        'F': 24,
        'G': 25,
        'H': 26,
        'J': 27,
        'K': 28,
        'L': 29,
        'M': 30,
        'n': 31,
        'N': 32,
        'o': 33,
        'p': 34,
        'P': 35,
        'Q': 36,
        'R': 37,
        'S': 38,
        'T': 39,
        'u': 40,
        'V': 41,
        'W': 42,
        'X': 43,
        'Y': 44,
        'Z': 45
    }

In [14]:
OUTPUT_IDX_TO_WORD = {idx: char for char, idx in OUTPUT_WORD_TO_IDX.items()}
OUTPUT_IDX_TO_WORD

{0: 'PAD',
 1: 'SOS',
 2: 'EOS',
 3: ' ',
 4: '!',
 5: '&',
 6: '(',
 7: '+',
 8: '-',
 9: '/',
 10: ':',
 11: '<',
 12: '=',
 13: '>',
 14: '[',
 15: ']',
 16: '_',
 17: '~',
 18: 'a',
 19: 'B',
 20: 'c',
 21: 'C',
 22: 'd',
 23: 'D',
 24: 'F',
 25: 'G',
 26: 'H',
 27: 'J',
 28: 'K',
 29: 'L',
 30: 'M',
 31: 'n',
 32: 'N',
 33: 'o',
 34: 'p',
 35: 'P',
 36: 'Q',
 37: 'R',
 38: 'S',
 39: 'T',
 40: 'u',
 41: 'V',
 42: 'W',
 43: 'X',
 44: 'Y',
 45: 'Z'}

In [15]:
SEQUENCE_LENGTH = 4

In [16]:
bible = HebrewWords('../data/t-in_voc', '../data/t-out', SEQUENCE_LENGTH)

len_train = int(0.7 * len(bible))
len_eval = len(bible) - len_train
# alwyas use the same seed for train/test split
training_data, evaluation_data = random_split(
        bible, [len_train, len_eval], generator=torch.Generator().manual_seed(42))

In [17]:
len_train

207294

In [8]:
#from torchtext.data.utils import get_tokenizer
#from torchtext.vocab import build_vocab_from_iterator
#from torchtext.datasets import Multi30k
#from typing import Iterable, List


#SRC_LANGUAGE = 'de'
#TGT_LANGUAGE = 'en'

# Place-holders
#token_transform = {}
#vocab_transform = {}


# Create source and target language tokenizer. Make sure to install the dependencies.
# pip install -U spacy
# python -m spacy download en_core_web_sm
# python -m spacy download de_core_news_sm
#token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')
#token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')


# helper function to yield list of tokens
#def yield_tokens(data_iter: Iterable) -> List[str]:
    #language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}
    
#    for data_sample in data_iter[0:300]:

        #print(data_sample)
        #print(token_transform[language](data_sample[language_index[language]]))
        #print(' ')
        #yield token_transform[language](data_sample[language_index[language]])
#        yield [sign for sign in data_iter]
        
# Define special symbols and indices
#UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
#special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']
 
#for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    
#    if ln == SRC_LANGUAGE:
#        train_iter = bible.input_data
#    elif ln == TGT_LANGUAGE:    
#        train_iter = bible.output_data
    # Training data Iterator 
    #train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    # Create torchtext's Vocab object 
#    vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_iter),
#                                                    min_freq=1,
#                                                    specials=special_symbols,
#                                                    special_first=True)
    
# Set UNK_IDX as the default index. This index is returned when the token is not found. 
# If not set, it throws RuntimeError when the queried token is not found in the Vocabulary. 
#for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
#    vocab_transform[ln].set_default_index(UNK_IDX)


Seq2Seq Network using Transformer
---------------------------------

Transformer is a Seq2Seq model introduced in `“Attention is all you
need” <https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf>`__
paper for solving machine translation tasks. 
Below, we will create a Seq2Seq network that uses Transformer. The network
consists of three parts. First part is the embedding layer. This layer converts tensor of input indices
into corresponding tensor of input embeddings. These embedding are further augmented with positional
encodings to provide position information of input tokens to the model. The second part is the 
actual `Transformer <https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`__ model. 
Finally, the output of Transformer model is passed through linear layer
that give un-normalized probabilities for each token in the target language. 




In [18]:
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

# Seq2Seq Network 
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, 
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

During training, we need a subsequent word mask that will prevent model to look into
the future words when making predictions. We will also need masks to hide
source and target padding tokens. Below, let's define a function that will take care of both. 




In [19]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

Let's now define the parameters of our model and instantiate the same. Below, we also 
define our loss function which is the cross-entropy loss and the optmizer used for training.




In [20]:
torch.manual_seed(0)

PAD_IDX = 0

SRC_VOCAB_SIZE = 37 #len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = 46 #len(vocab_transform[TGT_LANGUAGE])
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3

transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE, 
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

Collation
---------

As seen in the ``Data Sourcing and Processing`` section, our data iterator yields a pair of raw strings. 
We need to convert these string pairs into the batched tensors that can be processed by our ``Seq2Seq`` network 
defined previously. Below we define our collate function that convert batch of raw strings into batch tensors that
can be fed directly into our model.   




In [21]:
from torch.nn.utils.rnn import pad_sequence

def token_transform(input_str: str):
    return [sign for sign in input_str]

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
#def tensor_transform(token_ids: List[int]):
#    return torch.cat((torch.tensor([BOS_IDX]), 
#                      torch.tensor(token_ids), 
#                      torch.tensor([EOS_IDX])))

# src and tgt language text transforms to convert raw strings into tensors indices
#text_transform = {}
#for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
#    text_transform[ln] = sequential_transforms(token_transform, #Tokenization
#                                               vocab_transform[ln], #Numericalization
#                                               tensor_transform) # Add BOS/EOS and create tensor


# function to collate data samples into batch tesors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    
    for sample in batch:
        src_batch.append(sample['encoded_text'])
        tgt_batch.append(sample['encoded_output'])

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    return src_batch, tgt_batch

Let's define training and evaluation loop that will be called for each 
epoch.




In [22]:
from torch.utils.data import DataLoader

def train_epoch(model, optimizer, training_data):
    model.train()
    losses = 0
    
    train_dataloader = DataLoader(training_data, batch_size=BATCH_SIZE, collate_fn=collate_fn)
    
    n = 0
    for src, tgt in train_dataloader:
        n += 1
        if n % 100 == 0:
            print(n)
    
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        optimizer.zero_grad()

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(train_dataloader)


def evaluate(model, evaluation_data):
    model.eval()
    losses = 0

    val_dataloader = DataLoader(evaluation_data, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in val_dataloader:
    #for src, tgt in evaluation_data:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)
        
        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()

    return losses / len(val_dataloader)

Now we have all the ingredients to train our model. Let's do it!




In [52]:
from timeit import default_timer as timer
NUM_EPOCHS = 9

for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(transformer, optimizer, training_data)
    end_time = timer()
    val_loss = evaluate(transformer, evaluation_data)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}", f"Epoch time = {(end_time - start_time):.3f}s"))


100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
('Epoch: 1, Train loss: 0.748, Val loss: 0.176', 'Epoch time = 7602.011s')
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
('Epoch: 2, Train loss: 0.191, Val loss: 0.089', 'Epoch time = 7361.652s')
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
('Epoch: 3, Train loss: 0.111, Val loss: 0.058', 'Epoch time = 7473.641s')
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
('Epoch: 4, Train loss: 0.079, Val loss: 0.044', 'Epoch time = 7413.358s')
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
('Epoch: 5, Train loss: 0.062, Val loss: 0.036', 'Epoch time = 6897.201s')
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
('Epoch: 6, Train loss: 0.051, Val loss: 0.030', 'Epoch time = 6785.708s')
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
('Epoch: 7, Train loss: 0.044, Val loss: 0.026', 'Epo

In [23]:
save_path = './seq2seq_transformer.pth'

In [54]:
torch.save(transformer.state_dict(), save_path)

In [24]:
loaded_transf = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE, 
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)
loaded_transf.load_state_dict(torch.load(save_path))
loaded_transf.eval()

Seq2SeqTransformer(
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
          )
          (linear1): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=512, out_features=512, bias=True)
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
        (1): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
          )
          (linear1): Linear(in_feature

In [25]:
# function to generate output sequence using greedy algorithm 
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len-1):
        memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_token:
            break
    return ys


# actual function to translate input sentence into target language
def translate(model: torch.nn.Module, encoded_sentence): #src_sentence: str):
    model.eval()
    src = encoded_sentence.view(-1, 1)
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model, src, src_mask, max_len=num_tokens + 5, start_symbol=SOS_token).flatten()
    return "".join([OUTPUT_IDX_TO_WORD[idx] for idx in list(tgt_tokens.cpu().numpy())]).replace("SOS", "").replace("EOS", "")

In [28]:
SOS_token = 1
EOS_token = 2

word_eval_dict = collections.defaultdict(lambda: collections.defaultdict(list))

correct_complete_sequence = 0
correct_all_words = [0 for i in range(SEQUENCE_LENGTH)]

test_len = 10

for i in range(test_len):
    
    predicted = translate(loaded_transf, evaluation_data[i]['encoded_text'])
    gold = evaluation_data[i]['output']
    print(predicted)
    print(gold)
    print(' ')
    
    predicted_words = predicted.split()
    gold_words = gold.split()
    
    if len(predicted_words) != SEQUENCE_LENGTH:
        continue
        
    if predicted == gold:
        correct_complete_sequence += 1
        
    for word_idx in range(SEQUENCE_LENGTH):
        if predicted_words[word_idx] == gold_words[word_idx]:
            correct_all_words[word_idx] += 1
            
            word_eval_dict[gold_words[word_idx]][word_idx].append('correct')
        else:
            word_eval_dict[gold_words[word_idx]][word_idx].append('wrong')

print('complete string', correct_complete_sequence / test_len)
print('distinct words', [correct_count / test_len for correct_count in correct_all_words])    

pprint(word_eval_dict)

Y(W[ GBQ/+K= GCKJC/+H=->!JC<[H=
SBJB/(WTJ+NW W-<TH !CM<[ >LH(J(M/J+NW
 
H <XHKC/Ta T=!(H]C(W&JF[N(W+N!(H(T]C&NCD<[/
H]B(W&J>[(W+HW L-QBR/J MLK/J JFR>L/
 
J!(H](JDC[(JH-DH(J/H H(J&>CQDC/~H </
B-(H-J(WM/JM H-HM CN(H/JM L-!H]B(W&J>[/c
 
Q<C/ C-<VQCNC/ J!(H](JF&JF[
N> B->ZN/J= KL/ B<L/J
 
W-JCJ<JHCB(H/H
GM >TH T!(NTN[ B-JD/+NW
 
J GTC>NDC/ SGPC/JM
B-(H-BQR=/ CKR=/ J!RDP[W M!>XR[/Jd
 
K-(H-JL<C/J+HW GDC/J+HMJ !(H]C&F&JR[H=
Wn-J!(JY>[ H-GWRL/a H-XMJCJ/ L-MVH/c
 
H-<>C/+N GC/+H K-(H-CKH=/
H->C/ W-KBWD/c JHWH/ <L
 
W-B-(H-CMCKJB/J H]GC&JC[+H](W&
BJT LXM/ Wn-T=!(N]H(WM[ KL/
 
PCD[HduCD<LCK(J(M/J+K&C-K-(H-SKSKS(J&F/+K=
M(N-PRJ/ H->RY/a Wn-J!(H](J&WRD[W >L&J+NW
 
complete string 0.0
distinct words [0.0, 0.0, 0.0, 0.0]
defaultdict(<function <lambda> at 0x000002AB30F4FB80>, {})


In [3]:
for i in range(1):
    print(i)

0


References
----------

1. Attention is all you need paper.
   https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf
2. The annotated transformer. https://nlp.seas.harvard.edu/2018/04/03/attention.html#positional-encoding

