# Finetuning BART

Relevant links:
- https://huggingface.co/mse30/bart-base-finetuned-pubmed
- https://arxiv.org/pdf/2210.09932.pdf

In [None]:
data_path = './data/'
model_path = './models'

# Specify the name of the resulting model
model_name = ''

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, GenerationConfig, DataCollatorForSeq2Seq, Seq2SeqTrainer
from transformers import EarlyStoppingCallback
from nltk import sent_tokenize
import numpy as np
from evaluate import load

from utils import remove_references, extract_abstract

import torch
import random
import transformers

## Load dataset

In [None]:
dataset = load_dataset("json", data_files={'train': data_path+'/eLife_train.jsonl', 'validation': data_path+'/eLife_val.jsonl'})

In [None]:
dataset

### Preprocessing articles

In [None]:
# Preprocessing for abstract as input

dataset = dataset.map(extract_abstract)

In [None]:
# Preprocessing for training with article without references

dataset = dataset.map(remove_references)

## Tokenization

In [None]:
max_input_length = 1024

In [None]:
model_checkpoint = "facebook/bart-base"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, model_max_length=max_input_length)

In [None]:
# The target column for the tokenization must be change depending on the preprocessing

def preprocess_function(examples):
    model_inputs = tokenizer(
        #examples["abstract"],
        examples["article_norefs"],
        max_length=max_input_length,
        truncation=True,
    )
    labels = tokenizer(
        examples["lay_summary"],
        truncation=True
    )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
tokenized_dataset = dataset.map(preprocess_function, batched=True)

In [None]:
# A collator is used to pad the labels, first we remove the columns with strings because the collator cannot pad these elements
tokenized_dataset = tokenized_dataset.remove_columns(
    dataset["train"].column_names
)

In [None]:
tokenized_dataset

In [None]:
del dataset

In [None]:
# Load metric for evaluation while training

metric = load("rouge")

In [None]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    
    # Decode generated summaries into text
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    
    # Replace -100 in the labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    
    # Decode reference summaries into text
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # ROUGE expects a newline after each sentence
    decoded_preds = ["\n".join(sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(sent_tokenize(label.strip())) for label in decoded_labels]
    
    # Compute ROUGE scores
    result = metric.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True, rouge_types=["rouge1", "rouge2"]
    )
    
    # Extract the median scores
    result = {key: value * 100 for key, value in result.items()}
    return {k: round(v, 4) for k, v in result.items()}

## Model training

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

In [None]:
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    transformers.set_seed(seed)
    
set_seed(42)

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

In [None]:
generation_config = GenerationConfig(
    max_new_tokens=600,
    decoder_start_token_id=model.config.decoder_start_token_id,
    bos_token_id=model.config.bos_token_id
)

In [None]:
batch_size = 8
num_train_epochs = 25

# Show the training loss with every epoch
logging_steps = len(tokenized_dataset["train"]) // batch_size

# Specify training arguments
args = Seq2SeqTrainingArguments(
    output_dir=f"{model_path}/{model_name}",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    seed=42,
    num_train_epochs=num_train_epochs,
    predict_with_generate=True,  # it will generate summaries during evaluation to compute ROUGE scores for each epoch
    logging_steps=logging_steps,
    generation_config=generation_config,
    metric_for_best_model='rouge2',
    load_best_model_at_end = True
)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
# Instantiate the trainer element

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

In [None]:
torch.cuda.empty_cache()

In [None]:
trainer.train()

In [None]:
trainer.evaluate(eval_dataset=tokenized_dataset["validation"])

In [None]:
directory = f"{model_path}/{model_name}-model"
trainer.save_model(directory)

# Saving model tokenizer
tokenizer.save_pretrained(directory)