In [1]:
import torch

# Local library
from dataset import DataSet
import transformer as tfr
import seq2seq as s2s

# Prepare Translation DataSet

In [2]:
# Read file and create dataset
data = DataSet()
data.read_file('./data/fra.txt')

Total number of samples: 229803


In [4]:
sample, dec_inputs = data.tokenize()
print(sample)
print(dec_inputs)

{'input_ids': tensor([[  631,   250,     0, 59513, 59513, 59513, 59513, 59513, 59513, 59513,
         59513, 59513, 59513, 59513, 59513, 59513],
        [  631,   250,     0, 59513, 59513, 59513, 59513, 59513, 59513, 59513,
         59513, 59513, 59513, 59513, 59513, 59513],
        [  631,   250,     0, 59513, 59513, 59513, 59513, 59513, 59513, 59513,
         59513, 59513, 59513, 59513, 59513, 59513]]), 'attention_mask': tensor([[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'labels': tensor([[  740,   291,     0, 59513, 59513, 59513, 59513, 59513, 59513, 59513,
         59513, 59513, 59513, 59513, 59513, 59513],
        [ 4714,   250,     0, 59513, 59513, 59513, 59513, 59513, 59513, 59513,
         59513, 59513, 59513, 59513, 59513, 59513],
        [   23,  2020,   291,     0, 59513, 59513, 59513, 59513, 59513, 59513,
         59513, 59513, 59513, 59513, 59513, 59

# Create NMT model

In [5]:
# Create transformer Seq2Seq model
# input parameters of encoder and decoder
# (vocab_size, num_hiddens, ffn_num_hiddens, num_heads, num_blks, dropout, use_bias=False)
vocab_size  = data.vocab_size
num_hiddens = 32
ffn_hiddens = 64
num_heads   = 8
num_blks    = 2
dropout     = 0.2

# Use transformer encoder/decoder. Can also use GRU encoder/decoder
encoder = tfr.TransformerEncoder(vocab_size, num_hiddens, ffn_hiddens, num_heads, num_blks, dropout)
decoder = tfr.TransformerDecoder(vocab_size, num_hiddens, ffn_hiddens, num_heads, num_blks, dropout)

model = s2s.Seq2Seq(encoder, decoder, 0, 0.001)



# Training our NMT models

In [6]:
enc_inputs = sample['input_ids']
valid_lens = sample['attention_mask'].sum(-1)

In [7]:
X = model(enc_inputs, dec_inputs, valid_lens)

In [9]:
print(X.shape)

torch.Size([3, 16, 59514])
