1. Create & Prepare embedding dataset

In [1]:
from datasets import load_dataset
 
# Load dataset from the hub
dataset = load_dataset('csv', data_files='../dataset/QA_data.csv', split="train")


# Shuffle trước để chọn ngẫu nhiên
dataset = dataset.shuffle(seed=42)

# Lấy 1/100 số dòng
# dataset = dataset.select(range(len(dataset) // 100))


print("Before:", len(dataset))
dataset = dataset.filter(lambda x: x["Answer"])
print("After:", len(dataset))
# rename columns
dataset = dataset.rename_column("Question", "anchor")
dataset = dataset.rename_column("Answer", "positive")
 
# Add an id column to the dataset
dataset = dataset.add_column("id", range(len(dataset)))
 
# split dataset into a 10% test set
dataset = dataset.train_test_split(test_size=0.1)
 
# save datasets to disk
dataset["train"].to_pandas().to_json("train_dataset.json", orient="records", lines=True, force_ascii=False)
dataset["test"].to_pandas().to_json("test_dataset.json", orient="records", lines=True, force_ascii=False)


  from .autonotebook import tqdm as notebook_tqdm


Before: 197379
After: 159501


2. Create baseline and evaluate pretrained model

In [2]:

from datasets import load_dataset, concatenate_datasets
 
# load test dataset
test_dataset = load_dataset("json", data_files="test_dataset.json", split="train")
train_dataset = load_dataset("json", data_files="train_dataset.json", split="train")
corpus_dataset = concatenate_datasets([train_dataset, test_dataset])
 
# Convert the datasets to dictionaries
corpus = dict(
    zip(corpus_dataset["id"], corpus_dataset["positive"])
)  # Our corpus (cid => document)
queries = dict(
    zip(test_dataset["id"], test_dataset["anchor"])
)  # Our queries (qid => question)
 
# Create a mapping of relevant document (1 in our case) for each query
relevant_docs = {}  # Query ID to relevant documents (qid => set([relevant_cids])
for q_id in queries:
    relevant_docs[q_id] = [q_id]
 
 


Generating train split: 15951 examples [00:00, 361251.10 examples/s]
Generating train split: 143550 examples [00:00, 603611.63 examples/s]


In [3]:
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import (
    InformationRetrievalEvaluator,
    SequentialEvaluator,
)
from sentence_transformers.util import cos_sim

model_id = "bkai-foundation-models/vietnamese-bi-encoder"  # Hugging Face model ID
matryoshka_dimensions = [768, 512, 256, 128, 64] # Important: large to small
 
# Load a model
model = SentenceTransformer(
    model_id, device="cuda" if torch.cuda.is_available() else "cpu"
)

matryoshka_evaluators = []
# Iterate over the different dimensions
for dim in matryoshka_dimensions:
    ir_evaluator = InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name=f"dim_{dim}",
        truncate_dim=dim,  # Truncate the embeddings to a certain dimension
        score_functions={"cosine": cos_sim},
    )
    matryoshka_evaluators.append(ir_evaluator)
 
# Create a sequential evaluator
evaluator = SequentialEvaluator(matryoshka_evaluators)

In [4]:
# Evaluate the model
results = evaluator(model)
for k,v in results.items():
    print(k, v)

print("=======================")
for dim in matryoshka_dimensions:
    key = f"dim_{dim}_cosine_ndcg@10"
    print
    print(f"{key}: {results[key]}")

dim_768_cosine_accuracy@1 0.23227383863080683
dim_768_cosine_accuracy@3 0.35847282302050026
dim_768_cosine_accuracy@5 0.4165256096796439
dim_768_cosine_accuracy@10 0.4767726161369193
dim_768_cosine_precision@1 0.23227383863080683
dim_768_cosine_precision@3 0.11949094100683343
dim_768_cosine_precision@5 0.08330512193592877
dim_768_cosine_precision@10 0.04767726161369193
dim_768_cosine_recall@1 0.23227383863080683
dim_768_cosine_recall@3 0.35847282302050026
dim_768_cosine_recall@5 0.4165256096796439
dim_768_cosine_recall@10 0.4767726161369193
dim_768_cosine_ndcg@10 0.34935886247371617
dim_768_cosine_mrr@10 0.309086856275119
dim_768_cosine_map@100 0.31452434828195525
dim_512_cosine_accuracy@1 0.23108269074039245
dim_512_cosine_accuracy@3 0.35571437527427746
dim_512_cosine_accuracy@5 0.4132656259795624
dim_512_cosine_accuracy@10 0.4730110964829791
dim_512_cosine_precision@1 0.23108269074039245
dim_512_cosine_precision@3 0.11857145842475914
dim_512_cosine_precision@5 0.08265312519591249
dim

3. Define loss function with Matryoshka Representation

In [5]:
from sentence_transformers import SentenceTransformerModelCardData, SentenceTransformer
 
# Hugging Face model ID: https://huggingface.co/BAAI/bge-base-en-v1.5
model_id = "bkai-foundation-models/vietnamese-bi-encoder"
 
# load model with SDPA for using Flash Attention 2
model = SentenceTransformer(
    model_id,
    model_kwargs={"attn_implementation": "sdpa"},
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name="BGE base Financial Matryoshka",
    ),
)
train_dataset = load_dataset("json", data_files="train_dataset.json", split="train")
train_dataset

Dataset({
    features: ['anchor', 'Context', 'positive', 'Answer_Start', 'Answer_End', 'id'],
    num_rows: 143550
})

In [6]:
train_dataset[0]

{'anchor': 'Tôi bị nhiễm covid, đã hoàn thành điều trị, về cách ly tại nhà dc 4 ngày, giờ ko ho ko sốt, nhưng cổ có mắc đàm, vòm họng đỏ, hơi rát, vì có bị viêm amidan mãn tính, vậy cho e hỏi e có thể uống thuốc gì dc ạ, e còn thuốc điều trị amidan có uống dc ko',
 'Context': 'Chào chị Dạ chào bác sỹ Chị đặt giúp em một lịch tư vấn online miễn phí. Em gọi điện hỏi thêm triệu chứng và tư vấn cụ thể cho chị nhé Chị ấn vào ảnh đại diện của em, ấn tư vấn trực tuyến, ấn xác nhận và đặt lịch ạ Cảm ơn chị',
 'positive': 'Chào chị Dạ chào bác sỹ Chị đặt giúp em một lịch tư vấn online miễn phí. Em gọi điện hỏi thêm triệu chứng và tư vấn cụ thể cho chị nhé Chị ấn vào ảnh đại diện của em, ấn tư vấn trực tuyến, ấn xác nhận và đặt lịch ạ Cảm ơn chị',
 'Answer_Start': 0,
 'Answer_End': 225,
 'id': 88434}

In [7]:
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
 
matryoshka_dimensions = [768, 512, 256, 128, 64]  # Important: large to small
inner_train_loss = MultipleNegativesRankingLoss(model)
train_loss = MatryoshkaLoss(
    model, inner_train_loss, matryoshka_dims=matryoshka_dimensions
)

4. Fine-tune embedding model with SentenceTransformersTrainer

In [8]:
from sentence_transformers import SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers
 
# load train dataset again
 
# define training arguments
args = SentenceTransformerTrainingArguments(
    output_dir="sample", # output directory and hugging face model ID
    num_train_epochs=10,                         # number of epochs
    per_device_train_batch_size=16,             # train batch size
    gradient_accumulation_steps=8,             # for a global batch size of 512
    per_device_eval_batch_size=8,    
    #gradient_checkpointing=True,
    warmup_ratio=0.1,                           # warmup ratio
    learning_rate=2e-5,                         # learning rate, 2e-5 is a good value
    lr_scheduler_type="cosine",                 # use constant learning rate scheduler
    optim="adamw_torch_fused",                  # use fused adamw optimizer
    # tf32=True,                                  # use tf32 precision
    bf16=True,                                  # use bf16 precision
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    eval_strategy="epoch",                      # evaluate after each epoch
    save_strategy="epoch",  
    # save_steps = 500,
    logging_steps=10,                           # log every 10 steps
    save_total_limit=3,                         # save only the last 3 models
    load_best_model_at_end=True,                # load the best model when training ends
    metric_for_best_model="eval_dim_128_cosine_ndcg@10",  # Optimizing for the best ndcg@10 score for the 128 dimension
    report_to = "none",
)

In [9]:
from sentence_transformers import SentenceTransformerTrainer
 
trainer = SentenceTransformerTrainer(
    model=model, # bg-base-en-v1
    args=args,  # training arguments
    train_dataset=train_dataset.select_columns(
        ["anchor", "positive"]
    ),  # training dataset
    loss=train_loss,
    evaluator=evaluator,
)

                                                                     

In [10]:
# start training, the model will be automatically saved to the hub and the output directory
trainer.train()
 
# save the best model
trainer.save_model()
 
# # push model to hub
# trainer.model.push_to_hub("bge-base-financial-matryoshka")

Epoch,Training Loss,Validation Loss,Dim 768 Cosine Accuracy@1,Dim 768 Cosine Accuracy@3,Dim 768 Cosine Accuracy@5,Dim 768 Cosine Accuracy@10,Dim 768 Cosine Precision@1,Dim 768 Cosine Precision@3,Dim 768 Cosine Precision@5,Dim 768 Cosine Precision@10,Dim 768 Cosine Recall@1,Dim 768 Cosine Recall@3,Dim 768 Cosine Recall@5,Dim 768 Cosine Recall@10,Dim 768 Cosine Ndcg@10,Dim 768 Cosine Mrr@10,Dim 768 Cosine Map@100,Dim 512 Cosine Accuracy@1,Dim 512 Cosine Accuracy@3,Dim 512 Cosine Accuracy@5,Dim 512 Cosine Accuracy@10,Dim 512 Cosine Precision@1,Dim 512 Cosine Precision@3,Dim 512 Cosine Precision@5,Dim 512 Cosine Precision@10,Dim 512 Cosine Recall@1,Dim 512 Cosine Recall@3,Dim 512 Cosine Recall@5,Dim 512 Cosine Recall@10,Dim 512 Cosine Ndcg@10,Dim 512 Cosine Mrr@10,Dim 512 Cosine Map@100,Dim 256 Cosine Accuracy@1,Dim 256 Cosine Accuracy@3,Dim 256 Cosine Accuracy@5,Dim 256 Cosine Accuracy@10,Dim 256 Cosine Precision@1,Dim 256 Cosine Precision@3,Dim 256 Cosine Precision@5,Dim 256 Cosine Precision@10,Dim 256 Cosine Recall@1,Dim 256 Cosine Recall@3,Dim 256 Cosine Recall@5,Dim 256 Cosine Recall@10,Dim 256 Cosine Ndcg@10,Dim 256 Cosine Mrr@10,Dim 256 Cosine Map@100,Dim 128 Cosine Accuracy@1,Dim 128 Cosine Accuracy@3,Dim 128 Cosine Accuracy@5,Dim 128 Cosine Accuracy@10,Dim 128 Cosine Precision@1,Dim 128 Cosine Precision@3,Dim 128 Cosine Precision@5,Dim 128 Cosine Precision@10,Dim 128 Cosine Recall@1,Dim 128 Cosine Recall@3,Dim 128 Cosine Recall@5,Dim 128 Cosine Recall@10,Dim 128 Cosine Ndcg@10,Dim 128 Cosine Mrr@10,Dim 128 Cosine Map@100,Dim 64 Cosine Accuracy@1,Dim 64 Cosine Accuracy@3,Dim 64 Cosine Accuracy@5,Dim 64 Cosine Accuracy@10,Dim 64 Cosine Precision@1,Dim 64 Cosine Precision@3,Dim 64 Cosine Precision@5,Dim 64 Cosine Precision@10,Dim 64 Cosine Recall@1,Dim 64 Cosine Recall@3,Dim 64 Cosine Recall@5,Dim 64 Cosine Recall@10,Dim 64 Cosine Ndcg@10,Dim 64 Cosine Mrr@10,Dim 64 Cosine Map@100,Sequential Score
1,4.8839,No log,0.372265,0.525547,0.571375,0.620463,0.372265,0.175182,0.114275,0.062046,0.372265,0.525547,0.571375,0.620463,0.496825,0.457113,0.462092,0.368629,0.523729,0.571438,0.618645,0.368629,0.174576,0.114288,0.061864,0.368629,0.523729,0.571438,0.618645,0.494357,0.454411,0.459396,0.363676,0.516958,0.565043,0.613817,0.363676,0.172319,0.113009,0.061382,0.363676,0.516958,0.565043,0.613817,0.489135,0.44909,0.454043,0.35189,0.503354,0.552567,0.603536,0.35189,0.167785,0.110513,0.060354,0.35189,0.503354,0.552567,0.603536,0.477686,0.437332,0.442267,0.324933,0.471569,0.519529,0.571626,0.324933,0.15719,0.103906,0.057163,0.324933,0.471569,0.519529,0.571626,0.447395,0.407659,0.412994,0.447395
2,4.0879,No log,0.388878,0.540593,0.587926,0.63739,0.388878,0.180198,0.117585,0.063739,0.388878,0.540593,0.587926,0.63739,0.513486,0.47372,0.478893,0.387938,0.536518,0.586609,0.634694,0.387938,0.178839,0.117322,0.063469,0.387938,0.536518,0.586609,0.634694,0.511661,0.472181,0.477426,0.382421,0.534136,0.579337,0.627359,0.382421,0.178045,0.115867,0.062736,0.382421,0.534136,0.579337,0.627359,0.505731,0.466632,0.47202,0.371325,0.521723,0.569243,0.617203,0.371325,0.173908,0.113849,0.06172,0.371325,0.521723,0.569243,0.617203,0.494705,0.455363,0.460718,0.346123,0.492822,0.539966,0.590935,0.346123,0.164274,0.107993,0.059093,0.346123,0.492822,0.539966,0.590935,0.468365,0.429123,0.434847,0.468365
3,3.3718,No log,0.398157,0.547991,0.594634,0.642468,0.398157,0.182664,0.118927,0.064247,0.398157,0.547991,0.594634,0.642468,0.521257,0.482317,0.487589,0.395148,0.546047,0.591938,0.6409,0.395148,0.182016,0.118388,0.06409,0.395148,0.546047,0.591938,0.6409,0.519017,0.479852,0.485138,0.391574,0.541283,0.58686,0.634945,0.391574,0.180428,0.117372,0.063494,0.391574,0.541283,0.58686,0.634945,0.514033,0.475204,0.480522,0.379412,0.53213,0.576704,0.626105,0.379412,0.177377,0.115341,0.06261,0.379412,0.53213,0.576704,0.626105,0.503916,0.464651,0.469884,0.363175,0.507429,0.554385,0.601655,0.363175,0.169143,0.110877,0.060166,0.363175,0.507429,0.554385,0.601655,0.482537,0.444313,0.449938,0.482537
4,2.0535,No log,0.40662,0.557896,0.60178,0.648047,0.40662,0.185965,0.120356,0.064805,0.40662,0.557896,0.60178,0.648047,0.528776,0.490377,0.495827,0.402295,0.555326,0.600589,0.646919,0.402295,0.185109,0.120118,0.064692,0.402295,0.555326,0.600589,0.646919,0.526167,0.487295,0.492667,0.395586,0.547928,0.594383,0.640838,0.395586,0.182643,0.118877,0.064084,0.395586,0.547928,0.594383,0.640838,0.519596,0.480574,0.485981,0.387938,0.537772,0.583161,0.628989,0.387938,0.179257,0.116632,0.062899,0.387938,0.537772,0.583161,0.628989,0.509643,0.471224,0.476819,0.364679,0.511379,0.560153,0.610683,0.364679,0.17046,0.112031,0.061068,0.364679,0.511379,0.560153,0.610683,0.487464,0.447995,0.453499,0.487464
5,1.8297,No log,0.401166,0.555702,0.60316,0.652122,0.401166,0.185234,0.120632,0.065212,0.401166,0.555702,0.60316,0.652122,0.527511,0.48745,0.492721,0.402169,0.55495,0.599774,0.649614,0.402169,0.184983,0.119955,0.064961,0.402169,0.55495,0.599774,0.649614,0.526402,0.486844,0.492206,0.395586,0.548994,0.595699,0.643471,0.395586,0.182998,0.11914,0.064347,0.395586,0.548994,0.595699,0.643471,0.520259,0.480653,0.486182,0.389067,0.538023,0.584916,0.634255,0.389067,0.179341,0.116983,0.063425,0.389067,0.538023,0.584916,0.634255,0.512085,0.47287,0.478425,0.36982,0.517146,0.563225,0.613378,0.36982,0.172382,0.112645,0.061338,0.36982,0.517146,0.563225,0.613378,0.491722,0.452716,0.45858,0.491722
6,1.968,No log,0.416212,0.565607,0.610244,0.659833,0.416212,0.188536,0.122049,0.065983,0.416212,0.565607,0.610244,0.659833,0.538639,0.499736,0.504924,0.414457,0.562347,0.607924,0.657514,0.414457,0.187449,0.121585,0.065751,0.414457,0.562347,0.607924,0.657514,0.536384,0.497516,0.502775,0.408501,0.559401,0.602721,0.651871,0.408501,0.186467,0.120544,0.065187,0.408501,0.559401,0.602721,0.651871,0.531091,0.492288,0.497704,0.400539,0.54937,0.594257,0.642468,0.400539,0.183123,0.118851,0.064247,0.400539,0.54937,0.594257,0.642468,0.522478,0.483946,0.48948,0.381606,0.530625,0.573381,0.622281,0.381606,0.176875,0.114676,0.062228,0.381606,0.530625,0.573381,0.622281,0.502864,0.464515,0.47028,0.502864
7,1.1077,No log,0.414958,0.567739,0.612752,0.660523,0.414958,0.189246,0.12255,0.066052,0.414958,0.567739,0.612752,0.660523,0.538843,0.499722,0.504997,0.414833,0.56404,0.609554,0.658705,0.414833,0.188013,0.121911,0.06587,0.414833,0.56404,0.609554,0.658705,0.537331,0.498372,0.503723,0.411636,0.561344,0.60504,0.655382,0.411636,0.187115,0.121008,0.065538,0.411636,0.561344,0.60504,0.655382,0.534153,0.495256,0.500488,0.399975,0.551,0.594947,0.644348,0.399975,0.183667,0.118989,0.064435,0.399975,0.551,0.594947,0.644348,0.523008,0.484046,0.489624,0.381355,0.528807,0.574071,0.623284,0.381355,0.176269,0.114814,0.062328,0.381355,0.528807,0.574071,0.623284,0.502908,0.464277,0.470081,0.502908
8,1.3489,No log,0.419848,0.570309,0.614068,0.662592,0.419848,0.190103,0.122814,0.066259,0.419848,0.570309,0.614068,0.662592,0.542104,0.503388,0.508618,0.417153,0.567174,0.612689,0.660523,0.417153,0.189058,0.122538,0.066052,0.417153,0.567174,0.612689,0.660523,0.539691,0.500849,0.50615,0.412827,0.563914,0.609115,0.656511,0.412827,0.187971,0.121823,0.065651,0.412827,0.563914,0.609115,0.656511,0.535814,0.497006,0.502342,0.40173,0.553382,0.598897,0.648549,0.40173,0.184461,0.119779,0.064855,0.40173,0.553382,0.598897,0.648549,0.5261,0.486762,0.492199,0.385995,0.535076,0.579337,0.629553,0.385995,0.178359,0.115867,0.062955,0.385995,0.535076,0.579337,0.629553,0.508493,0.469663,0.47535,0.508493
9,0.7969,No log,0.420162,0.569557,0.614507,0.662717,0.420162,0.189852,0.122901,0.066272,0.420162,0.569557,0.614507,0.662717,0.542292,0.503598,0.508895,0.417591,0.567363,0.613504,0.661087,0.417591,0.189121,0.122701,0.066109,0.417591,0.567363,0.613504,0.661087,0.540208,0.501358,0.506677,0.412513,0.562974,0.608363,0.656322,0.412513,0.187658,0.121673,0.065632,0.412513,0.562974,0.608363,0.656322,0.535423,0.496574,0.501986,0.403862,0.553758,0.598897,0.648611,0.403862,0.184586,0.119779,0.064861,0.403862,0.553758,0.598897,0.648611,0.526847,0.487761,0.493208,0.387562,0.534637,0.578584,0.62949,0.387562,0.178212,0.115717,0.062949,0.387562,0.534637,0.578584,0.62949,0.508979,0.470347,0.476051,0.508979


5. Evaluate fine-tuned model against baseline

In [11]:
from sentence_transformers import SentenceTransformer
 
fine_tuned_model = SentenceTransformer(
    args.output_dir, device="cuda" if torch.cuda.is_available() else "cpu"
)
# Evaluate the model
results = evaluator(fine_tuned_model)
 
# # COMMENT IN for full results
for k,v in results.items():
    print(k, v)
    
print("=======================")
 
# Print the main score
for dim in matryoshka_dimensions:
    key = f"dim_{dim}_cosine_ndcg@10"
    print(f"{key}: {results[key]}")

dim_768_cosine_accuracy@1 0.4199109773681901
dim_768_cosine_accuracy@3 0.5694313836123127
dim_768_cosine_accuracy@5 0.6146323114538274
dim_768_cosine_accuracy@10 0.6627797630242618
dim_768_cosine_precision@1 0.4199109773681901
dim_768_cosine_precision@3 0.18981046120410425
dim_768_cosine_precision@5 0.12292646229076545
dim_768_cosine_precision@10 0.06627797630242617
dim_768_cosine_recall@1 0.4199109773681901
dim_768_cosine_recall@3 0.5694313836123127
dim_768_cosine_recall@5 0.6146323114538274
dim_768_cosine_recall@10 0.6627797630242618
dim_768_cosine_ndcg@10 0.5422040066911552
dim_768_cosine_mrr@10 0.5034606229593954
dim_768_cosine_map@100 0.5087533694160415
dim_512_cosine_accuracy@1 0.418218293523917
dim_512_cosine_accuracy@3 0.5677386997680396
dim_512_cosine_accuracy@5 0.6136292395461099
dim_512_cosine_accuracy@10 0.6612751551626858
dim_512_cosine_precision@1 0.418218293523917
dim_512_cosine_precision@3 0.18924623325601322
dim_512_cosine_precision@5 0.12272584790922199
dim_512_cosine