In [21]:
%load_ext autoreload
%autoreload 2

from unsupervised_mt.dataset import Dataset
from unsupervised_mt.train import Trainer 
from unsupervised_mt.models import Embedding, Encoder, DecoderHat, Attention, Discriminator
from unsupervised_mt.batch_iterator import BatchIterator
from unsupervised_mt.utils import log_probs2indices, noise

from functools import partial
import torch.nn as nn

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
ds = Dataset(languages=('src', 'tgt'), 
             corp_paths=('../data/train.lc.norm.tok.en', '../data/train.lc.norm.tok.fr'), 
             emb_paths=('../data/wiki.multi.en.vec', '../data/wiki.multi.fr.vec'))

In [22]:
batch_iter = BatchIterator(ds)

In [4]:
hidden_size = 100
num_layers = 3

src_embedding = Embedding(ds.emb_matrix['src'])
tgt_embedding = Embedding(ds.emb_matrix['tgt'])
encoder_rnn = nn.GRU(input_size=src_embedding.embedding_dim, hidden_size=hidden_size, num_layers=num_layers)
decoder_rnn = nn.GRU(input_size=src_embedding.embedding_dim, hidden_size=hidden_size, num_layers=num_layers)
attention = Attention(src_embedding.embedding_dim, hidden_size, max_length=ds.max_length)
src_hat = DecoderHat(hidden_size, ds.vocabs['src'].size)
tgt_hat = DecoderHat(hidden_size, ds.vocabs['tgt'].size)
discriminator = Discriminator(hidden_size)

trainer = Trainer(partial(ds.translate_batch_word_by_word, language1='src', language2='tgt'), 
                  partial(ds.translate_batch_word_by_word, language1='tgt', language2='src'), 
                  src_embedding, tgt_embedding, encoder_rnn, decoder_rnn, attention, src_hat, tgt_hat, 
                  discriminator, 
                  ds.get_sos_index('src'), ds.get_sos_index('tgt'), 
                  ds.get_eos_index('src'), ds.get_eos_index('tgt'), 
                  ds.get_pad_index('src'), ds.get_pad_index('tgt'))

In [5]:
batch_size = 5
batch = batch_iter.load_batch(batch_size)
batch

{'src': tensor([[    4,     4,   180,     2,     4],
         [ 1997,    20,  1509,    38,   380],
         [   20,    27,    84,  1029,   681],
         [   66,    31,     4,    78,   177],
         [ 1998,    14,   818,     4,    38],
         [   21,   196,   182,   139,    16],
         [  845,    11,    11,    11,  1763],
         [  449,     1,     1,     1,    11],
         [   11,     3,     3,     3,     1],
         [    1,     3,     3,     3,     3]]),
 'tgt': tensor([[  159,   771,    23,    12,     4],
         [    4,   164,   138,    82,   209],
         [  117,    63,   164,   350,    16],
         [  288,   350,   973,     2,   363],
         [   22,     2,    12,  1845,   282],
         [   44,    12,  1334,  2068,    25],
         [  157,  1351,    83,   472,   133],
         [ 1162,    11,   329,   473,    68],
         [   11,     1,    11,    11,    11],
         [    1,     3,     1,     1,     1]])}

In [25]:
num_steps = 1000
for i, batch in enumerate(batch_iter.batch_generator(batch_size)):
    if i == num_steps:
        break
    trainer.train_step(batch)

KeyboardInterrupt: 

In [26]:
batch = batch_iter.load_batch(1)
ds.print_batch(batch['src'], 'src')
ds.print_batch(
    log_probs2indices(trainer.src2tgt.evaluate(batch['src'], ds.get_sos_index('tgt'), ds.get_eos_index('tgt'), 20)), 'tgt'
)

['a', 'black', 'and', 'a', 'tan', 'dog', '.', '<eos>']
['un', '.', '.', '.', '<eos>']


In [13]:
None != 0

True

<generator object BatchIterator.batch_generator at 0x1103f0d58>