In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook

In [2]:
from my_utils import Dictionary

n_unique = 10

src_dict = Dictionary(['<EOS>'])
tgt_dict = Dictionary(['<BOS>', '<EOS>'])
for n in range(n_unique):
    src_dict.add_word(str(n))
    tgt_dict.add_word(str(n))

In [3]:
from my_utils.toy_data import invert_seq
train = invert_seq(5000, n_unique=n_unique)
test = invert_seq(100, n_unique=n_unique)

In [4]:
import torch
from my_utils import DataLoader
from torch_models.utils import seq2seq

def numericalize(dataset, src_dict, tgt_dict):
    numericalized = [([src_dict(s) for s in src], [tgt_dict(t) for t in tgt]) for src, tgt in dataset]
    return numericalized

# device = 'cuda:0'
device = torch.device('cpu')
trans_func = seq2seq(device)

train_loader = DataLoader(numericalize(train, src_dict, tgt_dict), batch_size=16, trans_func=trans_func)
test_loader = DataLoader(numericalize(test, src_dict, tgt_dict), batch_size=50, trans_func=trans_func)

In [6]:
from torch_models import AttnSeq2Seq, Seq2Seq

embed_size=64
dropout=0
model = AttnSeq2Seq(embed_size=embed_size, hidden_size=embed_size, src_vocab_size=len(src_dict), tgt_vocab_size=len(tgt_dict),
                    src_EOS=src_dict('<EOS>'), tgt_BOS=tgt_dict('<BOS>'), tgt_EOS=tgt_dict('<EOS>'),
                    num_layers=1, bidirectional=True, dropout=dropout, rnn='lstm')
print(model)

AttnSeq2Seq(
  (encoder): RNNEncoder(
    (embedding): Embedding(12, 64, padding_idx=11)
    (rnn): LSTM(64, 64, bidirectional=True)
  )
  (decoder): RNNEncoder(
    (embedding): Embedding(13, 64, padding_idx=12)
    (rnn): LSTM(64, 64)
  )
  (generator): MLP(
    (fc_out): Linear(in_features=64, out_features=12, bias=True)
    (dropout): Dropout(p=0)
    (criterion): CrossEntropyLoss()
    (activation): Tanh()
  )
  (attn_hidden): Linear(in_features=128, out_features=64, bias=True)
  (attention): DotAttn()
)


In [7]:
%%time
from my_utils import Trainer, EvaluatorSeq, EvaluatorLoss
from my_utils.misc.logging import init_logger
from torch.optim import Adam, SGD

init_logger()

optimizer = Adam(model.parameters())
evaluator = EvaluatorSeq(model, test_loader, measure='BLEU')
# evaluator = EvaluatorLoss(model, test_loader)

trainer = Trainer(model, train_loader)
trainer.train_epoch(optimizer, max_epoch=3,
              evaluator=evaluator, score_monitor=None)

[2018-10-18 01:37:32,336 INFO] epoch [1/3]	loss: 0.6721453758510062	
[2018-10-18 01:37:32,411 INFO] Evaluator BLEU: 0.8948559423170824	
[2018-10-18 01:37:37,626 INFO] epoch [2/3]	loss: 0.008922914647833464	
[2018-10-18 01:37:37,697 INFO] Evaluator BLEU: 0.9032311716182105	
[2018-10-18 01:37:42,869 INFO] epoch [3/3]	loss: 0.0023302935984747383	
[2018-10-18 01:37:42,956 INFO] Evaluator BLEU: 0.9186254657926081	


CPU times: user 43.8 s, sys: 586 ms, total: 44.4 s
Wall time: 15.9 s


In [26]:
test_loader = DataLoader(numericalize(test, src_dict, tgt_dict), batch_size=100, trans_func=trans_func, shuffle=True)
test_evaluator = EvaluatorSeq(model, test_loader, measure='BLEU')
for _ in range(10):
    print(test_evaluator.evaluate())

0.9155832484543639
0.9155832484543639
0.9155832484543639
0.9155832484543639
0.9155832484543639
0.9155832484543639
0.9155832484543639
0.9155832484543639
0.9155832484543639
0.9155832484543639


In [24]:
iter(test_loader)
l = 100
inputs, targets = next(test_loader)
inputs = inputs[:l]
targets = targets[:l]
generated = model.predict(inputs)
print('======= input ======')
for seq in inputs:
    print([src_dict[s.item()] for s in seq])
print('======= output ======')
for seq in generated[:l]:
    print([tgt_dict[s] for s in seq])

['1', '9', '3']
['2', '5', '7', '1']
['5', '8', '6', '6', '9']
['2', '5', '4', '4', '1']
['8', '2', '6', '3']
['0', '2', '9']
['4', '5', '9', '5', '8']
['4', '9', '9', '3', '8']
['0', '5', '2', '3', '7']
['0', '3', '7', '4', '7']
['7', '5', '5', '0']
['6', '4', '2']
['9', '8', '8', '7']
['5', '0', '0']
['9', '2', '4', '6']
['2', '6', '7', '7']
['5', '2', '4', '6', '2']
['5', '9', '9', '0', '8']
['0', '5', '9', '8']
['6', '5', '9']
['3', '1', '8', '1', '5']
['6', '7', '4', '8']
['9', '6', '2', '7', '8']
['2', '5', '3', '6']
['0', '1', '4', '7', '7']
['6', '5', '1', '2']
['2', '2', '0', '2', '9']
['5', '9', '5']
['9', '8', '7', '9', '9']
['1', '1', '5']
['8', '4', '0', '9', '1']
['2', '5', '2']
['0', '0', '1']
['8', '2', '7']
['4', '5', '3', '0']
['3', '6', '4', '3', '6']
['3', '2', '4', '7']
['9', '0', '5']
['5', '0', '4']
['5', '3', '0']
['8', '6', '7']
['5', '4', '4']
['7', '8', '9', '9']
['6', '7', '0', '4']
['4', '7', '0', '7']
['3', '1', '1']
['2', '9', '6']
['3', '2', '9', '6', '1