In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.optim as optim
import itertools

In [36]:
TRAIN_FILE = 'data/ru-be-train.txt'
TEST_FILE = 'data/ru-be-test.txt'

In [3]:
MAX_LENGTH = 30
 
class Alphabet:
    START = '__START__'
    END = '_END_'
 
    def __init__(self, max_length=MAX_LENGTH):
        self.max_length = max_length
        self.letter2index_ = {Alphabet.START : 0, Alphabet.END : 1}
        self.index2letter_ = [Alphabet.START, Alphabet.END]
        
    def get_index(self, letter):
        if letter not in self.letter2index_:
            self.letter2index_[letter] = len(self.index2letter_)
            self.index2letter_.append(letter)
        return self.letter2index_[letter]
    
    @property
    def start_index(self):
        return self.letter2index_[Alphabet.START]
    
    @property
    def end_index(self):
        return self.letter2index_[Alphabet.END]
    
    def index2letter(self, x):
        return ''.join(self.index2letter_[index] for index in x)
    
    def letter2index(self, word):
        lst = [self.get_index(letter) for letter in word]
        return lst[:self.max_length - 1] + [self.get_index(Alphabet.END)] * max(1, self.max_length - len(lst))
    
    def __len__(self):
        return len(self.index2letter_)
    
    # torch utils
    def get_length(self, input_sequence):
        return (input_sequence == self.end_index).max(dim=1)[1] + 1
    
    def get_mask(self, input_sequence):
        return (torch.cumsum(input_sequence == self.end_index, dim=1) < 2).float()

In [4]:
ru = Alphabet()
be = Alphabet()

In [5]:
def load_pair_dataset(filename, alph1, alph2):
    x, y = [], []
    with open(filename, 'r') as ftr:
        for line in ftr:
            try:
                word1, word2 = line.split()
            except ValueError:
                continue
            x.append(alph1.letter2index(word1))
            y.append(alph2.letter2index(word2))
    return np.array(x), np.array(y)       

In [6]:
train_X, train_Y = load_pair_dataset(TRAIN_FILE, ru, be)

In [7]:
train_X.dtype

dtype('int64')

In [8]:
class SimpleGRUEncoder(nn.Module):
    def __init__(self, alphabet, embedding_size, hidden_size):
        super(SimpleGRUEncoder, self).__init__()
        self.alphabet = alphabet
        self.embedding = nn.Embedding(num_embeddings=len(self.alphabet), embedding_dim=embedding_size)
        self.gru = nn.GRU(input_size=embedding_size, hidden_size=hidden_size, batch_first=True)
        
    def forward(self, input_sequence):
        batch_size = input_sequence.size(0)
        embeddings = self.embedding(input_sequence)
        out, _ = self.gru(embeddings)
        return out[range(batch_size), self.alphabet.get_length(input_sequence) - 1]

In [9]:
class SimpleGRUDecoder(nn.Module):
    def __init__(self, alphabet, embedding_size, hidden_size):
        super(SimpleGRUDecoder, self).__init__()
        self.alphabet = alphabet
        self.embedding = nn.Embedding(num_embeddings=len(alphabet), embedding_dim=embedding_size)
        self.gru_cell = nn.GRUCell(input_size=embedding_size, hidden_size=hidden_size)
        self.logit_linear = nn.Linear(hidden_size, len(alphabet))
        
    def forward(self, token, prev_h):
        embedding = self.embedding(token)
        h = self.gru_cell(embedding, prev_h)
        out = self.logit_linear(h)
        return out, h

In [10]:
class SimpleGRUSupervisedSeq2Seq(nn.Module):
    def __init__(self, src_alphabet, dst_alphabet, embedding_size, hidden_size):
        super(SimpleGRUSupervisedSeq2Seq, self).__init__()
        self.encoder = SimpleGRUEncoder(src_alphabet, embedding_size, hidden_size)
        self.h_linear = nn.Linear(hidden_size, hidden_size)
        self.decoder = SimpleGRUDecoder(dst_alphabet, embedding_size, hidden_size)
        
    def start(self, batch_size):
        return Variable(torch.from_numpy(np.repeat(self.decoder.alphabet.start_index, batch_size)))
    
    def forward(self, input_sequence, output_sequence):
        enc_h = self.encoder(input_sequence)
        dec_h = F.tanh(self.h_linear(enc_h))
        logits = []
        for x in itertools.chain((self.start(output_sequence.size(0)),), output_sequence.transpose(0, 1)[:-1]):
            out, dec_h = self.decoder(x, dec_h)
            logits.append(out)
        return F.log_softmax(torch.stack(logits, dim=1), dim=-1)
    
    def translate(self, word, strategy='', max_length=30):
        self.eval()
        input_sequence = Variable(torch.from_numpy(np.array([self.encoder.alphabet.letter2index(word)])))
        #print(input_sequence.shape)
        hidden = F.tanh(self.h_linear(self.encoder(input_sequence)))
        token = self.start(1)
        #print(token.shape, hidden.shape)
        lst = []
        for i in range(10):
            out, hidden = self.decoder(hidden, hidden)
            token = out.max(1)[1]
            #print(token, out)
            lst.append(token.data[0])
            if token.data[0] == self.decoder.alphabet.end_index:
                break
        return ''.join(self.decoder.alphabet.index2letter(lst))

In [11]:
def batch_iterator(X, Y=None, batch_size=32):
    assert Y is None or X.shape[0] == Y.shape[0]
    for i in range(0, X.shape[0], batch_size):
        if Y is not None:
            yield X[i:i + batch_size], Y[i:i + batch_size]
        else:
            yield X[i:i + batch_size]

In [12]:
model = SimpleGRUSupervisedSeq2Seq(ru, be, 20, 128)
opt = optim.Adam(model.parameters(), lr=1e-3)

In [13]:
def cross_entropy(log_predictions, targets, alphabet):
    length_mask = alphabet.get_mask(targets)
    targets_mask = torch.zeros_like(log_predictions).scatter_(2, targets.view(*targets.shape, 1), 1.0)
    mask = targets_mask * length_mask.view(*length_mask.shape, 1)
    return (log_predictions * mask / (mask.sum(1, keepdim=True).sum(2, keepdim=True) * -log_predictions.size(0))).sum()
 
cur_loss = 0
model.train()
for epoch in range(5):
    for i, (x, y) in enumerate(batch_iterator(train_X, train_Y)):
        inputs = Variable(torch.from_numpy(x))
        targets = Variable(torch.from_numpy(y))
        log_predictions = model(inputs, targets)
        loss = cross_entropy(log_predictions, targets, be)
        #print(loss.data, log_predictions.data.min())
        loss.backward()
        cur_loss = 0.9 * cur_loss + 0.1 * loss.data[0]
        opt.step()
        opt.zero_grad()
        if i % 50 == 0:
            print(i, cur_loss)

0 0.4267039775848389
50 3.017125253379641
100 2.8566916491802283
150 2.6569024035058244
200 2.5558981824584874
250 2.4629374067439977
300 2.448218905982744
350 2.408644325049544
400 2.3337143830146885
450 2.2942711534909717
500 2.2577187371134197
550 2.225521624412744
600 2.191419548497916
650 2.1366446513904354
700 2.096031303106419
750 2.0582389565804435
800 1.9783964410218875
850 1.9377282285932376
900 1.903832688812762
950 1.8791217350708709
1000 1.8393332935364415
1050 1.817865797141295
1100 1.7627719991147452
1150 1.7399911065736364
1200 1.7073639922944785
1250 1.6991308142751853
1300 1.631982261550076
1350 1.646041470315422
1400 1.6432528764843082
1450 1.5972227410757223
1500 1.6097965745619114
1550 1.5725634690439907
1600 1.4905670423466602
1650 1.4828901499001503
1700 1.5043077476634845
1750 1.4790364644010026
1800 1.384018368218199
1850 1.3745882533296847
0 1.3620776218242419
50 1.3671799249667576
100 1.3361650960330826
150 1.3489249794967286
200 1.34003538274784
250 1.277185

In [14]:
x

array([[ 5,  4, 15,  9, 12,  7, 14, 12,  3,  4, 18,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
       [12,  7,  8,  8,  9, 11, 13,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
       [25,  9,  4,  3, 14,  5, 28,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
       [11, 10,  3, 21, 13, 12,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
       [10,  9,  4,  8, 18, 15,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
       [ 2,  4,  3, 12,  3, 12,  9,  2,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
       [ 5, 20,  5, 10,  5,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
       [29, 12,  3, 16,  1,  1,  1,  1,  

In [30]:
def translate(model, word, strategy='', max_length=30):
    model.eval()
    input_sequence = Variable(torch.from_numpy(np.array([model.encoder.alphabet.letter2index(word)])))
    #print(input_sequence.shape)
    hidden = F.tanh(model.h_linear(model.encoder(input_sequence)))
    token = model.start(1)
    #print(token.shape, hidden.shape)
    lst = []
    for i in range(10):
        out, hidden = model.decoder(token, hidden)
        token = out.max(1)[1]
        #print(token, out)
        lst.append(token.data[0])
        if token.data[0] == model.decoder.alphabet.end_index:
            break
    return ''.join(model.decoder.alphabet.index2letter(lst))
    

In [37]:
translate(model, "дерево")

'дрэва_END_'

In [35]:
! head -n 50 data/ru-be.txt

поражением	паразай
местному	мясцовым
испанских	іспанскіх
сорока	сарока
способна	здольная
факторы	фактары
высадки	высадкі
аргентина	аргенціна
феликс	фелікс
финальном	фінальным
фактором	фактарам
оригинального	арыгінальнага
древнейших	найстаражытных
художественным	мастацкім
резиденции	рэзідэнцыі
кораблях	караблях
православие	праваслаўе
мл	мл
гражданский	грамадзянскі
транспортного	транспартнага
листьями	лісцем
поврежд	поврежд
выехал	выехаў
сицилии	сіцыліі
чтение	чытанне
повсеместно	паўсюдна
украшения	ўпрыгажэнні
передали	перадалі
проявления	праявы
вокалист	вакаліст
склад	склад
уголовное	крымінальная
работало	працавала
ушел	сышоў
запросов	запытаў
спектакля	спектакля
продать	прадаць
сл	сл
ленинского	ленінскага
землями	землямі
свидетелей	сведак
существующей	існуючай
частного	прыватнага
образцу	узоры
отмечали	адзначалі
воздействием	уздзеяннем
категорически	катэгарычна
дальнейшие	далейшыя
сообщается	паведамляецца
периоде	перыядзе
