# Simplification experiments
This code was used for finetuning all encoder-decoder LMs for the simplification task and their evaluation on this task.

## Part 1. Finetuning encoder-decoder LMs

In [None]:
# dataset load
train, dev = [], []
length_text = []

with open("./data/simplification/train.txt") as inf:
    for line in inf.read().split("</s>"):
        if len(line.strip().replace("<s>", "")) > 0:
            train.append(line.strip().replace("<s>", ""))
        
with open("./data/simplification/dev.txt") as inf:
    for line in inf.read().split("</s>"):
        if len(line.strip().replace("<s>", "")) > 0:
            dev.append(line.strip().replace("<s>", ""))

print(len(train), len(dev))
print(train[0])

In [None]:
import pandas as pd
source, corrected,  = [], []
for e, row in enumerate(train):
        origin, seq = row.split(" ==> ")
        corrected.append(seq)
        source.append(origin)        
data = {'text_origin': source, 'text_par': corrected}
data_train = pd.DataFrame.from_dict(data)
data_train.tail()

In [None]:
source, corrected,  = [], []
for e, row in enumerate(dev):
        origin, seq = row.split(" ==> ")
        corrected.append(seq)
        source.append(origin)       
data = {'text_origin': source, 'text_par': corrected}
data_val = pd.DataFrame.from_dict(data)
data_val.tail()

In [None]:
from datasets import Dataset
dataset_train = Dataset.from_pandas(data_train)
dataset_val = Dataset.from_pandas(data_val)
dataset_train

In [None]:
def add_eos_to_examples(example):
        example['input_text'] = 'Упрости: %s </s>' % (example['text_origin'])
        example['target_text'] = '%s </s>' % example['text_par']
        return example

def convert_to_features(example_batch):
    input_encodings = tokenizer.batch_encode_plus(
        example_batch['input_text'], pad_to_max_length=True, max_length=512
    )
    target_encodings = tokenizer.batch_encode_plus(
        example_batch['target_text'], pad_to_max_length=True, max_length=512
    )

    encodings = {
        'input_ids': input_encodings['input_ids'], 
        'attention_mask': input_encodings['attention_mask'],
        'target_ids': target_encodings['input_ids'],
        'target_attention_mask': target_encodings['attention_mask']
    }
    return encodings

add_eos_to_examples(dataset_train[10])

In [None]:
#!pip install transformers -U
#!pip install tokenizers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_name = "ai-forever/ruT5-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenized_dataset_train = dataset_train.map(add_eos_to_examples)
tokenized_dataset_train = tokenized_dataset_train.map(convert_to_features, batched=True)

tokenized_dataset_val = dataset_val.map(add_eos_to_examples)
tokenized_dataset_val = tokenized_dataset_val.map(convert_to_features, batched=True)

In [None]:
columns = ['input_ids', 'target_ids', 'attention_mask', 'target_attention_mask']
tokenized_dataset_train.set_format(type='torch', columns=columns)
tokenized_dataset_val.set_format(type='torch', columns=columns)

In [None]:
from transformers import DataCollator
from dataclasses import dataclass, field
from typing import Dict, List, Optional
import torch


class T2TDataCollator():
    def __call__(self, batch: List):
        """
        Take a list of samples from a Dataset and collate them into a batch.
        Returns: A dictionary of tensors
        """
        input_ids = torch.stack([example['input_ids'] for example in batch])
        labels = torch.stack([example['target_ids'] for example in batch])
        labels[labels[:, :] == 0] = -100
        attention_mask = torch.stack([example['attention_mask'] for example in batch])
        decoder_attention_mask = torch.stack([example['target_attention_mask'] for example in batch])
        
        return {
            'input_ids': input_ids, 
            'attention_mask': attention_mask,
            'labels': labels, 
            'decoder_attention_mask': decoder_attention_mask
        }

In [None]:
from transformers import TrainingArguments
from transformers import EarlyStoppingCallback

data_output = "./models/simplification/"+model_name +'_01_08'

training_args = TrainingArguments(
    data_output,
    num_train_epochs=10,
    overwrite_output_dir = 'True',
    evaluation_strategy="steps",
    eval_steps=500,
    logging_steps = 500,
    learning_rate = 1e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    seed=0,
    save_total_limit = 1,
    load_best_model_at_end=True,
    remove_unused_columns=False
)

In [None]:
from transformers import T5ForConditionalGeneration

model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

In [None]:
from transformers import Trainer

trainer = Trainer(
    model,
    training_args,
    train_dataset = tokenized_dataset_train,
    eval_dataset = tokenized_dataset_val,
    data_collator = T2TDataCollator(),
    tokenizer = tokenizer,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]
)
trainer.train()

In [None]:
trainer.save_model(data_output)

## Part 2. Encoder-decoder LMs evaluation 

In [None]:
from transformers import T5ForConditionalGeneration, T5Tokenizer
model_name = "./models/simplification/"+model_name +'_01_08'
mymodel = T5ForConditionalGeneration.from_pretrained(model_name)
mytokenizer = T5Tokenizer.from_pretrained(model_name)

In [None]:
# !pip install evaluate

from evaluate import load

sari = load("sari")
bertscore = load("bertscore")

In [None]:
import tqdm
import torch
import pandas as pd

def add_eos_to_examples(example):
    return 'Упрости: %s </s>' % (example.strip())

def convert_to_features(example):
    encodings = mytokenizer(
        example, 
#         pad_to_max_length=True,
        max_length=800,
        return_tensors="pt"
    )
    return encodings



In [None]:
import tqdm
do_sample = False
num_beams = 5
data_answers = pd.read_csv("./data/simplification/public_test_sents.csv", sep=",")
sources = []
answers = []
ans_for_q = []
for source, ans in zip(data_answers["INPUT:source"], data_answers["OUTPUT:output"]):
    if source not in sources:
        if ans_for_q:
            answers.append(ans_for_q)
        ans_for_q = []
        sources.append(source)
        ans_for_q.append(ans)
    else:
        ans_for_q.append(ans)
answers.append(ans_for_q)

path_to_file = "./results/simplification/simplification_"+model_name.replace('/','_')+'_public.txt'
with open(path_to_file, "w") as out:
    for text in tqdm.tqdm(sources):
        formatted_text = add_eos_to_examples(text)
        encodings = convert_to_features(formatted_text)
        input_ids, attention_mask = encodings.input_ids, encodings.attention_mask
        length = input_ids.detach().cpu().numpy().shape[1]
        with torch.no_grad():
            output = mymodel.generate(
                    input_ids=input_ids, 
                    do_sample = do_sample, 
                    max_length=2 * length + 10,
                    num_beams = num_beams)
            dec = [mytokenizer.decode(ids, skip_special_tokens=True) for ids in output]
            out.write(dec[0]+'\n')
        

with open(path_to_file) as inf:
    predictions = [i.strip().replace('\n','') for i in inf.readlines()]
    
print(len(answers), len(predictions), len(sources))
print(answers[20], predictions[20], sources[20])

In [None]:
import numpy as np
print(model_name)
print('public results')
results = bertscore.compute(predictions=predictions, references=sources, lang="ru")
print('BertScore', np.mean(results["f1"]))
results = sari.compute(predictions=predictions, sources=sources, references=answers)
print('Sari', np.mean(results["sari"]))

In [None]:
import tqdm
data_answers = pd.read_csv("./data/simplification/hidden_test_sents.csv", sep=",")
sources = []
answers = []
ans_for_q = []
for source, ans in zip(data_answers["INPUT:source"], data_answers["OUTPUT:output"]):
    if source not in sources:
        if ans_for_q:
            answers.append(ans_for_q)
        ans_for_q = []
        sources.append(source)
        ans_for_q.append(ans)
    else:
        ans_for_q.append(ans)
answers.append(ans_for_q)

path_to_file = "./results/simplification/simplification_"+model_name.replace('/','_')+'_hidden.txt'
with open(path_to_file, "w") as out:
    for text in tqdm.tqdm(sources):
        formatted_text = add_eos_to_examples(text)
        encodings = convert_to_features(formatted_text)
        input_ids, attention_mask = encodings.input_ids, encodings.attention_mask
        length = input_ids.detach().cpu().numpy().shape[1]
        with torch.no_grad():
            output = mymodel.generate(
                    input_ids=input_ids, 
                    do_sample = do_sample, 
                    max_length=2 * length + 10,
                    num_beams = num_beams)
            dec = [mytokenizer.decode(ids, skip_special_tokens=True) for ids in output]
            out.write(dec[0]+'\n')

with open(path_to_file) as inf:
    predictions = [i.strip().replace('\n','') for i in inf.readlines()]
    
print(len(answers), len(predictions), len(sources))
print(answers[20], predictions[20], sources[20])

In [None]:
import numpy as np
print(model_name)
print('private results')
results = bertscore.compute(predictions=predictions, references=sources, lang="ru")
print('BertScore', np.mean(results["f1"]))
results = sari.compute(predictions=predictions, sources=sources, references=answers)
print('Sari', np.mean(results["sari"]))