In [1]:
!pip install sentence-transformers datasets transformers[torch]

[0m

In [2]:
import logging
import traceback
from datetime import datetime

import numpy as np
from datasets import DatasetDict, load_dataset

from sentence_transformers import LoggingHandler, SentenceTransformer
from sentence_transformers.evaluation import (
    EmbeddingSimilarityEvaluator,
    MSEEvaluator,
    SequentialEvaluator,
    TranslationEvaluator,
)
from sentence_transformers.losses import MSELoss
from sentence_transformers.trainer import SentenceTransformerTrainer
from sentence_transformers.training_args import SentenceTransformerTrainingArguments

logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
)
logger = logging.getLogger(__name__)

In [3]:
# The teacher model is monolingual, we use it for English embeddings
teacher_model_name = "sentence-transformers/multi-qa-distilbert-dot-v1"
# The student model is multilingual, we train it such that embeddings of non-English texts mimic the teacher model's English embeddings
student_model_name = "distilbert/distilbert-base-multilingual-cased"

student_max_seq_length = 128  # Student model max. lengths for inputs (number of word pieces)
train_batch_size = 64  # Batch size for training
inference_batch_size = 64  # Batch size at inference
max_sentences_per_language = 1000000  # Maximum number of parallel sentences for training

num_train_epochs = 5  # Train for x epochs
num_evaluation_steps = 5000  # Evaluate performance after every xxxx steps

# Define the language codes you would like to extend the model to
source_languages = set(["en"])  # Our teacher model accepts English (en) sentences
target_languages = set(["id"])  # Extend the model to Indonesian (id)

output_dir = (
    "output/make-multilingual-"
    + "-".join(sorted(list(source_languages)) + sorted(list(target_languages)))
    + "-"
    + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
)

In [4]:
# 1a. Here we define our SentenceTransformer teacher model.
teacher_model = SentenceTransformer(teacher_model_name)
# If we want, we can limit the maximum sequence length for the model
logging.info(f"Teacher model: {teacher_model}")

# 1b. Here we define our SentenceTransformer student model. If not already a Sentence Transformer model,
# it will automatically create one with "mean" pooling.
student_model = SentenceTransformer(student_model_name)
student_model.max_seq_length = student_max_seq_length
logging.info(f"Student model: {student_model}")


2024-12-05 08:17:57 - Use pytorch device_name: cuda
2024-12-05 08:17:57 - Load pretrained SentenceTransformer: sentence-transformers/multi-qa-distilbert-dot-v1
2024-12-05 08:17:59 - Teacher model: SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: DistilBertModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
)
2024-12-05 08:17:59 - Use pytorch device_name: cuda
2024-12-05 08:17:59 - Load pretrained SentenceTransformer: distilbert/distilbert-base-multilingual-cased
2024-12-05 08:18:00 - No sentence-transformers model found with name distilbert/distilbert-base-multilingual-cased. Creating a new one with mean pooling.
2024-12-05 08:18:00 - Student model: SentenceTransformer(
  (0): 

In [5]:
# 2. Load the parallel sentences training dataset
dataset_to_use = "carles-undergrad-thesis/en-id-parallel-sentences"
train_dataset_dict = DatasetDict()
eval_dataset_dict = DatasetDict()
subset = "default"

# Load the training dataset
train_dataset = load_dataset(dataset_to_use, split="train")

# If the training dataset is too large, select only a portion of it
if len(train_dataset) > max_sentences_per_language:
    train_dataset = train_dataset.select(range(max_sentences_per_language))

# Split the training dataset into train and evaluation sets
# Set aside 1% of the training dataset for evaluation
split_dataset = train_dataset.train_test_split(test_size=0.01, shuffle=True)

# Assign the split datasets
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]

# Add the datasets to the DatasetDicts
train_dataset_dict[subset] = train_dataset
eval_dataset_dict[subset] = eval_dataset

logging.info(train_dataset_dict)


2024-12-05 08:18:03 - DatasetDict({
    default: Dataset({
        features: ['text_en', 'text_id'],
        num_rows: 990000
    })
})


In [6]:
# We want the student EN embeddings to be similar to the teacher EN embeddings and
# the student ID embeddings to be similar to the teacher EN embeddings
def prepare_dataset(batch):
    return {
        "english": batch["text_en"],
        "indonesian": batch["text_id"],
        "label": teacher_model.encode(batch["text_en"], batch_size=inference_batch_size, show_progress_bar=False),
    }

column_names = list(train_dataset_dict.values())[0].column_names
train_dataset_dict = train_dataset_dict.map(
    prepare_dataset, batched=True, batch_size=30000, remove_columns=column_names
)
logging.info("Prepared datasets for training:", train_dataset_dict)

Map:   0%|          | 0/990000 [00:00<?, ? examples/s]

2024-12-05 08:22:41 - Prepared datasets for training:


In [7]:
# 3. Define our training loss
train_loss = MSELoss(model=student_model)

# 4. Define evaluators for use during training.
evaluators = []
for subset, eval_dataset in eval_dataset_dict.items():
    logger.info(f"Creating evaluators for {subset}")

    # Mean Squared Error (MSE) measures the (euclidean) distance between teacher and student embeddings
    dev_mse = MSEEvaluator(
        source_sentences=eval_dataset["text_en"],
        target_sentences=eval_dataset["text_id"],
        name=subset,
        teacher_model=teacher_model,
        batch_size=inference_batch_size,
    )
    evaluators.append(dev_mse)

    # TranslationEvaluator checks if the embedding of source[i] is the closest to target[i] out of all available target sentences
    dev_trans_acc = TranslationEvaluator(
        source_sentences=eval_dataset["text_en"],
        target_sentences=eval_dataset["text_id"],
        name=subset,
        batch_size=inference_batch_size,
    )
    evaluators.append(dev_trans_acc)

# Combined Evaluator
evaluator = SequentialEvaluator(evaluators, main_score_function=lambda scores: np.mean(scores))
# Prepare the evaluation datasets for training
eval_dataset_dict = eval_dataset_dict.map(prepare_dataset, batched=True, batch_size=30000, remove_columns=column_names)

2024-12-05 08:22:41 - Creating evaluators for default


Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [None]:
from transformers import TrainerCallback, TrainingArguments, Trainer

# 5. Define the training arguments
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=output_dir,
    # Optional training parameters:
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=train_batch_size,
    per_device_eval_batch_size=train_batch_size,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    learning_rate=2e-5,
    # Optional tracking/debugging parameters:
    eval_strategy="epoch",  # Evaluate at the end of every epoch
    save_strategy="epoch",  # Save at the end of every epoch
    save_total_limit=2,
    logging_steps=100,
    run_name=f"multilingual-{'-'.join(source_languages)}-{'-'.join(target_languages)}",
)

# 6. Create the trainer
trainer = SentenceTransformerTrainer(
    model=student_model,
    args=args,
    train_dataset=train_dataset_dict,
    eval_dataset=eval_dataset_dict,
    loss=train_loss,
    evaluator=evaluator,
)

# Define a custom callback to ensure validation loss is logged
class ValidationLoggerCallback(TrainerCallback):
    def on_evaluate(self, args, state, control, **kwargs):
        validation_loss = kwargs['metrics'].get('eval_loss')
        if validation_loss is not None:
            print(f"Validation Loss at step {state.global_step}: {validation_loss}")

# Add the callback to the trainer
trainer.add_callback(ValidationLoggerCallback())

# 7. Start training
trainer.train()


Epoch,Training Loss,Validation Loss,Default Loss,Default Negative Mse,Default Src2trg Accuracy,Default Trg2src Accuracy,Default Mean Accuracy,Sequential Score
1,0.0485,No log,0.047881,-4.959352,0.9737,0.965,0.96935,-1.995001
2,0.0405,No log,0.039383,-4.152809,0.9863,0.9807,0.9835,-1.584655
3,0.0366,No log,0.035675,-3.799809,0.9893,0.9846,0.98695,-1.40643
4,0.0347,No log,0.033866,-3.628373,0.9895,0.9857,0.9876,-1.320386
5,0.0342,No log,0.033113,-3.555394,0.9894,0.9861,0.98775,-1.283822


2024-12-05 08:51:54 - MSE evaluation (lower = better) on the default dataset:
2024-12-05 08:51:54 - MSE (*100):	4.959352
2024-12-05 08:51:54 - Evaluating translation matching Accuracy of the model on the default dataset:
2024-12-05 08:51:56 - Accuracy src2trg: 97.37
2024-12-05 08:51:56 - Accuracy trg2src: 96.50
2024-12-05 08:51:56 - Saving model checkpoint to output/make-multilingual-en-id-2024-12-05_08-17-57/checkpoint-15469
2024-12-05 08:51:56 - Save model to output/make-multilingual-en-id-2024-12-05_08-17-57/checkpoint-15469


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

2024-12-05 09:20:57 - MSE evaluation (lower = better) on the default dataset:
2024-12-05 09:20:57 - MSE (*100):	4.152809
2024-12-05 09:20:57 - Evaluating translation matching Accuracy of the model on the default dataset:
2024-12-05 09:20:59 - Accuracy src2trg: 98.63
2024-12-05 09:20:59 - Accuracy trg2src: 98.07
2024-12-05 09:20:59 - Saving model checkpoint to output/make-multilingual-en-id-2024-12-05_08-17-57/checkpoint-30938
2024-12-05 09:20:59 - Save model to output/make-multilingual-en-id-2024-12-05_08-17-57/checkpoint-30938
2024-12-05 09:50:01 - MSE evaluation (lower = better) on the default dataset:
2024-12-05 09:50:01 - MSE (*100):	3.799809
2024-12-05 09:50:01 - Evaluating translation matching Accuracy of the model on the default dataset:
2024-12-05 09:50:03 - Accuracy src2trg: 98.93
2024-12-05 09:50:03 - Accuracy trg2src: 98.46
2024-12-05 09:50:03 - Saving model checkpoint to output/make-multilingual-en-id-2024-12-05_08-17-57/checkpoint-46407
2024-12-05 09:50:03 - Save model to 

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



2024-12-05 10:19:04 - MSE evaluation (lower = better) on the default dataset:
2024-12-05 10:19:04 - MSE (*100):	3.628373
2024-12-05 10:19:04 - Evaluating translation matching Accuracy of the model on the default dataset:
2024-12-05 10:19:07 - Accuracy src2trg: 98.95
2024-12-05 10:19:07 - Accuracy trg2src: 98.57
2024-12-05 10:19:07 - Saving model checkpoint to output/make-multilingual-en-id-2024-12-05_08-17-57/checkpoint-61876
2024-12-05 10:19:07 - Save model to output/make-multilingual-en-id-2024-12-05_08-17-57/checkpoint-61876
2024-12-05 10:48:01 - Saving model checkpoint to output/make-multilingual-en-id-2024-12-05_08-17-57/checkpoint-77345
2024-12-05 10:48:01 - Save model to output/make-multilingual-en-id-2024-12-05_08-17-57/checkpoint-77345
2024-12-05 10:48:10 - MSE evaluation (lower = better) on the default dataset:
2024-12-05 10:48:10 - MSE (*100):	3.555394
2024-12-05 10:48:10 - Evaluating translation matching Accuracy of the model on the default dataset:
2024-12-05 10:48:12 - Ac

TrainOutput(global_step=77345, training_loss=0.043607932929886635, metrics={'train_runtime': 8724.2473, 'train_samples_per_second': 567.384, 'train_steps_per_second': 8.866, 'total_flos': 0.0, 'train_loss': 0.043607932929886635, 'epoch': 5.0})

In [9]:
    final_output_dir = f"{output_dir}/final"
    student_model.save(final_output_dir)

2024-12-05 10:48:14 - Save model to output/make-multilingual-en-id-2024-12-05_08-17-57/final
