In [1]:
import torch
from torch import nn
from pathlib import Path
import numpy as np
from collections import namedtuple

torch.backends.cudnn.benchmark = True
torch.manual_seed(42)
torch.cuda.manual_seed(42)

import sys
sys.path.append("..")

from nmt.models import NMT
from nmt.datasets import read_corpus, batch_iter, Vocab
import tqdm

In [2]:
data_loc = Path("..") / "nmt" / "datasets" / "data"
en_es_data_loc = data_loc / "en_es_data"
train_data_src_path = en_es_data_loc / "train_tiny.es"
train_data_tgt_path = en_es_data_loc / "train_tiny.en"
dev_data_src_path = en_es_data_loc / "dev_tiny.es"
dev_data_tgt_path = en_es_data_loc / "dev_tiny.en"
vocab_path = data_loc / "vocab_tiny_q2.json"

In [3]:
train_src = read_corpus(train_data_src_path)
train_tgt = read_corpus(train_data_tgt_path, is_target=True)
vocab = Vocab.load(vocab_path)

In [4]:
len(vocab.src), len(vocab.tgt)

(26, 32)

In [5]:
Hypothesis = namedtuple('Hypothesis', ['value', 'score'])

In [6]:
valid_src = read_corpus(dev_data_src_path)
valid_tgt = read_corpus(dev_data_tgt_path, is_target=True)

In [7]:
BATCH_SIZE=2
MAX_EPOCH=201
SEED=42
EMBEDDING_SIZE=256
HIDDEN_SIZE=256
GRAD_CLIP=5.0
UNIFORM_INIT=0.1
USE_CHAR_DECODER=True
LEARNING_RATE=0.001

In [8]:
model = NMT(
    vocab=vocab,
    embedding_dim=EMBEDDING_SIZE,
    hidden_size=HIDDEN_SIZE,
    use_char_decoder=True
)

In [9]:
model = model.train()

In [10]:
uniform_init = UNIFORM_INIT
if np.abs(uniform_init) > 0.:
    print('uniformly initialize parameters [-%f, +%f]' %
            (uniform_init, uniform_init), file=sys.stderr)
    for p in model.parameters():
        p.data.uniform_(-uniform_init, uniform_init)

uniformly initialize parameters [-0.100000, +0.100000]


In [11]:
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [12]:
example = [train_src[0]] ## Source sent

In [13]:
beam_size = 5

In [14]:
example_char_tensor = model.vocab.src.to_tensor(example, tokens=False)

In [15]:
src_encode, dec_state = model.encoder(example_char_tensor, [len(x) for x in example])

In [16]:
hypotheses = [['<s>']]

In [17]:
completed_hypotheses = []

In [18]:
hyp_scores = torch.zeros(
    len(hypotheses), dtype=torch.float
)
print(hyp_scores.size())

torch.Size([1])


In [19]:
num_hyps = len(hypotheses)

In [20]:
init_attention = torch.zeros(1, model.hidden_size)

In [21]:
end_of_sent = model.vocab.tgt.end_token_idx

In [22]:
#### Iterating through 1 timestep

In [23]:
o_prev = torch.zeros(1, model.hidden_size, device=model.device)

In [24]:
y_t = model.vocab.tgt.to_tensor(
             list([hyp[-1]] for hyp in hypotheses),
             tokens=False, device=model.device
        )


In [25]:
y_t.shape

torch.Size([1, 1, 21])

In [26]:
o_prev, dec_state, _ = model.decoder(y_t, src_encode, dec_state, o_prev, None)

In [27]:
log_p_t = model.generator(o_prev)

In [28]:
log_p_t.shape

torch.Size([1, 32])

In [29]:
log_p_t

tensor([[-3.3516, -3.5049, -3.4377, -3.5354, -3.3738, -3.4165, -3.6984, -3.4240,
         -3.4034, -3.3990, -3.4905, -3.4608, -3.3913, -3.5046, -3.5579, -3.4502,
         -3.4206, -3.3932, -3.5450, -3.6357, -3.4550, -3.5125, -3.3933, -3.5584,
         -3.3560, -3.5021, -3.4387, -3.3057, -3.5281, -3.6132, -3.4817, -3.4801]],
       grad_fn=<LogSoftmaxBackward>)

In [30]:
live_hyp_num = beam_size - len(completed_hypotheses)

In [31]:
continuing_hyp_scores = (hyp_scores.unsqueeze(
                1).expand_as(log_p_t) + log_p_t).view(-1)
continuing_hyp_scores.shape

torch.Size([32])

In [32]:
top_cand_hyp_scores, top_cand_hyp_pos = torch.topk(
                continuing_hyp_scores, k=live_hyp_num)

In [33]:
top_cand_hyp_scores, top_cand_hyp_pos

(tensor([-3.3057, -3.3516, -3.3560, -3.3738, -3.3913], grad_fn=<TopkBackward>),
 tensor([27,  0, 24,  4, 12]))

In [34]:
prev_hyp_ids = top_cand_hyp_pos / len(model.vocab.tgt)
hyp_word_ids = top_cand_hyp_pos % len(model.vocab.tgt)
print(prev_hyp_ids, hyp_word_ids)

tensor([0, 0, 0, 0, 0]) tensor([27,  0, 24,  4, 12])


In [35]:
new_hypotheses = []
live_hyp_ids = []
new_hyp_scores = []

In [36]:
decoderStatesForUNKsHere = []

In [37]:
for prev_hyp_id, hyp_word_id, cand_new_hyp_score in zip(prev_hyp_ids, hyp_word_ids, top_cand_hyp_scores):
    prev_hyp_id = prev_hyp_id.item()
    hyp_word_id = hyp_word_id.item()
    cand_new_hyp_score = cand_new_hyp_score.item()

    hyp_word = model.vocab.tgt.to_tokens(hyp_word_id)

    # Record output layer in case UNK was generated
    if hyp_word == "<unk>":
        hyp_word = "<unk>"+str(len(decoderStatesForUNKsHere))
        decoderStatesForUNKsHere.append(o_prev[prev_hyp_id])

    new_hyp_sent = hypotheses[prev_hyp_id] + [hyp_word]
    if hyp_word == '</s>':
        completed_hypotheses.append(Hypothesis(value=new_hyp_sent[1:-1],
                                                score=cand_new_hyp_score))
    else:
        new_hypotheses.append(new_hyp_sent)
        live_hyp_ids.append(prev_hyp_id)
        new_hyp_scores.append(cand_new_hyp_score)

In [38]:
print(new_hypotheses)
print(live_hyp_ids)
print(new_hyp_scores)
print(completed_hypotheses)

[['<s>', 'what'], ['<s>', '<pad>'], ['<s>', 'sob)'], ['<s>', 'to'], ['<s>', 'this']]
[0, 0, 0, 0, 0]
[-3.3057379722595215, -3.3516037464141846, -3.3559703826904297, -3.373790979385376, -3.3912715911865234]
[]


In [39]:
if len(decoderStatesForUNKsHere) > 0 and model.char_decoder is not None:  # decode UNKs
    decoderStatesForUNKsHere = torch.stack(
        decoderStatesForUNKsHere, dim=0)
    decodedWords = model.greedy_char_decode((decoderStatesForUNKsHere.unsqueeze(
        0), decoderStatesForUNKsHere.unsqueeze(0)), max_length=21)
    assert len(decodedWords) == decoderStatesForUNKsHere.size()[
        0], "Incorrect number of decoded words"
    for hyp in new_hypotheses:
        if hyp[-1].startswith("<unk>"):
            hyp[-1] = decodedWords[int(hyp[-1][5:])]  # [:-1]


In [40]:
live_hyp_ids = torch.tensor(
    live_hyp_ids,
    dtype=torch.long,
    device=model.device
)

In [41]:
o_prev[live_hyp_ids].shape

torch.Size([5, 256])

In [42]:
hypotheses = new_hypotheses
hyp_scores = torch.tensor(
    new_hyp_scores, dtype=torch.float, device=model.device)


In [43]:
if len(completed_hypotheses) == 0:
    completed_hypotheses.append(Hypothesis(value=hypotheses[0][1:],
                                            score=hyp_scores[0].item()))

In [44]:
completed_hypotheses.sort(key=lambda hyp: hyp.score, reverse=True)

In [45]:
completed_hypotheses

[Hypothesis(value=['what'], score=-3.3057379722595215)]

In [46]:
hypotheses = new_hypotheses

In [47]:
hypotheses

[['<s>', 'what'],
 ['<s>', '<pad>'],
 ['<s>', 'sob)'],
 ['<s>', 'to'],
 ['<s>', 'this']]

In [48]:
%%time
for epoch in range(MAX_EPOCH):
    cum_loss = 0
    for i, (src_sents, tgt_sents) in enumerate(batch_iter((train_src, train_tgt), batch_size=BATCH_SIZE, shuffle=True)):
        optimizer.zero_grad()
        batch_size = len(src_sents)

        batch_loss = -model(src_sents, tgt_sents).sum()
        batch_loss /= batch_size
        cum_loss += batch_loss
        batch_loss.backward()

         # clip gradient
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        optimizer.step()
    cum_loss /= len(train_src)
    print(f"Epoch: {str(epoch).zfill(3)} - Cumulative loss: {cum_loss}")

Epoch: 000 - Cumulative loss: 76.07613372802734
Epoch: 001 - Cumulative loss: 72.86973571777344
Epoch: 002 - Cumulative loss: 70.12184143066406
Epoch: 003 - Cumulative loss: 68.92034149169922
Epoch: 004 - Cumulative loss: 66.57279968261719
Epoch: 005 - Cumulative loss: 63.73548126220703
Epoch: 006 - Cumulative loss: 60.94084930419922
Epoch: 007 - Cumulative loss: 58.18658447265625
Epoch: 008 - Cumulative loss: 54.783241271972656
Epoch: 009 - Cumulative loss: 51.839515686035156
Epoch: 010 - Cumulative loss: 49.84929656982422
Epoch: 011 - Cumulative loss: 48.63710403442383
Epoch: 012 - Cumulative loss: 47.67711639404297
Epoch: 013 - Cumulative loss: 46.95255661010742
Epoch: 014 - Cumulative loss: 46.43537902832031
Epoch: 015 - Cumulative loss: 45.718360900878906
Epoch: 016 - Cumulative loss: 45.62044906616211
Epoch: 017 - Cumulative loss: 45.077247619628906
Epoch: 018 - Cumulative loss: 45.197303771972656
Epoch: 019 - Cumulative loss: 44.873931884765625
Epoch: 020 - Cumulative loss: 44.4

In [49]:
from nmt.scripts import beam_search_decoder

In [50]:
hypothesis = beam_search_decoder(
    model=model,
    src_sent=train_src[0]
)

In [51]:
hypothesis

[Hypothesis(value=['anthe', 'quedado', 'conmovido', 'por', 'confer', 'esta', 'y', 'conferencia,', 'deseo', 'a', 'agradece', 'todos', 'a', 'ustados', 'de', 'denado', 'de', 'de', 'lo', 'lo', 'que', 'que', 'que', 'que', 'que', 'decir', 'decir', 'lo', 'noche.', 'noche.', 'noche', 'noche.', 'que', 'noche.', 'ncenthe', 'nicer', 'coment', 'naca', 'niche', 'noche.', 'que', 'niche', 'niche.', 'niche.', 'niche', 'noche.', 'niche', 'noche.', 'nicer', 'noche.', 'noche.', 'noche.', 'noche.', 'noche.', 'noche.', 'niche.', 'noche.', 'niche.', 'noche.', 'nicer', 'niche.', 'nicer', 'noche.', 'nicer', 'niche.', 'niche.', 'niche.', 'niche.', 'noche.', 'noche.'], score=-1.1814473867416382)]