<a href="https://colab.research.google.com/github/VictoriaRe/HSE-DataScience/blob/master/Assignment7.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Assignment 7

Train a Transformer model for Machine Translation from Russian to English.  
Dataset: http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz   
Make all source and target text to lower case.  
Use following tokenization for english:  
```
import sentencepiece as spm

...
spm.SentencePieceTrainer.Train('--input=data/text.en --model_prefix=bpe_en --vocab_size=32000 --character_coverage=0.98 --model_type=bpe')

tok_en = spm.SentencePieceProcessor()
tok_en.load('bpe_en.model')

TGT = data.Field(
    fix_length=50,
    init_token='<s>',
    eos_token='</s>',
    lower=True,
    tokenize = lambda x: tok_en.encode_as_pieces(x),
    batch_first=True,
)

...
TGT.build_vocab(..., min_freq=5)
...

```
Score: corpus-bleu `nltk.translate.bleu_score.corpus_bleu`  
Use last 1000 sentences for model evalutation (test dataset).  
Use your target sequence tokenization for BLEU score.  
Use max_len=50 for sequence prediction.  


Hint: You may consider much smaller model, than shown in the example.  

Baselines:  
[4 point] BLEU = 0.05  
[6 point] BLEU = 0.10  
[9 point] BLEU = 0.15  

[1 point] Share weights between target embeddings and output dense layer. Notice, they have the same shape.


Readings:
1. BLUE score how to https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
1. Transformer code and comments http://nlp.seas.harvard.edu/2018/04/03/attention.html

In [33]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [34]:
!pip install tqdm



In [35]:
!pip install sentencepiece



In [36]:
!pip install transformers



In [37]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
from torchtext import datasets, data
import sentencepiece as spm
from tqdm import tqdm
from transformers import get_linear_schedule_with_warmup

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [38]:
# tokenize english 
with open('news-commentary-v13.ru-en.en') as f:
    with open('text.en', 'w') as out:
            out.write(f.read().lower())
        
spm.SentencePieceTrainer.Train('--input=text.en --model_prefix=bpe_en --vocab_size=32000 --character_coverage=0.98 --model_type=bpe')

True

In [39]:
# tokenize russian 
with open('news-commentary-v13.ru-en.ru') as f:
    with open('text.ru', 'w') as out:
            out.write(f.read().lower())
        
spm.SentencePieceTrainer.Train('--input=text.ru --model_prefix=bpe_ru --vocab_size=32000 --character_coverage=0.98 --model_type=bpe')

True

In [0]:
tok_ru = spm.SentencePieceProcessor()
tok_ru.load('bpe_ru.model')

tok_en = spm.SentencePieceProcessor()
tok_en.load('bpe_en.model')

SRC = data.Field(
    fix_length=50,
    init_token='<s>',
    eos_token='</s>',
    lower=True,
    tokenize = lambda x: tok_ru.encode_as_pieces(x),
    batch_first=True,
)

TGT = data.Field(
    fix_length=50,
    init_token='<s>',
    eos_token='</s>',
    lower=True,
    tokenize = lambda x: tok_en.encode_as_pieces(x),
    batch_first=True,
)

fields = (('src', SRC), ('tgt', TGT))

In [41]:
with open('text.ru') as f:
    src_snt = list(map(str.strip, f.readlines()))
    
with open('text.en') as f:
    tgt_snt = list(map(str.strip, f.readlines()))
    
examples = [data.Example.fromlist(x, fields) for x in tqdm(zip(src_snt, tgt_snt))]
test = data.Dataset(examples[-1000:], fields)
train, valid = data.Dataset(examples[:-1000], fields).split(0.9)

235159it [01:03, 3688.63it/s]


In [42]:
print('src: ' + " ".join(train.examples[100].src))
print('tgt: ' + " ".join(train.examples[100].tgt))

src: ▁до ▁сих ▁пор , ▁он ▁не ▁сделал ▁ничего ▁для ▁того , ▁чтобы ▁противостоять ▁право вому ▁ниги ли зму , ▁против ▁которого ▁он ▁выступил ▁в ▁прошлом .
tgt: ▁so ▁far , ▁he ▁has ▁done ▁nothing ▁to ▁counteract ▁the ▁legal ▁nihilism ▁against ▁which ▁he ▁himself ▁has ▁spoken .


In [43]:
len(train), len(valid), len(test)

(210743, 23416, 1000)

In [0]:
TGT.build_vocab(train, min_freq=5)
SRC.build_vocab(train, min_freq=5)

In [0]:
from transformer import make_model, Batch

    
class BucketIteratorWrapper(DataLoader):
    __initialized = False

    def __init__(self, iterator: data.Iterator):
        self.batch_size = iterator.batch_size
        self.num_workers = 1
        self.collate_fn = None
        self.pin_memory = False
        self.drop_last = False
        self.timeout = 0
        self.worker_init_fn = None
        self.sampler = iterator
        self.batch_sampler = iterator
        self.__initialized = True

    def __iter__(self):
        return map(
            lambda batch: Batch(batch.src, batch.tgt, pad=TGT.vocab.stoi['<pad>']),
            self.batch_sampler.__iter__()
        )

    def __len__(self):
        return len(self.batch_sampler)
    
class MyCriterion(nn.Module):
    def __init__(self, pad_idx):
        super(MyCriterion, self).__init__()
        self.pad_idx = pad_idx
        self.criterion = nn.CrossEntropyLoss(reduction='sum', ignore_index=pad_idx)
        
    def forward(self, x, target):
        x = x.contiguous().permute(0,2,1)
        ntokens = (target != self.pad_idx).data.sum()
        
        return self.criterion(x, target) / ntokens

In [0]:
torch.cuda.empty_cache()

batch_size = 64
num_epochs = 5

train_iter, valid_iter, test_iter = data.BucketIterator.splits((train, valid, test), 
                                              batch_sizes=(batch_size, batch_size, batch_size), 
                                  sort_key=lambda x: len(x.src),
                                  shuffle=True,
                                  device=DEVICE,
                                  sort_within_batch=False)
                                  
train_iter = BucketIteratorWrapper(train_iter)
valid_iter = BucketIteratorWrapper(valid_iter)
test_iter = BucketIteratorWrapper(test_iter)

model = make_model(len(SRC.vocab), len(TGT.vocab), N=3, 
               d_model=256, d_ff=512, h=8, dropout=0.1)
model = model.to(DEVICE)
criterion = MyCriterion(pad_idx=TGT.vocab.stoi["<pad>"])

# share weights
model.generator.weight = model.tgt_embed[0].lut.weight

In [0]:
class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
            (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))
        
def get_std_opt(model):
    return NoamOpt(model.src_embed[0].d_model, 2, 4000,
            torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

In [0]:
optimizer = NoamOpt(model.src_embed[0].d_model, 1, 2000,
        torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

In [57]:
def train_epoch(data_iter, model, criterion):
    total_loss = 0
    counter = 0
    for n, batch in enumerate(data_iter):
        model.zero_grad()
        out = model.forward(batch)
        loss = criterion(out, batch.tgt_y)
        loss.backward()
        optimizer.step()
        total_loss += loss
        counter += 1
        if n % 50 == 1:
            print("Epoch Step: %d Loss: %f" %
                    (n, loss))
    return total_loss / counter


def valid_epoch(data_iter, model, criterion):
    total_loss = 0
    counter = 0
    for i, batch in enumerate(data_iter):
        out = model.forward(batch)
        loss = criterion(out, batch.tgt_y)
        total_loss += loss
        counter += 1
        if n % 50 == 1:
            print("Epoch Step: %d Loss: %f" %
                    (n, loss))
    return total_loss / counter


for epoch in range(num_epochs):
    model.train()
    loss = train_epoch(train_iter, model, criterion)
    print('train', loss)
    
    model.eval()
    with torch.no_grad():
        loss = valid_epoch(valid_iter, model, criterion)
        print('valid', loss)

Epoch Step: 1 Loss: 10.253040
Epoch Step: 51 Loss: 9.861892
Epoch Step: 101 Loss: 9.071190
Epoch Step: 151 Loss: 8.034776
Epoch Step: 201 Loss: 7.123394
Epoch Step: 251 Loss: 6.847806
Epoch Step: 301 Loss: 6.702139
Epoch Step: 351 Loss: 6.525114
Epoch Step: 401 Loss: 6.374889
Epoch Step: 451 Loss: 6.223004
Epoch Step: 501 Loss: 6.056887
Epoch Step: 551 Loss: 6.006391
Epoch Step: 601 Loss: 5.988037
Epoch Step: 651 Loss: 5.748144
Epoch Step: 701 Loss: 5.808645
Epoch Step: 751 Loss: 5.571938
Epoch Step: 801 Loss: 5.489046
Epoch Step: 851 Loss: 5.516775
Epoch Step: 901 Loss: 5.259710
Epoch Step: 951 Loss: 5.413601
Epoch Step: 1001 Loss: 5.395199
Epoch Step: 1051 Loss: 5.386392
Epoch Step: 1101 Loss: 5.185882
Epoch Step: 1151 Loss: 5.174133
Epoch Step: 1201 Loss: 5.202023
Epoch Step: 1251 Loss: 4.970671
Epoch Step: 1301 Loss: 5.024116
Epoch Step: 1351 Loss: 5.101749
Epoch Step: 1401 Loss: 4.936704


KeyboardInterrupt: ignored

In [0]:
start_symbol_id = TGT.vocab.stoi["<s>"]

Реализация Beam Search - из дз1

In [0]:
def beam_search(model, src, src_mask, max_len=20, k=5, offset=0):
    memory = model.encode(src, src_mask)
    start_token = TGT.vocab.stoi["<s>"]
    end_token = TGT.vocab.stoi["</s>"]
    ys = torch.ones(1, 1).fill_(start_token).type_as(src.data)
    
    beam = [(ys, 0)]
    for i in range(max_len):
        candidates= []
        candidates_proba = []
        for snt, snt_proba in beam:
            if snt[0][-1] == end_token:
                candidates.append(snt)
                candidates_proba.append(snt_proba)
            else:
                proba = model.decode(memory, src_mask, snt,
                                     subsequent_mask(snt.size(1)).type_as(src.data))
                proba = proba[0][i]
                best_k = torch.argsort(-proba)[:k].tolist()
                proba = proba.tolist()
                for tok in best_k:
                    candidates.append(torch.cat([snt, torch.ones(1, 1).type_as(src.data).fill_(tok)], dim=1))
                    candidates_proba.append(snt_proba + np.log(proba[tok])) 
         
        best_candidates = np.argsort(-np.array(candidates_proba))[offset:k+offset]
        beam = [(candidates[j], candidates_proba[j]) for j in best_candidates]
    
    return beam

In [61]:
model.eval()
with torch.no_grad():
    for n, batch in enumerate(valid_iter):
        src = batch.src[:1]
        src_key_padding_mask = src != SRC.vocab.stoi["<pad>"]
        beam = beam_search(model, src, src_key_padding_mask, max_len = 20, k = 5, offset = 3)
        
        seq = []
        for i in range(1, src.size(1)):
            sym = SRC.vocab.itos[src[0, i]]
            if sym == "</s>": break
            seq.append(sym)
        seq = tok_ru.decode_pieces(seq)
        print("\nSource:", seq)
        
        print("Translation:")
        for pred, pred_proba in beam:                
            seq = []
            for i in range(1, pred.size(1)):
                sym = TGT.vocab.itos[pred[0, i]]
                if sym == "</s>": break
                seq.append(sym)
            seq = tok_en.decode_pieces(seq)
                
        seq = []
        for i in range(1, batch.trg.size(1)):
            sym = TGT.vocab.itos[batch.trg[0, i]]
            if sym == "</s>": break 
            seq.append(sym)
        seq = tok_en.decode_pieces(seq)
        print("Target:", seq)
         if i == 10:
            break




Source: выбор сербии
Translation:
israel
israel
israel
israel
israel
Target: serbia\u0027s choice

Source: мягкая сила оон
Translation:
japan
japan
japan
qaeda
japan
Target: the soft power of the united nations

Source: 3.
Translation:
only “2
only “2
only “2)
only “2”
only “2
Target: 3.

Source: старший брат  ⁇ 
Translation:
“the "
“the "
“the "
“the "
“the "
Target: big brother google?

Source: лига демократических государств?
Translation:
zens
zens
zens
question?
zens
Target: a league of democracies?

Source: на выходе из демократии?
Translation:
for equestion??
for equestion?
for equestion?
for example?
for equestion?
Target: exiting from democracy?

Source: сравнение может показаться чрезмерным.
Translation:
it will be equity
it will be eze.
it will be expectations
it will be equity
it will be equity
Target: the comparison may seem over the top.

Source: выход из положения в ираке
Translation:
questions
questions
questions
questions
questions
Target: the way out of iraq

Source: 

In [0]:
from nltk.translate.bleu_score import corpus_bleu

In [63]:
hypotheses = []
references = []

model.eval()
with torch.no_grad():
    for batch in tqdm(test_iter):
        for src, tgt in zip(batch.src, batch.tgt):
          src = src.unsqueeze(0)
          src_mask = (src != SRC.vocab.stoi["<pad>"]).unsqueeze(-2)
          out = beam_search(model, src, src_mask, 
                              max_len=40, k=5)
          
          tgt = tgt.unsqueeze(0)
          for (prob, transl), gold in zip(out, tgt):
              hyp = []
              for i in range(1, transl.size(-1)):
                  sym = TGT.vocab.itos[transl[0, i]]
                  if sym == "</s>": 
                    break
                  hyp.append(sym)
              hypotheses.append(hyp)
              
              ref = []
              for i in range(1, gold.size(-1)):
                  sym = TGT.vocab.itos[gold.data[i]]
                  if sym == "</s>": 
                    break
                  ref.append(sym)
              references.append([ref])

100%|██████████| 16/16 [43:32<00:00, 163.28s/it]


In [0]:
from nltk import translate

In [65]:
corpus_bleu(references, hypotheses, 
            smoothing_function=translate.bleu_score.SmoothingFunction().method3,
            auto_reweigh=True
           )

0.017830902426596573