In [1]:
%load_ext autoreload
%autoreload 2

from unsupervised_mt.dataset import Dataset
from unsupervised_mt.train import Trainer 
from unsupervised_mt.models import Embedding, Encoder, DecoderHat, Attention, Discriminator
from unsupervised_mt.batch_iterator import BatchIterator
from unsupervised_mt.utils import log_probs2indices, noise

from functools import partial
import torch
import torch.nn as nn
from tqdm import tqdm_notebook as tqdm


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [2]:
ds = Dataset(languages=('src', 'tgt'), 
             corp_paths=('../data/train.lc.norm.tok.en', '../data/train.lc.norm.tok.fr'), 
             emb_paths=('../data/wiki.multi.en.vec', '../data/wiki.multi.fr.vec'))

In [3]:
batch_iter = BatchIterator(ds)

In [4]:
hidden_size = 100
num_layers = 3

src_embedding = Embedding(ds.emb_matrix['src']).to(device)
tgt_embedding = Embedding(ds.emb_matrix['tgt']).to(device)
encoder_rnn = nn.GRU(input_size=src_embedding.embedding_dim, hidden_size=hidden_size, num_layers=num_layers)
decoder_rnn = nn.GRU(input_size=src_embedding.embedding_dim, hidden_size=hidden_size, num_layers=num_layers)
attention = Attention(src_embedding.embedding_dim, hidden_size, max_length=ds.max_length)
src_hat = DecoderHat(hidden_size, ds.vocabs['src'].size)
tgt_hat = DecoderHat(hidden_size, ds.vocabs['tgt'].size)
discriminator = Discriminator(hidden_size)

trainer = Trainer(partial(ds.translate_batch_word_by_word, language1='src', language2='tgt'), 
                  partial(ds.translate_batch_word_by_word, language1='tgt', language2='src'), 
                  src_embedding, tgt_embedding, encoder_rnn, decoder_rnn, attention, src_hat, tgt_hat, 
                  discriminator, 
                  ds.get_sos_index('src'), ds.get_sos_index('tgt'), 
                  ds.get_eos_index('src'), ds.get_eos_index('tgt'), 
                  ds.get_pad_index('src'), ds.get_pad_index('tgt'), 
                  device)

In [None]:
batch_size = 50
num_steps = 10000

for i in tqdm(range(num_steps)):
    trainer.train_step(batch_iter.load_batch(batch_size))

HBox(children=(IntProgress(value=0, max=10000), HTML(value='')))

In [38]:
batch = batch_iter.load_batch(1)
ds.print_batch(batch['tgt'], 'tgt')
ds.print_batch(
    log_probs2indices(trainer.tgt2src.evaluate(batch['tgt'].to(device), ds.get_sos_index('src'), ds.get_eos_index('src'), 20)), 'src'
)

['un', 'homme', 'savoure', 'un', 'sandwich', 'subway', '.', '<eos>']
['swerves', 'swerves', 'job', 'job', 'job', 'job', 'job', 'job', 'job', 'job', 'urban', 'urban', 'urban', 'urban', 'urban', 'urban', 'urban', 'urban', 'urban', 'urban']
