In [28]:
! pip3 install --user torch torchvision

Collecting torch
  Downloading https://files.pythonhosted.org/packages/e8/c5/0763a145e051ce7c84c128621693d1c5dfad5a42d551e8d79742261002e2/torch-0.3.1-cp35-cp35m-manylinux1_x86_64.whl (496.4MB)
[K    100% |████████████████████████████████| 496.4MB 2.3kB/s eta 0:00:01  7% |██▌                             | 38.6MB 35.4MB/s eta 0:00:13    10% |███▍                            | 52.1MB 39.0MB/s eta 0:00:12    13% |████▍                           | 68.3MB 26.1MB/s eta 0:00:17    14% |████▊                           | 73.2MB 26.8MB/s eta 0:00:16    15% |█████                           | 76.5MB 37.9MB/s eta 0:00:12    19% |██████▏                         | 95.2MB 40.8MB/s eta 0:00:10    20% |██████▍                         | 99.5MB 33.0MB/s eta 0:00:13    24% |████████                        | 122.5MB 21.3MB/s eta 0:00:18    27% |████████▊                       | 135.9MB 42.2MB/s eta 0:00:09    31% |██████████                      | 155.7MB 44.5MB/s eta 0:00:08    33% |██████████▉             

In [133]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.optim as optim
import itertools

In [247]:
TRAIN_FILE = 'data/ru-be-train.txt'
TEST_FILE = 'data/ru-be-test.txt'

In [248]:
MAX_LENGTH = 30
 
class Alphabet:
    START = '__START__'
    END = '_END_'
 
    def __init__(self, max_length=MAX_LENGTH):
        """Initialize the class which works with letter and index representations of sequences.
        Parameters
        ----------
        max_length : int
            The largest permitted length for sequence. Longer sequences are cropped.
        """
        self.max_length = max_length
        self.letter2index_ = {Alphabet.START : 0, Alphabet.END : 1}
        self.index2letter_ = [Alphabet.START, Alphabet.END]
        
    def get_index(self, letter):
        if letter not in self.letter2index_:
            self.letter2index_[letter] = len(self.index2letter_)
            self.index2letter_.append(letter)
        return self.letter2index_[letter]
    
    @property
    def start_index(self):
        return self.letter2index_[Alphabet.START]
    
    @property
    def end_index(self):
        return self.letter2index_[Alphabet.END]
    
    def index2letter(self, x, with_start_end=True):
        result = []
        for index in x[0 if with_start_end else 1:]:
            if index == self.end_index:
                if with_start_end:
                    result.append(self.index2letter_[index])
                break
            result.append(self.index2letter_[index])
        return ''.join(result)
    
    def letter2index(self, word):
        lst = [self.get_index(letter) for letter in word]
        return [self.start_index] + lst[:self.max_length - 2] + [self.end_index] * max(1, self.max_length - len(lst) - 1)
    
    def __len__(self):
        return len(self.index2letter_)
    
    # torch utils
    def get_length(self, input_sequence):
        """Infers the lengths of the sequences in batch
        
        input_sequence: Tensor NxT
        
        returs: Tensor N
        """
        return (input_sequence == self.end_index).max(dim=1)[1] + 1
    
    def get_mask(self, input_sequence):
        """Infers the mask of the sequences in batch
        
        input_sequence: Tensor NxT
        
        returns: Tensor NxT contained 0s and 1s.
        """
        return (torch.cumsum(input_sequence == self.end_index, dim=1) < 2).float()
    
    def get_one_hot_repr(self, input_sequence):
        """Produces one_hot representation from label representation/
        
        input_sequence: LongTensor NxT
        
        returns: FloatTensor NxTxH
        """
        
        onehot = torch.FloatTensor(*input_sequence.shape, len(self)).zero_()
        onehot.scatter_(2, input_sequence.unsqueeze(2), 1.)
        
        return onehot 

In [249]:
ru = Alphabet()
be = Alphabet()

In [250]:
def load_pair_dataset(filename, alph1, alph2):
    x, y = [], []
    with open(filename, 'r') as ftr:
        for line in ftr:
            try:
                word1, word2 = line.split()
            except ValueError:
                continue
            x.append(alph1.letter2index(word1))
            y.append(alph2.letter2index(word2))
    return np.array(x), np.array(y)       

In [252]:
len(ru)

35

In [253]:
X, Y = load_pair_dataset(TRAIN_FILE, ru, be)

In [254]:
onehot = torch.FloatTensor(2, 2, 35)

In [255]:
indx = torch.LongTensor([[1, 4], [4, 1]])

In [256]:
onehot.scatter_

<function FloatTensor.scatter_>

In [257]:
from sklearn.model_selection import train_test_split

train_X, val_X, train_Y, val_Y = train_test_split(X, Y, test_size=0.1, random_state=42)

In [258]:
train_X.dtype

dtype('int64')

In [259]:
def seq2seq_softmax_with_mask(entries, mask):
    entries = entries[:,:,0]
    maxs = entries.max(1, keepdim=True)[0]
    #print(entries.shape, maxs.shape, mask.shape)
    entries = torch.exp(entries - maxs) * mask
    return entries / (entries.sum(dim=1, keepdim=True) + 1e-15)


class MultiplicativeAttentionWithMask(nn.Module):
    def __init__(self, input_size, output_size):
        super(MultiplicativeAttentionWithMask, self).__init__()
        self.encoder_linear = nn.Linear(input_size, output_size)
        self.decoder_linear = nn.Linear(input_size, output_size)
        
    def forward(self, decoder_hidden, encoder_hiddens, encoder_mask):
        """
        decoder_hidden: NxH
        encoder_hiddens: NxTxH
        """
        decoder_hidden_key = F.tanh(self.decoder_linear(decoder_hidden))
        encoder_hiddens_keys = F.tanh(self.encoder_linear(encoder_hiddens))
        weights = torch.bmm(encoder_hiddens_keys, decoder_hidden_key.unsqueeze(2))
        weights = seq2seq_softmax_with_mask(weights, encoder_mask)
        return torch.bmm(encoder_hiddens.transpose(1, 2), weights.unsqueeze(2))[:,:,0]

In [260]:
class SimpleGRUEncoder(nn.Module):
    def __init__(self, alphabet, embedding_size, hidden_size):
        super(SimpleGRUEncoder, self).__init__()
        self.alphabet = alphabet
        self.embedding = nn.Embedding(num_embeddings=len(self.alphabet), embedding_dim=embedding_size)
        self.gru = nn.GRU(input_size=embedding_size, hidden_size=hidden_size, batch_first=True)
        
    def forward(self, input_sequence):
        batch_size = input_sequence.size(0)
        embeddings = self.embedding(input_sequence)
        out, _ = self.gru(embeddings)
        return out, self.alphabet.get_mask(input_sequence)

In [261]:
class SimpleGRUDecoderWithAttention(nn.Module):
    def __init__(self, alphabet, embedding_size, hidden_size):
        super(SimpleGRUDecoderWithAttention, self).__init__()
        self.hidden_size = hidden_size
        self.alphabet = alphabet
        self.embedding = nn.Embedding(num_embeddings=len(alphabet), embedding_dim=embedding_size)
        self.gru_cell = nn.GRUCell(input_size=embedding_size + hidden_size, hidden_size=hidden_size)
        self.logit_linear = nn.Linear(hidden_size, len(alphabet))
        self.attention = MultiplicativeAttentionWithMask(hidden_size, embedding_size)
        
    def init_hidden(self, batch_size):
        return Variable(torch.zeros(batch_size, self.hidden_size))
        
    def forward(self, token, prev_h, encoder_hs, encoder_mask):
        embedding = self.embedding(token)
        attention = self.attention(prev_h, encoder_hs, encoder_mask)
        #print(attention.shape, embedding.shape)
        h = self.gru_cell(torch.cat((embedding, attention), dim=1), prev_h)
        out = self.logit_linear(h)
        return out, h

In [321]:
class SimpleGRUSupervisedSeq2Seq(nn.Module):
    def __init__(self, src_alphabet, dst_alphabet, embedding_size, hidden_size):
        super(SimpleGRUSupervisedSeq2Seq, self).__init__()
        self.encoder = SimpleGRUEncoder(src_alphabet, embedding_size, hidden_size)
        self.h_linear = nn.Linear(hidden_size, hidden_size)
        self.decoder = SimpleGRUDecoderWithAttention(dst_alphabet, embedding_size, hidden_size)
        
    def start(self, batch_size):
        return Variable(torch.from_numpy(np.repeat(self.decoder.alphabet.start_index, batch_size)))
    
    '''
    def middle_layer(self, out, mask):
        #print(mask.sum(1))
        return F.tanh(self.h_linear(out[range(out.shape[0]), mask.sum(1).long() - 1]))
    '''
    
    def forward(self, input_sequence, output_sequence):
        enc_out, enc_mask = self.encoder(input_sequence)
        dec_h = self.decoder.init_hidden(input_sequence.size(0))
        logits = []
        for x in output_sequence.transpose(0, 1)[:-1]:
            out, dec_h = self.decoder(x, dec_h, enc_out, enc_mask)
            logits.append(out)
        return F.log_softmax(torch.stack(logits, dim=1), dim=-1)
    
    def translate(self, word, strategy='', max_length=30, with_start_end=True):
        if isinstance(word, str):
            as_word = True
            input_sequence = Variable(torch.from_numpy(np.array([model.encoder.alphabet.letter2index(word)])))
        elif isinstance(word, torch.autograd.variable.Variable):
            as_word = False
            input_sequence = word
        else:
            assert False, "word argument must be str or numpy array"
            
        #print(input_sequence.shape)
        enc_out, enc_mask = self.encoder(input_sequence)
        hidden = self.decoder.init_hidden(input_sequence.size(0))
        tokens = self.start(input_sequence.size(0))
        #print(token.shape, hidden.shape)
        lst = [tokens]
        for i in range(max_length - 1):
            out, hidden = self.decoder(tokens, hidden, enc_out, enc_mask)
            tokens = out.max(1)[1]
            #print(token, out)
            lst.append(tokens)
            if as_word and tokens.data[0] == self.decoder.alphabet.end_index:
                break
        if as_word:
            return ''.join(model.decoder.alphabet.index2letter(
                    [x.data[0] for x in lst], 
                    with_start_end=with_start_end)
            )
        else:
            return torch.stack(lst).transpose(0, 1)

In [322]:
def batch_iterator(X, Y=None, batch_size=32):
    assert Y is None or X.shape[0] == Y.shape[0]
    ind = np.arange(X.shape[0])
    np.random.shuffle(ind)
    for i in range(0, X.shape[0], batch_size):
        if Y is not None:
            yield X[ind[i:i + batch_size]], Y[ind[i:i + batch_size]]
        else:
            yield X[ind[i:i + batch_size]]

In [323]:
val_src_words[25]

'оспорен'

In [324]:
model.translate(val_src_words[21], with_start_end=False)

'рынго'

In [325]:
model = SimpleGRUSupervisedSeq2Seq(ru, be, 65, 256)
opt = optim.Adam(model.parameters(), lr=1e-3)

In [326]:
len(ru)

35

In [327]:
model.state_dict

<bound method Module.state_dict of SimpleGRUSupervisedSeq2Seq(
  (encoder): SimpleGRUEncoder(
    (embedding): Embedding(35, 65)
    (gru): GRU(65, 256, batch_first=True)
  )
  (h_linear): Linear(in_features=256, out_features=256)
  (decoder): SimpleGRUDecoderWithAttention(
    (embedding): Embedding(65, 65)
    (gru_cell): GRUCell(321, 256)
    (logit_linear): Linear(in_features=256, out_features=65)
    (attention): MultiplicativeAttentionWithMask(
      (encoder_linear): Linear(in_features=256, out_features=65)
      (decoder_linear): Linear(in_features=256, out_features=65)
    )
  )
)>

In [330]:
model.load_state_dict(torch.load(os.path.join(CHECKPOINTS, 'state_dict_4_0.7777775635820685.pth')))

In [296]:
def cross_entropy(log_predictions, targets, alphabet):
    """ Cross entropy loss for sequences
    Parameters
    ---------
    log_predictions: Tensor NxTxH
        Log probabilities
    targets: Tensor NxT
        True index-encoded translations
    alphabet: Alphabet
        Alphabet object
    
    """
    length_mask = alphabet.get_mask(targets)
    targets_mask = torch.zeros_like(log_predictions).scatter_(2, targets.view(*targets.shape, 1), 1.0)
    mask = targets_mask * length_mask.view(*length_mask.shape, 1)
    #print(mask.sum(1, keepdim=True).sum(2, keepdim=True))
    return (log_predictions * mask / (mask.sum(2, keepdim=True).sum(1, keepdim=True) * -log_predictions.size(0))).sum()
 

In [297]:
val_src_words = [ru.index2letter(x, with_start_end=False) for x in val_X]
val_trg_words = [be.index2letter(y, with_start_end=False) for y in val_Y]

In [667]:
src_words = [ru.index2letter(x, with_start_end=False) for x in train_X[:500]]
trg_words = [be.index2letter(y, with_start_end=False) for y in train_Y[:500]]

In [None]:
compute_accuracy(model, val_src_words, val_trg_words)

In [298]:
len(ru)

35

In [184]:
CHECKPOINTS = './checkpoints'

! mkdir -p {CHECKPOINTS}

In [669]:
import editdistance as ed
import time
import os
import nltk.translate.bleu_score as bl
from tqdm import tqdm_notebook

def compute_bleu_score(model, src_words, trg_words):
    return _compute_metric_average(model, src_words, trg_words, lambda x, y: bl.sentence_bleu([list(x)], list(y)))

def compute_editdistance(model, src_words, trg_words):
    return _compute_metric_average(model, src_words, trg_words, ed.eval)

def compute_accuracy(model, src_words, trg_words):
    return _compute_metric_average(model, src_words, trg_words, lambda x, y: x == y)

def _compute_metric_average(model, src_words, trg_words, metric):
    scs = [metric(model.translate(x, with_start_end=False), y) for x, y in zip(tqdm_notebook(src_words), trg_words)]
    return np.mean(scs)

def train(model, opt, train_X, train_Y, val_src_words, val_trg_words, checkpoints_folder, metrics_compute_freq=50, n_epochs=7):
    cur_loss = 0
    for epoch in range(n_epochs):
        model.train()
        start_time = time.time()
        for i, (x, y) in enumerate(batch_iterator(train_X, train_Y)):
            inputs = Variable(torch.from_numpy(x))
            targets = Variable(torch.from_numpy(y))
            log_predictions = model(inputs, targets)
            #print(x)
            loss = cross_entropy(log_predictions, targets[:,1:].contiguous(), be)
            #print(loss.data, log_predictions.data.min())
            loss.backward()
            cur_loss = 0.9 * cur_loss + 0.1 * loss.data[0]
            opt.step()
            opt.zero_grad()
            if (i + 1) % metrics_compute_freq == 0:
                print("epoch: {} iter: {} loss: {}".format(epoch, i, cur_loss))
        model.eval() 
        val_score = compute_bleu_score(model, val_src_words, val_trg_words)
        print("epoch: {} val_score: {} time: {}"
              .format(epoch, val_score, time.time() - start_time))
        torch.save(model.state_dict(), os.path.join(checkpoints_folder, "state_dict_{}_{}.pth".format(epoch, val_score)))
                
train(model, opt, train_X, train_Y, val_src_words, val_trg_words, checkpoints_folder=CHECKPOINTS, metrics_compute_freq=50, n_epochs=5)

KeyboardInterrupt: 

In [681]:
class LSTMDiscriminator(nn.Module):
    def __init__(self, alph, embedding_size, hidden_size):
        super(LSTMDiscriminator, self).__init__()
        self.alph = alph
        self.embedding = nn.Embedding(embedding_dim=embedding_size, num_embeddings=len(alph))
        self.bilstm = nn.LSTM(input_size=embedding_size, hidden_size=hidden_size, batch_first=True)
        self.output = nn.Linear(hidden_size, 1)
        
    def forward(self, input_sequence):
        embedding = self.embedding(input_sequence)
        out, _ = self.bilstm(embedding)
        return F.sigmoid(self.output(out[range(out.size(0)), self.alph.get_length(input_sequence) - 1].view(out.size(0), -1)))
        

In [687]:
def train_discriminator(disc_model, gen_model, opt, alph_X, train_X, train_Y, n_epochs=50):
    cur_loss = 0
    for epoch in range(n_epochs):
        disc_model.train()
        gen_model.eval()
        start_time = time.time()
        for i, (x, y) in enumerate(batch_iterator(train_X[:100], train_Y[:100])):
            inputs = Variable(torch.from_numpy(x))
            targets = Variable(torch.from_numpy(y))
            real_data_pred = disc_model(targets)
            #print(targets.shape, gen_model.translate(inputs).shape)
            gen_data_pred = disc_model(gen_model.translate(inputs))
            #print(targets, gen_model.translate(inputs), real_data_pred, gen_data_pred)
            #print(gen_data_pred)
            #print(real_data_pred)
            loss = F.binary_cross_entropy(gen_data_pred, torch.zeros_like(gen_data_pred)) \
                    + F.binary_cross_entropy(real_data_pred, torch.ones_like(real_data_pred))
            cur_loss = 0.9 * cur_loss + 0.1 * loss.data[0]
            loss.backward()
            opt.step()
            opt.zero_grad()
            if i % 10 == 9:
                print(loss.data[0])
            #break
        print(cur_loss)

In [688]:
len(ru)

35

In [689]:
disc = LSTMDiscriminator(be, 65, 128)
disc_opt = optim.Adam(disc.parameters(), lr=1e-3)

In [690]:
train_discriminator(disc, model, disc_opt, ru, train_X, train_Y)

0.47828762354850773
0.7891616330778266
0.9927513984005869
1.1246100995517252
1.2097757312396695
1.264631254033487
1.298679733613282
1.315615722915447
1.3327620490002825
1.3283503470036722
1.3357163702143897
1.3015101269327118
1.2952145269594424
1.276770088341116
1.2837131696264008
1.2676731819778375
1.2707987941426044
1.229980041669185
1.211134419065387
1.2125875161289557
1.1641344943987744
1.2033323353890424
1.1905862552767312
1.160494529600224
1.1233999980764031
1.1019565296386356
1.0876009395734516
1.090764420394066
1.0215543484382592
0.9934764123630544
1.0085826131660256
0.9832224847477753
0.9864519969911502
0.9845802245398967
0.9435776073795903
0.9810906967744041
0.9470984337126809
0.9522093755753093
0.8778221385823783
0.8356584701484652
0.8129246926392784
0.8318862313371707
0.8050565543599587
0.8164536657012731
0.8888504459811123
0.9235375377705614
0.943011606540823
0.9092179370736815
0.946454079448224
0.9943259457970866


In [686]:
model.translate("полемики")

'__START__палемікі_END_'

In [None]:
disc()

In [None]:
val_trg_words[1]

In [None]:
for x, y in zip(val_src_words[:10], val_trg_words):
    tr = model.translate(x[:-5])[:-5]
    y = y[:-5]
    
    print(tr, y, ed.eval(tr, y))

In [None]:
translate(model, "")

In [None]:
from tqdm import tqdm_notebook
import editdistance as ed


scs = []
with open(TEST_FILE, "r") as ftr:
    for ruw, bew in map(lambda x: x.split(), 
                        filter(lambda x: len(x.split()) == 2, tqdm_notebook(ftr.readlines()[:2000]))):
        res = model.translate(ruw)
        scs.append(ed.eval(bew, res[:-5]))
        print(ruw, bew, res)
    
        

In [None]:
! pip3 install --user editdistance

In [None]:
np.mean(scs)

In [None]:
! head -n 50 data/ru-be.txt