In [None]:
# External library
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm, trange
from matplotlib import pyplot as plt

# Local library
from dataset import DataSet
import transformer as tfr
import seq2seq as s2s

# Prepare Translation DataSet

In [None]:
# Read file and create dataset
data = DataSet()
data.read_file('./data/fra.txt')

In [None]:
sample, sample_dec = data.tokenize()
print(sample['input_ids'][0])
print(sample_dec[0])

In [None]:
tensors = list()
for i in tqdm(range(len(sample_dec))):
    input_ids  = sample['input_ids'][i]
    valid_lens = sample['attention_mask'][i].sum()
    labels     = sample['labels'][i]
    dec_inputs = sample_dec[i]
    tensors.append((input_ids, valid_lens, dec_inputs, labels))

In [None]:
tensors[0]

In [None]:
train_dataloader = DataLoader(tensors, batch_size=32, shuffle=True)

In [None]:
for batch in train_dataloader:
    enc_inputs, valid_lens, dec_inputs, labels = batch
    print(enc_inputs.shape)
    print(valid_lens.shape)
    print(dec_inputs.shape)
    print(labels.shape)
    break

# Create NMT model

In [None]:
# Create transformer Seq2Seq model
# input parameters of encoder and decoder
# (vocab_size, num_hiddens, ffn_num_hiddens, num_heads, num_blks, dropout, use_bias=False)
vocab_size  = data.vocab_size
num_hiddens = 32
ffn_hiddens = 64
num_heads   = 8
num_blks    = 2
dropout     = 0.2

# Use transformer encoder/decoder. Can also use GRU encoder/decoder
encoder = tfr.TransformerEncoder(vocab_size, num_hiddens, ffn_hiddens, num_heads, num_blks, dropout)
decoder = tfr.TransformerDecoder(vocab_size, num_hiddens, ffn_hiddens, num_heads, num_blks, dropout)

# Seq2Seq model
padding_index = data.tokenizer.pad_token_id
lr = 1e-3

model = s2s.Seq2Seq(encoder, decoder, padding_index, lr)

# Training our NMT models

In [None]:
# Training setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
print(device)

epochs = 5

# Use wandb to monitor the model

In [None]:
losses = list()

model.to(device)
model.train()
for epoch in trange(epochs):
    for batch in tqdm(train_dataloader):
        a, b, c, d = batch
        enc_inputs = a.to(device)
        valid_lens = b.to(device)
        dec_inputs = c.to(device)
        labels = d.to(device)
        
        Y_hat = model(enc_inputs, dec_inputs, valid_lens)
        
        loss = model.loss(Y_hat.transpose(1, 2), labels)
        
        model.optimizer.zero_grad()
        loss.backward()
        model.optimizer.step()
        
        losses.append(loss.item())

In [None]:
plt.plot(losses)