In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.optim as optim
import itertools

In [4]:
TRAIN_FILE = 'data/ru-be-train.txt'
TEST_FILE = 'data/ru-be-test.txt'

In [5]:
MAX_LENGTH = 30
 
class Alphabet:
    START = '__START__'
    END = '_END_'
 
    def __init__(self, max_length=MAX_LENGTH):
        """Initialize the class which works with letter and index representations of sequences.
        Parameters
        ----------
        max_length : int
            The largest permitted length for sequence. Longer sequences are cropped.
        """
        self.max_length = max_length
        self.letter2index_ = {Alphabet.START : 0, Alphabet.END : 1}
        self.index2letter_ = [Alphabet.START, Alphabet.END]
        
    def get_index(self, letter):
        if letter not in self.letter2index_:
            self.letter2index_[letter] = len(self.index2letter_)
            self.index2letter_.append(letter)
        return self.letter2index_[letter]
    
    @property
    def start_index(self):
        return self.letter2index_[Alphabet.START]
    
    @property
    def end_index(self):
        return self.letter2index_[Alphabet.END]
    
    def index2letter(self, x):
        return ''.join(self.index2letter_[index] for index in x)
    
    def letter2index(self, word):
        lst = [self.get_index(letter) for letter in word]
        return lst[:self.max_length - 1] + [self.get_index(Alphabet.END)] * max(1, self.max_length - len(lst))
    
    def __len__(self):
        return len(self.index2letter_)
    
    # torch utils
    def get_length(self, input_sequence):
        """Infers the lengths of sequences in batch
        
        """
        return (input_sequence == self.end_index).max(dim=1)[1] + 1
    
    def get_mask(self, input_sequence):
        return (torch.cumsum(input_sequence == self.end_index, dim=1) < 2).float()

In [6]:
ru = Alphabet()
be = Alphabet()

In [7]:
def load_pair_dataset(filename, alph1, alph2):
    x, y = [], []
    with open(filename, 'r') as ftr:
        for line in ftr:
            try:
                word1, word2 = line.split()
            except ValueError:
                continue
            x.append(alph1.letter2index(word1))
            y.append(alph2.letter2index(word2))
    return np.array(x), np.array(y)       

In [8]:
train_X, train_Y = load_pair_dataset(TRAIN_FILE, ru, be)

In [9]:
train_X.dtype

dtype('int64')

In [37]:
def seq2seq_softmax_with_mask(entries, mask):
    entries = torch.where(mask > 0.5, entries[:,:,0], -np.inf)
    maxs = entries.max(1, keepdim=True)[0]
    #print(entries.shape, maxs.shape, mask.shape)
    entries = torch.exp(entries - maxs) * mask
    print(entries.max(1))
    return entries / (entries.sum(dim=1, keepdim=True) + 1e-15)


class MultiplicativeAttentionWithMask(nn.Module):
    def __init__(self):
        super(MultiplicativeAttentionWithMask, self).__init__()
        
    def forward(self, decoder_hidden, encoder_hiddens, encoder_mask):
        """
        decoder_hidden: NxH
        encoder_hiddens: NxTxH
        
        """
        weights = torch.bmm(encoder_hiddens, decoder_hidden.unsqueeze(2))
        weights = seq2seq_softmax_with_mask(weights, encoder_mask)
        return torch.bmm(encoder_hiddens.transpose(1, 2), weights.unsqueeze(2))[:,:,0]

In [38]:
class SimpleGRUEncoder(nn.Module):
    def __init__(self, alphabet, embedding_size, hidden_size):
        super(SimpleGRUEncoder, self).__init__()
        self.alphabet = alphabet
        self.embedding = nn.Embedding(num_embeddings=len(self.alphabet), embedding_dim=embedding_size)
        self.gru = nn.GRU(input_size=embedding_size, hidden_size=hidden_size, batch_first=True)
        
    def forward(self, input_sequence):
        batch_size = input_sequence.size(0)
        embeddings = self.embedding(input_sequence)
        out, _ = self.gru(embeddings)
        return out, self.alphabet.get_mask(input_sequence)

In [39]:
class SimpleGRUDecoderWithAttention(nn.Module):
    def __init__(self, alphabet, embedding_size, hidden_size):
        super(SimpleGRUDecoderWithAttention, self).__init__()
        self.alphabet = alphabet
        self.embedding = nn.Embedding(num_embeddings=len(alphabet), embedding_dim=embedding_size)
        self.gru_cell = nn.GRUCell(input_size=embedding_size + hidden_size, hidden_size=hidden_size)
        self.logit_linear = nn.Linear(hidden_size, len(alphabet))
        self.attention = MultiplicativeAttentionWithMask()
        
    def forward(self, token, prev_h, encoder_hs, encoder_mask):
        embedding = self.embedding(token)
        attention = self.attention(prev_h, encoder_hs, encoder_mask)
        #print(attention.shape, embedding.shape)
        h = self.gru_cell(torch.cat((embedding, attention), dim=1), prev_h)
        out = self.logit_linear(h)
        return out, h

In [40]:
class SimpleGRUSupervisedSeq2Seq(nn.Module):
    def __init__(self, src_alphabet, dst_alphabet, embedding_size, hidden_size):
        super(SimpleGRUSupervisedSeq2Seq, self).__init__()
        self.encoder = SimpleGRUEncoder(src_alphabet, embedding_size, hidden_size)
        self.h_linear = nn.Linear(hidden_size, hidden_size)
        self.decoder = SimpleGRUDecoderWithAttention(dst_alphabet, embedding_size, hidden_size)
        
    def start(self, batch_size):
        return Variable(torch.from_numpy(np.repeat(self.decoder.alphabet.start_index, batch_size)))
    
    def middle_layer(self, out, mask):
        #print(mask.sum(1))
        return F.tanh(self.h_linear(out[range(out.shape[0]), mask.sum(1).long() - 1]))
    
    def forward(self, input_sequence, output_sequence):
        enc_out, enc_mask = self.encoder(input_sequence)
        dec_h = self.middle_layer(enc_out, enc_mask)
        logits = []
        for x in itertools.chain((self.start(output_sequence.size(0)),), output_sequence.transpose(0, 1)[:-1]):
            out, dec_h = self.decoder(x, dec_h, enc_out, enc_mask)
            logits.append(out)
        return F.log_softmax(torch.stack(logits, dim=1), dim=-1)
    
    def translate(self, word, strategy='', max_length=30):
        self.eval()
        input_sequence = Variable(torch.from_numpy(np.array([model.encoder.alphabet.letter2index(word)])))
        #print(input_sequence.shape)
        enc_out, enc_mask = self.encoder(input_sequence)
        hidden = self.middle_layer(enc_out, enc_mask)
        token = self.start(1)
        #print(token.shape, hidden.shape)
        lst = []
        for i in range(max_length):
            out, hidden = self.decoder(token, hidden, enc_out, enc_mask)
            token = out.max(1)[1]
            #print(token, out)
            lst.append(token.data[0])
            if token.data[0] == self.decoder.alphabet.end_index:
                break
        return ''.join(model.decoder.alphabet.index2letter(lst))

In [41]:
def batch_iterator(X, Y=None, batch_size=32):
    assert Y is None or X.shape[0] == Y.shape[0]
    ind = np.arange(X.shape[0])
    np.random.shuffle(ind)
    for i in range(0, X.shape[0], batch_size):
        if Y is not None:
            yield X[ind[i:i + batch_size]], Y[ind[i:i + batch_size]]
        else:
            yield X[ind[i:i + batch_size]]

In [42]:
model = SimpleGRUSupervisedSeq2Seq(ru, be, 65, 256)
opt = optim.Adam(model.parameters(), lr=1e-3)

In [43]:
model.state_dict

<bound method Module.state_dict of SimpleGRUSupervisedSeq2Seq(
  (encoder): SimpleGRUEncoder(
    (embedding): Embedding(35, 65)
    (gru): GRU(65, 256, batch_first=True)
  )
  (h_linear): Linear(in_features=256, out_features=256)
  (decoder): SimpleGRUDecoderWithAttention(
    (embedding): Embedding(65, 65)
    (gru_cell): GRUCell(321, 256)
    (logit_linear): Linear(in_features=256, out_features=65)
    (attention): MultiplicativeAttentionWithMask(
    )
  )
)>

In [44]:
def cross_entropy(log_predictions, targets, alphabet):
    """ Cross entropy loss for sequences
    Parameters
    ---------
    log_predictions: Tensor NxTxH
        Log probabilities
    targets: Tensor NxT
        True index-encoded translations
    alphabet: Alphabet
        Alphabet object
    
    """
    length_mask = alphabet.get_mask(targets)
    targets_mask = torch.zeros_like(log_predictions).scatter_(2, targets.view(*targets.shape, 1), 1.0)
    mask = targets_mask * length_mask.view(*length_mask.shape, 1)
    #print(mask.sum(1, keepdim=True).sum(2, keepdim=True))
    return (log_predictions * mask / (mask.sum(2, keepdim=True).sum(1, keepdim=True) * -log_predictions.size(0))).sum()
 
cur_loss = 0
model.train()
for epoch in range(10):
    for i, (x, y) in enumerate(batch_iterator(train_X, train_Y)):
        inputs = Variable(torch.from_numpy(x))
        targets = Variable(torch.from_numpy(y))
        log_predictions = model(inputs, targets)
        #print(x)
        loss = cross_entropy(log_predictions, targets, be)
        #print(loss.data, log_predictions.data.min())
        loss.backward()
        cur_loss = 0.9 * cur_loss + 0.1 * loss.data[0]
        opt.step()
        opt.zero_grad()
        if i % 50 == 0:
            print(i, cur_loss)
        #break
    #break

AttributeError: module 'torch' has no attribute 'where'

In [None]:
model.translate("зеленый")

In [16]:
def translate(model, word, strategy='', max_length=30):

    

SyntaxError: unexpected EOF while parsing (<ipython-input-16-c58232b29394>, line 3)

In [None]:
translate(model, "")

In [29]:
from tqdm import tqdm_notebook
import editdistance as ed


scs = []
with open(TEST_FILE, "r") as ftr:
    for ruw, bew in map(lambda x: x.split(), 
                        filter(lambda x: len(x.split()) == 2, tqdm_notebook(ftr.readlines()[:10]))):
        res = model.translate(ruw)
        scs.append(ed.eval(bew, res[:-5]))
        print(ruw, bew, res)
    
        

указателя паказальніка указаталя_END_
развитые развітыя развітыя_END_
томский томскі томскі_END_
софией сафіяй сафіей_END_
иски пазовы іскі_END_
сдаться здацца сдатацца_END_
дюка дзюка дюка_END_
врожд врожд врожд_END_
груз груз груз_END_
стартовом стартавым стартавам_END_



In [27]:
np.mean(scs)

2.4730290456431536

In [35]:
! head -n 50 data/ru-be.txt

поражением	паразай
местному	мясцовым
испанских	іспанскіх
сорока	сарока
способна	здольная
факторы	фактары
высадки	высадкі
аргентина	аргенціна
феликс	фелікс
финальном	фінальным
фактором	фактарам
оригинального	арыгінальнага
древнейших	найстаражытных
художественным	мастацкім
резиденции	рэзідэнцыі
кораблях	караблях
православие	праваслаўе
мл	мл
гражданский	грамадзянскі
транспортного	транспартнага
листьями	лісцем
поврежд	поврежд
выехал	выехаў
сицилии	сіцыліі
чтение	чытанне
повсеместно	паўсюдна
украшения	ўпрыгажэнні
передали	перадалі
проявления	праявы
вокалист	вакаліст
склад	склад
уголовное	крымінальная
работало	працавала
ушел	сышоў
запросов	запытаў
спектакля	спектакля
продать	прадаць
сл	сл
ленинского	ленінскага
землями	землямі
свидетелей	сведак
существующей	існуючай
частного	прыватнага
образцу	узоры
отмечали	адзначалі
воздействием	уздзеяннем
категорически	катэгарычна
дальнейшие	далейшыя
сообщается	паведамляецца
периоде	перыядзе
