In [None]:
# External library
import torch
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm, trange
from matplotlib import pyplot as plt

# Local library
from dataset import DataSet
import transformer as tfr
import seq2seq as s2s

# Prepare Translation DataSet

In [None]:
# Read file and create dataset
data = DataSet(max_length=16)
data.read_file('./data/fra.txt')

In [None]:
sample, sample_dec = data.tokenize()

In [None]:
print(sample['input_ids'][0])
print(sample['labels'][0])
print(sample_dec[0])
print(data.tokenizer.decode(sample['labels'][0], skip_special_tokens=True))

In [None]:
tensors = list()
for i in tqdm(range(len(sample_dec))):
    input_ids  = sample['input_ids'][i]
    valid_lens = sample['attention_mask'][i].sum()
    labels     = sample['labels'][i]
    dec_inputs = sample_dec[i]
    tensors.append((input_ids, dec_inputs, valid_lens, labels))

In [None]:
train_dataloader = DataLoader(tensors[:512], batch_size=128, shuffle=False)

for batch in train_dataloader:
    enc_inputs, dec_inputs, valid_lens, labels = batch
    print(enc_inputs.shape)
    print(dec_inputs.shape)
    print(valid_lens.shape)
    print(labels.shape)
    break

# Create NMT model

In [None]:
# Create transformer Seq2Seq model
# input parameters of encoder and decoder
# (vocab_size, num_hiddens, ffn_num_hiddens, num_heads, num_blks, dropout, use_bias=False)
vocab_size  = data.vocab_size
num_hiddens = 256
ffn_hiddens = 64
num_heads   = 4
num_blks    = 2
dropout     = 0.2

# Use transformer encoder/decoder. Can also use GRU encoder/decoder
encoder = tfr.TransformerEncoder(vocab_size, num_hiddens, ffn_hiddens, num_heads, num_blks, dropout)
decoder = tfr.TransformerDecoder(vocab_size, num_hiddens, ffn_hiddens, num_heads, num_blks, dropout)

# Seq2Seq model
padding_index = data.tokenizer.pad_token_id
lr = 1e-3

model = s2s.Seq2Seq(encoder, decoder, padding_index, lr)

# Training our NMT models

In [None]:
# Training setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
print(device)

epochs = 30

# Use wandb to monitor the model

In [None]:
losses = list()

model.to(device)
model.train()
for epoch in trange(epochs):
    for batch in tqdm(train_dataloader):
        a, b, c, d = batch
        enc_inputs = a.to(device)
        dec_inputs = b.to(device)
        valid_lens = c.to(device)
        labels     = d.to(device)
        
        Y_hat = model(enc_inputs, dec_inputs, valid_lens)
        
        loss = model.loss(Y_hat.transpose(1, 2), labels)
        
        model.optimizer.zero_grad()
        loss.backward()
        model.optimizer.step()
        
        losses.append(loss.item())

In [None]:
plt.plot(losses)

In [None]:
test_dataloader = DataLoader(tensors[:512], batch_size=4, shuffle=True)

In [None]:
for batch in test_dataloader:

    preds, _ = model.predict_step(batch, device, data.max_length)
    engs = [data.tokenizer.convert_ids_to_tokens(batch[0][i], skip_special_tokens=True) for i in range(4)]
    fras = [data.tokenizer.convert_ids_to_tokens(batch[3][i], skip_special_tokens=True) for i in range(4)]
    
    for en, fr, p in zip(engs, fras, preds):
        translation = []
        for token in data.tokenizer.convert_ids_to_tokens(p):
            if token == '</s>':
                break
            translation.append(token)
        str_en = data.tokenizer.convert_tokens_to_string(en)
        str_fr = data.tokenizer.convert_tokens_to_string(fr)
        # print(translation)
        print(f'{str_en} => {str_fr}, bleu,'
              f'{s2s.bleu(translation, fr, k=2):.3f}')

    break

In [None]:
fras = [5682, 21, 2137, 19, 6381, 21, 682, 291, 0,]
engs = [631, 250, 0, 59513]
print(data.tokenizer.convert_tokens_to_string(data.tokenizer.convert_ids_to_tokens(engs)))
print(data.tokenizer.convert_ids_to_tokens(fras))
print(data.tokenizer.decode(fras))
print(data.tokenizer.convert_tokens_to_string(data.tokenizer.convert_ids_to_tokens(fras)))

In [None]:
s2s.bleu("a b c d e", "a b c e f", k=2)