In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook

In [2]:
from my_utils import Dictionary

n_unique = 10

src_dict = Dictionary(['<EOS>'])
tgt_dict = Dictionary(['<BOS>', '<EOS>'])
for n in range(n_unique):
    src_dict.add_word(str(n))
    tgt_dict.add_word(str(n))

In [3]:
from my_utils.toy_data import invert_seq
train = invert_seq(5000, n_unique=n_unique)
test = invert_seq(100, n_unique=n_unique)

In [4]:
import torch
from my_utils import DataLoader
from torch_models.utils import seq2seq

def numericalize(dataset, src_dict, tgt_dict):
    numericalized = [([src_dict(s) for s in src], [tgt_dict(t) for t in tgt]) for src, tgt in dataset]
    return numericalized

# device = 'cuda:0'
device = torch.device('cpu')
trans_func = seq2seq(device)

train_loader = DataLoader(numericalize(train, src_dict, tgt_dict), batch_size=3, trans_func=trans_func)
test_loader = DataLoader(numericalize(test, src_dict, tgt_dict), batch_size=50, trans_func=trans_func)
print(train_loader)

DataLoader(
	datasize: 5000
	batchsize: 3
	n_batches: 1667
	trans_func: seq2seq
	device: cpu
)


In [23]:
from torch_models import AttnSeq2Seq, Seq2Seq

embed_size=64
dropout=0
model = AttnSeq2Seq(embed_size=embed_size, hidden_size=embed_size, src_vocab_size=len(src_dict), tgt_vocab_size=len(tgt_dict),
                    src_EOS=src_dict('<EOS>'), tgt_BOS=tgt_dict('<BOS>'), tgt_EOS=tgt_dict('<EOS>'),
                    num_layers=2, bidirectional=True, dropout=dropout, rnn='LSTM',
                    attention='bilinear', fuse_query='add', input_feeding=True)
print(model)

AttnSeq2Seq(
  (encoder): RNNEncoder(
    (embedding): Embedding(12, 64, padding_idx=11)
    (rnn): LSTM(64, 64, num_layers=2, bidirectional=True)
  )
  (decoder): RNNEncoder(
    (embedding): Embedding(13, 64, padding_idx=12)
    (rnn): LSTM(128, 64, num_layers=2)
  )
  (generator): MLP(
    (fc_out): Linear(in_features=64, out_features=12, bias=True)
    (dropout): Dropout(p=0)
    (criterion): CrossEntropyLoss()
    (activation): Tanh()
  )
  (attention): BiLinearAttn(
    (linear): Linear(in_features=64, out_features=64, bias=False)
  )
)


In [25]:
iter(train_loader)
inputs, targets = next(train_loader)

encoded = model.encode(inputs)
model.decode_input_feeding(targets, encoded)

tensor([[-0.0435,  0.0482,  0.1312,  0.1326,  0.1308,  0.0004,  0.0260,
         -0.0086,  0.0914,  0.0520, -0.0205,  0.2289,  0.0085,  0.1192,
         -0.0428,  0.1020,  0.1002, -0.0133, -0.0635, -0.1770, -0.0924,
         -0.1275, -0.0476,  0.0240, -0.1513, -0.0148, -0.1380,  0.1871,
          0.0782, -0.0613,  0.0673, -0.0782, -0.0476,  0.0240, -0.0239,
          0.1622, -0.0246, -0.0250, -0.0239, -0.0447, -0.0939,  0.1039,
         -0.0374, -0.0226, -0.0015,  0.1455, -0.0287,  0.0617,  0.1312,
         -0.0151,  0.1560,  0.0158,  0.0607,  0.0279, -0.0392, -0.0542,
         -0.1133, -0.0898, -0.0582,  0.0746,  0.0428,  0.0838, -0.0336,
          0.0275],
        [-0.0180,  0.0337,  0.1198,  0.1294,  0.0881,  0.0095,  0.0109,
         -0.0351,  0.0690,  0.0355, -0.0142,  0.2128,  0.0356,  0.1107,
         -0.0376,  0.0726,  0.0653, -0.0207, -0.0436, -0.1445, -0.0812,
         -0.1157, -0.0591,  0.0130, -0.1331,  0.0094, -0.1223,  0.1424,
          0.0724, -0.0550,  0.0942, -0.0768, 

In [319]:
%%time
for inputs, targets in train_loader:
    encoded = model.encode(inputs)
    decoded_ = model.decode_input_feeding(targets, encoded)

CPU times: user 16.8 s, sys: 114 ms, total: 16.9 s
Wall time: 8.66 s


In [320]:
%%time
for inputs, targets in train_loader:
    encoded = model.encode(inputs)
    decoded = model.decode(targets, encoded)

CPU times: user 13.4 s, sys: 85.5 ms, total: 13.5 s
Wall time: 6.84 s


In [161]:
%%time
for _ in range(10):
    for inputs, targets in train_loader:
        decoded = model.decode(inputs, {})

CPU times: user 7.76 s, sys: 50.5 ms, total: 7.81 s
Wall time: 3.98 s


In [6]:
%%time
from my_utils import Trainer, EvaluatorSeq, EvaluatorLoss
from my_utils.misc.logging import init_logger
from torch.optim import Adam, SGD

init_logger()

optimizer = Adam(model.parameters(), lr=0.001)
evaluator = EvaluatorSeq(model, test_loader, measure='accuracy')
# evaluator = EvaluatorLoss(model, test_loader)

trainer = Trainer(model, train_loader)
trainer.train_epoch(optimizer, max_epoch=3,
              evaluator=evaluator, score_monitor=None)

[2018-10-27 02:13:24,574 INFO] steps [100/100]	loss: 1.9606751823425292	
[2018-10-27 02:13:24,640 INFO] Evaluator accuracy: 0.671875	
[2018-10-27 02:13:27,495 INFO] steps [100/100]	loss: 0.33240895237773654	
[2018-10-27 02:13:27,561 INFO] Evaluator accuracy: 1.0	
[2018-10-27 02:13:30,236 INFO] steps [100/100]	loss: 0.03726597668603063	
[2018-10-27 02:13:30,297 INFO] Evaluator accuracy: 1.0	


CPU times: user 23.1 s, sys: 332 ms, total: 23.4 s
Wall time: 8.49 s


In [30]:
iter(test_loader)
inputs, targets = next(test_loader)

In [31]:
iter(train_loader)
l = 10
inputs, targets = next(train_loader)
inputs = inputs[:l]
targets = targets[:l]
generated = model.predict(inputs)
print('======= input ======')
for seq in inputs:
    print([src_dict[s.item()] for s in seq])
print('======= output ======')
for seq in generated[:l]:
    print([tgt_dict[s] for s in seq])

tensor([[[ 1.4424, -0.1017,  1.5236, -1.4988,  1.3046, -1.0514,  1.2535,
          -0.0197,  1.1458, -1.2257, -0.0167,  0.9817,  1.2132,  0.2181,
          -1.4476, -1.3340, -1.3247,  1.3937, -1.2628,  1.1180,  1.1383,
          -1.2805, -1.0017,  1.4529, -0.1208,  1.2628, -0.1153,  0.8560,
           1.3472, -1.3998,  1.2472, -1.5000,  1.1410,  0.8864, -0.3519,
          -1.2725,  1.2763,  1.2862,  1.4910,  0.3846, -0.2490,  1.4636,
          -1.3695,  1.3621,  1.1179, -0.2767, -1.4731, -1.3272, -1.2645,
          -0.8719, -0.6327,  0.1412, -1.2650, -1.4987, -1.3373, -1.5591,
           1.2529, -1.2680, -1.0812,  1.3748, -0.1330, -1.4495, -1.2359,
           0.9917]]])
tensor([[[ 1.4424, -0.1017,  1.5236, -1.4988,  1.3046, -1.0514,  1.2535,
          -0.0197,  1.1458, -1.2257, -0.0167,  0.9817,  1.2132,  0.2181,
          -1.4476, -1.3340, -1.3247,  1.3937, -1.2628,  1.1180,  1.1383,
          -1.2805, -1.0017,  1.4529, -0.1208,  1.2628, -0.1153,  0.8560,
           1.3472, -1.3998,  