In [3]:
# !pip install --quiet "torch==2.1.2" tensorboard

# # Install Hugging Face libraries
# !pip install --quiet --upgrade \
#   "sentence-transformers>=3" \
#   "datasets==2.19.1"  \
#   "transformers==4.41.2"

In [1]:
import json

from torch.utils.data import DataLoader
from sentence_transformers import InputExample
from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
from sentence_transformers import SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers
import torch
from sentence_transformers.evaluation import (
    InformationRetrievalEvaluator,
    SequentialEvaluator,
)
from sentence_transformers.util import cos_sim

  from tqdm.autonotebook import tqdm, trange


In [2]:
TRAIN_DATASET_FPATH = './generated_data/train_dataset.json'
VAL_DATASET_FPATH = './generated_data/val_dataset.json'

In [3]:
with open(TRAIN_DATASET_FPATH, 'r') as f:
    train_dataset = json.load(f)

with open(VAL_DATASET_FPATH, 'r') as f:
    val_dataset = json.load(f)

In [4]:
model_id = 'all-mpnet-base-v2'  # Hugging Face model ID

model = SentenceTransformer(
    model_id, device="cuda" if torch.cuda.is_available() else "cpu"
)



In [5]:
# dataset = train_dataset

# BATCH_SIZE = 128

matryoshka_dimensions = [768, 512, 256, 128, 64] # Important: large to small

# corpus = dataset['corpus']
# queries = dataset['queries']
# relevant_docs = dataset['relevant_docs']

# examples = []
# for query_id, query in queries.items():
#     node_id = relevant_docs[query_id][0]
#     text = corpus[node_id]
#     example = InputExample(texts=[query, text])
#     examples.append(example)

In [6]:
# loader = DataLoader(
#     examples, batch_size=BATCH_SIZE
# )

In [7]:
inner_train_loss = MultipleNegativesRankingLoss(model)
train_loss = MatryoshkaLoss(
    model, inner_train_loss, matryoshka_dims=matryoshka_dimensions
)

In [8]:
dataset = val_dataset

corpus = dataset['corpus']
queries = dataset['queries']
relevant_docs = dataset['relevant_docs']

matryoshka_evaluators = []
# Iterate over the different dimensions
for dim in matryoshka_dimensions:
    ir_evaluator = InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name=f"dim_{dim}",
        truncate_dim=dim,  # Truncate the embeddings to a certain dimension
        score_functions={"cosine": cos_sim},
    )
    matryoshka_evaluators.append(ir_evaluator)

# Create a sequential evaluator
evaluator = SequentialEvaluator(matryoshka_evaluators)

first, let's see the accuracy of different dimensions without fine tuning.

In [9]:
# Evaluate the model
results = evaluator(model)

# Print the main score
for dim in matryoshka_dimensions:
    key = f"dim_{dim}_cosine_ndcg@10"
    print
    print(f"{key}: {results[key]}")

dim_768_cosine_ndcg@10: 0.43439555076932584
dim_512_cosine_ndcg@10: 0.4261059206595997
dim_256_cosine_ndcg@10: 0.4280086571502763
dim_128_cosine_ndcg@10: 0.4154250006485585
dim_64_cosine_ndcg@10: 0.3755185794895535


In [None]:
# I also want to compare the openAI embedding performance

model_openai = openai.embedding()

results = evaluator(model_openai)

# Print the main score
for dim in matryoshka_dimensions:
    key = f"dim_{dim}_cosine_ndcg@10"
    print
    print(f"{key}: {results[key]}")

Then, we will fine tune the embedding model to see if there is a improvement.

In [10]:
trainset = {
    "queries": [],  # List to store all queries
    "corpus": []    # List to store all corresponding corpus texts
}

# Iterate through the queries and relevant_docs
for query_id, query in train_dataset['queries'].items():
    # Check if the query_id exists in relevant_docs and corpus
    if query_id in train_dataset['relevant_docs']:
        relevant_doc_id = train_dataset['relevant_docs'][query_id][0]  # Get relevant doc ID

        # Check if relevant_doc_id exists in the corpus
        if relevant_doc_id in train_dataset['corpus']:
            # Append query to "queries" list
            trainset['queries'].append(query)

            # Append corresponding corpus text to "corpus" list
            trainset['corpus'].append(train_dataset['corpus'][relevant_doc_id])

In [11]:
from datasets import Dataset

hf_dataset = Dataset.from_dict(trainset)

In [12]:
hf_dataset

Dataset({
    features: ['queries', 'corpus'],
    num_rows: 1090
})

In [17]:
args = SentenceTransformerTrainingArguments(
    output_dir="mpnet-base-v2-matryoshka", # output directory and hugging face model ID
    num_train_epochs=50,                         # number of epochs
    per_device_train_batch_size=32,             # train batch size
    gradient_accumulation_steps=16,             # for a global batch size of 512
    per_device_eval_batch_size=16,              # evaluation batch size
    warmup_ratio=0.1,                           # warmup ratio
    learning_rate=2e-5,                         # learning rate, 2e-5 is a good value
    lr_scheduler_type="cosine",                 # use constant learning rate scheduler
    optim="adamw_torch_fused",                  # use fused adamw optimizer
    tf32=False,                                  # use tf32 precision
    fp16=True,    #bf16=true                              # use bf16 precision
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    eval_strategy="epoch",                      # evaluate after each epoch
    save_strategy="epoch",                      # save after each epoch
    logging_steps=10,                           # log every 10 steps
    save_total_limit=3,                         # save only the last 3 models
    load_best_model_at_end=True,                # load the best model when training ends
    metric_for_best_model="eval_dim_128_cosine_ndcg@10",  # Optimizing for the best ndcg@10 score for the 128 dimension
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [18]:
from sentence_transformers import SentenceTransformerTrainer

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,  # training arguments
    train_dataset=hf_dataset,  # training dataset
    loss=train_loss,
    evaluator=evaluator,
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [19]:
# start training, the model will be automatically saved to the hub and the output directory
trainer.train()

# save the best model
trainer.save_model()

Epoch,Training Loss,Validation Loss,Dim 768 Cosine Accuracy@1,Dim 768 Cosine Accuracy@3,Dim 768 Cosine Accuracy@5,Dim 768 Cosine Accuracy@10,Dim 768 Cosine Precision@1,Dim 768 Cosine Precision@3,Dim 768 Cosine Precision@5,Dim 768 Cosine Precision@10,Dim 768 Cosine Recall@1,Dim 768 Cosine Recall@3,Dim 768 Cosine Recall@5,Dim 768 Cosine Recall@10,Dim 768 Cosine Ndcg@10,Dim 768 Cosine Mrr@10,Dim 768 Cosine Map@100,Dim 512 Cosine Accuracy@1,Dim 512 Cosine Accuracy@3,Dim 512 Cosine Accuracy@5,Dim 512 Cosine Accuracy@10,Dim 512 Cosine Precision@1,Dim 512 Cosine Precision@3,Dim 512 Cosine Precision@5,Dim 512 Cosine Precision@10,Dim 512 Cosine Recall@1,Dim 512 Cosine Recall@3,Dim 512 Cosine Recall@5,Dim 512 Cosine Recall@10,Dim 512 Cosine Ndcg@10,Dim 512 Cosine Mrr@10,Dim 512 Cosine Map@100,Dim 256 Cosine Accuracy@1,Dim 256 Cosine Accuracy@3,Dim 256 Cosine Accuracy@5,Dim 256 Cosine Accuracy@10,Dim 256 Cosine Precision@1,Dim 256 Cosine Precision@3,Dim 256 Cosine Precision@5,Dim 256 Cosine Precision@10,Dim 256 Cosine Recall@1,Dim 256 Cosine Recall@3,Dim 256 Cosine Recall@5,Dim 256 Cosine Recall@10,Dim 256 Cosine Ndcg@10,Dim 256 Cosine Mrr@10,Dim 256 Cosine Map@100,Dim 128 Cosine Accuracy@1,Dim 128 Cosine Accuracy@3,Dim 128 Cosine Accuracy@5,Dim 128 Cosine Accuracy@10,Dim 128 Cosine Precision@1,Dim 128 Cosine Precision@3,Dim 128 Cosine Precision@5,Dim 128 Cosine Precision@10,Dim 128 Cosine Recall@1,Dim 128 Cosine Recall@3,Dim 128 Cosine Recall@5,Dim 128 Cosine Recall@10,Dim 128 Cosine Ndcg@10,Dim 128 Cosine Mrr@10,Dim 128 Cosine Map@100,Dim 64 Cosine Accuracy@1,Dim 64 Cosine Accuracy@3,Dim 64 Cosine Accuracy@5,Dim 64 Cosine Accuracy@10,Dim 64 Cosine Precision@1,Dim 64 Cosine Precision@3,Dim 64 Cosine Precision@5,Dim 64 Cosine Precision@10,Dim 64 Cosine Recall@1,Dim 64 Cosine Recall@3,Dim 64 Cosine Recall@5,Dim 64 Cosine Recall@10,Dim 64 Cosine Ndcg@10,Dim 64 Cosine Mrr@10,Dim 64 Cosine Map@100,Sequential Score
0,No log,No log,0.270073,0.507299,0.682482,0.813869,0.270073,0.1691,0.136496,0.081387,0.270073,0.507299,0.682482,0.813869,0.521848,0.430285,0.439398,0.255474,0.514599,0.682482,0.806569,0.255474,0.171533,0.136496,0.080657,0.255474,0.514599,0.682482,0.806569,0.512656,0.420169,0.429177,0.266423,0.532847,0.660584,0.806569,0.266423,0.177616,0.132117,0.080657,0.266423,0.532847,0.660584,0.806569,0.519174,0.428677,0.437855,0.244526,0.49635,0.653285,0.79927,0.244526,0.16545,0.130657,0.079927,0.244526,0.49635,0.653285,0.79927,0.501313,0.407674,0.417356,0.218978,0.445255,0.587591,0.770073,0.218978,0.148418,0.117518,0.077007,0.218978,0.445255,0.587591,0.770073,0.469252,0.37596,0.387683,0.387683
1,No log,No log,0.262774,0.510949,0.678832,0.817518,0.262774,0.170316,0.135766,0.081752,0.262774,0.510949,0.678832,0.817518,0.521053,0.427959,0.43696,0.270073,0.521898,0.675182,0.806569,0.270073,0.173966,0.135036,0.080657,0.270073,0.521898,0.675182,0.806569,0.518817,0.428309,0.437636,0.266423,0.525547,0.667883,0.810219,0.266423,0.175182,0.133577,0.081022,0.266423,0.525547,0.667883,0.810219,0.52075,0.429703,0.439047,0.248175,0.50365,0.656934,0.80292,0.248175,0.167883,0.131387,0.080292,0.248175,0.50365,0.656934,0.80292,0.504814,0.411115,0.420715,0.226277,0.445255,0.59854,0.784672,0.226277,0.148418,0.119708,0.078467,0.226277,0.445255,0.59854,0.784672,0.478493,0.383871,0.394792,0.394792
2,No log,No log,0.270073,0.532847,0.682482,0.824818,0.270073,0.177616,0.136496,0.082482,0.270073,0.532847,0.682482,0.824818,0.527546,0.434162,0.443042,0.270073,0.525547,0.678832,0.821168,0.270073,0.175182,0.135766,0.082117,0.270073,0.525547,0.678832,0.821168,0.52482,0.431856,0.440375,0.273723,0.536496,0.664234,0.817518,0.273723,0.178832,0.132847,0.081752,0.273723,0.536496,0.664234,0.817518,0.52782,0.436721,0.445723,0.251825,0.507299,0.675182,0.806569,0.251825,0.1691,0.135036,0.080657,0.251825,0.507299,0.675182,0.806569,0.50792,0.414021,0.423834,0.229927,0.456204,0.60219,0.806569,0.229927,0.152068,0.120438,0.080657,0.229927,0.456204,0.60219,0.806569,0.4898,0.392055,0.401294,0.401294
3,No log,No log,0.262774,0.521898,0.678832,0.839416,0.262774,0.173966,0.135766,0.083942,0.262774,0.521898,0.678832,0.839416,0.528316,0.431006,0.439234,0.284672,0.521898,0.678832,0.813869,0.284672,0.173966,0.135766,0.081387,0.284672,0.521898,0.678832,0.813869,0.527598,0.437869,0.447762,0.266423,0.532847,0.671533,0.806569,0.266423,0.177616,0.134307,0.080657,0.266423,0.532847,0.671533,0.806569,0.523521,0.433781,0.444187,0.259124,0.50365,0.667883,0.810219,0.259124,0.167883,0.133577,0.081022,0.259124,0.50365,0.667883,0.810219,0.515794,0.423162,0.433227,0.237226,0.489051,0.620438,0.806569,0.237226,0.163017,0.124088,0.080657,0.237226,0.489051,0.620438,0.806569,0.497244,0.40123,0.410793,0.410793
4,6.434100,No log,0.270073,0.532847,0.671533,0.835766,0.270073,0.177616,0.134307,0.083577,0.270073,0.532847,0.671533,0.835766,0.531802,0.436635,0.445464,0.284672,0.532847,0.678832,0.810219,0.284672,0.177616,0.135766,0.081022,0.284672,0.532847,0.678832,0.810219,0.528797,0.440295,0.450699,0.277372,0.525547,0.686131,0.817518,0.277372,0.175182,0.137226,0.081752,0.277372,0.525547,0.686131,0.817518,0.531178,0.440721,0.450048,0.259124,0.518248,0.660584,0.817518,0.259124,0.172749,0.132117,0.081752,0.259124,0.518248,0.660584,0.817518,0.522312,0.429327,0.438864,0.240876,0.514599,0.638686,0.80292,0.240876,0.171533,0.127737,0.080292,0.240876,0.514599,0.638686,0.80292,0.501495,0.407243,0.416929,0.416929
5,6.434100,No log,0.281022,0.529197,0.678832,0.832117,0.281022,0.176399,0.135766,0.083212,0.281022,0.529197,0.678832,0.832117,0.53724,0.444662,0.45395,0.284672,0.543796,0.675182,0.810219,0.284672,0.181265,0.135036,0.081022,0.284672,0.543796,0.675182,0.810219,0.534216,0.446744,0.457341,0.277372,0.536496,0.693431,0.821168,0.277372,0.178832,0.138686,0.082117,0.277372,0.536496,0.693431,0.821168,0.534284,0.443632,0.453033,0.270073,0.540146,0.664234,0.839416,0.270073,0.180049,0.132847,0.083942,0.270073,0.540146,0.664234,0.839416,0.535542,0.440421,0.448563,0.248175,0.5,0.642336,0.806569,0.248175,0.166667,0.128467,0.080657,0.248175,0.5,0.642336,0.806569,0.507738,0.41416,0.423928,0.423928
6,6.434100,No log,0.281022,0.525547,0.671533,0.846715,0.281022,0.175182,0.134307,0.084672,0.281022,0.525547,0.671533,0.846715,0.542679,0.447489,0.455754,0.277372,0.540146,0.686131,0.821168,0.277372,0.180049,0.137226,0.082117,0.277372,0.540146,0.686131,0.821168,0.535865,0.445599,0.455529,0.291971,0.525547,0.689781,0.846715,0.291971,0.175182,0.137956,0.084672,0.291971,0.525547,0.689781,0.846715,0.547898,0.454265,0.46162,0.273723,0.532847,0.675182,0.857664,0.273723,0.177616,0.135036,0.085766,0.273723,0.532847,0.675182,0.857664,0.542865,0.444772,0.451245,0.255474,0.532847,0.656934,0.821168,0.255474,0.177616,0.131387,0.082117,0.255474,0.532847,0.656934,0.821168,0.523366,0.429785,0.438614,0.438614
7,6.434100,No log,0.288321,0.551095,0.686131,0.868613,0.288321,0.183698,0.137226,0.086861,0.288321,0.551095,0.686131,0.868613,0.554134,0.455957,0.462485,0.277372,0.558394,0.70073,0.868613,0.277372,0.186131,0.140146,0.086861,0.277372,0.558394,0.70073,0.868613,0.553679,0.454897,0.461048,0.29562,0.543796,0.69708,0.864964,0.29562,0.181265,0.139416,0.086496,0.29562,0.543796,0.69708,0.864964,0.559564,0.463662,0.469739,0.270073,0.536496,0.69708,0.864964,0.270073,0.178832,0.139416,0.086496,0.270073,0.536496,0.69708,0.864964,0.547177,0.447687,0.454069,0.277372,0.529197,0.675182,0.832117,0.277372,0.176399,0.135036,0.083212,0.277372,0.529197,0.675182,0.832117,0.53846,0.446327,0.454549,0.454549
8,6.434100,No log,0.288321,0.562044,0.718978,0.872263,0.288321,0.187348,0.143796,0.087226,0.288321,0.562044,0.718978,0.872263,0.560079,0.461929,0.46839,0.281022,0.562044,0.711679,0.864964,0.281022,0.187348,0.142336,0.086496,0.281022,0.562044,0.711679,0.864964,0.557566,0.46051,0.467341,0.288321,0.569343,0.722628,0.868613,0.288321,0.189781,0.144526,0.086861,0.288321,0.569343,0.722628,0.868613,0.56396,0.467512,0.473721,0.273723,0.565693,0.708029,0.868613,0.273723,0.188564,0.141606,0.086861,0.273723,0.565693,0.708029,0.868613,0.555014,0.456159,0.462622,0.284672,0.547445,0.693431,0.832117,0.284672,0.182482,0.138686,0.083212,0.284672,0.547445,0.693431,0.832117,0.543135,0.452107,0.460754,0.460754
9,4.654400,No log,0.288321,0.565693,0.729927,0.875912,0.288321,0.188564,0.145985,0.087591,0.288321,0.565693,0.729927,0.875912,0.564778,0.46672,0.47297,0.288321,0.576642,0.726277,0.875912,0.288321,0.192214,0.145255,0.087591,0.288321,0.576642,0.726277,0.875912,0.565191,0.467192,0.47327,0.284672,0.554745,0.740876,0.879562,0.284672,0.184915,0.148175,0.087956,0.284672,0.554745,0.740876,0.879562,0.567635,0.468855,0.474547,0.288321,0.580292,0.740876,0.861314,0.288321,0.193431,0.148175,0.086131,0.288321,0.580292,0.740876,0.861314,0.564152,0.469707,0.477084,0.291971,0.565693,0.70438,0.843066,0.291971,0.188564,0.140876,0.084307,0.291971,0.565693,0.70438,0.843066,0.550197,0.457915,0.465807,0.465807


In [20]:
fine_tuned_model = SentenceTransformer(
    args.output_dir, device="cuda" if torch.cuda.is_available() else "cpu"
)
# Evaluate the model
results = evaluator(fine_tuned_model)

# # COMMENT IN for full results
# print(results)

# Print the main score
for dim in matryoshka_dimensions:
    key = f"dim_{dim}_cosine_ndcg@10"
    print(f"{key}: {results[key]}")

dim_768_cosine_ndcg@10: 0.5962774229140696
dim_512_cosine_ndcg@10: 0.598387254630502
dim_256_cosine_ndcg@10: 0.5941885845319166
dim_128_cosine_ndcg@10: 0.6017626832431484
dim_64_cosine_ndcg@10: 0.5941843312040349


In [None]:
def get_openai_embedding(text):
    response = openai.Embedding.create(input=text, model="text-embedding-ada-002")
    return response['data'][0]['embedding']