In [None]:
import random
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchtext as thtext

In [None]:
class Args:
    pass

args = Args()
args.c_size = 1024
args.de_h_size = 1024
args.en_h_size = 1024
args.gpu = 0
args.h_size = 1024
args.max_len = 10
args.n_epochs = 10
args.p_teacher = 0.5
args.s_size = 1024
args.shrink_size = 1024
args.src_e_size = 1024
args.trg_e_size = 1024

In [None]:
if args.gpu < 0:
    cuda = False
    device = th.device('cpu')
else:
    cuda = True
    device = th.device('cuda', args.gpu)
    th.cuda.set_device(args.gpu)

In [None]:
# TODO train/dev/test
sf = thtext.data.Field(tokenize=thtext.data.utils.get_tokenizer('moses'))
tf = thtext.data.Field(tokenize=thtext.data.utils.get_tokenizer('moses'))
wmt14 = thtext.datasets.WMT14('WMT14/newstest2009', ('.de', '.en'), (sf, tf))
sf.build_vocab(wmt14)
tf.build_vocab(wmt14)
sort_key = lambda x: thtext.data.interleave_keys(len(x.src), len(x.trg))
train_loader = thtext.data.BucketIterator(wmt14, batch_size=32, sort_key=sort_key)
sos_token = len(tf.vocab.itos)

In [None]:
class Encoder(nn.Module):
    def __init__(self, v_size, e_size, h_size, c_size):
        super().__init__()
        self.h_size = h_size

        self.embedding = nn.Embedding(v_size, e_size)
        self.gru = nn.GRU(e_size, h_size)
        self.v = nn.Parameter(th.randn(1, h_size, c_size))

    def forward(self, src):
        embeded = self.embedding(src)
        h = th.zeros(1, src.size(1), self.h_size, device=device)
        _, h = self.gru(embeded, h)
        c = th.tanh(th.bmm(h, self.v))
        return c

class Decoder(nn.Module):
    def __init__(self, c_size, h_size, v_size, e_size, s_size, shrinked_size):
        super().__init__()
        self.v = nn.Parameter(th.randn(1, c_size, h_size))
        self.embedding = nn.Embedding(v_size, e_size)
        self.gru = nn.GRU(e_size, h_size)
        self.linear_h0, self.linear_h1 = nn.Linear(h_size, s_size), nn.Linear(h_size, s_size)
        self.linear_y0, self.linear_y1 = nn.Linear(e_size, s_size), nn.Linear(e_size, s_size) # TODO embedded?
        self.linear_c0, self.linear_c1 = nn.Linear(c_size, s_size), nn.Linear(c_size, s_size)
        self.linear_shrink = nn.Linear(s_size, shrinked_size)
        self.linear = nn.Linear(shrinked_size, v_size)

    def forward(self, y, h, c):
        if h is None:
            h = th.tanh(th.bmm(c, self.v))
        y = self.embedding(y)
        _, h = self.gru(y.unsqueeze(0), h)
        h_squeezed = th.squeeze(h)
        c = th.squeeze(c)
        s0 = self.linear_h0(h_squeezed) + self.linear_y0(y) + self.linear_c0(c) # TODO embedded?
        s1 = self.linear_h1(h_squeezed) + self.linear_y1(y) + self.linear_c1(c) # TODO embedded?
        s = th.max(s0, s1)
        y = self.linear(self.linear_shrink(s))
        return y, h

# TODO
class AttnDecoder(nn.Module):
    pass

In [None]:
encoder = Encoder(len(sf.vocab.itos), args.src_e_size, args.en_h_size, args.c_size)
decoder = Decoder(args.c_size, args.de_h_size, len(tf.vocab.itos) + 1, args.trg_e_size, args.s_size, args.shrink_size)
if cuda:
    encoder.cuda()
    decoder.cuda()
en_optim = optim.SGD(encoder.parameters(), 1e-3)
de_optim = optim.SGD(decoder.parameters(), 1e-3)

In [None]:
for i in range(args.n_epochs):
    for j, b in enumerate(train_loader):
        c = encoder(b.src)

        sos = th.full((b.trg.size(1),), sos_token, dtype=th.long, device=device)
        h = None
        objective = 0
        for k, t in enumerate(b.trg):
            if random.random() < args.p_teacher:
                y, h = decoder(b.trg[k - 1] if k > 0 else sos, h, c)
            else:
                y, h = decoder(th.max(y, 1)[1].squeeze() if k > 0 else sos, h, c)
            objective += F.nll_loss(F.log_softmax(y, 1), t)
        objective /= k + 1
        en_optim.zero_grad()
        de_optim.zero_grad()
        objective.backward()
        en_optim.step()
        de_optim.step()
        print('[iteration %d]%f' % (j, objective))        
#         break
        
    print('[epoch %d]%f' % (i, objective))
    break