## Lab assignment 02

### Neural Machine Translation in the wild
In the third homework you are supposed to get the best translation you can for the EN-RU translation task.

Basic approach using RNNs as encoder and decoder is implemented for you. 

Your ultimate task is to use the techniques we've covered, e.g.

* Optimization enhancements (e.g. learning rate decay)

* Transformer/CNN/<whatever you select> encoder (with or without positional encoding)

* attention/self-attention mechanism

* pretraining the language models (for decoder and encoder)

* or just fine-tunning BART/ELECTRA/... ;)

to improve the translation quality. 

__Please use at least three different approaches/models and compare them (translation quality/complexity/training and evaluation time).__

Write down some summary on your experiments and illustrate it with convergence plots/metrics and your thoughts. Just like you would approach a real problem.

In [1]:
# Thanks to YSDA NLP course team for the data
# (who thanks tilda and deephack teams for the data in their turn)

import os
path_do_data = 'data.txt'
if not os.path.exists(path_do_data):
    print("Dataset not found locally. Downloading from github.")
    !wget https://raw.githubusercontent.com/neychev/made_nlp_course/master/datasets/Machine_translation_EN_RU/data.txt -nc

In [2]:
# Baseline solution BLEU score is quite low. Try to achieve at least __21__ BLEU on the test set. 
# The checkpoints are:

# * __21__ - minimal score to submit the homework, 30% of points

# * __25__ - good score, 70% of points

# * __27__ - excellent score, 100% of points

In [3]:
import numpy as np
import pandas as pd
import torch
import random
import matplotlib.pyplot as plt
import time

from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
from transformers.modeling_outputs import BaseModelOutput
from transformers import T5Model, T5Tokenizer, T5Config, T5ForConditionalGeneration, AutoModelForSeq2SeqLM, AutoModelForCausalLM
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from nltk.translate.bleu_score import corpus_bleu
from IPython.display import clear_output

import wandb

In [4]:
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mnaumenko-km[0m ([33mvector2text[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [6]:
with open('data.txt', 'r') as f:
    texts = f.read()

texts = texts.split(sep='\n')
texts = [row.split('\t') for row in texts]
texts_en = [row[0] for row in texts if len(row) == 2]
texts_ru = [row[1] for row in texts if len(row) == 2]

print('Num texts:', len(texts_en), len(texts_ru))
print('En max len:', max([len(row) for row in texts_en]))
print('Ru max len:', max([len(row) for row in texts_ru]))

Num texts: 50000 50000
En max len: 518
Ru max len: 431


In [7]:
DEVICE = torch.device('cuda:4' if torch.cuda.is_available() else 'cpu')
MAX_LEN = 518
DEVICE

device(type='cuda', index=4)

In [8]:
class TextDataset(Dataset):
    def __init__(self, texts_en, texts_ru):
        self.texts_en = texts_en
        self.texts_ru = texts_ru
        
    def __len__(self):
        return len(self.texts_en)
    
    def __getitem__(self, idx):
        return self.texts_en[idx], self.texts_ru[idx]

In [9]:
train_texts_en, val_texts_en, train_texts_ru, val_texts_ru = train_test_split(texts_en, texts_ru, test_size=0.05, random_state=42)
train_texts_en, test_texts_en, train_texts_ru, test_texts_ru = train_test_split(train_texts_en, train_texts_ru, test_size=0.05, random_state=42)

train_dataset = TextDataset(train_texts_en, train_texts_ru)
val_dataset = TextDataset(val_texts_en, val_texts_ru)
test_dataset = TextDataset(test_texts_en, test_texts_ru)

In [10]:
n_epochs = 10
batch_size = 16
log_each_n_iterations = 200
generate_n = 1

In [11]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

train_loader = DataLoader(train_dataset, batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size)
test_loader = DataLoader(test_dataset, batch_size)
generate_loader = DataLoader(val_dataset, generate_n, shuffle=True)


enc_name = 'distilbert-base-multilingual-cased'
dec_name = 't5-small'
# dec_name = "cointegrated/rut5-base-multitask"

enc_tokenizer = AutoTokenizer.from_pretrained(enc_name)
encoder = AutoModel.from_pretrained(enc_name).to(DEVICE)

dec_tokenizer = AutoTokenizer.from_pretrained(dec_name)
decoder = AutoModelForSeq2SeqLM.from_pretrained(dec_name).to(DEVICE)
# dec_tokenizer = T5Tokenizer.from_pretrained("cointegrated/rut5-base-multitask")
config = T5Config(vocab_size=dec_tokenizer.vocab_size, d_model=encoder.config.dim, decoder_start_token_id=0)
decoder = T5ForConditionalGeneration(config).to(DEVICE)

for p in decoder.encoder.parameters():
    p.requires_grad = False
for p in decoder.decoder.parameters():
    p.requires_grad = True


LR = 1e-5
optimizer = torch.optim.AdamW(list(encoder.parameters()) + list(decoder.parameters()), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)

Some weights of the model checkpoint at distilbert-base-multilingual-cased were not used when initializing DistilBertModel: ['vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [12]:
def encode(texts):
    encoded_input = enc_tokenizer(texts, padding=True, truncation=True, max_length=MAX_LEN, return_tensors='pt')
    with torch.no_grad():
        model_output = encoder(**encoded_input.to(encoder.device))
        embeddings = model_output.last_hidden_state
    return embeddings


def decode(embeddings, max_length=MAX_LEN, repetition_penalty=3.0, **kwargs):
    with torch.no_grad():
        out = decoder.generate(
            encoder_outputs=BaseModelOutput(last_hidden_state=embeddings), 
            max_length=max_length, 
            repetition_penalty=repetition_penalty,
            **kwargs
        )
        return [dec_tokenizer.decode(tokens, skip_special_tokens=True) for tokens in out]
    

def calc_bleu(loader):
    original_text = []
    generated_text = []
    encoder.eval()
    decoder.eval()

    for en, ru in tqdm(loader):
        embeds = encode(en)
        generated = decode(embeds, max_length=MAX_LEN, repetition_penalty=None)
        
        original_text.extend(ru)
        generated_text.extend(generated)

    return corpus_bleu([[text] for text in original_text], generated_text) * 100

In [14]:
wandb.init(
    # set the wandb project where this run will be logged
    project="nlp-lab2",
    notes="baseline",
    name='experiment_4',
    entity='naumenko-km',
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": LR,
    "encoder": enc_name,
    "decoder": dec_name,
    "epochs": n_epochs,
    }
)
best_bleu = 0
val_bleu = 0
mean_val_loss = 10
iters = 1

for i in range(1, n_epochs + 1):
    print(f'[EPOCH {i}]')
    tqdm_iterator = tqdm(train_loader)
    for text_en_batch, text_ru_batch in tqdm_iterator:
        encoder.train()
        decoder.train()
        x = enc_tokenizer(text_en_batch, return_tensors='pt', padding=True, truncation=True, max_length=MAX_LEN).to(DEVICE)
        y = dec_tokenizer(text_ru_batch, return_tensors='pt', padding=True, truncation=True, max_length=MAX_LEN).to(DEVICE)

        y.input_ids[y.input_ids == 0] = -100  # не учитываем паддинг в лоссе
        embeds = encoder(**x.to(encoder.device))
        embeds = embeds.last_hidden_state.to(DEVICE)

        loss = decoder(
            encoder_outputs=BaseModelOutput(last_hidden_state=embeds),
            labels=y.input_ids,
            decoder_attention_mask=y.attention_mask,
            return_dict=True
        ).loss
                
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        wandb.log({"batch loss": loss.item(), "val loss": mean_val_loss, 'val bleu': val_bleu}, step=iters)

        if iters % log_each_n_iterations == 0:
            encoder.eval()
            decoder.eval()    
            en, ru = next(iter(generate_loader))
            embeds = encode(en)
            generated = decode(embeds, max_length=MAX_LEN, repetition_penalty=None)
            example = wandb.Html(data=f'batches: {iters} <br> True: {ru[0]} <br> Generated: {generated[0]}')
            wandb.log({"texts": example}, step=iters)
        iters += 1

    tqdm_iterator = tqdm(test_loader)
    val_loss = []
    for text_en_batch, text_ru_batch in tqdm_iterator:
        encoder.eval()
        decoder.eval()
        with torch.no_grad():
            x = enc_tokenizer(text_en_batch, return_tensors='pt', padding=True, truncation=True, max_length=MAX_LEN).to(DEVICE)
            y = dec_tokenizer(text_ru_batch, return_tensors='pt', padding=True, truncation=True, max_length=MAX_LEN).to(DEVICE)

            y.input_ids[y.input_ids == 0] = -100  # не учитываем паддинг в лоссе
            embeds = encoder(**x.to(encoder.device))
            embeds = embeds.last_hidden_state.to(DEVICE)

            loss = decoder(
                encoder_outputs=BaseModelOutput(last_hidden_state=embeds),
                labels=y.input_ids,
                decoder_attention_mask=y.attention_mask,
                return_dict=True
            ).loss
            val_loss.append(loss.item())
    
    mean_val_loss = np.mean(val_loss)
    val_bleu = calc_bleu(val_loader)
    print("val loss:", mean_val_loss, 'val bleu:', val_bleu)
    scheduler.step()
    
    if val_bleu > best_bleu:
        best_bleu = val_bleu
        # torch.save(encoder.state_dict(), 'encoder.pt')
        # torch.save(decoder.state_dict(), 'decoder.pt')
        # print(f'best loss improved!')

wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mnaumenko-km[0m. Use [1m`wandb login --relogin`[0m to force relogin


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[EPOCH 1]


100%|██████████| 2820/2820 [12:04<00:00,  3.89it/s] 
100%|██████████| 149/149 [00:10<00:00, 14.44it/s]
100%|██████████| 157/157 [12:47<00:00,  4.89s/it]


val loss: 2.481465419666879 val bleu: 7.520443114850823
[EPOCH 2]


100%|██████████| 2820/2820 [11:34<00:00,  4.06it/s] 
100%|██████████| 149/149 [00:10<00:00, 14.56it/s]
100%|██████████| 157/157 [12:09<00:00,  4.65s/it]


val loss: 2.0935654952222067 val bleu: 14.446332161079756
[EPOCH 3]


100%|██████████| 2820/2820 [11:38<00:00,  4.04it/s] 
100%|██████████| 149/149 [00:10<00:00, 14.40it/s]
100%|██████████| 157/157 [10:49<00:00,  4.14s/it]


val loss: 1.9358698561687597 val bleu: 18.0393771522862
[EPOCH 4]


100%|██████████| 2820/2820 [11:39<00:00,  4.03it/s] 
100%|██████████| 149/149 [00:10<00:00, 14.43it/s]
100%|██████████| 157/157 [10:46<00:00,  4.12s/it]


val loss: 1.8647736174948264 val bleu: 18.883952113572956
[EPOCH 5]


100%|██████████| 2820/2820 [11:31<00:00,  4.08it/s]
100%|██████████| 149/149 [00:10<00:00, 14.44it/s]
100%|██████████| 157/157 [09:55<00:00,  3.79s/it]


val loss: 1.8324492169706614 val bleu: 20.688191561119112
[EPOCH 6]


100%|██████████| 2820/2820 [11:33<00:00,  4.06it/s]
100%|██████████| 149/149 [00:10<00:00, 14.45it/s]
100%|██████████| 157/157 [09:22<00:00,  3.58s/it]


val loss: 1.8181486813814047 val bleu: 20.909259701147707
[EPOCH 7]


100%|██████████| 2820/2820 [11:32<00:00,  4.07it/s]
100%|██████████| 149/149 [00:10<00:00, 14.54it/s]
100%|██████████| 157/157 [09:11<00:00,  3.51s/it]


val loss: 1.8099692319863594 val bleu: 21.286235918792013
[EPOCH 8]


100%|██████████| 2820/2820 [11:34<00:00,  4.06it/s]
100%|██████████| 149/149 [00:10<00:00, 14.43it/s]
100%|██████████| 157/157 [09:12<00:00,  3.52s/it]


val loss: 1.8059835601973053 val bleu: 21.399760867574617
[EPOCH 9]


100%|██████████| 2820/2820 [11:30<00:00,  4.08it/s]
100%|██████████| 149/149 [00:10<00:00, 14.57it/s]
100%|██████████| 157/157 [09:14<00:00,  3.53s/it]


val loss: 1.805214831893076 val bleu: 21.448090742136188
[EPOCH 10]


100%|██████████| 2820/2820 [11:34<00:00,  4.06it/s]
100%|██████████| 149/149 [00:10<00:00, 14.48it/s]
100%|██████████| 157/157 [09:17<00:00,  3.55s/it]


val loss: 1.8045732034932846 val bleu: 21.423642001663367


0,1
batch loss,█▆▅▃▃▃▃▂▂▃▃▂▃▂▂▃▂▂▂▂▂▃▃▂▂▂▂▁▂▂▂▂▁▂▂▂▁▃▃▂
val bleu,▁▁▁▁▃▃▃▃▆▆▆▆▇▇▇▇▇▇▇▇████████████████████
val loss,████▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
batch loss,1.93026
val bleu,21.44809
val loss,1.80521


In [15]:
calc_bleu(test_loader)

100%|██████████| 149/149 [08:32<00:00,  3.44s/it]


21.265918481687468