## Загрузка данных

Флаг _-nc_ позволяет не скачивать файлы, если они уже есть. 

In [1]:
!wget -nc https://raw.githubusercontent.com/multi30k/dataset/master/data/task1/tok/train.lc.norm.tok.fr
!wget -nc https://raw.githubusercontent.com/multi30k/dataset/master/data/task1/tok/train.lc.norm.tok.en
!wget -nc https://s3.amazonaws.com/arrival/embeddings/wiki.multi.en.vec
!wget -nc https://s3.amazonaws.com/arrival/embeddings/wiki.multi.fr.vec

File ‘train.lc.norm.tok.fr’ already there; not retrieving.

File ‘train.lc.norm.tok.en’ already there; not retrieving.

File ‘wiki.multi.en.vec’ already there; not retrieving.

File ‘wiki.multi.fr.vec’ already there; not retrieving.



## Определения

После первого запуска имеет смысл установить __LOAD_PICKLED = True__, это позволит загрузить переведённый датасет, а не переводить всё заново.

In [2]:
import pickle
import numpy as np
import unicodedata
import string
import re
import torch
from torch import nn
from torch import optim
import matplotlib.pyplot as plt

LOAD_PICKLED = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print (device)

cuda


### Вспомогательные функции

In [3]:
# Turn a Unicode string to plain ASCII, thanks to
# http://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s.strip()

def read_sentences(path):
    lines = []
    with open(path) as f:
        for line in f:
            lines.append(normalizeString(line))
    return lines

In [4]:
def seq_format(seq, max_words):
    nwords = len(seq)
    seq_new = seq + ["<EOS>"]
    seq_new += ["<PAD>" for i in range(max_words - nwords)]
    return seq_new

def freq_filter(seq, lang, freq):
    if freq == -1:
        return True
    for w in seq:
        if lang.word2count[w] < freq:
            return False
    return True

def len_filter(seq, max_len):
    if max_len == -1:
        return True
    else:
        return len(seq) <= max_len

def prepare_list(list, lang, max_words, freq):
    list_seq = [s.split() for s in list]
    list_clean = []
    for s in list_seq:
        if len_filter(s, max_words) and freq_filter(s, lang, freq):
            list_clean.append(s)
        else:
            list_clean.append(None)
            
    return list_clean

def seq2ind(seq, lang):
    return [lang.word2index[w] for w in seq]

def noise(seq, drop_prob=0.1, shuffle_len=3):
    n = len(seq)
    ind = np.argsort(np.arange(0, n) + np.random.uniform(0, shuffle_len, n))
    drop_mask = np.random.binomial(1, 1-drop_prob, n).astype(np.bool)
    ind = ind[drop_mask]
    res = []
    for i in ind:
        res.append(seq[i])
    return res

### Словарь

Собирает статистику по словам в тексте и умеет выдавать идекс каждого слова.

In [5]:
class Vocabulary:
    """
    Tracks info about known words, their indices and frequences.
    """
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.dummies = self.get_dummies()
        for i, name in enumerate(self.dummies):
            self.word2index[name] = i
            self.word2count[name] = 0
        self.index2word = self.dummies.copy()
        self.n_words = len(self.dummies)
   
    def add_list(self, list):
        for s in list:
            self.add_sentence(s)
    
    def add_sentence(self, sentence):
        for word in sentence.split(' '):
            self.add_word(word)

    def add_seq(self, seq):
        for word in seq:
            self.add_word(word)

    def add_word(self, word):
        if word not in self.word2index:   
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word.append(word)
            self.n_words += 1
        else:
            self.word2count[word] += 1
    
    @staticmethod
    def get_dummies():
        return ["<SOS>", "<EOS>", "<PAD>", "<UNK>"]
    
    @staticmethod
    def get_dummy_ind(w):
        return Vocabulary.get_dummies().index(w)

### Датасет

Хранит данные, отвечает за генерацию перевода, тренировочной и валидационной выборок. 

In [6]:
class Dataset:
    """
    Data storage and preprocessing.
    """
    def __init__(self, lang_info, max_len=-1, min_freq=-1, val_ratio=0.1):
        """
        Arguments: 
        lang_info   -- dictionary with following info:
                      --> name = language name (str)
                      --> corpus_path = path to file with sentences (str)
        max_length  -- maximum sentence length (-1 for no limit)
        min_freq    -- minimum word appearing frequency (-1 for no limit)
        val_ration -- fraction of sentences to use for validation
        """
        self.max_len = max_len
        self.min_freq = min_freq
        
        # Sentences lists
        self.s_list = {}
        
        if len(lang_info) != 2:
            raise ValueError('Only pairs of languages are supported, but {} was passed.'.format(len(lang_info)))
        self.names = []
        for l, path in lang_info.items():
            self.names.append(l)
            self.s_list[l] = read_sentences(path)
            
        # Vocabularies 
        self.v_list = {}
        for l in self.names:
            self.v_list[l] = Vocabulary(l)
            self.v_list[l].add_list(self.s_list[l])
    
        # Filter sentences
        self.seq_list = {} 
        nsents = []
        
        seq_list = []
        seq_names = []
        for l in self.names:
            tmp = prepare_list(self.s_list[l], self.v_list[l], 
                                           max_len, min_freq)
            nsents.append(len(tmp))
            seq_list.append(tmp)
            seq_names.append(l)

        if len(set(nsents)) != 1:
            raise Warning('Numbers of sentences are not equal for the languages.')
        
        # Combaine parallel and non-parallel data
        nopair = {l:[] for l in seq_names}
        pair = {l:[] for l in seq_names}
        nfiltered = [0, 0]
        npairs = 0
        for s in zip(*seq_list):
            if s[0] == None and s[1] == None:
                continue
            if s[1] == None:
                nfiltered[0] += 1
                nopair[seq_names[0]].append(s[0])
            elif s[0] == None:
                nfiltered[1] += 1
                nopair[seq_names[1]].append(s[1])
            else:
                nfiltered[0] += 1
                nfiltered[1] += 1
                npairs += 1
                for i in range(2):
                    pair[seq_names[i]].append(s[i])
       
        # Form test and train sentences
        wanted_len = int(val_ratio*min(nfiltered))
        if wanted_len > npairs:
            raise Warning('Asked for {} test samples, but only {} can be provided.'.format(wanted_len, npairs))
        res_len = min(npairs, wanted_len)
        self.test_list = {}
        self.seq_list = {}
        self.val_size = res_len
        for l in self.names:
            self.seq_list[l] = pair[l][res_len:] + nopair[l]
            self.test_list[l] = pair[l][:res_len]
        # No translated version present
        self.seq_tr_list = {}
        for l in self.names:
            self.seq_tr_list[l] = None
        
    def translate(self, translator):
        """Build translation of stored sentences.
        
            Arguments:
            translator -- an object that has translate_seq(seq, from_lang, to_lang) function,
                          where: seq -- sequence of words
                                 from_lang, to_lang -- strings
        """
        other = dict(zip(self.names, self.names[::-1]))
        for l, seq_list in self.seq_list.items():
            self.seq_tr_list[l] = [translator.translate_seq(s, l, other[l]) for s in seq_list]
            for seq in self.seq_tr_list[l]:
                self.v_list[other[l]].add_seq(seq)
            
    def get_train(self, batch_size=1):
        """Get train data.
          
          Returns:
           X_auto{'en', 'fr'}   --  indexed noisy src sentences
           Y_auto{'en', 'fr'}   --  indexed clean src sentences
           X_cross{'en', 'fr'}   --  indexed translated noisy src sentences
           Y_cross{'en', 'fr'}   --  indexed clean src sentences
        """
        X_auto = {}
        Y_auto = {}
        
        X_cross = {}
        Y_cross = {}
        other = dict(zip(self.names, self.names[::-1]))
        for l, lang in self.v_list.items():
            # Autoencoders train
            batch_ind = np.random.choice(range(len(self.seq_list[l])), batch_size, replace=False)
            seq_list_tmp = [self.seq_list[l][i] for i in batch_ind]
            
            X_auto_tmp = list(map(noise, seq_list_tmp))
            Y_auto_tmp = seq_list_tmp
            
            # Cross-domain train
            batch_ind = np.random.choice(range(len(self.seq_tr_list[l])), batch_size, replace=False)
            seq_list_tmp = [self.seq_list[l][i] for i in batch_ind]
            seq_tr_list_tmp = [self.seq_tr_list[l][i] for i in batch_ind]
            
            X_cross_tmp = list(map(noise, seq_tr_list_tmp))
            Y_cross_tmp = seq_list_tmp
            
            vocabs = 3*[self.v_list[l]] + [self.v_list[other[l]]]
            seq_lists = [X_auto_tmp, Y_auto_tmp, Y_cross_tmp, X_cross_tmp]
            ind_lists = []
            for lang, seq_list in zip(vocabs, seq_lists):
                max_len = max(list(map(len, seq_list)))
                formatted = list(map(lambda x: seq_format(x, max_len), seq_list))
                inds = torch.tensor(list(map(lambda x: seq2ind(x, lang), formatted))) 
                ind_lists.append(inds)
                
            X_auto[l], Y_auto[l], Y_cross[l], X_cross[l] = ind_lists
                
        return X_auto, Y_auto, X_cross, Y_cross
    
    def get_test(self, nsamples=-1):
        """Get test data.
        
        Returns:
        X{'fr', 'en'} -- pairs of translated sentences 
        """
        X = {}
        if nsamples==-1:
            nsamples = self.val_size
        inds = np.random.choice(range(self.val_size), nsamples, replace=False)
        for l, lang in self.v_list.items():
            test_list_tmp = [self.test_list[l][i] for i in inds]
            max_len = max(list(map(len, test_list_tmp)))
            formatted = list(map(lambda x: seq_format(x, max_len), test_list_tmp))
            X[l] = torch.tensor(list(map(lambda x: seq2ind(x, lang), formatted))) 
        return X
    
    def ind2sent(self, ind_seq, lang):
        """Translate word indices to words.
        
            Arguments:
            ind_seq  -- sequence of indices
            lang     -- corresponding language ('en', 'fr', ...)
        """
        words = list(map(lambda x: self.v_list[lang].index2word[x], ind_seq))
        try:
            end = words.index('<EOS>')
        except ValueError:
            end = len(ind_seq)
        return ' '.join(words[:end])

### Переводчик

Нулевое приближение перевода, используя натренированные представления слов. Близким словам соответствуют близкие векторы, переводим вектор в вектор.

In [7]:
def load_vec(emb_path, max_words=-1):
    vectors = []
    word2id = {}
    it = 0
    with open(emb_path) as f:
        nvec, ndim = [int(k) for k in f.readline().split()]
        for line in f:
            if max_words != -1 and it > max_words:
                break
            it += 1
            orig_word, vect = line.rstrip().split(' ', 1)
            
            word = normalizeString(orig_word)
            vect = np.fromstring(vect, sep=' ')
       
            # Words are sorted by frequency, no need to add less 
            # frequent version of the same word  
            if not (word in word2id):
                vectors.append(vect)
                word2id[word] = len(word2id)

    id2word = {v: k for k, v in word2id.items()}
    embeddings = np.vstack(vectors)
    return embeddings, id2word, word2id

class NaiveTranslator:
    """Naive word-by-word translation with caching.
    """
    def __init__(self, lang_info, max_words=-1):
        """
        Arguments: 
        lang_info  -- dictionary with following info:
                      --> name = language name (str)
                      --> emb_path = path to file with embeddings (str)
        max_words -- maximum number of embeddings to load (sorted by frequency)
        """   
        self.emb = {}
        self.id2word = {}
        self.word2id = {}
        self.names = []
        for l, path in lang_info.items():
            self.names.append(l)
            self.emb[l], self.id2word[l], self.word2id[l] = load_vec(path, 
                                                                     max_words)
            
        self.cache = {l: {} for l in self.names}
        # Add dummies
        for l in self.names:
            for w in ["<SOS>", "<EOS>", "<PAD>", "<UNK>"]:
                self.cache[l][w] = w
        
    def translate(self, word, from_lang, to_lang):
        if word in self.cache[from_lang]:
            return self.cache[from_lang][word]
        else:
            # Handle unknown
            if word in self.word2id[from_lang]:
                id = self.word2id[from_lang][word]
            else:
                self.cache[from_lang][word] = "<UNK>"
                return "<UNK>"
            
            vec = self.emb[from_lang][id]
            dist = np.dot(self.emb[to_lang], vec)
            ind = np.asscalar(np.argmax(dist, axis=0))
            tr = self.id2word[to_lang][ind]
            self.cache[from_lang][word] = tr    
            return tr
    
    def translate_sent(self, sent, from_lang, to_lang):
        new_sent = ' '.join([self.translate(w, from_lang, to_lang) for w in sent.split()])
        return new_sent
    
    def translate_seq(self, seq, from_lang, to_lang):
        return [self.translate(w, from_lang, to_lang) for w in seq]

### Энкодер

Обычный GRU. Параметры:
* __embeddings__  -- оптимизируемые представления слов
* __hidden_size__ -- размерность векторов в скрытом пространстве (предложений), куда отображает энкодер. Совпадает с размерностью вектора состояния RNN (энкодера)  

In [8]:
class Encoder(torch.nn.Module):
    def __init__(self, embeddings, hidden_size):
        super().__init__()
        self.emb = embeddings
        self.hidden_size = hidden_size
        for emb in embeddings.values():
            self.input_size = emb.embedding_dim
            
        self.gru = nn.GRU(self.input_size, self.hidden_size, batch_first=True)

    def step(self, input, hidden, from_lang):
#         print('from', from_lang)
#         print('input', input)
#         print('hidden', hidden)

        embedded = self.emb[from_lang](input)
        output, hidden = self.gru(embedded, hidden)

        return output, hidden
    
    def forward(self, ind_batch, nsteps, from_lang):
#         print('>>Encoder:step')
        encoder_outputs = torch.zeros((ind_batch.shape[0], nsteps, hidden_size), 
                                      device=ind_batch.device)
        encoder_hidden = torch.zeros((1, ind_batch.shape[0], hidden_size), 
                                     device=ind_batch.device)
#         print('encoder_outputs', encoder_outputs.shape)
        for i in range(nsteps):
#             print('step:', i)
            encoder_output, encoder_hidden = self.step(ind_batch[:, [i]], encoder_hidden, from_lang)
#             print('encoder_output', encoder_output.shape)
            encoder_outputs[:, i, :] += encoder_output.squeeze()
#         print('<<Encoder:step')
        
        return encoder_outputs, encoder_hidden

### Сеть внимания

По вектору скрытого пространства и вектору состояния декодера говорит, какое внимание должно быть уделено первому.

* __input_size__ -- размер вектора в скрытом пространстве предложений
* __state_size__ -- размер вектора скрытого состояния
* __inner_size__ -- размер внутреннего слоя

#### Линейный вариант

In [9]:
class AttnLinear(torch.nn.Module):
    def __init__(self, input_size, state_size, inner_size = 10):
        super().__init__()
        self.W = nn.Linear(input_size + state_size, inner_size)
        self.v = nn.Linear(inner_size, 1)
        
    def forward(self, input, hidden):
#         print('>>AttnLinear')
#         print('input', input.shape)
#         print('hidden', hidden.shape)
#         print('<<AttnLinear')
        expanded = hidden.expand(-1, input.shape[1], -1)
        return torch.relu(self.v(self.W(torch.cat((input, expanded), dim=2))))

#### Вариант в виде сети

In [10]:
class AttnNet(torch.nn.Module):
    def __init__(self, input_size, state_size, inner_size=10):
        super().__init__()
        self.v = nn.Linear(inner_size, 1, bias=False)
        self.W = nn.Linear(input_size, inner_size, bias=False)
        self.U = nn.Linear(state_size, inner_size, bias=False)
        
    def forward(self, input, hidden):
#         print('>>AttnNet')
#         print('input', input.shape)
#         print('hidden', hidden.shape)
#         print('<<AttnNet')
        return torch.relu(self.v(self.W(input) + self.U(hidden)))

### Декодер

Умеет по представлению предложения в скрытом пространстве (последовательность векторов размерности __hidden_size__) получать последовательность слов для одного из языков.

* __embeddings__  -- оптимизируемые представления слов
* __hidden_size__ -- размер вектора в скрытом пространстве (предложений)
* __state_size__  -- размер вектора состояния RNN (декодера)
* __inner_size__  -- размер внутреннего слоя сети внимания 

In [11]:
class AttnDecoder(torch.nn.Module):
    def __init__(self, embeddings, hidden_size, state_size, attn_size):
        super().__init__()
        self.emb = embeddings
        for emb in embeddings.values():
            self.emb_size = emb.embedding_dim
        self.hidden_size = hidden_size
        self.attn_size = attn_size
        self.state_size = state_size

        self.attn = AttnLinear(hidden_size, state_size, attn_size)
        self.gru = nn.GRU(hidden_size + self.emb_size, state_size, batch_first=True)
        self.out = nn.ModuleDict()
        for l, emb in embeddings.items():
            self.out[l] = nn.Linear(state_size, emb.num_embeddings)
        
    def step(self, ind, hidden, encoder_outputs, to_lang):
#         print('>>AttnDecoder')
#         print('hidden', hidden.shape)
        input = self.emb[to_lang](ind)
        attn_weights = torch.softmax(
            self.attn(encoder_outputs, hidden.transpose(0, 1)), dim=1)
#         print('attn_weights', attn_weights.shape)
        attn_applied = torch.bmm(attn_weights.transpose(1, 2), encoder_outputs)
#         print('encoder_outputs', encoder_outputs.shape)
#         print('attn_applied', attn_applied.shape)
        gru_input = torch.cat((input, attn_applied), dim=2)
#         print('gru_input', gru_input.shape)
        output, hidden = self.gru(torch.cat((input, attn_applied), dim=2), hidden)
        output = self.out[to_lang](attn_applied)
        output = torch.log_softmax(output, dim=2)
        
        return output, hidden, attn_weights

    def forward(self, encoder_outputs, nsteps, to_lang):
#         print('>>AttnDecoder')
        decoder_outputs = torch.zeros((encoder_outputs.shape[0], nsteps,
                                       self.emb[to_lang].num_embeddings), 
                                       device=encoder_outputs.device)
#         print('decoder_outputs', decoder_outputs.shape)
        # Load encoded hidden state
        decoder_hidden = encoder_outputs[:, -1, :].unsqueeze(0).contiguous()
        # Get SOS (start of sentence)
        input = torch.full((encoder_outputs.shape[0], 1),
                           Vocabulary.get_dummy_ind('<SOS>'),
                           dtype=torch.long)
        input = input.to(encoder_outputs.device)
        for i in range(nsteps):
#             print('step', i)
            decoder_output, decoder_hidden, attn_weights = self.step(input, 
                                                       decoder_hidden, 
                                                       encoder_outputs,
                                                       to_lang)
#             print('decoder_output', decoder_output.shape)
            decoder_outputs[:, [i], :] += decoder_output
            _, input = decoder_output.topk(1, dim=2)
#             print('input', input.shape)
            input = input.view(encoder_outputs.shape[0], 1)
        
#         print('<<AttnDecoder')
        return decoder_outputs, decoder_hidden

#### Дискриминатор

Умеет по представлению предложения в скрытом пространстве (последовательность векторов размерности __hidden_size__) говорить, какому из двух языков (__0__ или __1__) она принадлежит.

* __hidden_size__ -- размер вектора в скрытом пространстве (предложений)
* __hidden_len__  -- максимальная длина последовательности векторов в скрытом пространстве; она же максимальная длина входный предложений
* __hidden_layer_size__  -- размер скрытого слоя

In [12]:
class Discriminator(nn.Module):
    def __init__(self, hidden_size, hidden_len, hidden_layer_size):
        super().__init__()
        self.hidden_len = hidden_len
        self.hidden_size = hidden_size
        self.hidden_layer_size = hidden_layer_size
        self.mesh = nn.Linear(hidden_len, 1)
        self.hid = nn.Linear(hidden_size, hidden_layer_size)
        self.out = nn.Linear(hidden_layer_size, 1)
        
    def forward(self, input):
#         print('>>Discriminator')
        device = input.device
        pad = self.hidden_len - input.shape[1]
#         print('pad', pad)
        padding = torch.zeros((input.shape[0], pad, input.shape[2]), device=device) 
#         print('input', input.shape)
        input = torch.cat((input, padding), dim=1)
        input = torch.relu(self.mesh(input.transpose(1, 2)))
#         print('<<Discriminator')
        return torch.sigmoid(self.out(torch.relu(self.hid(input.transpose(1, 2)))))

#### Обёртка

Управляет всеми частями модели. Параметры соответствуют описанным выше.

In [13]:
class Wrapper(nn.Module):
    def __init__(self, hidden_size, decoder_size, attn_size, discr_size, max_in_len, dataset):
        super().__init__()
        self.emb = nn.ModuleDict()
        self.names = []
        for name, vocab in dataset.v_list.items():
            self.emb[name] = nn.Embedding(len(vocab.word2index), hidden_size)
            self.names.append(name)
        self.hidden_size = hidden_size
        self.decoder_size = decoder_size
        self.attn_size = attn_size
        self.discr_size = discr_size
        self.max_in_len = max_in_len
        
        self.enc = Encoder(self.emb, hidden_size)
        self.dec = AttnDecoder(self.emb, hidden_size, decoder_size, attn_size)
        self.discr = Discriminator(hidden_size, max_in_len, discr_size)
        
    def encode(self, ind_batch, from_lang):
        return self.enc(ind_batch, ind_batch.shape[1], from_lang)
    
    def decode(self, encoder_outputs, output_len, to_lang):
        return self.dec(encoder_outputs, output_len, to_lang)
    
    def encode_decode(self, ind_batch, from_lang, to_lang, out_len=None):
        if out_len == None:
            out_len = ind_batch.shape[1]
        
        encoder_outputs, encoder_hidden = self.encode(ind_batch, from_lang)
        decoder_outputs, decoder_hidden = self.decode(encoder_outputs, out_len, to_lang)
        
        return encoder_outputs, decoder_outputs
    
    def discriminate(self, encoder_outputs):
        return self.discr(encoder_outputs)

## Загрузка датасета

Если был указан параметр __LOAD_PICKLED = True__, то загружается из файла. Иначе создаётся заново: процесс не очень быстрый.

In [14]:
if LOAD_PICKLED:
    with open('NaiveTranslator', 'rb') as f:
        tr = pickle.load(f)
else:
    lang_info = {'fr': 'wiki.multi.fr.vec',
                 'en': 'wiki.multi.en.vec'}

    tr = NaiveTranslator(lang_info)
    with open('NaiveTranslator', 'wb') as f:
        pickle.dump(tr, f)

In [15]:
max_in_len = 10
min_in_freq = 5

if LOAD_PICKLED:
    with open('Dataset', 'rb') as f:
        D = pickle.load(f)
else:
    lang_info = {'fr': 'train.lc.norm.tok.fr',
                 'en': 'train.lc.norm.tok.en'}

    D = Dataset(lang_info, max_in_len, min_in_freq)
    D.translate(tr)
    with open('Dataset', 'wb') as f:
        pickle.dump(D, f)

Пример генерации обучающей/тестовой выборки.

In [16]:
train = D.get_train(50)
l = 'fr'
other = 'en'
print('X_auto:', D.ind2sent(train[0][l][0], l))
print('Y_auto:', D.ind2sent(train[1][l][0], l))
print('X_cross:', D.ind2sent(train[2][l][0], other))
print('Y_cross:', D.ind2sent(train[3][l][0], l))

X_auto: une femme tenant un jaune la sous pluie .
Y_auto: une femme tenant un parapluie jaune sous la pluie .
X_cross: dogs leaping four by above a . obstacle
Y_cross: quatre chiens sautant par dessus un obstacle .


## Тренировка модели

In [29]:
# Load from saved iteration
# -1 for new model
LOAD_ITER = -1

batch_size = 50
niters = 1000
# Number of iterations between validations
val_per = 100
# Number of iterations between printing current iteration
it_per = 10
# Number of iterations between model saves
save_per = 100

# Ability to translate
tr_crit = nn.NLLLoss()
# Ability to fool the discriminator
tr_fake_crit = nn.BCELoss()
# Ability to predict language correctly
discr_crit = nn.BCELoss()

hidden_size = 10
decoder_size = 10
attn_size = 5
discr_size= 300 
# Add one position for <EOS> symbol
max_len = max_in_len + 1

wr = Wrapper(hidden_size, decoder_size, attn_size, discr_size, max_len, D).to(device)

if LOAD_ITER != -1:
    checkpoint = torch.load(f'checkpoint.{LOAD_ITER}')
    wr.load_state_dict(checkpoint['state_dict'])

# Discriminator optimizer
d_opt = torch.optim.Adam(wr.discr.parameters())
# Encoder and decoder optimizer
tr_opt = torch.optim.Adam(list(wr.enc.parameters()) + list(wr.dec.parameters()))

# Assign classes to languages
class_num = {name:cl for cl, name in enumerate(D.names)}
# Dict with lang pairs
other = dict(zip(D.names, D.names[::-1]))
# Get index of PAD 
pad_ind = Vocabulary.get_dummy_ind('<PAD>')

for it in range(niters):
    # Refresh gradients
    tr_opt.zero_grad()
    d_opt.zero_grad()
    
    # Get training batch
    X_auto, Y_auto, X_cross, Y_cross = D.get_train(batch_size)
    
    # Losses to be accumulated
    d_loss = 0
    tr_loss = 0
    # Train autoencoders
    for l in D.names:
        # Noisy sentences 
        X_auto[l] = X_auto[l].to(device)
        # Clean sentences
        Y_auto[l] = Y_auto[l].to(device)
        # Noisy translation to other language 
        X_cross[l] = X_cross[l].to(device)
        # Source sentences 
        Y_cross[l] = Y_cross[l].to(device)

        ## AUTOENCODER PHASE
        
        encoder_outputs, decoder_outputs =\
                           wr.encode_decode(X_auto[l], 
                                            l,
                                            l,
                                            Y_auto[l].shape[1])
        
        # Dont penalize padding
        if torch.any(Y_auto[l] == pad_ind):
            decoder_outputs[Y_auto[l] == pad_ind][:, pad_ind] = 0
        tr_loss += tr_crit(decoder_outputs.transpose(1, 2), Y_auto[l])
        # We want to predict wrong class labels (fool the discriminator)
        predicted = wr.discriminate(encoder_outputs)
        wanted = torch.full_like(predicted, class_num[other[l]], device=device)
        tr_loss += tr_fake_crit(predicted, wanted)
        
        # And predict correct classes by discriminator
        # .detach() allows us to ignore subgraph, 
        # connected with encoder+decoder
        correct = torch.full_like(predicted, class_num[l], device=device)
        predicted_det = wr.discriminate(encoder_outputs.detach())
        d_loss += discr_crit(predicted_det, correct)
        
        ## CROSS-DOMAIN PHASE
        ## Repeats previous phase with the only difference of source
        ## language change
        
        encoder_outputs, decoder_outputs =\
                       wr.encode_decode(X_cross[l], 
                                        other[l], # language changed here
                                        l,
                                        Y_cross[l].shape[1])
        if torch.any(Y_cross[l] == pad_ind):
            decoder_outputs[Y_cross[l] == pad_ind][:, pad_ind] = 0
        tr_loss += tr_crit(decoder_outputs.transpose(1, 2), Y_cross[l])
        
        # Language is changed here too
        predicted = wr.discriminate(encoder_outputs)
        wanted = torch.full_like(predicted, class_num[l], device=device)
        tr_loss += tr_fake_crit(predicted, wanted)
        
        correct = torch.full_like(predicted, class_num[other[l]], device=device)
        predicted_det = wr.discriminate(encoder_outputs.detach())
        d_loss += discr_crit(predicted_det, correct)
        
    ## BACKPROPAGATION PHASE
    
    # We don't want Discriminator weights to be affected
    for p in wr.discr.parameters():
        p.requires_grad = False
    # Computed encoder and decoder gradients
    tr_loss.backward()
    # Undo the changes
    for p in wr.discr.parameters():
            p.requires_grad = True
    
    # Now calculate the Discriminator loss
    d_loss.backward()
    
    ## OPTIMISATION STEP
    d_opt.step()
    tr_opt.step()

    if not it % it_per:
        print('Iterations:', it)
        
    if not it % save_per:
        path = f'checkpoint.{it}'
        torch.save({'state_dict': wr.state_dict()}, path)
        print('Saved model to:', path)
        
    if not it % val_per: 
        print('Last loss:')
        print('  d_loss =', d_loss.data)
        print('  tr_loss =', tr_loss.data)
        X = D.get_test(1)
        print(f'\n[{it}] Validation:')
        for l in D.names:
            X[l] = X[l].to(device)
            print('\n', l, ' --> ', other[l])
            print('<< ', D.ind2sent(X[l][0].cpu(), l))
            print('== ', D.ind2sent(X[other[l]][0].cpu(), other[l]))
            encoder_outputs, decoder_outputs =\
                       wr.encode_decode(X[l], 
                                        l,
                                        other[l],
                                        X[other[l]].shape[1])
            _, ind = decoder_outputs[0].cpu().topk(1, dim=1)
            print('>> ', D.ind2sent(ind, other[l]))

Iterations: 0
Saved model to: checkpoint.0
Last loss:
  d_loss = tensor(2.7549, device='cuda:0')
  tr_loss = tensor(40.1062, device='cuda:0')
[0] Validation:

 fr  -->  en
<<  une statue argentee d apos hommes sur des velos .
==  a silver statue of men on bikes .
>>  chelsea chelsea chelsea chelsea chelsea chelsea chelsea chelsea chelsea

 en  -->  fr
<<  a silver statue of men on bikes .
==  une statue argentee d apos hommes sur des velos .
>>  danseur danseur danseur danseur danseur danseur danseur danseur danseur danseur danseur
Iterations: 10
Iterations: 20
Iterations: 30
Iterations: 40
Iterations: 50
Iterations: 60
Iterations: 70
Iterations: 80
Iterations: 90
Iterations: 100
Saved model to: checkpoint.100
Last loss:
  d_loss = tensor(2.9585, device='cuda:0')
  tr_loss = tensor(34.4259, device='cuda:0')
[100] Validation:

 fr  -->  en
<<  le chien chasse la chevre dans la cour .
==  the dog is chasing the goat around the yard .
>>  a a a a a a a a a a a

 en  -->  fr
<<  the dog is

KeyboardInterrupt: 

Начало обучения не предвещает ничего хорошего:
 
```
Iterations: 0
Saved model to: checkpoint.0
Last loss:
  d_loss = tensor(2.7549, device='cuda:0')
  tr_loss = tensor(40.1062, device='cuda:0')
[0] Validation:

 fr  -->  en
<<  une statue argentee d apos hommes sur des velos .
==  a silver statue of men on bikes .
>>  chelsea chelsea chelsea chelsea chelsea chelsea chelsea chelsea chelsea

 en  -->  fr
<<  a silver statue of men on bikes .
==  une statue argentee d apos hommes sur des velos .
>>  danseur danseur danseur danseur danseur danseur danseur danseur danseur danseur danseur
Iterations: 10
Iterations: 20
Iterations: 30
Iterations: 40
Iterations: 50
Iterations: 60
Iterations: 70
Iterations: 80
Iterations: 90
Iterations: 100
Saved model to: checkpoint.100
Last loss:
  d_loss = tensor(2.9585, device='cuda:0')
  tr_loss = tensor(34.4259, device='cuda:0')
[100] Validation:

 fr  -->  en
<<  le chien chasse la chevre dans la cour .
==  the dog is chasing the goat around the yard .
>>  a a a a a a a a a a a

 en  -->  fr
<<  the dog is chasing the goat around the yard .
==  le chien chasse la chevre dans la cour .
>>  plage plage plage plage plage plage plage plage plage plage
Iterations: 110
Iterations: 120
Iterations: 130
Iterations: 140
Iterations: 150
Iterations: 160
Iterations: 170
Iterations: 180
Iterations: 190
Iterations: 200
Saved model to: checkpoint.200
Last loss:
  d_loss = tensor(2.7812, device='cuda:0')
  tr_loss = tensor(26.7634, device='cuda:0')
[200] Validation:

 fr  -->  en
<<  un jeune homme tenant une enorme tronconneuse .
==  a young man holding a huge chainsaw .
>>  <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>

 en  -->  fr
<<  a young man holding a huge chainsaw .
==  un jeune homme tenant une enorme tronconneuse .
>>  sur sur sur sur sur sur sur sur sur
Iterations: 210
Iterations: 220
Iterations: 230
Iterations: 240
Iterations: 250
Iterations: 260
Iterations: 270
Iterations: 280
Iterations: 290
Iterations: 300
Saved model to: checkpoint.300
Last loss:
  d_loss = tensor(2.7569, device='cuda:0')
  tr_loss = tensor(24.1434, device='cuda:0')
[300] Validation:

 fr  -->  en
<<  une foule se rassemble pour ecouter un orchestre
==  crowd gathers to listen to a band
>>  <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>

 en  -->  fr
<<  crowd gathers to listen to a band
==  une foule se rassemble pour ecouter un orchestre
>>  une une une une une une une une une
Iterations: 310
Iterations: 320
```