In [34]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [35]:
from my_utils import Dictionary

n_unique = 10

src_dict = Dictionary(['<EOS>'])
tgt_dict = Dictionary(['<BOS>', '<EOS>'])
for n in range(n_unique):
    src_dict.add_word(str(n))
    tgt_dict.add_word(str(n))

In [36]:
from my_utils.toy_data import invert_seq
train = invert_seq(5000, n_unique=n_unique)
test = invert_seq(100, n_unique=n_unique)

In [37]:
import torch
from my_utils import DataLoader
from torch_models.utils import seq2seq

def numericalize(dataset, src_dict, tgt_dict):
    numericalized = [([src_dict(s) for s in src], [tgt_dict(t) for t in tgt]) for src, tgt in dataset]
    return numericalized

# device = 'cuda:0'
device = torch.device('cpu')
trans_func = seq2seq(device)

train_loader = DataLoader(numericalize(train, src_dict, tgt_dict), batch_size=64, trans_func=trans_func)
test_loader = DataLoader(numericalize(test, src_dict, tgt_dict), batch_size=50, trans_func=trans_func)
print(train_loader)

DataLoader(
	datasize: 5000
	batchsize: 64
	n_batches: 79
	trans_func: seq2seq
	device: cpu
)


In [38]:
from torch_models import AttnSeq2Seq, Seq2Seq

embed_size=64
dropout=0.3
model = Seq2Seq(embed_size=embed_size, hidden_size=embed_size, src_vocab_size=len(src_dict), tgt_vocab_size=len(tgt_dict),
                    src_EOS=src_dict('<EOS>'), tgt_BOS=tgt_dict('<BOS>'), tgt_EOS=tgt_dict('<EOS>'),
                    num_layers=1, bidirectional=True, dropout=dropout, rnn='LSTM')
print(model)

Seq2Seq(
  (encoder): RNNEncoder(
    (embedding): Embedding(12, 64, padding_idx=11)
    (rnn): LSTM(64, 64, dropout=0.3, bidirectional=True)
  )
  (decoder): RNNEncoder(
    (embedding): Embedding(13, 64, padding_idx=12)
    (rnn): LSTM(64, 64, dropout=0.3)
  )
  (generator): MLP(
    (fc_out): Linear(in_features=64, out_features=12, bias=True)
    (dropout): Dropout(p=0.3)
    (criterion): CrossEntropyLoss()
    (activation): Tanh()
  )
)


  "num_layers={}".format(dropout, num_layers))


In [50]:
%%time
from my_utils import Trainer, EvaluatorSeq, EvaluatorLoss
from my_utils.misc.logging import init_logger
from torch.optim import Adam, SGD

init_logger()

optimizer = Adam(model.parameters(), lr=0.001)
# optimizer = SGD(model.parameters(), lr=0.01)
evaluator = EvaluatorSeq(model, test_loader, measure='accuracy')
# evaluator = EvaluatorLoss(model, test_loader)

trainer = Trainer(model, train_loader)
trainer.train_epoch(optimizer, max_epoch=5,
              evaluator=evaluator, score_monitor=None)

[2018-10-23 19:04:22,118 INFO] steps [79/79]	loss: 0.06758517437154733	


[[5, 3, 11, 11], [8, 9, 10, 8], [5, 6, 9, 8, 7], [7, 4, 9, 9, 6], [7, 8, 7, 2, 6], [8, 2, 11, 9], [7, 8, 7, 7], [11, 4, 4, 2, 2], [3, 11, 3], [5, 4, 2, 2], [6, 9, 2, 5], [2, 9, 10, 4], [2, 9, 9, 9, 3], [2, 10, 7, 9], [7, 10, 2, 8, 11], [3, 5, 4, 3, 6], [10, 5, 11, 11, 8], [4, 11, 8, 8], [5, 5, 5, 7, 11], [4, 6, 9, 3], [5, 7, 7, 2, 11], [9, 10, 6, 8], [2, 9, 9, 2], [9, 5, 3, 7], [2, 6, 4], [5, 8, 3, 3], [2, 4, 10], [11, 8, 9], [3, 10, 10, 7], [8, 6, 5], [8, 5, 6], [7, 8, 9], [5, 8, 7], [8, 4, 9], [9, 3, 8, 11], [5, 9, 2, 8, 4], [2, 6, 5, 8], [6, 5, 5, 7, 4], [10, 5, 8, 11], [5, 10, 10], [10, 9, 6], [10, 7, 8, 11, 7], [7, 8, 4, 2, 8], [5, 11, 9], [8, 8, 6], [3, 4, 9, 8], [9, 3, 8, 8, 6], [9, 3, 10, 6], [10, 7, 4, 4, 8], [6, 2, 10]]


[2018-10-23 19:04:22,889 INFO] Evaluator accuracy: 0.9950980392156863	


[[10, 2, 5, 9, 11], [4, 4, 3], [3, 2, 7, 11], [4, 8, 2, 11, 3], [8, 10, 7, 5], [4, 4, 6, 6], [3, 6, 2, 5], [6, 5, 4], [9, 7, 10, 6, 4], [6, 6, 7], [4, 11, 7, 11], [9, 2, 7, 4], [7, 3, 7, 5], [8, 2, 4, 4, 2], [9, 5, 9, 4, 4], [5, 4, 10], [2, 8, 10, 11, 6], [7, 2, 6, 9, 11], [6, 11, 5], [3, 6, 5, 5, 9], [7, 11, 2, 8], [7, 8, 4], [8, 2, 7, 6, 3], [10, 6, 11, 2, 7], [8, 6, 6, 4], [9, 4, 6], [7, 10, 2, 7, 7], [11, 6, 11, 6], [3, 10, 6], [9, 9, 10], [9, 6, 7, 5], [8, 11, 8, 3, 3], [5, 6, 3, 6], [5, 6, 6, 11], [6, 4, 6, 11], [3, 4, 5, 10, 7], [11, 8, 7, 5, 11], [6, 2, 3, 5, 8], [2, 7, 4, 8], [11, 7, 7, 7, 11], [8, 4, 8, 4, 7], [11, 2, 3, 2], [4, 3, 6, 3], [2, 7, 11], [5, 11, 2], [6, 4, 7], [3, 2, 5, 10], [7, 10, 7], [4, 11, 10, 5, 7], [2, 2, 3, 9, 2]]


[2018-10-23 19:04:25,179 INFO] steps [79/79]	loss: 0.04027925794826278	


[[5, 3, 11, 11], [8, 9, 10, 8], [5, 6, 9, 8, 7], [7, 4, 9, 9, 6], [7, 8, 7, 2, 6], [8, 2, 11, 9], [7, 8, 7, 7], [11, 4, 4, 2, 2], [3, 11, 3], [5, 4, 2, 2], [6, 9, 2, 5], [2, 9, 10, 4], [2, 9, 9, 9, 3], [2, 10, 7, 9], [7, 10, 2, 8, 11], [3, 5, 4, 3, 6], [10, 5, 11, 11, 8], [4, 11, 8, 8], [5, 5, 5, 7, 11], [4, 6, 9, 3], [5, 7, 7, 2, 11], [9, 10, 6, 8], [2, 9, 9, 2], [9, 5, 3, 7], [2, 6, 4], [5, 8, 3, 3], [2, 4, 10], [11, 8, 9], [3, 10, 10, 7], [8, 6, 5], [8, 5, 6], [7, 8, 9], [5, 8, 7], [8, 4, 9], [9, 3, 8, 11], [5, 9, 2, 8, 4], [2, 6, 5, 8], [6, 5, 5, 7, 4], [10, 5, 8, 11], [5, 10, 10], [10, 9, 6], [10, 7, 8, 11, 7], [7, 8, 4, 2, 8], [5, 11, 9], [8, 8, 6], [3, 4, 9, 8], [9, 3, 8, 8, 6], [9, 3, 10, 6], [10, 7, 4, 4, 8], [6, 2, 10]]


[2018-10-23 19:04:26,071 INFO] Evaluator accuracy: 0.9950980392156863	


[[10, 2, 5, 9, 11], [4, 4, 3], [3, 2, 7, 11], [4, 8, 2, 11, 3], [8, 10, 7, 5], [4, 4, 6, 6], [3, 6, 2, 5], [6, 5, 4], [9, 7, 10, 6, 4], [6, 6, 7], [4, 11, 7, 11], [9, 2, 7, 4], [7, 3, 7, 5], [8, 2, 4, 4, 2], [9, 5, 9, 4, 4], [5, 4, 10], [2, 8, 10, 11, 6], [7, 2, 6, 9, 11], [6, 11, 5], [3, 6, 5, 5, 9], [7, 11, 2, 8], [7, 8, 4], [8, 2, 7, 6, 3], [10, 6, 11, 2, 7], [8, 6, 6, 4], [9, 4, 6], [7, 10, 2, 7, 7], [11, 6, 11, 6], [3, 10, 6], [9, 9, 10], [9, 6, 7, 5], [8, 11, 8, 3, 3], [5, 6, 3, 6], [5, 6, 6, 11], [6, 4, 6, 11], [3, 4, 5, 10, 7], [11, 8, 7, 5, 11], [6, 2, 3, 5, 8], [2, 7, 4, 8], [11, 7, 7, 7, 11], [8, 4, 8, 4, 7], [11, 2, 3, 2], [4, 3, 6, 3], [2, 7, 11], [5, 11, 2], [6, 4, 7], [3, 2, 5, 10], [7, 10, 7], [4, 11, 10, 5, 7], [2, 2, 3, 9, 2]]


[2018-10-23 19:04:28,696 INFO] steps [79/79]	loss: 0.03252100456458858	


[[5, 3, 11, 11], [8, 9, 10, 8], [5, 6, 9, 8, 7], [7, 4, 9, 9, 6], [7, 8, 7, 2, 6], [8, 2, 11, 9], [7, 8, 7, 7], [11, 4, 4, 2, 2], [3, 11, 3], [5, 4, 2, 2], [6, 9, 2, 5], [2, 9, 10, 4], [2, 9, 9, 9, 3], [2, 10, 7, 9], [7, 10, 2, 8, 11], [3, 5, 4, 3, 6], [10, 5, 11, 11, 8], [4, 11, 8, 8], [5, 5, 5, 7, 11], [4, 6, 9, 3], [5, 7, 7, 2, 11], [9, 10, 6, 8], [2, 9, 9, 2], [9, 5, 3, 7], [2, 6, 4], [5, 8, 3, 3], [2, 4, 10], [11, 8, 9], [3, 10, 10, 7], [8, 6, 5], [8, 5, 6], [7, 8, 9], [5, 8, 7], [8, 4, 9], [9, 3, 8, 11], [5, 9, 2, 8, 4], [2, 6, 5, 8], [6, 5, 5, 7, 4], [10, 5, 8, 11], [5, 10, 10], [10, 9, 6], [10, 7, 8, 11, 7], [7, 8, 4, 2, 8], [5, 11, 9], [8, 8, 6], [3, 4, 9, 8], [9, 3, 8, 8, 6], [9, 3, 10, 6], [10, 7, 4, 4, 8], [6, 2, 10]]


[2018-10-23 19:04:29,620 INFO] Evaluator accuracy: 1.0	


[[10, 2, 5, 9, 11], [4, 4, 3], [3, 2, 7, 11], [4, 8, 2, 11, 3], [8, 10, 7, 5], [4, 4, 6, 6], [3, 6, 2, 5], [6, 5, 4], [9, 7, 10, 6, 4], [6, 6, 7], [4, 11, 7, 11], [9, 2, 7, 4], [7, 3, 7, 5], [8, 2, 4, 4, 2], [9, 5, 9, 4, 4], [5, 4, 10], [2, 8, 10, 11, 6], [7, 2, 6, 9, 11], [6, 11, 5], [3, 6, 5, 5, 9], [7, 11, 2, 8], [7, 8, 4], [2, 8, 7, 6, 3], [10, 6, 11, 2, 7], [8, 6, 6, 4], [9, 4, 6], [7, 10, 2, 7, 7], [11, 6, 11, 6], [3, 10, 6], [9, 9, 10], [9, 6, 7, 5], [8, 11, 8, 3, 3], [5, 6, 3, 6], [5, 6, 6, 11], [6, 4, 6, 11], [3, 4, 5, 10, 7], [11, 8, 7, 5, 11], [6, 2, 3, 5, 8], [2, 7, 4, 8], [11, 7, 7, 7, 11], [8, 4, 8, 4, 7], [11, 2, 3, 2], [4, 3, 6, 3], [2, 7, 11], [5, 11, 2], [6, 4, 7], [3, 2, 5, 10], [7, 10, 7], [4, 11, 10, 5, 7], [2, 2, 3, 9, 2]]


KeyboardInterrupt: 

In [47]:
for i in range(1, 10):
    model.beam_width = 3
    print(evaluator.evaluate())

[0, 1, 2]
[0, 0, 1]
[0, 1, 0]
[0, 1, 2]
[0, 1, 2]
[0, 0, 0]
[2, 1, 0]
[2, 0, 1]
[0, 1, 2]
[0, 1, 0]
[2, 0, 0]
[1, 0, 2]
[2, 1, 0]
[0, 1, 2]
[0, 1, 0]
[0, 1, 0]
[2, 0, 1]
[0, 1, 0]
[0, 1, 2]
[0, 0, 1]
[0, 0, 0]
[0, 1, 0]
[0, 2, 0]
[0, 1, 2]
[0, 0, 0]
[0, 2, 1]
[2, 1, 0]
[0, 1, 2]
[0, 0, 0]
[0, 0, 1]
[0, 1, 2]
[0, 1, 2]
[1, 0, 2]
[0, 1, 2]
[0, 1, 0]
[0, 1]
[0, 1, 2]
[0, 0, 1]
[0, 1, 2]
[0, 1, 2]
[0, 2, 0]
[1, 0, 2]
[0, 1, 2]
[0, 1, 0]
[0, 1, 2]
[0, 1, 2]
[0, 1, 2]
[2, 0, 0]
[2, 0, 1]
[1, 0, 2]
[0, 1, 2]
[0, 0, 1]
[0, 1, 2]
[0, 1, 2]
[0, 1, 2]
[0, 1, 2]
[0, 1, 0]
[1, 1, 2]
[2, 0, 1]
[0, 1, 2]
[0, 0, 0]
[0, 0, 1]
[0, 0, 1]
[0, 1, 0]
[0, 1, 2]
[0, 0, 1]
[0, 0, 1]
[1, 0, 0]
[0, 1, 1]
[0, 1, 2]
[0, 1, 0]
[0, 1, 2]
[0, 1, 2]
[1, 0]
[0, 1, 2]
[0, 0, 1]
[1, 0, 2]
[0, 0, 2]
[0, 1, 2]
[0, 1, 0]
[0, 0, 2]
[0, 0, 1]
[0, 1, 0]
[0, 1, 2]
[0, 1, 2]
[1, 0, 2]
[1, 2, 0]
[0, 1, 2]
[0, 0, 0]
[0, 1, 2]
[0, 1, 2]
[1, 0, 2]
[0, 1, 2]
[0, 2, 0]
[2, 0, 1]
[1, 2, 0]
[0, 1, 2]
[0, 1, 0]
[0, 1, 1]
[0, 1, 2]
[0, 1,

[0, 1, 2]
[0, 0, 1]
[0, 1, 2]
[0, 1, 2]
[0, 1, 2]
[0, 1, 2]
[0, 1, 0]
[1, 1, 2]
[2, 0, 1]
[0, 1, 2]
[0, 0, 0]
[0, 0, 1]
[0, 0, 1]
[0, 1, 0]
[0, 1, 2]
[0, 0, 1]
[0, 0, 1]
[1, 0, 0]
[0, 1, 1]
[0, 1, 2]
[0, 1, 0]
[0, 1, 2]
[0, 1, 2]
[1, 0]
[0, 1, 2]
[0, 0, 1]
[1, 0, 2]
[0, 0, 2]
[0, 1, 2]
[0, 1, 0]
[0, 0, 2]
[0, 0, 1]
[0, 1, 0]
[0, 1, 2]
[0, 1, 2]
[1, 0, 2]
[1, 2, 0]
[0, 1, 2]
[0, 0, 0]
[0, 1, 2]
[0, 1, 2]
[1, 0, 2]
[0, 1, 2]
[0, 2, 0]
[2, 0, 1]
[1, 2, 0]
[0, 1, 2]
[0, 1, 0]
[0, 1, 1]
[0, 1, 2]
[0, 1, 2]
[1, 0, 0]
[0, 1, 0]
[2, 0, 1]
[0, 1, 2]
[0, 1, 2]
[0, 2, 1]
[0, 1, 2]
[0, 1, 2]
[0, 2, 1]
[0, 0]
[0, 1, 2]
[0, 0, 0]
[0, 1, 2]
[0, 1, 2]
[1, 1, 0]
[0, 1, 2]
[0, 1, 2]
[0, 0, 1]
[1, 0, 2]
[0, 2, 2]
[0, 1, 2]
[0, 2, 1]
[0, 2, 1]
[0, 1, 2]
[0, 0, 1]
[1, 2, 0]
[0, 1, 2]
[0, 1, 2]
[0, 1, 0]
[0, 1, 2]
[0, 0, 1]
[1, 0, 2]
[0, 1, 2]
[0, 0, 1]
[0, 2, 1]
[0, 1, 2]
[0, 0, 0]
[2, 1, 0]
[0, 2, 0]
[0, 1, 2]
[0, 2, 0]
[0, 1, 2]
[1, 1, 2]
[0, 1, 2]
[0, 1, 2]
[0, 1, 0]
[0, 1, 0]
[0, 2, 0]
[0, 1, 2]
[1, 0,

[1, 0, 2]
[0, 2, 2]
[0, 1, 2]
[0, 2, 1]
[0, 2, 1]
[0, 1, 2]
[0, 0, 1]
[1, 2, 0]
[0, 1, 2]
[0, 1, 2]
[0, 1, 0]
[0, 1, 2]
[0, 0, 1]
[1, 0, 2]
[0, 1, 2]
[0, 0, 1]
[0, 2, 1]
[0, 1, 2]
[0, 0, 0]
[2, 1, 0]
[0, 2, 0]
[0, 1, 2]
[0, 2, 0]
[0, 1, 2]
[1, 1, 2]
[0, 1, 2]
[0, 1, 2]
[0, 1, 0]
[0, 1, 0]
[0, 2, 0]
[0, 1, 2]
[1, 0, 0]
[0, 1, 2]
[0, 0, 1]
[1, 0, 0]
[0, 1, 2]
[0, 1, 0]
[0, 0, 1]
[1, 0, 0]
[0, 1, 2]
[0, 0, 1]
[0, 0, 1]
[0, 1, 2]
[0, 0, 1]
[0, 1, 0]
[0, 1, 2]
[0, 2, 1]
[0, 1, 0]
[0, 1, 2]
[1, 0, 0]
[0, 1, 2]
[0, 1, 0]
[0, 1, 0]
[2, 0, 1]
[0, 2, 1]
[0, 1, 2]
[1, 0, 0]
[1, 2, 0]
[0, 1, 2]
[0, 0, 0]
[0, 1, 0]
[0, 1, 2]
[1, 0, 0]
[2, 1, 0]
[2, 1, 0]
[0, 1, 2]
[0, 1, 2]
[2, 0, 0]
[2, 1, 0]
[0, 2, 1]
[0, 1, 2]
[0, 1, 0]
[0, 1, 2]
[0, 1, 2]
[0, 1, 2]
[0, 1, 0]
[0, 1, 2]
[1, 2, 0]
[0, 1, 2]
[0, 1, 2]
[0, 1, 2]
[0, 1, 2]
[0, 1, 2]
[1, 0, 0]
[0, 0, 1]
[0, 2, 1]
[0, 2, 1]
[0, 1, 2]
[0, 0, 1]
[0, 0, 1]
[0, 1, 2]
[0, 0, 0]
[0, 1, 0]
[0, 2, 0]
[0, 1, 2]
[0, 1, 0]
[0, 2, 1]
[0, 1, 2]
[2, 0, 1]
[0, 1, 2]


[0, 1, 2]
[0, 1, 2]
[0, 1, 0]
[0, 1, 2]
[1, 2, 0]
[0, 1, 2]
[0, 1, 2]
[0, 1, 2]
[0, 1, 2]
[0, 1, 2]
[1, 0, 0]
[0, 0, 1]
[0, 2, 1]
[0, 2, 1]
[0, 1, 2]
[0, 0, 1]
[0, 0, 1]
[0, 1, 2]
[0, 0, 0]
[0, 1, 0]
[0, 2, 0]
[0, 1, 2]
[0, 1, 0]
[0, 2, 1]
[0, 1, 2]
[2, 0, 1]
[0, 1, 2]
[0, 0, 0]
[1, 0, 2]
[2, 0, 1]
[0, 1, 2]
[0, 1, 1]
[0, 1, 0]
[0, 2, 1]
[0, 1, 2]
[1, 0, 2]
[1, 0, 0]
[0, 0, 2]
[0, 1, 2]
[0, 1, 0]
[1, 0, 2]
[0, 1, 2]
[0, 0, 1]
[0, 0, 0]
[0, 2, 1]
[0, 0, 2]
[0, 1, 2]
[0, 0, 1]
[0, 0, 1]
[0, 1, 2]
[1, 0, 2]
[0, 2, 1]
[0, 1, 1]
[0, 1, 2]
[0, 2, 0]
[2, 0, 0]
[2, 0, 1]
[0, 1, 2]
[0, 0, 1]
[0, 0, 0]
[0, 0, 1]
[0, 1, 2]
[1, 0, 0]
[0, 1, 2]
[0, 1, 2]
[0, 1, 2]
[0, 1, 2]
[0, 0, 1]
[0, 1, 2]
[0, 1, 2]
[0, 1, 2]
[0, 1, 2]
[0, 1, 1]
[0, 2, 1]
[0, 1, 2]
[0, 1, 0]
[0, 1, 2]
[0, 2, 1]
[0, 2, 1]
[0, 1, 2]
[1, 0, 0]
[1, 0, 0]
[1, 0, 2]
[0, 1, 2]
[0, 1, 2]
[0, 1, 0]
[0, 2, 0]
[0, 1, 2]
[1, 0, 2]
[0, 1, 2]
[0, 1, 2]
[0, 1, 2]
[0, 1, 2]
[0, 1, 0]
[0, 0, 0]
[2, 1, 0]
[0, 1, 2]
[1, 0, 0]
[1, 0, 2]
[0, 1, 2]


[0, 1, 0]
[0, 0, 0]
[2, 1, 0]
[0, 1, 2]
[1, 0, 0]
[1, 0, 2]
[0, 1, 2]
[1, 1, 0]
[2, 0, 0]
[1, 0, 2]
[1, 0, 0]
[0, 1, 2]
[0, 0, 2]
[1, 0, 0]
[0, 1, 2]
[0, 0, 1]
[0, 1, 2]
[0, 0, 2]
[1, 0, 2]
[1, 0, 0]
[0, 1, 2]
[0, 0, 1]
[0, 1, 2]
[0, 1, 2]
[1, 0, 0]
[0, 1, 2]
[0, 0, 1]
[0, 0, 1]
[0, 1, 2]
[0, 1, 0]
[0, 0, 2]
[0, 1, 2]
[0, 1, 2]
[0, 1, 1]
[0, 1, 2]
[0, 1, 2]
[0, 0, 1]
[0, 0, 1]
[0, 1, 2]
[0, 0, 1]
[0, 1, 2]
[0, 2, 0]
[0, 1, 2]
[0, 0, 0]
[0, 1, 2]
[0, 1, 0]
[0, 2, 0]
[0, 1, 2]
[0, 0, 1]
[0, 0, 1]
[0, 2, 1]
[0, 1, 2]
[0, 0, 1]
[0, 1, 2]
[0, 1]
[0, 1, 2]
[0, 0, 1]
[0, 0, 1]
[0, 0, 0]
[0, 1, 2]
[0, 1, 0]
[2, 0, 0]
[0, 1, 0]
[0, 1, 2]
[0, 1, 2]
[0, 1, 0]
[0, 1, 0]
[1, 0, 2]
[0, 1, 2]
[0, 1, 2]
[2, 1, 0]
[1, 0, 2]
[0, 0, 1]
[0, 0, 1]
[0, 1, 2]
[0, 0, 2]
[0, 0, 1]
[1, 0, 0]
[0, 1, 2]
[0, 0, 1]
[0, 1, 0]
[0, 1, 2]
[0, 1, 2]
[0, 1, 2]
[0, 1, 0]
[0, 0, 2]
[0, 2, 1]
[0, 2, 1]
[0, 1, 2]
[0, 0, 2]
[0, 0, 1]
[0, 2, 1]
[0, 1, 2]
[0, 0, 1]
[0, 0, 0]
[0, 1, 2]
[0, 1, 2]
[0, 0, 1]
[0, 1, 2]
[0, 1, 2]
[0,

In [43]:
torch.tensor(4).item()

4

In [9]:
iter(train_loader)
l = 10
inputs, targets = next(train_loader)
inputs = inputs[:l]
targets = targets[:l]
generated = model.predict(inputs)
print('======= input ======')
for seq in inputs:
    print([src_dict[s.item()] for s in seq])
print('======= output ======')
for seq in generated[:l]:
    print([tgt_dict[s] for s in seq])

['1', '2', '7', '7']
['9', '3', '8', '4', '2']
['9', '7', '9']
['1', '8', '0', '2', '1']
['4', '6', '3']
['7', '5', '4', '4', '8']
['3', '9', '9', '3']
['8', '3', '7']
['9', '7', '2', '6']
['9', '3', '3', '8']
['7', '7', '2', '1']
['2', '4', '8', '3', '9']
['9', '7', '9']
['1', '2', '0', '8', '1']
['3', '6', '4']
['8', '4', '4', '5', '7']
['3', '9', '9', '3']
['7', '3', '8']
['6', '2', '7', '9']
['8', '3', '3', '9']
