In [15]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.optim as optim
import itertools

In [16]:
TRAIN_FILE = 'data/ru-be-train.txt'
TEST_FILE = 'data/ru-be-test.txt'

In [17]:
MAX_LENGTH = 30
 
class Alphabet:
    START = '__START__'
    END = '_END_'
 
    def __init__(self, max_length=MAX_LENGTH):
        """Initialize the class which works with letter and index representations of sequences.
        Parameters
        ----------
        max_length : int
            The largest permitted length for sequence. Longer sequences are cropped.
        """
        self.max_length = max_length
        self.letter2index_ = {Alphabet.START : 0, Alphabet.END : 1}
        self.index2letter_ = [Alphabet.START, Alphabet.END]
        
    def get_index(self, letter):
        if letter not in self.letter2index_:
            self.letter2index_[letter] = len(self.index2letter_)
            self.index2letter_.append(letter)
        return self.letter2index_[letter]
    
    @property
    def start_index(self):
        return self.letter2index_[Alphabet.START]
    
    @property
    def end_index(self):
        return self.letter2index_[Alphabet.END]
    
    def index2letter(self, x):
        return ''.join(self.index2letter_[index] for index in x)
    
    def letter2index(self, word):
        lst = [self.get_index(letter) for letter in word]
        return lst[:self.max_length - 1] + [self.get_index(Alphabet.END)] * max(1, self.max_length - len(lst))
    
    def __len__(self):
        return len(self.index2letter_)
    
    # torch utils
    def get_length(self, input_sequence):
        """Infers the lengths of sequences in batch
        
        """
        return (input_sequence == self.end_index).max(dim=1)[1] + 1
    
    def get_mask(self, input_sequence):
        return (torch.cumsum(input_sequence == self.end_index, dim=1) < 2).float()

In [18]:
ru = Alphabet()
be = Alphabet()

In [19]:
def load_pair_dataset(filename, alph1, alph2):
    x, y = [], []
    with open(filename, 'r') as ftr:
        for line in ftr:
            try:
                word1, word2 = line.split()
            except ValueError:
                continue
            x.append(alph1.letter2index(word1))
            y.append(alph2.letter2index(word2))
    return np.array(x), np.array(y)       

In [20]:
train_X, train_Y = load_pair_dataset(TRAIN_FILE, ru, be)

In [21]:
train_X.dtype

dtype('int64')

In [122]:
class MultiplicativeAttention(nn.Module):
    def __init__(self):
        super(MultiplicativeAttention, self).__init__()
        
    def forward(self, decoder_hidden, encoder_hiddens):
        weights = torch.bmm(encoder_hiddens, decoder_hidden.view(*decoder_hidden.shape, 1))
        weights = F.softmax(Variable(weights), dim=1)
        return torch.bmm(encoder_hiddens.transpose(1, 2), weights.data)[:,:,0]

In [123]:
attention = MultiplicativeAttention()

In [124]:
decoder_hidden = torch.arange(6).view(3, 2)
encoder_hiddens = torch.arange(12).view(3, 2, 2)

In [125]:
attention(decoder_hidden, encoder_hiddens).shape

torch.Size([3, 2])

In [22]:
class SimpleGRUEncoder(nn.Module):
    def __init__(self, alphabet, embedding_size, hidden_size):
        super(SimpleGRUEncoder, self).__init__()
        self.alphabet = alphabet
        self.embedding = nn.Embedding(num_embeddings=len(self.alphabet), embedding_dim=embedding_size)
        self.gru = nn.GRU(input_size=embedding_size, hidden_size=hidden_size, batch_first=True)
        
    def forward(self, input_sequence):
        batch_size = input_sequence.size(0)
        embeddings = self.embedding(input_sequence)
        out, _ = self.gru(embeddings)
        return out[range(batch_size), self.alphabet.get_length(input_sequence) - 1]

In [23]:
class SimpleGRUDecoder(nn.Module):
    def __init__(self, alphabet, embedding_size, hidden_size):
        super(SimpleGRUDecoder, self).__init__()
        self.alphabet = alphabet
        self.embedding = nn.Embedding(num_embeddings=len(alphabet), embedding_dim=embedding_size)
        self.gru_cell = nn.GRUCell(input_size=embedding_size, hidden_size=hidden_size)
        self.logit_linear = nn.Linear(hidden_size, len(alphabet))
        
    def forward(self, token, prev_h):
        embedding = self.embedding(token)
        h = self.gru_cell(embedding, prev_h)
        out = self.logit_linear(h)
        return out, h

In [24]:
class SimpleGRUSupervisedSeq2Seq(nn.Module):
    def __init__(self, src_alphabet, dst_alphabet, embedding_size, hidden_size):
        super(SimpleGRUSupervisedSeq2Seq, self).__init__()
        self.encoder = SimpleGRUEncoder(src_alphabet, embedding_size, hidden_size)
        self.h_linear = nn.Linear(hidden_size, hidden_size)
        self.decoder = SimpleGRUDecoder(dst_alphabet, embedding_size, hidden_size)
        
    def start(self, batch_size):
        return Variable(torch.from_numpy(np.repeat(self.decoder.alphabet.start_index, batch_size)))
    
    def forward(self, input_sequence, output_sequence):
        enc_h = self.encoder(input_sequence)
        dec_h = F.tanh(self.h_linear(enc_h))
        logits = []
        for x in itertools.chain((self.start(output_sequence.size(0)),), output_sequence.transpose(0, 1)[:-1]):
            out, dec_h = self.decoder(x, dec_h)
            logits.append(out)
        return F.log_softmax(torch.stack(logits, dim=1), dim=-1)
    
    def translate(self, word, strategy='', max_length=30):
        self.eval()
        input_sequence = Variable(torch.from_numpy(np.array([self.encoder.alphabet.letter2index(word)])))
        #print(input_sequence.shape)
        hidden = F.tanh(self.h_linear(self.encoder(input_sequence)))
        token = self.start(1)
        #print(token.shape, hidden.shape)
        lst = []
        for i in range(10):
            out, hidden = self.decoder(hidden, hidden)
            token = out.max(1)[1]
            #print(token, out)
            lst.append(token.data[0])
            if token.data[0] == self.decoder.alphabet.end_index:
                break
        return ''.join(self.decoder.alphabet.index2letter(lst))

In [57]:
def batch_iterator(X, Y=None, batch_size=32):
    assert Y is None or X.shape[0] == Y.shape[0]
    ind = np.arange(X.shape[0])
    np.random.shuffle(ind)
    for i in range(0, X.shape[0], batch_size):
        if Y is not None:
            yield X[ind[i:i + batch_size]], Y[ind[i:i + batch_size]]
        else:
            yield X[ind[i:i + batch_size]]

In [82]:
model = SimpleGRUSupervisedSeq2Seq(ru, be, 24, 256)
opt = optim.Adam(model.parameters(), lr=1e-3)

In [83]:
def cross_entropy(log_predictions, targets, alphabet):
    """ Cross entropy loss for sequences
    Parameters
    ---------
    log_predictions: Tensor NxTxH
        Log probabilities
    targets: Tensor NxT
        True index-encoded translations
    alphabet: Alphabet
        Alphabet object
    
    """
    length_mask = alphabet.get_mask(targets)
    targets_mask = torch.zeros_like(log_predictions).scatter_(2, targets.view(*targets.shape, 1), 1.0)
    mask = targets_mask * length_mask.view(*length_mask.shape, 1)
    #print(mask.sum(1, keepdim=True).sum(2, keepdim=True))
    return (log_predictions * mask / (mask.sum(2, keepdim=True).sum(1, keepdim=True) * -log_predictions.size(0))).sum()
 
cur_loss = 0
model.train()
for epoch in range(10):
    for i, (x, y) in enumerate(batch_iterator(train_X, train_Y)):
        inputs = Variable(torch.from_numpy(x))
        targets = Variable(torch.from_numpy(y))
        log_predictions = model(inputs, targets)
        #print(x)
        loss = cross_entropy(log_predictions, targets, be)
        #print(loss.data, log_predictions.data.min())
        loss.backward()
        cur_loss = 0.9 * cur_loss + 0.1 * loss.data[0]
        opt.step()
        opt.zero_grad()
        if i % 50 == 0:
            print(i, cur_loss)

0 0.4190867424011231
50 2.9168739572860716
100 2.631720038443434
150 2.4751932120122664
200 2.3549136442284455
250 2.2292583443003786
300 2.153186160866823
350 2.0728946457414272
400 2.007772875134808
450 1.8948324387015085
500 1.8464238839206966
550 1.7897904365669128
600 1.7117181047641865
650 1.682086088904152
700 1.596041369400188
750 1.5873877769618059
800 1.532857735104436
850 1.4270432315501707
900 1.3974845777059077
950 1.394875553452432
1000 1.3855915241216377
1050 1.2881448232991752
1100 1.231422457554406
1150 1.282550341639471
1200 1.1620309995468996
0 1.2193324077870857
50 1.1964899299263552
100 1.1416526493590347
150 1.1427525722214957
200 1.1290796344589427
250 1.1790902400506809
300 1.06338651959006
350 1.083659558269009
400 1.0299134590515389
450 1.0181521814348087
500 1.0280448702756815
550 0.9869868898738503
600 0.9744745563684433
650 0.9919636562126649
700 1.0076098636803814
750 1.0001313761563295
800 0.9604946596105047
850 0.9231747298422102
900 0.9051967137445381
9

KeyboardInterrupt: 

In [54]:
def translate(model, word, strategy='', max_length=30):
    model.eval()
    input_sequence = Variable(torch.from_numpy(np.array([model.encoder.alphabet.letter2index(word)])))
    #print(input_sequence.shape)
    hidden = F.tanh(model.h_linear(model.encoder(input_sequence)))
    token = model.start(1)
    #print(token.shape, hidden.shape)
    lst = []
    for i in range(max_length):
        out, hidden = model.decoder(token, hidden)
        token = out.max(1)[1]
        #print(token, out)
        lst.append(token.data[0])
        if token.data[0] == model.decoder.alphabet.end_index:
            break
    return ''.join(model.decoder.alphabet.index2letter(lst))
    

In [55]:
translate(model, "")

'_END_'

In [86]:
from tqdm import tqdm_notebook
import editdistance as ed


scs = []
with open(TEST_FILE, "r") as ftr:
    for ruw, bew in map(lambda x: x.split(), 
                        filter(lambda x: len(x.split()) == 2, tqdm_notebook(ftr.readlines()))):
        res = translate(model, ruw)
        scs.append(ed.eval(bew, res[:-5]))
        #print(ruw, bew, res)
        
        




In [87]:
np.mean(scs)

1.9062119983785974

In [35]:
! head -n 50 data/ru-be.txt

поражением	паразай
местному	мясцовым
испанских	іспанскіх
сорока	сарока
способна	здольная
факторы	фактары
высадки	высадкі
аргентина	аргенціна
феликс	фелікс
финальном	фінальным
фактором	фактарам
оригинального	арыгінальнага
древнейших	найстаражытных
художественным	мастацкім
резиденции	рэзідэнцыі
кораблях	караблях
православие	праваслаўе
мл	мл
гражданский	грамадзянскі
транспортного	транспартнага
листьями	лісцем
поврежд	поврежд
выехал	выехаў
сицилии	сіцыліі
чтение	чытанне
повсеместно	паўсюдна
украшения	ўпрыгажэнні
передали	перадалі
проявления	праявы
вокалист	вакаліст
склад	склад
уголовное	крымінальная
работало	працавала
ушел	сышоў
запросов	запытаў
спектакля	спектакля
продать	прадаць
сл	сл
ленинского	ленінскага
землями	землямі
свидетелей	сведак
существующей	існуючай
частного	прыватнага
образцу	узоры
отмечали	адзначалі
воздействием	уздзеяннем
категорически	катэгарычна
дальнейшие	далейшыя
сообщается	паведамляецца
периоде	перыядзе
