In [1]:
from datasets import load_from_disk
 
# Load dataset from the Disk
dataset = load_from_disk("APN_dataset")
 

# Add an id column to the dataset
dataset = dataset.add_column("id", range(len(dataset)))
print (f"Dataset: {dataset[1]}") 
# split dataset into a 10% test set
dataset = dataset.train_test_split(test_size=0.1)
 
# save datasets to disk
dataset["train"].to_json("train_dataset.json", orient="records")
dataset["test"].to_json("test_dataset.json", orient="records")


  from .autonotebook import tqdm as notebook_tqdm


Dataset: {'anchor': 'بله، به نقطه تاس۱۰ توجه کنید: «امکان ضمانت متقابل انواع وام\u200cها در بانک ملت وجود ندارد».', 'positive': 'آیا امکان ضمانت متقابل در وام\u200cهای مختلف در بانک ملت وجود دارد؟', 'negetive': 'نحوه محاسبه توان بازپرداخت چگونه است؟', 'id': 1}


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 134.10ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 1024.00ba/s]


123317

In [2]:
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import (
    InformationRetrievalEvaluator,
    SequentialEvaluator,
)
from sentence_transformers.util import cos_sim
from datasets import load_dataset, concatenate_datasets
 
model_id = "halong_finetuned_frezzed"  

# model_id = "hiieu/halong_embedding" 

matryoshka_dimensions = [768, 512, 256, 128, 64] 
 
# Load a model
model = SentenceTransformer(
    model_id, device="cuda" if torch.cuda.is_available() else "cpu"
)
 
# load test dataset
test_dataset = load_dataset("json", data_files="test_dataset.json", split="train")
train_dataset = load_dataset("json", data_files="train_dataset.json", split="train")
corpus_dataset = concatenate_datasets([train_dataset, test_dataset])
 
# Convert the datasets to dictionaries
corpus = dict(
    zip(corpus_dataset["id"], corpus_dataset["positive"])
)  # Our corpus (cid => document)
queries = dict(
    zip(test_dataset["id"], test_dataset["anchor"])
)  # Our queries (qid => question)
 
# Create a mapping of relevant document (1 in our case) for each query
relevant_docs = {}  # Query ID to relevant documents (qid => set([relevant_cids])
for q_id in queries:
    relevant_docs[q_id] = [q_id]
 
 
matryoshka_evaluators = []
# Iterate over the different dimensions
for dim in matryoshka_dimensions:
    ir_evaluator = InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name=f"dim_{dim}",
        truncate_dim=dim,  # Truncate the embeddings to a certain dimension
        score_functions={"cosine": cos_sim},
    )
    matryoshka_evaluators.append(ir_evaluator)
 
# Create a sequential evaluator
evaluator = SequentialEvaluator(matryoshka_evaluators)

No sentence-transformers model found with name halong_finetuned_frezzed. Creating a new one with mean pooling.
Some weights of XLMRobertaModel were not initialized from the model checkpoint at halong_finetuned_frezzed and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Generating train split: 69 examples [00:00, 45184.54 examples/s]
Generating train split: 617 examples [00:00, 58422.56 examples/s]


In [3]:
# Evaluate the model
results = evaluator(model)
 
# Print the main score
for dim in matryoshka_dimensions:
    key = f"dim_{dim}_cosine_ndcg@10"
    print
    print(f"{key}: {results[key]}")

dim_768_cosine_ndcg@10: 0.39444353944796706
dim_512_cosine_ndcg@10: 0.38574244041851047
dim_256_cosine_ndcg@10: 0.35502801651499005
dim_128_cosine_ndcg@10: 0.3725521848226817
dim_64_cosine_ndcg@10: 0.35926891916256354


In [4]:
# Evaluate the model
results = evaluator(model)
 
# Print the main score
for dim in matryoshka_dimensions:
    key = f"dim_{dim}_cosine_ndcg@10"
    print
    print(f"{key}: {results[key]}")

dim_768_cosine_ndcg@10: 0.39444353944796706
dim_512_cosine_ndcg@10: 0.38574244041851047
dim_256_cosine_ndcg@10: 0.35502801651499005
dim_128_cosine_ndcg@10: 0.3725521848226817
dim_64_cosine_ndcg@10: 0.35926891916256354


In [5]:
from sentence_transformers import SentenceTransformerModelCardData, SentenceTransformer
 

 
# load model with SDPA for using Flash Attention 2
model = SentenceTransformer(
    model_id,
    model_kwargs={"attn_implementation": "sdpa"},
    model_card_data=SentenceTransformerModelCardData(
        language="multilingual",
        license="apache-2.0",
        model_name="Halang",
    ),
)

No sentence-transformers model found with name halong_finetuned_frezzed. Creating a new one with mean pooling.
Some weights of XLMRobertaModel were not initialized from the model checkpoint at halong_finetuned_frezzed and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
 
matryoshka_dimensions = [768, 512, 256, 128, 64]  # Important: large to small
inner_train_loss = MultipleNegativesRankingLoss(model)
train_loss = MatryoshkaLoss(
    model, inner_train_loss, matryoshka_dims=matryoshka_dimensions
)

In [7]:
from transformers import EarlyStoppingCallback
early_stopping = EarlyStoppingCallback(
    early_stopping_patience=20,  
    early_stopping_threshold=0.001  
)


In [8]:
from sentence_transformers import SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers
 
# load train dataset again
train_dataset = load_dataset("json", data_files="train_dataset.json", split="train")
 
# define training arguments
args = SentenceTransformerTrainingArguments(
    output_dir="halong_finetuned_APN", # output directory and hugging face model ID
    num_train_epochs=300,                         # number of epochs
    per_device_train_batch_size=32,             # train batch size
    gradient_accumulation_steps=16,             # for a global batch size of 512
    per_device_eval_batch_size=16,              # evaluation batch size
    warmup_ratio=0.1,                           # warmup ratio
    learning_rate=2e-5,                         # learning rate, 2e-5 is a good value
    lr_scheduler_type="cosine",                 # use constant learning rate scheduler
    optim="adamw_torch_fused",                  # use fused adamw optimizer
    tf32=True,                                  # use tf32 precision
    bf16=True,                                  # use bf16 precision
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    eval_strategy="epoch",                      # evaluate after each epoch
    save_strategy="epoch",                      # save after each epoch
    logging_steps=10,                           # log every 10 steps
    save_total_limit=3,                         # save only the last 3 models
    load_best_model_at_end=True,                # load the best model when training ends
    metric_for_best_model="eval_dim_128_cosine_ndcg@10",  # Optimizing for the best ndcg@10 score for the 128 dimension
    greater_is_better=True
    
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [9]:
from sentence_transformers import SentenceTransformerTrainer
 
trainer = SentenceTransformerTrainer(
    model=model,
    args=args, 
    train_dataset=train_dataset.select_columns(
        ["anchor" , "positive" , "negetive"]
    ), 
    loss=train_loss,
    evaluator=evaluator,callbacks=[early_stopping]
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
                                                                     

In [10]:
# start training, 
trainer.train()
 


Epoch,Training Loss,Validation Loss,Dim 768 Cosine Accuracy@1,Dim 768 Cosine Accuracy@3,Dim 768 Cosine Accuracy@5,Dim 768 Cosine Accuracy@10,Dim 768 Cosine Precision@1,Dim 768 Cosine Precision@3,Dim 768 Cosine Precision@5,Dim 768 Cosine Precision@10,Dim 768 Cosine Recall@1,Dim 768 Cosine Recall@3,Dim 768 Cosine Recall@5,Dim 768 Cosine Recall@10,Dim 768 Cosine Ndcg@10,Dim 768 Cosine Mrr@10,Dim 768 Cosine Map@100,Dim 512 Cosine Accuracy@1,Dim 512 Cosine Accuracy@3,Dim 512 Cosine Accuracy@5,Dim 512 Cosine Accuracy@10,Dim 512 Cosine Precision@1,Dim 512 Cosine Precision@3,Dim 512 Cosine Precision@5,Dim 512 Cosine Precision@10,Dim 512 Cosine Recall@1,Dim 512 Cosine Recall@3,Dim 512 Cosine Recall@5,Dim 512 Cosine Recall@10,Dim 512 Cosine Ndcg@10,Dim 512 Cosine Mrr@10,Dim 512 Cosine Map@100,Dim 256 Cosine Accuracy@1,Dim 256 Cosine Accuracy@3,Dim 256 Cosine Accuracy@5,Dim 256 Cosine Accuracy@10,Dim 256 Cosine Precision@1,Dim 256 Cosine Precision@3,Dim 256 Cosine Precision@5,Dim 256 Cosine Precision@10,Dim 256 Cosine Recall@1,Dim 256 Cosine Recall@3,Dim 256 Cosine Recall@5,Dim 256 Cosine Recall@10,Dim 256 Cosine Ndcg@10,Dim 256 Cosine Mrr@10,Dim 256 Cosine Map@100,Dim 128 Cosine Accuracy@1,Dim 128 Cosine Accuracy@3,Dim 128 Cosine Accuracy@5,Dim 128 Cosine Accuracy@10,Dim 128 Cosine Precision@1,Dim 128 Cosine Precision@3,Dim 128 Cosine Precision@5,Dim 128 Cosine Precision@10,Dim 128 Cosine Recall@1,Dim 128 Cosine Recall@3,Dim 128 Cosine Recall@5,Dim 128 Cosine Recall@10,Dim 128 Cosine Ndcg@10,Dim 128 Cosine Mrr@10,Dim 128 Cosine Map@100,Dim 64 Cosine Accuracy@1,Dim 64 Cosine Accuracy@3,Dim 64 Cosine Accuracy@5,Dim 64 Cosine Accuracy@10,Dim 64 Cosine Precision@1,Dim 64 Cosine Precision@3,Dim 64 Cosine Precision@5,Dim 64 Cosine Precision@10,Dim 64 Cosine Recall@1,Dim 64 Cosine Recall@3,Dim 64 Cosine Recall@5,Dim 64 Cosine Recall@10,Dim 64 Cosine Ndcg@10,Dim 64 Cosine Mrr@10,Dim 64 Cosine Map@100,Sequential Score
1,No log,No log,0.217391,0.347826,0.434783,0.637681,0.217391,0.115942,0.086957,0.063768,0.217391,0.347826,0.434783,0.637681,0.394703,0.321871,0.337919,0.202899,0.362319,0.42029,0.637681,0.202899,0.120773,0.084058,0.063768,0.202899,0.362319,0.42029,0.637681,0.392909,0.319053,0.335141,0.188406,0.304348,0.42029,0.57971,0.188406,0.101449,0.084058,0.057971,0.188406,0.304348,0.42029,0.57971,0.353764,0.285507,0.304095,0.217391,0.304348,0.391304,0.594203,0.217391,0.101449,0.078261,0.05942,0.217391,0.304348,0.391304,0.594203,0.367462,0.300075,0.31617,0.202899,0.318841,0.376812,0.57971,0.202899,0.10628,0.075362,0.057971,0.202899,0.318841,0.376812,0.57971,0.3605,0.294531,0.310899,0.3605
2,No log,No log,0.217391,0.347826,0.434783,0.652174,0.217391,0.115942,0.086957,0.065217,0.217391,0.347826,0.434783,0.652174,0.39856,0.322976,0.338323,0.202899,0.362319,0.42029,0.623188,0.202899,0.120773,0.084058,0.062319,0.202899,0.362319,0.42029,0.623188,0.388084,0.316879,0.334309,0.202899,0.318841,0.42029,0.57971,0.202899,0.10628,0.084058,0.057971,0.202899,0.318841,0.42029,0.57971,0.360586,0.294421,0.313626,0.231884,0.304348,0.391304,0.608696,0.231884,0.101449,0.078261,0.06087,0.231884,0.304348,0.391304,0.608696,0.377636,0.309495,0.324688,0.202899,0.318841,0.362319,0.594203,0.202899,0.10628,0.072464,0.05942,0.202899,0.318841,0.362319,0.594203,0.366192,0.29797,0.313531,0.366192
3,No log,No log,0.217391,0.347826,0.449275,0.666667,0.217391,0.115942,0.089855,0.066667,0.217391,0.347826,0.449275,0.666667,0.405885,0.328232,0.342802,0.231884,0.347826,0.434783,0.666667,0.231884,0.115942,0.086957,0.066667,0.231884,0.347826,0.434783,0.666667,0.410205,0.333972,0.348385,0.202899,0.318841,0.391304,0.608696,0.202899,0.10628,0.078261,0.06087,0.202899,0.318841,0.391304,0.608696,0.370406,0.299189,0.31653,0.217391,0.333333,0.405797,0.608696,0.217391,0.111111,0.081159,0.06087,0.217391,0.333333,0.405797,0.608696,0.376396,0.307241,0.32338,0.188406,0.318841,0.42029,0.57971,0.188406,0.10628,0.084058,0.057971,0.188406,0.318841,0.42029,0.57971,0.358279,0.291115,0.30881,0.358279
4,No log,No log,0.246377,0.347826,0.449275,0.681159,0.246377,0.115942,0.089855,0.068116,0.246377,0.347826,0.449275,0.681159,0.420065,0.343007,0.356867,0.231884,0.333333,0.405797,0.681159,0.231884,0.111111,0.081159,0.068116,0.231884,0.333333,0.405797,0.681159,0.413598,0.334317,0.348443,0.231884,0.333333,0.405797,0.623188,0.231884,0.111111,0.081159,0.062319,0.231884,0.333333,0.405797,0.623188,0.389455,0.32002,0.33749,0.217391,0.347826,0.391304,0.608696,0.217391,0.115942,0.078261,0.06087,0.217391,0.347826,0.391304,0.608696,0.379336,0.310858,0.328076,0.202899,0.333333,0.391304,0.608696,0.202899,0.111111,0.078261,0.06087,0.202899,0.333333,0.391304,0.608696,0.375593,0.305521,0.322161,0.375593
5,70.975200,No log,0.231884,0.376812,0.449275,0.695652,0.231884,0.125604,0.089855,0.069565,0.231884,0.376812,0.449275,0.695652,0.422688,0.341396,0.354917,0.231884,0.362319,0.434783,0.695652,0.231884,0.120773,0.086957,0.069565,0.231884,0.362319,0.434783,0.695652,0.419168,0.337129,0.350594,0.246377,0.333333,0.42029,0.637681,0.246377,0.111111,0.084058,0.063768,0.246377,0.333333,0.42029,0.637681,0.400404,0.330268,0.347535,0.217391,0.333333,0.405797,0.608696,0.217391,0.111111,0.081159,0.06087,0.217391,0.333333,0.405797,0.608696,0.378798,0.310162,0.329054,0.217391,0.333333,0.42029,0.608696,0.217391,0.111111,0.084058,0.06087,0.217391,0.333333,0.42029,0.608696,0.383597,0.315614,0.332863,0.383597
6,70.975200,No log,0.26087,0.376812,0.449275,0.681159,0.26087,0.125604,0.089855,0.068116,0.26087,0.376812,0.449275,0.681159,0.431722,0.357683,0.373715,0.26087,0.376812,0.463768,0.652174,0.26087,0.125604,0.092754,0.065217,0.26087,0.376812,0.463768,0.652174,0.42417,0.35563,0.373957,0.26087,0.347826,0.434783,0.652174,0.26087,0.115942,0.086957,0.065217,0.26087,0.347826,0.434783,0.652174,0.417471,0.3482,0.365773,0.246377,0.362319,0.434783,0.637681,0.246377,0.120773,0.086957,0.063768,0.246377,0.362319,0.434783,0.637681,0.404589,0.335179,0.352198,0.231884,0.347826,0.42029,0.608696,0.231884,0.115942,0.084058,0.06087,0.231884,0.347826,0.42029,0.608696,0.391401,0.325794,0.343841,0.391401
7,70.975200,No log,0.289855,0.376812,0.449275,0.695652,0.289855,0.125604,0.089855,0.069565,0.289855,0.376812,0.449275,0.695652,0.446616,0.373522,0.389488,0.289855,0.376812,0.449275,0.666667,0.289855,0.125604,0.089855,0.066667,0.289855,0.376812,0.449275,0.666667,0.438749,0.371245,0.389396,0.275362,0.347826,0.449275,0.652174,0.275362,0.115942,0.089855,0.065217,0.275362,0.347826,0.449275,0.652174,0.428486,0.361853,0.381924,0.246377,0.376812,0.449275,0.666667,0.246377,0.125604,0.089855,0.066667,0.246377,0.376812,0.449275,0.666667,0.41919,0.345503,0.361329,0.289855,0.391304,0.449275,0.594203,0.289855,0.130435,0.089855,0.05942,0.289855,0.391304,0.449275,0.594203,0.416989,0.363826,0.384708,0.416989
8,70.975200,No log,0.275362,0.376812,0.478261,0.710145,0.275362,0.125604,0.095652,0.071014,0.275362,0.376812,0.478261,0.710145,0.450107,0.372872,0.389882,0.289855,0.362319,0.463768,0.724638,0.289855,0.120773,0.092754,0.072464,0.289855,0.362319,0.463768,0.724638,0.456356,0.377059,0.39197,0.275362,0.362319,0.449275,0.710145,0.275362,0.120773,0.089855,0.071014,0.275362,0.362319,0.449275,0.710145,0.447412,0.369393,0.38545,0.275362,0.391304,0.434783,0.666667,0.275362,0.130435,0.086957,0.066667,0.275362,0.391304,0.434783,0.666667,0.432174,0.362934,0.380011,0.26087,0.362319,0.478261,0.652174,0.26087,0.120773,0.095652,0.065217,0.26087,0.362319,0.478261,0.652174,0.424489,0.355912,0.373072,0.424489
9,70.975200,No log,0.304348,0.42029,0.536232,0.73913,0.304348,0.140097,0.107246,0.073913,0.304348,0.42029,0.536232,0.73913,0.477905,0.400127,0.415801,0.289855,0.434783,0.492754,0.73913,0.289855,0.144928,0.098551,0.073913,0.289855,0.434783,0.492754,0.73913,0.471709,0.392161,0.407738,0.275362,0.376812,0.463768,0.724638,0.275362,0.125604,0.092754,0.072464,0.275362,0.376812,0.463768,0.724638,0.455087,0.375247,0.391193,0.275362,0.391304,0.463768,0.666667,0.275362,0.130435,0.092754,0.066667,0.275362,0.391304,0.463768,0.666667,0.439424,0.37144,0.390363,0.275362,0.391304,0.478261,0.666667,0.275362,0.130435,0.095652,0.066667,0.275362,0.391304,0.478261,0.666667,0.43693,0.368007,0.3863,0.43693
10,50.569800,No log,0.289855,0.42029,0.507246,0.753623,0.289855,0.140097,0.101449,0.075362,0.289855,0.42029,0.507246,0.753623,0.475239,0.3928,0.407661,0.304348,0.449275,0.521739,0.73913,0.304348,0.149758,0.104348,0.073913,0.304348,0.449275,0.521739,0.73913,0.483909,0.407695,0.423486,0.289855,0.391304,0.478261,0.753623,0.289855,0.130435,0.095652,0.075362,0.289855,0.391304,0.478261,0.753623,0.47043,0.386767,0.401136,0.26087,0.405797,0.492754,0.710145,0.26087,0.135266,0.098551,0.071014,0.26087,0.405797,0.492754,0.710145,0.449777,0.371998,0.389418,0.289855,0.376812,0.449275,0.695652,0.289855,0.125604,0.089855,0.069565,0.289855,0.376812,0.449275,0.695652,0.449666,0.37718,0.395095,0.449666


TrainOutput(global_step=76, training_loss=30.202686811748304, metrics={'train_runtime': 222.146, 'train_samples_per_second': 833.236, 'train_steps_per_second': 1.35, 'total_flos': 0.0, 'train_loss': 30.202686811748304, 'epoch': 38.0})

In [12]:
# save the best model
trainer.save_model()
 


In [11]:
from sentence_transformers import SentenceTransformer
 
fine_tuned_model = SentenceTransformer(
    args.output_dir, device="cuda" if torch.cuda.is_available() else "cpu"
)
# Evaluate the model
results = evaluator(fine_tuned_model)
 

# Print the main score
for dim in matryoshka_dimensions:
    key = f"dim_{dim}_cosine_ndcg@10"
    print(f"{key}: {results[key]}")

dim_768_cosine_ndcg@10: 0.7340257807397909
dim_512_cosine_ndcg@10: 0.7632895558112947
dim_256_cosine_ndcg@10: 0.7748767678236386
dim_128_cosine_ndcg@10: 0.8145157578527988
dim_64_cosine_ndcg@10: 0.7742869662818314
