In [1]:
%%capture
import trax
from trax import layers as tl
from trax.fastmath import numpy as fastnp
from trax.supervised import training

In [2]:
import numpy as np
import pandas as pd

import random
import os

from sklearn.model_selection import train_test_split

In [3]:
#!pip install sentencepiece
import sentencepiece as spm
from unicodedata import normalize

In [4]:
df = pd.read_csv('./eng-spa.csv')
train_data, test_data = train_test_split(df)

In [5]:
def generate_txt(filepath, data):
    texts = [example for example in data] # Decode the byte sequences
    text = '\n\n'.join(texts)       # Separate different articles by two newlines
    text = normalize('NFKC', text)  # Normalize the text

    with open(filepath, 'w') as fw:
        fw.write(text)

In [6]:
english_text_path = 'eng.txt'
spanish_text_path = 'spa.txt'

generate_txt(english_text_path, df['eng'])
generate_txt(spanish_text_path, df['spa'])

english_vocab_size = 30000
spanish_vocab_size = 30000

model_dir = './sentencepiece/'
os.makedirs(model_dir, exist_ok=True)

spm.SentencePieceTrainer.train('--input={} --model_prefix={}en_bpe --vocab_size={} --model_type=bpe'.format(english_text_path, model_dir, english_vocab_size))
spm.SentencePieceTrainer.train('--input={} --model_prefix={}esp_bpe --vocab_size={} --model_type=bpe'.format(spanish_text_path, model_dir, spanish_vocab_size))

en_tokenizer = spm.SentencePieceProcessor(add_eos=True, add_bos=False)
en_tokenizer.load(model_dir+'en_bpe.model')

esp_tokenizer = spm.SentencePieceProcessor(add_eos=True, add_bos=False)
esp_tokenizer.load(model_dir+'esp_bpe.model')

True

In [7]:
def data_generator(dataset, en_tokenizer, esp_tokenizer):
    while True:
        index = random.choice(df.index)
        data = df.iloc[index]
        eng = en_tokenizer.Tokenize(data['eng'])
        spa = esp_tokenizer.Tokenize(data['spa'])
        yield np.array(eng), np.array(spa)

In [8]:
next(data_generator(df, en_tokenizer, esp_tokenizer))

(array([    9, 29965, 29954,   118,   185,  1620, 29953,     2]),
 array([   61,   387,  1042,  3310, 29949,     2]))

In [9]:
boundaries=[16, 32, 128, 512, 2048]
batch_sizes=[512, 256, 64, 32, 16, 1]

In [10]:
data_pipeline = trax.data.Serial(
                trax.data.Shuffle(),
                trax.data.FilterByLength(max_length=2048),
                trax.data.BucketByLength(boundaries, batch_sizes,length_keys=[0, 1]),
                trax.data.AddLossWeights(id_to_mask=0)
                )

In [11]:
train_batch_stream = data_pipeline(data_generator(df, en_tokenizer, esp_tokenizer))

In [12]:
def input_encoder_fn(input_vocab_size, d_model, n_encoder_layers):
    
    input_encoder = tl.Serial(
            tl.Embedding(input_vocab_size, d_model),
            [tl.LSTM(d_model) for _ in n_encoder_layers])
    
    return input_encoder

In [13]:
def pre_attention_decoder_fn(mode, target_vocab_size, d_model):
    
    pre_attention_decoder = tl.Serial(
            tl.ShiftRight(mode=mode),
            tl.Embedding(target_vocab_size, d_model),
            tl.LSTM(d_model))
    
    return pre_attention_decoder

In [14]:
## The input encoder outputs are the keys and values for attention.
## Mask are used for softmax to not compute probabilities for it.

In [16]:
def prepare_attention_input(encoder_activations, decoder_activations, inputs):
    
    # set the keys and values to the encoder activations
    keys = values = encoder_activations
    
    #set the queries to decoder activations
    queries = decoder_activations
    
    mask = (inputs != 0)
    
    #Add axes to the mask for attention heads and decoder length
    mask = fastnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1]))
    
    #Broadcast so mask shape is [batch_size, attention heads, decoder-len, encoder-len]
    mask = mask + fastnp.zeros((1, 1, decoder_activations.shape[1], 1))
    
    return queries, keys, values, mask

In [17]:
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'):
    return tl.Serial(
            tl.Parallel(
            tl.Dense(d_feature),
            tl.Dense(d_feature),
            tl.Dense(d_feature),
            ),
            tl.PureAttention(
                n_heads=n_heads, dropout=dropout, mode=mode),
            tl.Dense(d_feature),
    )