In [1]:
from datasets import load_dataset

# Load dataset from the hub
dataset = load_dataset("json", data_files="dataset.json", split="train")

# Add an id column to the dataset
dataset = dataset.add_column("id", range(len(dataset)))

# split dataset into a 10% eval set
dataset = dataset.train_test_split(test_size=0.1)

# save datasets to disk
dataset["train"].to_json("train_dataset.json", orient="records")
dataset["eval"].to_json("eval_dataset.json", orient="records")

Creating json from Arrow format:   0%|          | 0/7 [00:00<?, ?ba/s]

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

243039

In [2]:
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import (
    InformationRetrievalEvaluator,
    SequentialEvaluator,
)
from sentence_transformers.util import cos_sim
from datasets import load_dataset, concatenate_datasets

model_id = "models/vietnamese-embedding"  # Hugging Face model ID
matryoshka_dimensions = [768, 512, 256, 128, 64] # Important: large to small

# Load a model
model = SentenceTransformer(
    model_id, device="cuda" if torch.cuda.is_available() else "cpu"
)

# load eval dataset
eval_dataset = load_dataset("json", data_files="eval_dataset.json", split="train")
train_dataset = load_dataset("json", data_files="train_dataset.json", split="train")
corpus_dataset = concatenate_datasets([train_dataset, eval_dataset])

# Convert the datasets to dictionaries
corpus = dict(
    zip(corpus_dataset["id"], corpus_dataset["answer"])
)  # Our corpus (cid => document)
queries = dict(
    zip(eval_dataset["id"], eval_dataset["question"])
)  # Our queries (qid => question)

# Create a mapping of relevant document (1 in our case) for each query
relevant_docs = {}  # Query ID to relevant documents (qid => set([relevant_cids])
for q_id in queries:
    relevant_docs[q_id] = [q_id]


matryoshka_evaluators = []
# Iterate over the different dimensions
for dim in matryoshka_dimensions:
    ir_evaluator = InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name=f"dim_{dim}",
        truncate_dim=dim,  # Truncate the embeddings to a certain dimension
        score_functions={"cosine": cos_sim},
    )
    matryoshka_evaluators.append(ir_evaluator)

# Create a sequential evaluator
evaluator = SequentialEvaluator(matryoshka_evaluators)



Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [15]:
corpus_dataset

Dataset({
    features: ['anchor', 'positive', 'id'],
    num_rows: 7000
})

In [16]:
dataset

DatasetDict({
    train: Dataset({
        features: ['anchor', 'positive', 'id'],
        num_rows: 6300
    })
    test: Dataset({
        features: ['anchor', 'positive', 'id'],
        num_rows: 700
    })
})

In [3]:
# Evaluate the model
results = evaluator(model)

# # COMMENT IN for full results
# print(results)

# Print the main score
for dim in matryoshka_dimensions:
    key = f"dim_{dim}_cosine_ndcg@10"
    print
    print(f"{key}: {results[key]}")

dim_768_cosine_ndcg@10: 0.7516061491368686
dim_512_cosine_ndcg@10: 0.7444379636635219
dim_256_cosine_ndcg@10: 0.7347367618134649
dim_128_cosine_ndcg@10: 0.7060806393737453
dim_64_cosine_ndcg@10: 0.6350737168891617


In [4]:
from sentence_transformers import SentenceTransformerModelCardData, SentenceTransformer

# load model with SDPA for using Flash Attention 2
model = SentenceTransformer(
    model_id,
    model_kwargs={"attn_implementation": "sdpa"},
    model_card_data=SentenceTransformerModelCardData(
        language=["en","vn"],
        model_name="Vietnamese embeddings Matryoshka",
    ),
)

In [5]:
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss

matryoshka_dimensions = [768, 512, 256, 128, 64]  # Important: large to small
inner_train_loss = MultipleNegativesRankingLoss(model)
train_loss = MatryoshkaLoss(
    model, inner_train_loss, matryoshka_dims=matryoshka_dimensions
)

In [9]:
from sentence_transformers import SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers

# define training arguments
args = SentenceTransformerTrainingArguments(
    output_dir="models/finetune-vietnamese-embeddings", # output directory and hugging face model ID
    num_train_epochs=100,                         # number of epochs
    per_device_train_batch_size=32,             # train batch size
    gradient_accumulation_steps=16,             # for a global batch size of 512
    per_device_eval_batch_size=16,              # evaluation batch size
    warmup_ratio=0.1,                           # warmup ratio
    learning_rate=2e-5,                         # learning rate, 2e-5 is a good value
    lr_scheduler_type="cosine",                 # use constant learning rate scheduler
    optim="adamw_torch_fused",                  # use fused adamw optimizer
    # tf32=True,                                  # use tf32 precision
    bf16=True,                                  # use bf16 precision
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    eval_strategy="epoch",                      # evaluate after each epoch
    save_strategy="epoch",                      # save after each epoch
    logging_steps=10,                           # log every 10 steps
    save_total_limit=2,                         # save only the last 2 models
    load_best_model_at_end=True,                # load the best model when training ends
    metric_for_best_model="eval_dim_128_cosine_ndcg@10",  # Optimizing for the best ndcg@10 score for the 128 dimension
)

In [11]:
from sentence_transformers import SentenceTransformerTrainer

trainer = SentenceTransformerTrainer(
    model=model, # bg-base-en-v1
    args=args,  # training arguments
    train_dataset=train_dataset.select_columns(
        ["answer", "question"]
    ),
    eval_dataset=eval_dataset.select_columns(
        ["answer", "question"]
    ),
    loss=train_loss,
    evaluator=evaluator,
)
trainer.train()

Epoch,Training Loss,Validation Loss,Dim 768 Cosine Accuracy@1,Dim 768 Cosine Accuracy@3,Dim 768 Cosine Accuracy@5,Dim 768 Cosine Accuracy@10,Dim 768 Cosine Precision@1,Dim 768 Cosine Precision@3,Dim 768 Cosine Precision@5,Dim 768 Cosine Precision@10,Dim 768 Cosine Recall@1,Dim 768 Cosine Recall@3,Dim 768 Cosine Recall@5,Dim 768 Cosine Recall@10,Dim 768 Cosine Ndcg@10,Dim 768 Cosine Mrr@10,Dim 768 Cosine Map@100,Dim 512 Cosine Accuracy@1,Dim 512 Cosine Accuracy@3,Dim 512 Cosine Accuracy@5,Dim 512 Cosine Accuracy@10,Dim 512 Cosine Precision@1,Dim 512 Cosine Precision@3,Dim 512 Cosine Precision@5,Dim 512 Cosine Precision@10,Dim 512 Cosine Recall@1,Dim 512 Cosine Recall@3,Dim 512 Cosine Recall@5,Dim 512 Cosine Recall@10,Dim 512 Cosine Ndcg@10,Dim 512 Cosine Mrr@10,Dim 512 Cosine Map@100,Dim 256 Cosine Accuracy@1,Dim 256 Cosine Accuracy@3,Dim 256 Cosine Accuracy@5,Dim 256 Cosine Accuracy@10,Dim 256 Cosine Precision@1,Dim 256 Cosine Precision@3,Dim 256 Cosine Precision@5,Dim 256 Cosine Precision@10,Dim 256 Cosine Recall@1,Dim 256 Cosine Recall@3,Dim 256 Cosine Recall@5,Dim 256 Cosine Recall@10,Dim 256 Cosine Ndcg@10,Dim 256 Cosine Mrr@10,Dim 256 Cosine Map@100,Dim 128 Cosine Accuracy@1,Dim 128 Cosine Accuracy@3,Dim 128 Cosine Accuracy@5,Dim 128 Cosine Accuracy@10,Dim 128 Cosine Precision@1,Dim 128 Cosine Precision@3,Dim 128 Cosine Precision@5,Dim 128 Cosine Precision@10,Dim 128 Cosine Recall@1,Dim 128 Cosine Recall@3,Dim 128 Cosine Recall@5,Dim 128 Cosine Recall@10,Dim 128 Cosine Ndcg@10,Dim 128 Cosine Mrr@10,Dim 128 Cosine Map@100,Dim 64 Cosine Accuracy@1,Dim 64 Cosine Accuracy@3,Dim 64 Cosine Accuracy@5,Dim 64 Cosine Accuracy@10,Dim 64 Cosine Precision@1,Dim 64 Cosine Precision@3,Dim 64 Cosine Precision@5,Dim 64 Cosine Precision@10,Dim 64 Cosine Recall@1,Dim 64 Cosine Recall@3,Dim 64 Cosine Recall@5,Dim 64 Cosine Recall@10,Dim 64 Cosine Ndcg@10,Dim 64 Cosine Mrr@10,Dim 64 Cosine Map@100,Sequential Score
0,1.5244,No log,0.68,0.81,0.841429,0.901429,0.68,0.27,0.168286,0.090143,0.68,0.81,0.841429,0.901429,0.790951,0.755882,0.759713,0.677143,0.805714,0.845714,0.894286,0.677143,0.268571,0.169143,0.089429,0.677143,0.805714,0.845714,0.894286,0.786448,0.752056,0.75638,0.668571,0.8,0.831429,0.877143,0.668571,0.266667,0.166286,0.087714,0.668571,0.8,0.831429,0.877143,0.775548,0.742885,0.748016,0.664286,0.785714,0.818571,0.874286,0.664286,0.261905,0.163714,0.087429,0.664286,0.785714,0.818571,0.874286,0.767728,0.733961,0.738187,0.614286,0.74,0.784286,0.841429,0.614286,0.246667,0.156857,0.084143,0.614286,0.74,0.784286,0.841429,0.726217,0.689673,0.694407,0.614286


Epoch,Training Loss,Validation Loss,Dim 768 Cosine Accuracy@1,Dim 768 Cosine Accuracy@3,Dim 768 Cosine Accuracy@5,Dim 768 Cosine Accuracy@10,Dim 768 Cosine Precision@1,Dim 768 Cosine Precision@3,Dim 768 Cosine Precision@5,Dim 768 Cosine Precision@10,Dim 768 Cosine Recall@1,Dim 768 Cosine Recall@3,Dim 768 Cosine Recall@5,Dim 768 Cosine Recall@10,Dim 768 Cosine Ndcg@10,Dim 768 Cosine Mrr@10,Dim 768 Cosine Map@100,Dim 512 Cosine Accuracy@1,Dim 512 Cosine Accuracy@3,Dim 512 Cosine Accuracy@5,Dim 512 Cosine Accuracy@10,Dim 512 Cosine Precision@1,Dim 512 Cosine Precision@3,Dim 512 Cosine Precision@5,Dim 512 Cosine Precision@10,Dim 512 Cosine Recall@1,Dim 512 Cosine Recall@3,Dim 512 Cosine Recall@5,Dim 512 Cosine Recall@10,Dim 512 Cosine Ndcg@10,Dim 512 Cosine Mrr@10,Dim 512 Cosine Map@100,Dim 256 Cosine Accuracy@1,Dim 256 Cosine Accuracy@3,Dim 256 Cosine Accuracy@5,Dim 256 Cosine Accuracy@10,Dim 256 Cosine Precision@1,Dim 256 Cosine Precision@3,Dim 256 Cosine Precision@5,Dim 256 Cosine Precision@10,Dim 256 Cosine Recall@1,Dim 256 Cosine Recall@3,Dim 256 Cosine Recall@5,Dim 256 Cosine Recall@10,Dim 256 Cosine Ndcg@10,Dim 256 Cosine Mrr@10,Dim 256 Cosine Map@100,Dim 128 Cosine Accuracy@1,Dim 128 Cosine Accuracy@3,Dim 128 Cosine Accuracy@5,Dim 128 Cosine Accuracy@10,Dim 128 Cosine Precision@1,Dim 128 Cosine Precision@3,Dim 128 Cosine Precision@5,Dim 128 Cosine Precision@10,Dim 128 Cosine Recall@1,Dim 128 Cosine Recall@3,Dim 128 Cosine Recall@5,Dim 128 Cosine Recall@10,Dim 128 Cosine Ndcg@10,Dim 128 Cosine Mrr@10,Dim 128 Cosine Map@100,Dim 64 Cosine Accuracy@1,Dim 64 Cosine Accuracy@3,Dim 64 Cosine Accuracy@5,Dim 64 Cosine Accuracy@10,Dim 64 Cosine Precision@1,Dim 64 Cosine Precision@3,Dim 64 Cosine Precision@5,Dim 64 Cosine Precision@10,Dim 64 Cosine Recall@1,Dim 64 Cosine Recall@3,Dim 64 Cosine Recall@5,Dim 64 Cosine Recall@10,Dim 64 Cosine Ndcg@10,Dim 64 Cosine Mrr@10,Dim 64 Cosine Map@100,Sequential Score
0,1.5244,No log,0.68,0.81,0.841429,0.901429,0.68,0.27,0.168286,0.090143,0.68,0.81,0.841429,0.901429,0.790951,0.755882,0.759713,0.677143,0.805714,0.845714,0.894286,0.677143,0.268571,0.169143,0.089429,0.677143,0.805714,0.845714,0.894286,0.786448,0.752056,0.75638,0.668571,0.8,0.831429,0.877143,0.668571,0.266667,0.166286,0.087714,0.668571,0.8,0.831429,0.877143,0.775548,0.742885,0.748016,0.664286,0.785714,0.818571,0.874286,0.664286,0.261905,0.163714,0.087429,0.664286,0.785714,0.818571,0.874286,0.767728,0.733961,0.738187,0.614286,0.74,0.784286,0.841429,0.614286,0.246667,0.156857,0.084143,0.614286,0.74,0.784286,0.841429,0.726217,0.689673,0.694407,0.614286
1,0.6654,No log,0.691429,0.815714,0.85,0.905714,0.691429,0.271905,0.17,0.090571,0.691429,0.815714,0.85,0.905714,0.797937,0.763721,0.767703,0.688571,0.818571,0.851429,0.908571,0.688571,0.272857,0.170286,0.090857,0.688571,0.818571,0.851429,0.908571,0.798328,0.763282,0.76681,0.692857,0.811429,0.84,0.888571,0.692857,0.270476,0.168,0.088857,0.692857,0.811429,0.84,0.888571,0.791992,0.761062,0.765751,0.671429,0.81,0.84,0.888571,0.671429,0.27,0.168,0.088857,0.671429,0.81,0.84,0.888571,0.782393,0.748266,0.752168,0.637143,0.764286,0.804286,0.862857,0.637143,0.254762,0.160857,0.086286,0.637143,0.764286,0.804286,0.862857,0.747796,0.711442,0.716057,0.637143


Epoch,Training Loss,Validation Loss,Dim 768 Cosine Accuracy@1,Dim 768 Cosine Accuracy@3,Dim 768 Cosine Accuracy@5,Dim 768 Cosine Accuracy@10,Dim 768 Cosine Precision@1,Dim 768 Cosine Precision@3,Dim 768 Cosine Precision@5,Dim 768 Cosine Precision@10,Dim 768 Cosine Recall@1,Dim 768 Cosine Recall@3,Dim 768 Cosine Recall@5,Dim 768 Cosine Recall@10,Dim 768 Cosine Ndcg@10,Dim 768 Cosine Mrr@10,Dim 768 Cosine Map@100,Dim 512 Cosine Accuracy@1,Dim 512 Cosine Accuracy@3,Dim 512 Cosine Accuracy@5,Dim 512 Cosine Accuracy@10,Dim 512 Cosine Precision@1,Dim 512 Cosine Precision@3,Dim 512 Cosine Precision@5,Dim 512 Cosine Precision@10,Dim 512 Cosine Recall@1,Dim 512 Cosine Recall@3,Dim 512 Cosine Recall@5,Dim 512 Cosine Recall@10,Dim 512 Cosine Ndcg@10,Dim 512 Cosine Mrr@10,Dim 512 Cosine Map@100,Dim 256 Cosine Accuracy@1,Dim 256 Cosine Accuracy@3,Dim 256 Cosine Accuracy@5,Dim 256 Cosine Accuracy@10,Dim 256 Cosine Precision@1,Dim 256 Cosine Precision@3,Dim 256 Cosine Precision@5,Dim 256 Cosine Precision@10,Dim 256 Cosine Recall@1,Dim 256 Cosine Recall@3,Dim 256 Cosine Recall@5,Dim 256 Cosine Recall@10,Dim 256 Cosine Ndcg@10,Dim 256 Cosine Mrr@10,Dim 256 Cosine Map@100,Dim 128 Cosine Accuracy@1,Dim 128 Cosine Accuracy@3,Dim 128 Cosine Accuracy@5,Dim 128 Cosine Accuracy@10,Dim 128 Cosine Precision@1,Dim 128 Cosine Precision@3,Dim 128 Cosine Precision@5,Dim 128 Cosine Precision@10,Dim 128 Cosine Recall@1,Dim 128 Cosine Recall@3,Dim 128 Cosine Recall@5,Dim 128 Cosine Recall@10,Dim 128 Cosine Ndcg@10,Dim 128 Cosine Mrr@10,Dim 128 Cosine Map@100,Dim 64 Cosine Accuracy@1,Dim 64 Cosine Accuracy@3,Dim 64 Cosine Accuracy@5,Dim 64 Cosine Accuracy@10,Dim 64 Cosine Precision@1,Dim 64 Cosine Precision@3,Dim 64 Cosine Precision@5,Dim 64 Cosine Precision@10,Dim 64 Cosine Recall@1,Dim 64 Cosine Recall@3,Dim 64 Cosine Recall@5,Dim 64 Cosine Recall@10,Dim 64 Cosine Ndcg@10,Dim 64 Cosine Mrr@10,Dim 64 Cosine Map@100,Sequential Score
0,1.5244,No log,0.68,0.81,0.841429,0.901429,0.68,0.27,0.168286,0.090143,0.68,0.81,0.841429,0.901429,0.790951,0.755882,0.759713,0.677143,0.805714,0.845714,0.894286,0.677143,0.268571,0.169143,0.089429,0.677143,0.805714,0.845714,0.894286,0.786448,0.752056,0.75638,0.668571,0.8,0.831429,0.877143,0.668571,0.266667,0.166286,0.087714,0.668571,0.8,0.831429,0.877143,0.775548,0.742885,0.748016,0.664286,0.785714,0.818571,0.874286,0.664286,0.261905,0.163714,0.087429,0.664286,0.785714,0.818571,0.874286,0.767728,0.733961,0.738187,0.614286,0.74,0.784286,0.841429,0.614286,0.246667,0.156857,0.084143,0.614286,0.74,0.784286,0.841429,0.726217,0.689673,0.694407,0.614286
1,0.6654,No log,0.691429,0.815714,0.85,0.905714,0.691429,0.271905,0.17,0.090571,0.691429,0.815714,0.85,0.905714,0.797937,0.763721,0.767703,0.688571,0.818571,0.851429,0.908571,0.688571,0.272857,0.170286,0.090857,0.688571,0.818571,0.851429,0.908571,0.798328,0.763282,0.76681,0.692857,0.811429,0.84,0.888571,0.692857,0.270476,0.168,0.088857,0.692857,0.811429,0.84,0.888571,0.791992,0.761062,0.765751,0.671429,0.81,0.84,0.888571,0.671429,0.27,0.168,0.088857,0.671429,0.81,0.84,0.888571,0.782393,0.748266,0.752168,0.637143,0.764286,0.804286,0.862857,0.637143,0.254762,0.160857,0.086286,0.637143,0.764286,0.804286,0.862857,0.747796,0.711442,0.716057,0.637143
2,0.4665,No log,0.694286,0.82,0.861429,0.904286,0.694286,0.273333,0.172286,0.090429,0.694286,0.82,0.861429,0.904286,0.799897,0.766518,0.770857,0.698571,0.824286,0.855714,0.908571,0.698571,0.274762,0.171143,0.090857,0.698571,0.824286,0.855714,0.908571,0.802969,0.769353,0.773098,0.7,0.821429,0.845714,0.892857,0.7,0.27381,0.169143,0.089286,0.7,0.821429,0.845714,0.892857,0.797374,0.766839,0.77147,0.68,0.814286,0.841429,0.89,0.68,0.271429,0.168286,0.089,0.68,0.814286,0.841429,0.89,0.78719,0.754099,0.758251,0.648571,0.77,0.808571,0.864286,0.648571,0.256667,0.161714,0.086429,0.648571,0.77,0.808571,0.864286,0.754671,0.719931,0.724814,0.648571


Epoch,Training Loss,Validation Loss,Dim 768 Cosine Accuracy@1,Dim 768 Cosine Accuracy@3,Dim 768 Cosine Accuracy@5,Dim 768 Cosine Accuracy@10,Dim 768 Cosine Precision@1,Dim 768 Cosine Precision@3,Dim 768 Cosine Precision@5,Dim 768 Cosine Precision@10,Dim 768 Cosine Recall@1,Dim 768 Cosine Recall@3,Dim 768 Cosine Recall@5,Dim 768 Cosine Recall@10,Dim 768 Cosine Ndcg@10,Dim 768 Cosine Mrr@10,Dim 768 Cosine Map@100,Dim 512 Cosine Accuracy@1,Dim 512 Cosine Accuracy@3,Dim 512 Cosine Accuracy@5,Dim 512 Cosine Accuracy@10,Dim 512 Cosine Precision@1,Dim 512 Cosine Precision@3,Dim 512 Cosine Precision@5,Dim 512 Cosine Precision@10,Dim 512 Cosine Recall@1,Dim 512 Cosine Recall@3,Dim 512 Cosine Recall@5,Dim 512 Cosine Recall@10,Dim 512 Cosine Ndcg@10,Dim 512 Cosine Mrr@10,Dim 512 Cosine Map@100,Dim 256 Cosine Accuracy@1,Dim 256 Cosine Accuracy@3,Dim 256 Cosine Accuracy@5,Dim 256 Cosine Accuracy@10,Dim 256 Cosine Precision@1,Dim 256 Cosine Precision@3,Dim 256 Cosine Precision@5,Dim 256 Cosine Precision@10,Dim 256 Cosine Recall@1,Dim 256 Cosine Recall@3,Dim 256 Cosine Recall@5,Dim 256 Cosine Recall@10,Dim 256 Cosine Ndcg@10,Dim 256 Cosine Mrr@10,Dim 256 Cosine Map@100,Dim 128 Cosine Accuracy@1,Dim 128 Cosine Accuracy@3,Dim 128 Cosine Accuracy@5,Dim 128 Cosine Accuracy@10,Dim 128 Cosine Precision@1,Dim 128 Cosine Precision@3,Dim 128 Cosine Precision@5,Dim 128 Cosine Precision@10,Dim 128 Cosine Recall@1,Dim 128 Cosine Recall@3,Dim 128 Cosine Recall@5,Dim 128 Cosine Recall@10,Dim 128 Cosine Ndcg@10,Dim 128 Cosine Mrr@10,Dim 128 Cosine Map@100,Dim 64 Cosine Accuracy@1,Dim 64 Cosine Accuracy@3,Dim 64 Cosine Accuracy@5,Dim 64 Cosine Accuracy@10,Dim 64 Cosine Precision@1,Dim 64 Cosine Precision@3,Dim 64 Cosine Precision@5,Dim 64 Cosine Precision@10,Dim 64 Cosine Recall@1,Dim 64 Cosine Recall@3,Dim 64 Cosine Recall@5,Dim 64 Cosine Recall@10,Dim 64 Cosine Ndcg@10,Dim 64 Cosine Mrr@10,Dim 64 Cosine Map@100,Sequential Score
0,1.5244,No log,0.68,0.81,0.841429,0.901429,0.68,0.27,0.168286,0.090143,0.68,0.81,0.841429,0.901429,0.790951,0.755882,0.759713,0.677143,0.805714,0.845714,0.894286,0.677143,0.268571,0.169143,0.089429,0.677143,0.805714,0.845714,0.894286,0.786448,0.752056,0.75638,0.668571,0.8,0.831429,0.877143,0.668571,0.266667,0.166286,0.087714,0.668571,0.8,0.831429,0.877143,0.775548,0.742885,0.748016,0.664286,0.785714,0.818571,0.874286,0.664286,0.261905,0.163714,0.087429,0.664286,0.785714,0.818571,0.874286,0.767728,0.733961,0.738187,0.614286,0.74,0.784286,0.841429,0.614286,0.246667,0.156857,0.084143,0.614286,0.74,0.784286,0.841429,0.726217,0.689673,0.694407,0.614286
1,0.6654,No log,0.691429,0.815714,0.85,0.905714,0.691429,0.271905,0.17,0.090571,0.691429,0.815714,0.85,0.905714,0.797937,0.763721,0.767703,0.688571,0.818571,0.851429,0.908571,0.688571,0.272857,0.170286,0.090857,0.688571,0.818571,0.851429,0.908571,0.798328,0.763282,0.76681,0.692857,0.811429,0.84,0.888571,0.692857,0.270476,0.168,0.088857,0.692857,0.811429,0.84,0.888571,0.791992,0.761062,0.765751,0.671429,0.81,0.84,0.888571,0.671429,0.27,0.168,0.088857,0.671429,0.81,0.84,0.888571,0.782393,0.748266,0.752168,0.637143,0.764286,0.804286,0.862857,0.637143,0.254762,0.160857,0.086286,0.637143,0.764286,0.804286,0.862857,0.747796,0.711442,0.716057,0.637143
2,0.4665,No log,0.694286,0.82,0.861429,0.904286,0.694286,0.273333,0.172286,0.090429,0.694286,0.82,0.861429,0.904286,0.799897,0.766518,0.770857,0.698571,0.824286,0.855714,0.908571,0.698571,0.274762,0.171143,0.090857,0.698571,0.824286,0.855714,0.908571,0.802969,0.769353,0.773098,0.7,0.821429,0.845714,0.892857,0.7,0.27381,0.169143,0.089286,0.7,0.821429,0.845714,0.892857,0.797374,0.766839,0.77147,0.68,0.814286,0.841429,0.89,0.68,0.271429,0.168286,0.089,0.68,0.814286,0.841429,0.89,0.78719,0.754099,0.758251,0.648571,0.77,0.808571,0.864286,0.648571,0.256667,0.161714,0.086429,0.648571,0.77,0.808571,0.864286,0.754671,0.719931,0.724814,0.648571
3,0.3673,No log,0.695714,0.821429,0.861429,0.907143,0.695714,0.27381,0.172286,0.090714,0.695714,0.821429,0.861429,0.907143,0.801089,0.767289,0.77136,0.701429,0.824286,0.855714,0.907143,0.701429,0.274762,0.171143,0.090714,0.701429,0.824286,0.855714,0.907143,0.803386,0.770365,0.77423,0.7,0.822857,0.847143,0.89,0.7,0.274286,0.169429,0.089,0.7,0.822857,0.847143,0.89,0.796736,0.766776,0.771743,0.682857,0.814286,0.842857,0.888571,0.682857,0.271429,0.168571,0.088857,0.682857,0.814286,0.842857,0.888571,0.788116,0.755736,0.760055,0.652857,0.771429,0.808571,0.864286,0.652857,0.257143,0.161714,0.086429,0.652857,0.771429,0.808571,0.864286,0.756586,0.722484,0.72738,0.652857


Epoch,Training Loss,Validation Loss,Dim 768 Cosine Accuracy@1,Dim 768 Cosine Accuracy@3,Dim 768 Cosine Accuracy@5,Dim 768 Cosine Accuracy@10,Dim 768 Cosine Precision@1,Dim 768 Cosine Precision@3,Dim 768 Cosine Precision@5,Dim 768 Cosine Precision@10,Dim 768 Cosine Recall@1,Dim 768 Cosine Recall@3,Dim 768 Cosine Recall@5,Dim 768 Cosine Recall@10,Dim 768 Cosine Ndcg@10,Dim 768 Cosine Mrr@10,Dim 768 Cosine Map@100,Dim 512 Cosine Accuracy@1,Dim 512 Cosine Accuracy@3,Dim 512 Cosine Accuracy@5,Dim 512 Cosine Accuracy@10,Dim 512 Cosine Precision@1,Dim 512 Cosine Precision@3,Dim 512 Cosine Precision@5,Dim 512 Cosine Precision@10,Dim 512 Cosine Recall@1,Dim 512 Cosine Recall@3,Dim 512 Cosine Recall@5,Dim 512 Cosine Recall@10,Dim 512 Cosine Ndcg@10,Dim 512 Cosine Mrr@10,Dim 512 Cosine Map@100,Dim 256 Cosine Accuracy@1,Dim 256 Cosine Accuracy@3,Dim 256 Cosine Accuracy@5,Dim 256 Cosine Accuracy@10,Dim 256 Cosine Precision@1,Dim 256 Cosine Precision@3,Dim 256 Cosine Precision@5,Dim 256 Cosine Precision@10,Dim 256 Cosine Recall@1,Dim 256 Cosine Recall@3,Dim 256 Cosine Recall@5,Dim 256 Cosine Recall@10,Dim 256 Cosine Ndcg@10,Dim 256 Cosine Mrr@10,Dim 256 Cosine Map@100,Dim 128 Cosine Accuracy@1,Dim 128 Cosine Accuracy@3,Dim 128 Cosine Accuracy@5,Dim 128 Cosine Accuracy@10,Dim 128 Cosine Precision@1,Dim 128 Cosine Precision@3,Dim 128 Cosine Precision@5,Dim 128 Cosine Precision@10,Dim 128 Cosine Recall@1,Dim 128 Cosine Recall@3,Dim 128 Cosine Recall@5,Dim 128 Cosine Recall@10,Dim 128 Cosine Ndcg@10,Dim 128 Cosine Mrr@10,Dim 128 Cosine Map@100,Dim 64 Cosine Accuracy@1,Dim 64 Cosine Accuracy@3,Dim 64 Cosine Accuracy@5,Dim 64 Cosine Accuracy@10,Dim 64 Cosine Precision@1,Dim 64 Cosine Precision@3,Dim 64 Cosine Precision@5,Dim 64 Cosine Precision@10,Dim 64 Cosine Recall@1,Dim 64 Cosine Recall@3,Dim 64 Cosine Recall@5,Dim 64 Cosine Recall@10,Dim 64 Cosine Ndcg@10,Dim 64 Cosine Mrr@10,Dim 64 Cosine Map@100,Sequential Score
0,1.5244,No log,0.68,0.81,0.841429,0.901429,0.68,0.27,0.168286,0.090143,0.68,0.81,0.841429,0.901429,0.790951,0.755882,0.759713,0.677143,0.805714,0.845714,0.894286,0.677143,0.268571,0.169143,0.089429,0.677143,0.805714,0.845714,0.894286,0.786448,0.752056,0.75638,0.668571,0.8,0.831429,0.877143,0.668571,0.266667,0.166286,0.087714,0.668571,0.8,0.831429,0.877143,0.775548,0.742885,0.748016,0.664286,0.785714,0.818571,0.874286,0.664286,0.261905,0.163714,0.087429,0.664286,0.785714,0.818571,0.874286,0.767728,0.733961,0.738187,0.614286,0.74,0.784286,0.841429,0.614286,0.246667,0.156857,0.084143,0.614286,0.74,0.784286,0.841429,0.726217,0.689673,0.694407,0.614286
1,0.6654,No log,0.691429,0.815714,0.85,0.905714,0.691429,0.271905,0.17,0.090571,0.691429,0.815714,0.85,0.905714,0.797937,0.763721,0.767703,0.688571,0.818571,0.851429,0.908571,0.688571,0.272857,0.170286,0.090857,0.688571,0.818571,0.851429,0.908571,0.798328,0.763282,0.76681,0.692857,0.811429,0.84,0.888571,0.692857,0.270476,0.168,0.088857,0.692857,0.811429,0.84,0.888571,0.791992,0.761062,0.765751,0.671429,0.81,0.84,0.888571,0.671429,0.27,0.168,0.088857,0.671429,0.81,0.84,0.888571,0.782393,0.748266,0.752168,0.637143,0.764286,0.804286,0.862857,0.637143,0.254762,0.160857,0.086286,0.637143,0.764286,0.804286,0.862857,0.747796,0.711442,0.716057,0.637143
2,0.4665,No log,0.694286,0.82,0.861429,0.904286,0.694286,0.273333,0.172286,0.090429,0.694286,0.82,0.861429,0.904286,0.799897,0.766518,0.770857,0.698571,0.824286,0.855714,0.908571,0.698571,0.274762,0.171143,0.090857,0.698571,0.824286,0.855714,0.908571,0.802969,0.769353,0.773098,0.7,0.821429,0.845714,0.892857,0.7,0.27381,0.169143,0.089286,0.7,0.821429,0.845714,0.892857,0.797374,0.766839,0.77147,0.68,0.814286,0.841429,0.89,0.68,0.271429,0.168286,0.089,0.68,0.814286,0.841429,0.89,0.78719,0.754099,0.758251,0.648571,0.77,0.808571,0.864286,0.648571,0.256667,0.161714,0.086429,0.648571,0.77,0.808571,0.864286,0.754671,0.719931,0.724814,0.648571
3,0.3673,No log,0.695714,0.821429,0.861429,0.907143,0.695714,0.27381,0.172286,0.090714,0.695714,0.821429,0.861429,0.907143,0.801089,0.767289,0.77136,0.701429,0.824286,0.855714,0.907143,0.701429,0.274762,0.171143,0.090714,0.701429,0.824286,0.855714,0.907143,0.803386,0.770365,0.77423,0.7,0.822857,0.847143,0.89,0.7,0.274286,0.169429,0.089,0.7,0.822857,0.847143,0.89,0.796736,0.766776,0.771743,0.682857,0.814286,0.842857,0.888571,0.682857,0.271429,0.168571,0.088857,0.682857,0.814286,0.842857,0.888571,0.788116,0.755736,0.760055,0.652857,0.771429,0.808571,0.864286,0.652857,0.257143,0.161714,0.086429,0.652857,0.771429,0.808571,0.864286,0.756586,0.722484,0.72738,0.652857


TrainOutput(global_step=48, training_loss=0.6937383562326431, metrics={'train_runtime': 323.653, 'train_samples_per_second': 77.861, 'train_steps_per_second': 0.148, 'total_flos': 0.0, 'train_loss': 0.6937383562326431, 'epoch': 3.8984771573604062})

In [14]:
from sentence_transformers import SentenceTransformer

fine_tuned_model = SentenceTransformer(
    args.output_dir, device="cuda" if torch.cuda.is_available() else "cpu"
)
# Evaluate the model
results = evaluator(fine_tuned_model)

# # COMMENT IN for full results
# print(results)

# Print the main score
for dim in matryoshka_dimensions:
    key = f"dim_{dim}_cosine_ndcg@10"
    print(f"{key}: {results[key]}")

dim_768_cosine_ndcg@10: 0.8006141326599298
dim_512_cosine_ndcg@10: 0.8010131384546558
dim_256_cosine_ndcg@10: 0.7961564309588393
dim_128_cosine_ndcg@10: 0.7875840341168918
dim_64_cosine_ndcg@10: 0.754755355270157
