In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.optim as optim
import itertools

In [2]:
TRAIN_FILE = 'data/ru-be-train.txt'
TEST_FILE = 'data/ru-be-test.txt'

In [3]:
MAX_LENGTH = 30
 
class Alphabet:
    START = '__START__'
    END = '_END_'
 
    def __init__(self, max_length=MAX_LENGTH):
        """Initialize the class which works with letter and index representations of sequences.
        Parameters
        ----------
        max_length : int
            The largest permitted length for sequence. Longer sequences are cropped.
        """
        self.max_length = max_length
        self.letter2index_ = {Alphabet.START : 0, Alphabet.END : 1}
        self.index2letter_ = [Alphabet.START, Alphabet.END]
        
    def get_index(self, letter):
        if letter not in self.letter2index_:
            self.letter2index_[letter] = len(self.index2letter_)
            self.index2letter_.append(letter)
        return self.letter2index_[letter]
    
    @property
    def start_index(self):
        return self.letter2index_[Alphabet.START]
    
    @property
    def end_index(self):
        return self.letter2index_[Alphabet.END]
    
    def index2letter(self, x):
        return ''.join(self.index2letter_[index] for index in x)
    
    def letter2index(self, word):
        lst = [self.get_index(letter) for letter in word]
        return lst[:self.max_length - 1] + [self.get_index(Alphabet.END)] * max(1, self.max_length - len(lst))
    
    def __len__(self):
        return len(self.index2letter_)
    
    # torch utils
    def get_length(self, input_sequence):
        """Infers the lengths of sequences in batch
        
        """
        return (input_sequence == self.end_index).max(dim=1)[1] + 1
    
    def get_mask(self, input_sequence):
        return (torch.cumsum(input_sequence == self.end_index, dim=1) < 2).float()

In [4]:
ru = Alphabet()
be = Alphabet()

In [5]:
def load_pair_dataset(filename, alph1, alph2):
    x, y = [], []
    with open(filename, 'r') as ftr:
        for line in ftr:
            try:
                word1, word2 = line.split()
            except ValueError:
                continue
            x.append(alph1.letter2index(word1))
            y.append(alph2.letter2index(word2))
    return np.array(x), np.array(y)       

In [6]:
train_X, train_Y = load_pair_dataset(TRAIN_FILE, ru, be)

In [7]:
train_X.dtype

dtype('int64')

In [8]:
def seq2seq_softmax_with_mask(entries, mask):
    entries = entries[:,:,0]
    maxs = entries.max(1, keepdim=True)[0]
    #print(entries.shape, maxs.shape, mask.shape)
    entries = torch.exp(entries - maxs) * mask
    return entries / (entries.sum(dim=1, keepdim=True) + 1e-15)


class MultiplicativeAttentionWithMask(nn.Module):
    def __init__(self):
        super(MultiplicativeAttentionWithMask, self).__init__()
        
    def forward(self, decoder_hidden, encoder_hiddens, encoder_mask):
        #print('#', decoder_hidden.unsqueeze(2).shape, encoder_hiddens.shape)
        weights = torch.bmm(encoder_hiddens, decoder_hidden.unsqueeze(2))
        weights = seq2seq_softmax_with_mask(weights, encoder_mask)
        return torch.bmm(encoder_hiddens.transpose(1, 2), weights.unsqueeze(2))[:,:,0]

In [9]:
class SimpleGRUEncoder(nn.Module):
    def __init__(self, alphabet, embedding_size, hidden_size):
        super(SimpleGRUEncoder, self).__init__()
        self.alphabet = alphabet
        self.embedding = nn.Embedding(num_embeddings=len(self.alphabet), embedding_dim=embedding_size)
        self.gru = nn.GRU(input_size=embedding_size, hidden_size=hidden_size, batch_first=True)
        
    def forward(self, input_sequence):
        batch_size = input_sequence.size(0)
        embeddings = self.embedding(input_sequence)
        out, _ = self.gru(embeddings)
        return out, self.alphabet.get_mask(input_sequence)

In [10]:
class SimpleGRUDecoderWithAttention(nn.Module):
    def __init__(self, alphabet, embedding_size, hidden_size):
        super(SimpleGRUDecoderWithAttention, self).__init__()
        self.alphabet = alphabet
        self.embedding = nn.Embedding(num_embeddings=len(alphabet), embedding_dim=embedding_size)
        self.gru_cell = nn.GRUCell(input_size=embedding_size, hidden_size=hidden_size)
        self.logit_linear = nn.Linear(hidden_size, len(alphabet))
        self.attention = MultiplicativeAttentionWithMask()
        
    def forward(self, token, prev_h, encoder_hs, encoder_mask):
        embedding = self.embedding(token)
        attention = self.attention(prev_h, encoder_hs, encoder_mask)
        #h = self.gru_cell(torch.cat((embedding, attention), dim=1), prev_h)
        h = self.gru_cell(embedding, prev_h)
        out = self.logit_linear(h)
        return out, h

In [11]:
class SimpleGRUSupervisedSeq2Seq(nn.Module):
    def __init__(self, src_alphabet, dst_alphabet, embedding_size, hidden_size):
        super(SimpleGRUSupervisedSeq2Seq, self).__init__()
        self.encoder = SimpleGRUEncoder(src_alphabet, embedding_size, hidden_size)
        self.h_linear = nn.Linear(hidden_size, hidden_size)
        self.decoder = SimpleGRUDecoderWithAttention(dst_alphabet, embedding_size, hidden_size)
        
    def start(self, batch_size):
        return Variable(torch.from_numpy(np.repeat(self.decoder.alphabet.start_index, batch_size)))
    
    def middle_layer(self, out, mask):
        #print(mask.sum(1))
        return F.tanh(self.h_linear(out[range(out.shape[0]), mask.sum(1).long() - 1]))
    
    def forward(self, input_sequence, output_sequence):
        enc_out, enc_mask = self.encoder(input_sequence)
        dec_h = self.middle_layer(enc_out, enc_mask)
        logits = []
        for x in itertools.chain((self.start(output_sequence.size(0)),), output_sequence.transpose(0, 1)[:-1]):
            out, dec_h = self.decoder(x, dec_h, enc_out, enc_mask)
            logits.append(out)
        return F.log_softmax(torch.stack(logits, dim=1), dim=-1)
    
    def translate(self, word, strategy='', max_length=30):
        self.eval()
        input_sequence = Variable(torch.from_numpy(np.array([model.encoder.alphabet.letter2index(word)])))
        #print(input_sequence.shape)
        enc_out, enc_mask = self.encoder(input_sequence)
        hidden = self.middle_layer(enc_out, enc_mask)
        token = self.start(1)
        #print(token.shape, hidden.shape)
        lst = []
        for i in range(max_length):
            out, hidden = self.decoder(token, hidden, enc_out, enc_mask)
            token = out.max(1)[1]
            #print(token, out)
            lst.append(token.data[0])
            if token.data[0] == self.decoder.alphabet.end_index:
                break
        return ''.join(model.decoder.alphabet.index2letter(lst))

In [12]:
def batch_iterator(X, Y=None, batch_size=32):
    assert Y is None or X.shape[0] == Y.shape[0]
    ind = np.arange(X.shape[0])
    np.random.shuffle(ind)
    for i in range(0, X.shape[0], batch_size):
        if Y is not None:
            yield X[ind[i:i + batch_size]], Y[ind[i:i + batch_size]]
        else:
            yield X[ind[i:i + batch_size]]

In [17]:
model = SimpleGRUSupervisedSeq2Seq(ru, be, 65, 256)
opt = optim.Adam(model.parameters(), lr=1e-3)

In [18]:
def cross_entropy(log_predictions, targets, alphabet):
    """ Cross entropy loss for sequences
    Parameters
    ---------
    log_predictions: Tensor NxTxH
        Log probabilities
    targets: Tensor NxT
        True index-encoded translations
    alphabet: Alphabet
        Alphabet object
    
    """
    length_mask = alphabet.get_mask(targets)
    targets_mask = torch.zeros_like(log_predictions).scatter_(2, targets.view(*targets.shape, 1), 1.0)
    mask = targets_mask * length_mask.view(*length_mask.shape, 1)
    #print(mask.sum(1, keepdim=True).sum(2, keepdim=True))
    return (log_predictions * mask / (mask.sum(2, keepdim=True).sum(1, keepdim=True) * -log_predictions.size(0))).sum()
 
cur_loss = 0
model.train()
for epoch in range(10):
    for i, (x, y) in enumerate(batch_iterator(train_X, train_Y)):
        inputs = Variable(torch.from_numpy(x))
        targets = Variable(torch.from_numpy(y))
        log_predictions = model(inputs, targets)
        #print(x)
        loss = cross_entropy(log_predictions, targets, be)
        #print(loss.data, log_predictions.data.min())
        loss.backward()
        cur_loss = 0.9 * cur_loss + 0.1 * loss.data[0]
        opt.step()
        opt.zero_grad()
        if i % 50 == 0:
            print(i, cur_loss)

0 0.4200156688690186
50 2.7578244786646855
100 2.476875003504594
150 2.2598716689577083
200 2.1065707596336085
250 2.01535555634957
300 1.876824888887845
350 1.7951637667744584
400 1.6633113864847724
450 1.6105036111570512
500 1.5305022972919486
550 1.5100637966351476
600 1.408566580184508
650 1.3746855779744651
700 1.3140967167384212
750 1.362946494030467
800 1.2373363631253536
850 1.2333869452152861
900 1.2481400593116936
950 1.1982124097484257
1000 1.1443143058662386
1050 1.157661315740318
1100 1.0945232432196785
1150 1.124057089954417
1200 1.0943175566566314
0 1.08509125660035
50 1.0140618514667548
100 0.9940563239478807
150 0.9764196344233472
200 0.9075055310552524
250 0.9224959634548793
300 0.9859912421201463
350 0.8751660133411552
400 0.9414093511153143
450 0.878290504321146
500 0.877956566902798
550 0.8428007214933687
600 0.8380512539265206
650 0.8925832643114691
700 0.7932154460600708
750 0.8695035853252
800 0.8293131378794679
850 0.818125127833568
900 0.816307650533187
950 0.

KeyboardInterrupt: 

In [19]:
model.translate("зеленый")

'зялёны_END_'

In [16]:
def translate(model, word, strategy='', max_length=30):

    

SyntaxError: unexpected EOF while parsing (<ipython-input-16-c58232b29394>, line 3)

In [None]:
translate(model, "")

In [22]:
from tqdm import tqdm_notebook
import editdistance as ed


scs = []
with open(TEST_FILE, "r") as ftr:
    for ruw, bew in map(lambda x: x.split(), 
                        filter(lambda x: len(x.split()) == 2, tqdm_notebook(ftr.readlines()))):
        res = model.translate(ruw)
        scs.append(ed.eval(bew, res[:-5]))
        #print(ruw, bew, res)
    
        




In [27]:
np.mean(scs)

2.4730290456431536

In [35]:
! head -n 50 data/ru-be.txt

поражением	паразай
местному	мясцовым
испанских	іспанскіх
сорока	сарока
способна	здольная
факторы	фактары
высадки	высадкі
аргентина	аргенціна
феликс	фелікс
финальном	фінальным
фактором	фактарам
оригинального	арыгінальнага
древнейших	найстаражытных
художественным	мастацкім
резиденции	рэзідэнцыі
кораблях	караблях
православие	праваслаўе
мл	мл
гражданский	грамадзянскі
транспортного	транспартнага
листьями	лісцем
поврежд	поврежд
выехал	выехаў
сицилии	сіцыліі
чтение	чытанне
повсеместно	паўсюдна
украшения	ўпрыгажэнні
передали	перадалі
проявления	праявы
вокалист	вакаліст
склад	склад
уголовное	крымінальная
работало	працавала
ушел	сышоў
запросов	запытаў
спектакля	спектакля
продать	прадаць
сл	сл
ленинского	ленінскага
землями	землямі
свидетелей	сведак
существующей	існуючай
частного	прыватнага
образцу	узоры
отмечали	адзначалі
воздействием	уздзеяннем
категорически	катэгарычна
дальнейшие	далейшыя
сообщается	паведамляецца
периоде	перыядзе
