In [None]:
import transformers
from datasets import load_dataset, load_metric
import os

In [None]:
dataset = load_dataset("csv", data_files="data/all_languages_40000.csv")

In [None]:
print(dataset)

## Dataset train/validation/test split

In [None]:
dataset_train_validation = dataset["train"].train_test_split(test_size=10000) 

dataset["train"] = dataset_train_validation["train"]
dataset["validation"] = dataset_train_validation["test"]

dataset

In [None]:
n_samples_train = len(dataset["train"])
n_samples_validation = len(dataset["validation"])
n_samples_total = n_samples_train + n_samples_validation

print(f"- Training set: {n_samples_train*100/n_samples_total:.2f}%")
print(f"- Validation set: {n_samples_validation*100/n_samples_total:.2f}%")

## Data preprocessing

In [None]:
import nltk
nltk.download('punkt')
import string
from transformers import AutoTokenizer

In [None]:
model_checkpoint = "google/flan-t5-base"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
prefix = "paraphrase: "

max_input_length = 128
max_target_length = 128

def clean_text(text):
  sentences = nltk.sent_tokenize(text.strip())
  sentences_cleaned = [s for sent in sentences for s in sent.split("\n")]
  sentences_cleaned_no_titles = [sent for sent in sentences_cleaned
                                 if len(sent) > 0 and
                                 sent[-1] in string.punctuation]
  text_cleaned = "\n".join(sentences_cleaned_no_titles)
  return text_cleaned

def preprocess_data(examples):
  texts_cleaned = [clean_text(text) for text in examples["original"]]
  inputs = [prefix + text for text in texts_cleaned]
  model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)


  with tokenizer.as_target_tokenizer():
    labels = tokenizer(examples["simplifications"], max_length=max_target_length, 
                       truncation=True)

  model_inputs["labels"] = labels["input_ids"]
  return model_inputs

In [None]:
tokenized_dataset = dataset.map(preprocess_data, batched=True)
tokenized_dataset

## Fine-tune T5

In [None]:
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

In [None]:
batch_size = 16
model_name = "best_model"
model_dir = f"models/{model_name}"
args = Seq2SeqTrainingArguments(
    model_dir,
    evaluation_strategy="steps", 
    eval_steps=500, 
    logging_strategy="steps", 
    logging_steps=500, 
    save_strategy="steps",
    save_steps=1000, 
    learning_rate=4e-5,
    per_device_train_batch_size=batch_size, 
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01, 
    save_total_limit=1, 
    num_train_epochs=8, 
    predict_with_generate=True, 
    fp16=False,
    load_best_model_at_end=True,
)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer) 

In [None]:

def model_init():
    return AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

trainer = Seq2SeqTrainer(
    model_init=model_init,
    args=args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

In [None]:
trainer.train()

In [None]:
trainer.save_model() 