## Lab assignment 02

### Neural Machine Translation in the wild
In the third homework you are supposed to get the best translation you can for the EN-RU translation task.

Basic approach using RNNs as encoder and decoder is implemented for you. 

Your ultimate task is to use the techniques we've covered, e.g.

* Optimization enhancements (e.g. learning rate decay)

* Transformer/CNN/<whatever you select> encoder (with or without positional encoding)

* attention/self-attention mechanism

* pretraining the language models (for decoder and encoder)

* or just fine-tunning BART/ELECTRA/... ;)

to improve the translation quality. 

__Please use at least three different approaches/models and compare them (translation quality/complexity/training and evaluation time).__

Write down some summary on your experiments and illustrate it with convergence plots/metrics and your thoughts. Just like you would approach a real problem.

In [1]:
# Thanks to YSDA NLP course team for the data
# (who thanks tilda and deephack teams for the data in their turn)

import os
path_do_data = 'data.txt'
if not os.path.exists(path_do_data):
    print("Dataset not found locally. Downloading from github.")
    !wget https://raw.githubusercontent.com/neychev/made_nlp_course/master/datasets/Machine_translation_EN_RU/data.txt -nc

In [2]:
# Baseline solution BLEU score is quite low. Try to achieve at least __21__ BLEU on the test set. 
# The checkpoints are:

# * __21__ - minimal score to submit the homework, 30% of points

# * __25__ - good score, 70% of points

# * __27__ - excellent score, 100% of points

In [3]:
import numpy as np
import pandas as pd
import torch
import random
import matplotlib.pyplot as plt
import time

from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
from transformers.modeling_outputs import BaseModelOutput
from transformers import T5Model, T5Tokenizer, T5Config, T5ForConditionalGeneration, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoConfig
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from nltk.translate.bleu_score import corpus_bleu
from IPython.display import clear_output

import wandb

In [4]:
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mnaumenko-km[0m ([33mvector2text[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [5]:
with open('data.txt', 'r') as f:
    texts = f.read()

texts = texts.split(sep='\n')
texts = [row.split('\t') for row in texts]
texts_en = [row[0] for row in texts if len(row) == 2]
texts_ru = [row[1] for row in texts if len(row) == 2]

print('Num texts:', len(texts_en), len(texts_ru))
print('En max len:', max([len(row) for row in texts_en]))
print('Ru max len:', max([len(row) for row in texts_ru]))

Num texts: 50000 50000
En max len: 518
Ru max len: 431


In [6]:
DEVICE = torch.device('cuda:4' if torch.cuda.is_available() else 'cpu')
MAX_LEN = 518
DEVICE

device(type='cuda', index=4)

In [7]:
class TextDataset(Dataset):
    def __init__(self, texts_en, texts_ru):
        self.texts_en = texts_en
        self.texts_ru = texts_ru
        
    def __len__(self):
        return len(self.texts_en)
    
    def __getitem__(self, idx):
        return self.texts_en[idx], self.texts_ru[idx]

In [8]:
train_texts_en, val_texts_en, train_texts_ru, val_texts_ru = train_test_split(texts_en, texts_ru, test_size=0.05, random_state=42)
train_texts_en, test_texts_en, train_texts_ru, test_texts_ru = train_test_split(train_texts_en, train_texts_ru, test_size=0.05, random_state=42)

train_dataset = TextDataset(train_texts_en, train_texts_ru)
val_dataset = TextDataset(val_texts_en, val_texts_ru)
test_dataset = TextDataset(test_texts_en, test_texts_ru)

In [30]:
n_epochs = 12
batch_size = 16
log_each_n_iterations = 200
generate_n = 1

In [31]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

train_loader = DataLoader(train_dataset, batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size)
test_loader = DataLoader(test_dataset, batch_size)
generate_loader = DataLoader(val_dataset, generate_n, shuffle=True)

In [41]:
model_name = 'cointegrated/rut5-base-multitask'
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

config = T5Config(vocab_size=tokenizer.vocab_size, decoder_start_token_id=0)
model =  T5ForConditionalGeneration(config)
model.to(DEVICE);

In [42]:
LR = 1e-5
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.5)

In [43]:
def generate(en, max_length=MAX_LEN, repetition_penalty=None):
    model.eval()
    with torch.no_grad():
        x = tokenizer(en, padding=True, truncation=True, max_length=MAX_LEN, return_tensors='pt').to(DEVICE)
        out = model.generate(**x, max_length=MAX_LEN, repetition_penalty=None)
        generated = [tokenizer.decode(tokens, skip_special_tokens=True) for tokens in out]
        return generated
    

def calc_bleu(loader):
    original_text = []
    generated_text = []
    for en, ru in tqdm(loader):
        generated = generate(en)
        original_text.extend(ru)
        generated_text.extend(generated)
    return corpus_bleu([[text] for text in original_text], generated_text) * 100

In [44]:
wandb.init(
    # set the wandb project where this run will be logged
    project="nlp-lab2",
    notes="baseline",
    name='t5_1',
    entity='naumenko-km',
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": LR,
    "encoder": model_name,
    "decoder": model_name,
    "epochs": n_epochs,
    }
)
best_bleu = 0
val_bleu = 0
mean_val_loss = 10
iters = 1

for i in range(1, n_epochs + 1):
    print(f'[EPOCH {i}]')
    tqdm_iterator = tqdm(train_loader)
    for text_en_batch, text_ru_batch in tqdm_iterator:
        model.train()
        x = tokenizer(text_en_batch, return_tensors='pt', padding=True, truncation=True, max_length=MAX_LEN).to(DEVICE)
        y = tokenizer(text_ru_batch, return_tensors='pt', padding=True, truncation=True, max_length=MAX_LEN).to(DEVICE)

        y.input_ids[y.input_ids == 0] = -100  # не учитываем паддинг в лоссе
        loss = model(
            **x,
            labels=y.input_ids,
            decoder_attention_mask=y.attention_mask,
            return_dict=True
        ).loss
                
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        wandb.log({"batch loss": loss.item(), "val loss": mean_val_loss, 'val bleu': val_bleu}, step=iters)

        if iters % log_each_n_iterations == 0:
            en, ru = next(iter(generate_loader))
            generated = generate(en)
            example = wandb.Html(data=f'batches: {iters} <br> True: {ru[0]} <br> Generated: {generated[0]}')
            wandb.log({"texts": example}, step=iters)
        iters += 1
    tqdm_iterator = tqdm(test_loader)
    val_loss = []
    for text_en_batch, text_ru_batch in tqdm_iterator:
        model.eval()
        with torch.no_grad():
            x = tokenizer(text_en_batch, return_tensors='pt', padding=True, truncation=True, max_length=MAX_LEN).to(DEVICE)
            y = tokenizer(text_ru_batch, return_tensors='pt', padding=True, truncation=True, max_length=MAX_LEN).to(DEVICE)

            y.input_ids[y.input_ids == 0] = -100  # не учитываем паддинг в лоссе
            loss = model(
                **x,
                labels=y.input_ids,
                decoder_attention_mask=y.attention_mask,
                return_dict=True
            ).loss
            val_loss.append(loss.item())
    mean_val_loss = np.mean(val_loss)
    val_bleu = calc_bleu(val_loader)
    print("val loss:", mean_val_loss, 'val bleu:', val_bleu)
    scheduler.step()
    
    if val_bleu > best_bleu:
        best_bleu = val_bleu
        # torch.save(encoder.state_dict(), 'encoder.pt')
        # torch.save(decoder.state_dict(), 'decoder.pt')
        # print(f'best loss improved!')
wandb.finish()

[EPOCH 1]


100%|██████████| 2820/2820 [06:16<00:00,  7.49it/s] 
100%|██████████| 149/149 [00:04<00:00, 32.26it/s]
100%|██████████| 157/157 [05:47<00:00,  2.21s/it]


val loss: 4.887052372798023 val bleu: 14.367539879601225
[EPOCH 2]


100%|██████████| 2820/2820 [05:44<00:00,  8.18it/s]
100%|██████████| 149/149 [00:04<00:00, 34.81it/s]
100%|██████████| 157/157 [05:14<00:00,  2.00s/it]


val loss: 4.255457145255684 val bleu: 20.49577676135446
[EPOCH 3]


100%|██████████| 2820/2820 [05:54<00:00,  7.96it/s]
100%|██████████| 149/149 [00:04<00:00, 31.03it/s]
100%|██████████| 157/157 [07:13<00:00,  2.76s/it]


val loss: 3.885418815100753 val bleu: 19.259078858520105
[EPOCH 4]


100%|██████████| 2820/2820 [05:58<00:00,  7.86it/s]
100%|██████████| 149/149 [00:04<00:00, 35.12it/s]
100%|██████████| 157/157 [07:33<00:00,  2.89s/it]


val loss: 3.626132713868314 val bleu: 18.5123030085037
[EPOCH 5]


100%|██████████| 2820/2820 [06:02<00:00,  7.77it/s]
100%|██████████| 149/149 [00:04<00:00, 33.41it/s]
100%|██████████| 157/157 [05:33<00:00,  2.12s/it]


val loss: 3.487332924900439 val bleu: 22.921628029687852
[EPOCH 6]


100%|██████████| 2820/2820 [05:57<00:00,  7.89it/s]
100%|██████████| 149/149 [00:04<00:00, 35.23it/s]
100%|██████████| 157/157 [04:59<00:00,  1.91s/it]


val loss: 3.3929932181467146 val bleu: 26.601592643704784
[EPOCH 7]


100%|██████████| 2820/2820 [05:55<00:00,  7.94it/s]
100%|██████████| 149/149 [00:04<00:00, 34.41it/s]
100%|██████████| 157/157 [04:58<00:00,  1.90s/it]


val loss: 3.301370926351355 val bleu: 27.72791526248802
[EPOCH 8]


100%|██████████| 2820/2820 [05:59<00:00,  7.83it/s]
100%|██████████| 149/149 [00:04<00:00, 35.42it/s]
100%|██████████| 157/157 [04:28<00:00,  1.71s/it]


val loss: 3.2161433552735605 val bleu: 29.333407628048725
[EPOCH 9]


100%|██████████| 2820/2820 [05:56<00:00,  7.90it/s]
100%|██████████| 149/149 [00:04<00:00, 32.66it/s]
100%|██████████| 157/157 [04:15<00:00,  1.63s/it]


val loss: 3.179365705323699 val bleu: 32.409613299286875
[EPOCH 10]


100%|██████████| 2820/2820 [06:17<00:00,  7.48it/s]
100%|██████████| 149/149 [00:04<00:00, 35.01it/s]
100%|██████████| 157/157 [03:42<00:00,  1.41s/it]


val loss: 3.140476950863064 val bleu: 34.415834088591566
[EPOCH 11]


100%|██████████| 2820/2820 [05:57<00:00,  7.88it/s]
100%|██████████| 149/149 [00:04<00:00, 32.76it/s]
100%|██████████| 157/157 [03:47<00:00,  1.45s/it]


val loss: 3.109533921184156 val bleu: 34.92120809002799
[EPOCH 12]


100%|██████████| 2820/2820 [06:06<00:00,  7.69it/s]
100%|██████████| 149/149 [00:04<00:00, 35.06it/s]
100%|██████████| 157/157 [03:24<00:00,  1.30s/it]


val loss: 3.0746183811418164 val bleu: 37.33281026958067


0,1
batch loss,█▆▅▅▄▄▅▄▄▃▄▃▃▃▃▄▃▂▂▃▂▂▃▂▂▃▃▃▂▂▁▂▁▂▁▂▂▂▂▁
val bleu,▁▁▁▄▄▄▄▅▅▅▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇███████
val loss,███▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
batch loss,2.95841
val bleu,34.92121
val loss,3.10953


In [45]:
calc_bleu(test_loader)

100%|██████████| 149/149 [03:27<00:00,  1.39s/it]


36.78979854750574