# Translation Error Detection for Ubuntu - Faine-tuning for a STS task.
This notebook will be used to fine-tune the two models we want to compare :
- Original model : ``sentence-transformers/distiluse-base-multilingual-cased-v2`` using mean pooling to compute sentence embeddings,
- Generalized model : Based on the architecture of the previous one, it uses a MultiHead Generalized Pooling method that will be learnt while fine-tuning the model.

<br>

The code below is built using the following sources :
- Multidataset training documentation for sentence transformer models : https://sbert.net/docs/sentence_transformer/training_overview.html#multi-dataset-training
- The fine-tuning file for STS task : https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/sts/training_stsbenchmark.py


## Downloading and importing modules

In [1]:
import logging
import sys
import traceback
from datetime import datetime
from typing import List, Union
import torch
import os

# HuggingFace libraries
from datasets import load_dataset, DatasetDict, Dataset, IterableDatasetDict, IterableDataset

# Sentence transformer libraries
from sentence_transformers import SentenceTransformer, losses
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.similarity_functions import SimilarityFunction
from sentence_transformers.trainer import SentenceTransformerTrainer
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
from sentence_transformers.evaluation import SentenceEvaluator


# Pooling libraries
from sentence_pooling import GeneralizedSentenceTransformerMaker



## Loading the data

In [None]:
# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)

# Load the STSB dataset: https://huggingface.co/datasets/sentence-transformers/stsb

### DatasetDict to used for training, evaluating and testing
train_dataset = DatasetDict()
eval_dataset = DatasetDict()
test_dataset = DatasetDict()


## Multi STS Benchmark dataset (monolingual version)
for config in ['de', 'es', 'fr', 'it', 'nl', 'pl', 'pt', 'ru', 'zh'] :
    multi_stsb = load_dataset("PhilipMay/stsb_multi_mt", config)

    for split in multi_stsb.keys() :
        multi_stsb[split] = multi_stsb[split].rename_column("similarity_score", "score")
        multi_stsb[split] = multi_stsb[split].map(lambda x: {"score": x["score"] / 5})

    config = config.replace('-', '_')

    train_dataset[f"multi_stsb_{config}"] = multi_stsb["train"]
    eval_dataset[f"multi_stsb_{config}"] = multi_stsb["dev"]
    test_dataset[f"muti_stsb_{config}"] = multi_stsb["test"]

    logging.info(multi_stsb["train"])



## Annex relevant functions

In [3]:
def make_evaluators(dataset_dict: DatasetDict, split: str) -> list[SentenceEvaluator] :
    """Function that builds a list of evaluators, one for each evaluation/test dataset.

    Args:
        dataset_dict (DatasetDict): An evaluation or test DatasetDict gathering several datasets.
        split (str): the string to design the split, either test or evaluation split.

    Returns:
        list[EmbeddingSimilarityEvaluator]: The list of evaluators.
    """
    evaluators = []
    key = None
    try :
        for key in dataset_dict.keys() :
            dataset = dataset_dict[key]
            evaluator = EmbeddingSimilarityEvaluator(
                sentences1=dataset["sentence1"], # type: ignore
                sentences2=dataset["sentence2"], # type: ignore
                scores=dataset["score"], # type: ignore
                main_similarity=SimilarityFunction.COSINE,
                name=f"sts-{split}",
            )
            evaluators.append(evaluator)
    except Exception as e :
        print(f"An exception occured : {e}")
        print(f"Dataset : {key}")

    return evaluators


def setup_training(model: SentenceTransformer, output_dir, train_batch_size, num_epochs, step_nb, eval_dataset=eval_dataset) -> SentenceTransformerTrainer :
    """Setup the training of a given model.

    Args:
        model (SentenceTransformer): The model to train.

    Returns:
        SentenceTransformerTrainer: The trainer of the model.
    """

    # Define our training loss
    # CosineSimilarityLoss (https://sbert.net/docs/package_reference/sentence_transformer/losses.html#cosinesimilarityloss) needs two text columns and one
    # similarity score column (between 0 and 1)

    # Loss for the original model
    train_loss = {}
    for dataset in train_dataset :
        train_loss[dataset] = losses.CoSENTLoss(model=model)
    

    # Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss.
    dev_evaluators = make_evaluators(eval_dataset, 'eval')


    # Define the training arguments
    args = SentenceTransformerTrainingArguments(
        # Required parameter:
        output_dir=output_dir,
        # Optional training parameters:
        num_train_epochs=num_epochs,
        per_device_train_batch_size=train_batch_size,
        per_device_eval_batch_size=train_batch_size,
        warmup_ratio=0.1,
        fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
        bf16=False,  # Set to True if you have a GPU that supports BF16
        # Optional tracking/debugging parameters:
        eval_strategy="steps",
        eval_steps=step_nb,
        save_strategy="steps",
        save_steps=2*step_nb,
        save_total_limit=2,
        logging_steps=step_nb,
        run_name="sts",  # Will be used in W&B if `wandb` is installed
    )

    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        loss=train_loss,
        evaluator=dev_evaluators,
    )

    return trainer




## GPU connection

In [None]:
# Number of GPUs available
num_gpus = torch.cuda.device_count()
print(f"Number of GPUs: {num_gpus}")

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f"Using device : {device}")


## Local model initialization
After the first pre-training step, it is required to load the model that has been trained on the previous step. This is done here

In [None]:
model = SentenceTransformer('./output/training_original_multilingual_sentence_embedder_wmt_mlqe-2024-11-27_10-34-02/checkpoint-16725', device=device)
params = sum(p.numel() for p in model.parameters())
print(f"Cuurent model : {params} parameters")
print(model)

In [None]:
# You can specify any Sentence Transformer model here, for example all-mpnet-base-v2, all-MiniLM-L6-v2, mixedbread-ai/mxbai-embed-large-v1
model_name = "sentence-transformers/distiluse-base-multilingual-cased-v2"
output_dir = (
    "output/training_original_multilingual_sentence_embedder_wmt_mlqe_multists" + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
)

## Setting up the model training

In [None]:
# Training parameters (same as in the research paper)
train_batch_size = 16
num_epochs = 4

trainer = setup_training(model, train_batch_size=train_batch_size, output_dir=output_dir, num_epochs=num_epochs, step_nb=12_960)
print("Setup of the trainer for the model successful !")

## Model training

In [None]:
trainer.train()

test_evaluators = make_evaluators(test_dataset, 'test')

for test_evaluator in test_evaluators :
    test_evaluator(model)

# 8. Save the trained & evaluated model locally
final_output_dir = f"{output_dir}"
model.save(final_output_dir)

## Saving the models online

In [None]:
# 9. (Optional) save the model to the Hugging Face Hub!
# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
!huggingface-cli login

model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
try:
    model.push_to_hub(f"{model_name}-sts")
except Exception:
    logging.error(
        f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
        f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` "
        f"and saving it using `model.push_to_hub('{model_name}-sts')`.")