## Задание 
1. Реализуйте задачу машинного перевода с использованием transformer.
Датасет: http://www.manythings.org/anki/

In [1]:
from io import open
import unicodedata
import string
import re
import random
import math
import torch
import torch as tr
import torch.nn as nn
import torch.nn.functional as F

In [2]:
dev = tr.device('cuda:0' if tr.cuda.is_available() else 'cpu')
device = dev
print(f"work on {(tr.cuda.get_device_name() if dev.type == 'cuda' else 'cpu')}")

work on GeForce GTX 1070


In [3]:
!wget https://www.manythings.org/anki/rus-eng.zip
!unzip rus-eng.zip

In [4]:
!tail rus.txt

In [6]:
SOS_token = 0
EOS_token = 1


class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS", 2:'MASK'}
        self.n_words = 3  # Count SOS and EOS and MASK

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

In [7]:
# Turn a Unicode string to plain ASCII, thanks to
# http://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn')

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Zа-яА-Я.!?]+", r" ", s)
    return s

In [8]:
def readLangs(lang1, lang2, reverse=False):
    print("Reading lines...")

    # Read the file and split into lines
    lines = open('rus.txt', encoding='utf-8').read().strip().split('\n')

    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split('\t')[:-1]] for l in lines]

    # Reverse pairs, make Lang instances
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else:
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)

    return input_lang, output_lang, pairs

In [9]:
MAX_LENGTH = 10

eng_prefixes = (
    "i am ", "i m ",
    "he is", "he s ",
    "she is", "she s",
    "you are", "you re ",
    "we are", "we re ",
    "they are", "they re "
)


def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH and \
        len(p[1].split(' ')) < MAX_LENGTH and \
        p[1].startswith(eng_prefixes)


def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

In [10]:
def prepareData(lang1, lang2, reverse=False):
    input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
    print("Read %s sentence pairs" % len(pairs))
    pairs = filterPairs(pairs)
    print("Trimmed to %s sentence pairs" % len(pairs))
    print("Counting words...")
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
    print("Counted words:")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)
    return input_lang, output_lang, pairs


# input_lang, output_lang, pairs = prepareData('eng', 'fra', True)
input_lang, output_lang, pairs = prepareData('eng', 'rus', True)

print(random.choice(pairs))

Reading lines...
Read 429117 sentence pairs
Trimmed to 25315 sentence pairs
Counting words...
Counted words:
rus 9531
eng 4112
['я иду к своеи бабушке .', 'i m going to my grandmother s .']


In [12]:
def indexesFromSentence(lang, sentence):
    return [lang.word2index[word] for word in sentence.split(' ')]

def tensorFromSentence(lang, sentence):
    indexes = indexesFromSentence(lang, sentence)
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)

def tensorsFromPair(pair):
    input_tensor = tensorFromSentence(input_lang, pair[0])
    target_tensor = tensorFromSentence(output_lang, pair[1])
    return (input_tensor, target_tensor)

In [19]:
def inds2text(inds):
    return ' '.join([output_lang.index2word[ind.item()] for ind in inds])

def inds2text2(inds):
    return ' '.join([input_lang.index2word[ind.item()] for ind in inds])


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [20]:
class Trans1(tr.nn.Module):
    def __init__(self, n_classes_in, n_classes_out, n_emb, n_layers = 2, activ = tr.nn.LeakyReLU):
        super().__init__()
        self.emb_srs = tr.nn.Embedding(n_classes_in, n_emb)
        self.emb_tgt = tr.nn.Embedding(n_classes_out, n_emb)
        self.trans = tr.nn.Transformer(n_emb, num_encoder_layers = n_layers, num_decoder_layers = n_layers)
        self.activ = activ()
        self.pos_encoder = PositionalEncoding(n_emb, 0.1)
        self.fc = tr.nn.Linear(n_emb, n_classes_out)
        self.fc2 = tr.nn.Linear(n_classes_out, n_classes_out)

    def forward(self, srs, tgt, srs_mask = None, tgt_mask = None):

        if srs_mask: 
            srs_mask = self.trans.generate_square_subsequent_mask(srs.shape[0]).to(dev)
        if tgt_mask: 
            tgt_mask = self.trans.generate_square_subsequent_mask(tgt.shape[0]).to(dev)
            
        srs = self.emb_srs(srs)
        tgt = self.emb_tgt(tgt)

        srs = self.pos_encoder(srs)
        tgt = self.pos_encoder(tgt)

        x = self.trans(srs, tgt, tgt_mask = tgt_mask)
        x = self.fc(x)

        x = self.activ(x)
        x = self.fc2(x)

        return x


In [24]:
training_pairs = [tensorsFromPair(x) for x in pairs]

batch = tr.zeros(2, len(training_pairs), 10, 1) + 1
for i, x in enumerate(training_pairs):
    batch[0, i, :x[0].shape[0]] = x[0]
    batch[1, i, :x[1].shape[0]] = x[1]

batch_size = 500
train = tr.utils.data.DataLoader(batch.permute(1,0,2,3), batch_size = batch_size, shuffle = True)


In [36]:
n_classes_in = input_lang.n_words
n_classes_out = output_lang.n_words
n_emb = 128
n_layers = 2

model1 = Trans1(n_classes_in, n_classes_out, n_emb, n_layers).to(dev)
optim = tr.optim.Adam(model1.parameters())
crit = tr.nn.CrossEntropyLoss()

In [None]:
n_iters = 1000
n_prints = 100
loss = tr.tensor([10])
loss2 = tr.tensor([0])

model1.train()
for iter in range(n_iters):
    for data in train:    
        input_tensor = data[:,0].squeeze().long().to(dev)
        target_tensor = data[:,1].squeeze().long().to(dev)
        target_tensor2 = target_tensor.clone()
        m = tr.rand(target_tensor2.shape).to(dev)
        mask = m > (loss.item() / 5)**0.5

        target_tensor2[mask] = 2
        optim.zero_grad() 
        pred = model1(input_tensor, target_tensor2,)  #  True, True
        loss = crit(pred.reshape(-1, n_classes_out), target_tensor.reshape(-1))
        mask_sum = mask.sum()
        if mask_sum > 0:
            loss2 = crit(pred[mask].squeeze(), target_tensor[mask]) * 0.2
            loss += loss2
        else:
            loss2 *= 0
        # print(f'iter {iter}, loss = {loss.item(), loss2.item(), mask_sum.item()}')

        loss.backward()
        optim.step()

    if iter % (n_iters / n_prints) == 0:
        n = tr.randint(0, data.shape[0], (1,)).item()
        pred_text = inds2text(pred[n].max(dim=-1).indices)
        tgt_text =  inds2text(target_tensor[n])
        tgt2_text = inds2text(target_tensor2[n])
        srs_text =  inds2text2(input_tensor[n])
        # print(f'iter {iter}, loss = {loss.item()}')
        print(f"""iter {iter}, loss = {loss.item()} - {loss2.item()}, masks {mask_sum.item()},
                loss per word = {(loss.item() - loss2.item()) / tr.tensor(target_tensor.shape).prod() * 10000} 
                loss per mask = {(loss2.to(dev) / mask_sum).item() * 10000}
                \n{srs_text}\n{tgt_text}\n{tgt2_text}\n{pred_text}\n""")

iter 0, loss = 1.8850069046020508 - 0.5785824656486511, masks 1269,
                loss per word = 4.147378921508789 
                loss per mask = 4.559357475955039
                
он никогда еще не был влюблен . EOS EOS EOS
he s never been in love before . EOS EOS
he MASK never been MASK MASK MASK MASK MASK MASK
he m never been . EOS EOS EOS EOS EOS

iter 10, loss = 1.759853720664978 - 0.5636107921600342, masks 1337,
                loss per word = 3.7975964546203613 
                loss per mask = 4.215488443151116
                
его устраивает его текущее положение . EOS EOS EOS EOS
he is content with his present state . EOS EOS
MASK MASK content with MASK present state . MASK EOS
i m content with . present state . EOS EOS

iter 20, loss = 1.6946673393249512 - 0.5563330054283142, masks 1289,
                loss per word = 3.6137595176696777 
                loss per mask = 4.316004633437842
                
ты пьян . EOS EOS EOS EOS EOS EOS EOS
you are drunk ! EOS EOS EOS E

iter 240, loss = 1.637103796005249 - 0.5287327170372009, masks 1320,
                loss per word = 3.5186378955841064 
                loss per mask = 4.0055508725345135
                
я рад что вы вернулись . EOS EOS EOS EOS
i m glad you got back . EOS EOS EOS
i MASK glad MASK MASK back . MASK MASK EOS
i m glad . . back . EOS EOS EOS

iter 250, loss = 1.6466577053070068 - 0.5192140936851501, masks 1368,
                loss per word = 3.579185962677002 
                loss per mask = 3.795424709096551
                
мы оба говорим правду . EOS EOS EOS EOS EOS
we re both telling the truth . EOS EOS EOS
we re MASK MASK MASK MASK . MASK EOS MASK
we re not . . EOS . EOS EOS EOS

iter 260, loss = 1.7395315170288086 - 0.5356320738792419, masks 1416,
                loss per word = 3.8219027519226074 
                loss per mask = 3.7827124469913542
                
жду твоего ответа . EOS EOS EOS EOS EOS EOS
i m waiting for your answer . EOS EOS EOS
MASK m MASK for your answer MASK

iter 470, loss = 1.7152987718582153 - 0.5445715188980103, masks 1354,
                loss per word = 3.7165944576263428 
                loss per mask = 4.02194622438401
                
ты такои упрямыи . EOS EOS EOS EOS EOS EOS
you re so stubborn . EOS EOS EOS EOS EOS
you MASK so stubborn MASK MASK EOS MASK MASK MASK
you m so stubborn . EOS EOS EOS EOS EOS

iter 480, loss = 1.6020193099975586 - 0.5191721320152283, masks 1314,
                loss per word = 3.4376096725463867 
                loss per mask = 3.9510815986432135
                
она не безденежная . EOS EOS EOS EOS EOS EOS
she s not penniless . EOS EOS EOS EOS EOS
she s MASK MASK MASK MASK EOS MASK EOS MASK
she s not . . EOS EOS EOS EOS EOS

iter 490, loss = 1.6931147575378418 - 0.5145402550697327, masks 1443,
                loss per word = 3.741506338119507 
                loss per mask = 3.5657675471156836
                
вы даже не пытаетесь . EOS EOS EOS EOS EOS
you aren t even trying . EOS EOS EOS EOS
you aren

iter 710, loss = 1.665461778640747 - 0.5410787463188171, masks 1309,
                loss per word = 3.569469451904297 
                loss per mask = 4.133527399972081
                
мы не собираемся ждать тома . EOS EOS EOS EOS
we re not going to wait for tom . EOS
we re MASK going MASK MASK for MASK . EOS
we re not going EOS EOS for EOS . EOS

iter 720, loss = 1.731485366821289 - 0.5400165319442749, masks 1390,
                loss per word = 3.782440662384033 
                loss per mask = 3.8850109558552504
                
она в безопасности . EOS EOS EOS EOS EOS EOS
she is out of danger . EOS EOS EOS EOS
MASK MASK MASK of danger . EOS MASK MASK EOS
i m not of danger . EOS EOS EOS EOS

iter 730, loss = 1.7048298120498657 - 0.5457501411437988, masks 1338,
                loss per word = 3.6796178817749023 
                loss per mask = 4.078850033693016
                
ты чувствительныи . EOS EOS EOS EOS EOS EOS EOS
you re sensitive . EOS EOS EOS EOS EOS EOS
you re MASK MA

iter 950, loss = 1.635324239730835 - 0.5179453492164612, masks 1359,
                loss per word = 3.547234535217285 
                loss per mask = 3.8112240144982934
                
я преподаватель тома по французскому . EOS EOS EOS EOS
i m tom s french teacher . EOS EOS EOS
MASK m MASK s MASK teacher . MASK EOS MASK
i m not s . teacher . EOS EOS EOS

iter 960, loss = 1.6549147367477417 - 0.5260317325592041, masks 1352,
                loss per word = 3.5837554931640625 
                loss per mask = 3.890767111442983
                
я вряд ли выиграю . EOS EOS EOS EOS EOS
i m not likely to win . EOS EOS EOS
i m not likely MASK win . MASK MASK MASK
i m not likely . win . EOS EOS EOS

iter 970, loss = 1.7186903953552246 - 0.5518717169761658, masks 1332,
                loss per word = 3.704185962677002 
                loss per mask = 4.14318114053458
                
я сеичас очень зол . EOS EOS EOS EOS EOS
i m very angry now . EOS EOS EOS EOS
MASK MASK MASK MASK MASK . EOS EO