In [1]:
import re
import collections
import itertools


# http://www.manythings.org/anki/spa-eng.zip
remove_marks_regex = re.compile('[,\.\(\)\[\]\*:;¿¡]|<.*?>')
shift_marks_regex = re.compile('([?!\.])')
unk = 0
sos = 1
eos = 2

def normalize(text):
    text = text.lower()
    text = remove_marks_regex.sub('', text)
    text = shift_marks_regex.sub(r' \1', text)
    return text

def parse_line(line):
    line = normalize(line.strip())
    src, trg = line.split('\t')
    src_tokens = src.strip().split()
    trg_tokens = trg.strip().split()
    return src_tokens, trg_tokens

def build_vocab(tokens):
    counts = collections.Counter(tokens)
    sorted_counts = sorted(counts.items(), key=lambda c: c[1], reverse=True)
    word_list = ['<UNK>', '<SOS>', '<EOS>'] + [x[0] for x in sorted_counts]
    word_dict = dict((w, i) for i, w in enumerate(word_list))
    return word_list, word_dict

def words2tensor(words, word_dict, max_len, padding=0):
    words = words + ['<EOS>']
    words = [word_dict.get(w, 0) for w in words]
    seq_len = len(words)
    if seq_len < max_len + 1:
        words = words + [padding] * (max_len + 1 - seq_len)
    return torch.LongTensor(words), seq_len

In [3]:
from torch.utils.data import Dataset


class TranslationPairDataset(Dataset):
    def __init__(self, path, max_len=15):
        def filter_pair(p):
            return not (len(p[0]) > max_len or len(p[1]) > max_len)
        
        with open(path) as fp:
            pairs = map(parse_line, fp)
            pairs = filter(filter_pair, pairs)
            pairs = list(pairs)
        src = [p[0] for p in pairs]
        trg = [p[1] for p in pairs]
        self.src_word_list, self.src_word_dict = build_vocab(itertools.chain.from_iterable(src))
        self.trg_word_list, self.trg_word_dict = build_vocab(itertools.chain.from_iterable(trg))
        self.src_data = [words2tensor(words, self.src_word_dict, max_len) for words in src]
        self.trg_data = [words2tensor(words, self.trg_word_dict, max_len, -100) for words in trg]
        
    def __len__(self):
        return len(self.src_data)
    
    def __getitem__(self, idx):
        src, lsrc = self.src_data[idx]
        trg, ltrg = self.trg_data[idx]
        return src, lsrc, trg, ltrg

In [6]:
import torch
from torch.utils.data import DataLoader


batch_size = 64
max_len = 10
path = './spa.txt'
ds = TranslationPairDataset(path, max_len=max_len)
loader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=4)

In [8]:
from torch import nn


class Encoder(nn.Module):
    def __init__(self, num_embeddings, embedding_dim=50, hidden_size=50, num_layers=1, dropout=0.2):
        super().__init__()
        self.emb = nn.Embedding(num_embeddings, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(embedding_dim, hidden_size, num_layers, batch_first=True, dropout=dropout)
        
    def forward(self, x, h0=None, l=None):
        x = self.emb(x)
        if l is not None:
            x = nn.utils.rnn.pack_padded_sequence(x, l, batch_first=True)
        _, h = self.lstm(x, h0)
        return h

In [9]:
class Decoder(nn.Module):
    def __init__(self, num_embeddings, embedding_dim=50, hidden_size=50, num_layers=1, dropout=0.2):
        super().__init__()
        self.emb = nn.Embedding(num_embeddings, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(embedding_dim, hidden_size, num_layers, batch_first=True, dropout=dropout)
        self.linear = nn.Linear(hidden_size, num_embeddings)
    
    def forward(self, x, h):
        x = self.emb(x)
        x, h = self.lstm(x, h)
        x = x.view(-1, self.lstm.hidden_size)
        x = self.linear(x)
        return x, h

In [10]:
def translate(input_str, enc, dec, max_len=15):
    words = normalize(input_str).split()
    input_tensor, seq_len = words2tensor(words, ds.src_word_dict, max_len=max_len)
    input_tensor = input_tensor.unsqueeze(0)
    seq_len = [seq_len]
    sos_inputs = torch.LongTensor([sos]).unsqueeze(1)
    ctx = enc(V(input_tensor, volatile=True), l=seq_len)
    z = V(sos_inputs, volatile=True)
    h = ctx
    results = []
    for i in range(max_len):
        o, h = dec(z, h)
        wi = o.data.max(1)[1].view(1)
        if wi[0] == eos:
            break
        results.append(wi[0])
        z = V(wi.view(1, 1), volatile=True)
    return " ".join(ds.trg_word_list[i] for i in results)

In [12]:
from torch.autograd import Variable as V


enc = Encoder(len(ds.src_word_list), 100, 100, 2)
dec = Decoder(len(ds.trg_word_list), 100, 100, 2)
translate('I am a student.', enc, dec)

'decís mocos barrita productivo productivo mentales mente australia recae suyas suyas remate decís canción canción'

In [14]:
from torch import optim


enc = Encoder(len(ds.src_word_list), 100, 100, 1, dropout=0.1)
dec = Decoder(len(ds.trg_word_list), 100, 100, 1, dropout=0.1)
opt_enc = optim.Adam(enc.parameters(), 0.002)
opt_dec = optim.Adam(dec.parameters(), 0.01)
loss_f = nn.CrossEntropyLoss()

In [15]:
from statistics import mean


for epoc in range(10):
    enc.train()
    dec.train()
    losses = []
    for x, lx, y, ly in loader:
        sos_inputs = torch.LongTensor([sos] * len(x)).unsqueeze(1)
        lx, sort_idx = lx.sort(descending=True)
        x, y = x[sort_idx], y[sort_idx]
        x, y = V(x), V(y)
        loss = 0
        ctx = enc(x, l=list(lx))
        z = V(sos_inputs)
        h = ctx
        for i in range(max_len):
            o, h = dec(z, h)
            loss += loss_f(o, y[:, i])
            wi = o.data.max(1)[1].unsqueeze(1)
            z = V(wi)
        enc.zero_grad()
        dec.zero_grad()
        loss.backward()
        opt_enc.step()
        opt_dec.step()
        losses.append(loss.data[0])
    enc.eval()
    dec.eval()
    print('===================================================================================')
    print(epoc, mean(losses))
    print(translate('I am a student.', enc, dec, max_len=max_len))
    print(translate('He likes to eat pizza.', enc, dec, max_len=max_len))
    print(translate('She is my mother.', enc, dec, max_len=max_len))

0 47.26453171384473
soy un
él siempre estar a los
ella es mi mi
1 37.29820858401656
soy un estudiante
a me gusta los
ella es mi madre
2 32.39493524187697
soy estudiante estudiante
a me gusta estar pizza
ella es mi madre
3 29.46315075414156
soy estudiante estudiante
a tom gusta gusta pizza
ella es mi madre
4 27.39850279657772
soy un estudiante
a gusta le gusta las las
ella es mi madre
5 25.88043307731811
soy un estudiante
a tom le gusta pizza pizza
ella es mi madre
6 24.75429555961941
soy un estudiante
a gusta le gusta pizza pizza
ella es mi madre
7 23.910765654320077
soy un estudiante
a no no gusta pizza pizza
ella es mi madre
8 23.180759546875148
soy un estudiante
a gusta me gusta pizza
ella mi mi madre
9 22.614906204616563
soy un estudiante
a gusta le gusta pizza pizza
ella es mi madre
