In [17]:
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import (
    InformationRetrievalEvaluator,
    SequentialEvaluator,
)
from sentence_transformers.util import cos_sim
from datasets import load_dataset, concatenate_datasets
from sentence_transformers import SentenceTransformerModelCardData, SentenceTransformer
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
from sentence_transformers import SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers import SentenceTransformerTrainer

In [18]:
df = pd.read_csv("./input/train-data/train.csv")
df.rename(columns={"Unnamed: 0":"id","question":"anchor","answer":"positive"},inplace=True)
df.drop("text",axis=1,inplace=True)
df.head()

Unnamed: 0,id,anchor,positive
0,0,What significant challenges does the rapid exp...,"Balancing scalability and security, computatio..."
1,1,How does the proposed framework address the in...,By employing edge aggregating servers and Ethe...
2,2,What are the primary benefits of using blockch...,"Data integrity, device authentication, and pro..."
3,3,Why are traditional blockchain-based solutions...,"Due to scalability, cost issues, and computati..."
4,4,How does the proposed framework ensure data pr...,Through the use of Zero-Knowledge Proofs (ZKPs...


In [19]:
train, test = train_test_split(df, test_size=0.1, random_state=42)

train.to_json("train_dataset.json", orient="records")
test.to_json("test_dataset.json", orient="records")

In [20]:
model_id = "BAAI/bge-base-en-v1.5" 
matryoshka_dimensions = [768, 512, 256, 128, 64]

In [21]:
model = SentenceTransformer(model_id, device="cuda" if torch.cuda.is_available() else "cpu")

In [22]:
test_dataset = load_dataset("json", data_files="test_dataset.json", split="train")
train_dataset = load_dataset("json", data_files="train_dataset.json", split="train")
combined_dataset = concatenate_datasets([train_dataset, test_dataset])

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [23]:
for i in train_dataset:
    print(i)
    break

{'id': 622, 'anchor': 'What is the significance of the study period chosen for the analysis?', 'positive': 'The study period captures the initial impact of COVID-19 on global stock markets.'}


In [24]:
corpus = dict(zip(combined_dataset["id"], combined_dataset["positive"])) 

queries = dict(zip(test_dataset["id"], test_dataset["anchor"])) 

In [25]:
for i in corpus.items():
    print(i)
    break
for i in queries.items():
    print(i)
    break

(622, 'The study period captures the initial impact of COVID-19 on global stock markets.')
(192, 'How do adversarial examples challenge the reliability of deep learning models?')


In [26]:
relevant_docs = {} #  Query ID to relevant documents (qid => set([relevant_cids])
for q_id in queries:
    relevant_docs[q_id] = [q_id]

In [27]:
for i in relevant_docs.items():
    print(i)
    break

(192, [192])


In [28]:
'''Given a set of queries and a large corpus set. It will retrieve for each query the top-k most similar document. 
It measures Mean Reciprocal Rank (MRR), Recall@k, and Normalized Discounted Cumulative Gain (NDCG)
https://sbert.net/docs/package_reference/sentence_transformer/evaluation.html'''

matryoshka_evaluators = []

for dim in matryoshka_dimensions:
    ir_evaluator = InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name=f"dim_{dim}",
        truncate_dim=dim,  # Truncate the embeddings to a certain dimension
        score_functions={"cosine": cos_sim},
    )
    matryoshka_evaluators.append(ir_evaluator)

matryoshka_evaluators

[<sentence_transformers.evaluation.InformationRetrievalEvaluator.InformationRetrievalEvaluator at 0x7c5b540f1e70>,
 <sentence_transformers.evaluation.InformationRetrievalEvaluator.InformationRetrievalEvaluator at 0x7c5b540f10f0>,
 <sentence_transformers.evaluation.InformationRetrievalEvaluator.InformationRetrievalEvaluator at 0x7c5b540f2a40>,
 <sentence_transformers.evaluation.InformationRetrievalEvaluator.InformationRetrievalEvaluator at 0x7c5b540f1150>,
 <sentence_transformers.evaluation.InformationRetrievalEvaluator.InformationRetrievalEvaluator at 0x7c5b540f0fa0>]

In [29]:
evaluator = SequentialEvaluator(matryoshka_evaluators)
evaluator

<sentence_transformers.evaluation.SequentialEvaluator.SequentialEvaluator at 0x7c5b540f33d0>

In [30]:
results = evaluator(model)

results

{'dim_768_cosine_accuracy@1': 0.19753086419753085,
 'dim_768_cosine_accuracy@3': 0.32098765432098764,
 'dim_768_cosine_accuracy@5': 0.38271604938271603,
 'dim_768_cosine_accuracy@10': 0.4691358024691358,
 'dim_768_cosine_precision@1': 0.19753086419753085,
 'dim_768_cosine_precision@3': 0.10699588477366255,
 'dim_768_cosine_precision@5': 0.07654320987654321,
 'dim_768_cosine_precision@10': 0.04691358024691358,
 'dim_768_cosine_recall@1': 0.19753086419753085,
 'dim_768_cosine_recall@3': 0.32098765432098764,
 'dim_768_cosine_recall@5': 0.38271604938271603,
 'dim_768_cosine_recall@10': 0.4691358024691358,
 'dim_768_cosine_ndcg@10': 0.3220467006858192,
 'dim_768_cosine_mrr@10': 0.2764060356652949,
 'dim_768_cosine_map@100': 0.29134803459402797,
 'dim_512_cosine_accuracy@1': 0.1728395061728395,
 'dim_512_cosine_accuracy@3': 0.32098765432098764,
 'dim_512_cosine_accuracy@5': 0.35802469135802467,
 'dim_512_cosine_accuracy@10': 0.43209876543209874,
 'dim_512_cosine_precision@1': 0.1728395061728

In [31]:
for dim in matryoshka_dimensions:
    key = f"dim_{dim}_cosine_ndcg@10"
    print
    print(f"{key}: {results[key]}")

dim_768_cosine_ndcg@10: 0.3220467006858192
dim_512_cosine_ndcg@10: 0.2997438851350533
dim_256_cosine_ndcg@10: 0.28755556990920456
dim_128_cosine_ndcg@10: 0.28250088889476005
dim_64_cosine_ndcg@10: 0.23287559095217322


In [32]:
model_id = "BAAI/bge-base-en-v1.5"

model = SentenceTransformer(
    model_id,
    model_kwargs={"attn_implementation": "sdpa"},
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name="BGE base blockchain Matryoshka",
    ),
).to("cuda")

In [33]:
matryoshka_dimensions = [768, 512, 256, 128, 64]  

inner_train_loss = MultipleNegativesRankingLoss(model)

train_loss = MatryoshkaLoss(
                        model, 
                        inner_train_loss,
                        matryoshka_dims=matryoshka_dimensions
                        )

In [34]:
train_dataset = load_dataset("json", data_files="train_dataset.json", split="train")

# define training arguments
args = SentenceTransformerTrainingArguments(
    output_dir="bge-base-blockchain-matryoshka", 
    num_train_epochs=4,                         
    per_device_train_batch_size=16,             # train batch size
    gradient_accumulation_steps=8,             # for a global batch size of 512
    per_device_eval_batch_size=8,              # evaluation batch size
    warmup_ratio=0.1,                           # warmup ratio
    learning_rate=2e-5,                         # learning rate, 2e-5 is a good value
    lr_scheduler_type="cosine",                 # use constant learning rate scheduler
    optim="adamw_torch_fused",                  # use fused adamw optimizer
    tf32= False,                                  # use tf32 precision
    bf16=True,                                  # use bf16 precision
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    eval_strategy="epoch",                      # evaluate after each epoch
    save_strategy="epoch",                      # save after each epoch
    logging_steps=10,                           # log every 10 steps
    save_total_limit=3,                         # save only the last 3 models
    load_best_model_at_end=True,                # load the best model when training ends
    metric_for_best_model="eval_dim_128_cosine_ndcg@10",  # Optimizing for the best ndcg@10 score for the 128 dimension
    report_to="none")

In [35]:
trainer = SentenceTransformerTrainer(
    model=model, 
    args=args,  
    train_dataset=train_dataset.select_columns(["anchor", "positive"]),  
    loss=train_loss,
    evaluator=evaluator,
)

In [36]:
trainer.train()

Epoch,Training Loss,Validation Loss,Dim 768 Cosine Accuracy@1,Dim 768 Cosine Accuracy@3,Dim 768 Cosine Accuracy@5,Dim 768 Cosine Accuracy@10,Dim 768 Cosine Precision@1,Dim 768 Cosine Precision@3,Dim 768 Cosine Precision@5,Dim 768 Cosine Precision@10,Dim 768 Cosine Recall@1,Dim 768 Cosine Recall@3,Dim 768 Cosine Recall@5,Dim 768 Cosine Recall@10,Dim 768 Cosine Ndcg@10,Dim 768 Cosine Mrr@10,Dim 768 Cosine Map@100,Dim 512 Cosine Accuracy@1,Dim 512 Cosine Accuracy@3,Dim 512 Cosine Accuracy@5,Dim 512 Cosine Accuracy@10,Dim 512 Cosine Precision@1,Dim 512 Cosine Precision@3,Dim 512 Cosine Precision@5,Dim 512 Cosine Precision@10,Dim 512 Cosine Recall@1,Dim 512 Cosine Recall@3,Dim 512 Cosine Recall@5,Dim 512 Cosine Recall@10,Dim 512 Cosine Ndcg@10,Dim 512 Cosine Mrr@10,Dim 512 Cosine Map@100,Dim 256 Cosine Accuracy@1,Dim 256 Cosine Accuracy@3,Dim 256 Cosine Accuracy@5,Dim 256 Cosine Accuracy@10,Dim 256 Cosine Precision@1,Dim 256 Cosine Precision@3,Dim 256 Cosine Precision@5,Dim 256 Cosine Precision@10,Dim 256 Cosine Recall@1,Dim 256 Cosine Recall@3,Dim 256 Cosine Recall@5,Dim 256 Cosine Recall@10,Dim 256 Cosine Ndcg@10,Dim 256 Cosine Mrr@10,Dim 256 Cosine Map@100,Dim 128 Cosine Accuracy@1,Dim 128 Cosine Accuracy@3,Dim 128 Cosine Accuracy@5,Dim 128 Cosine Accuracy@10,Dim 128 Cosine Precision@1,Dim 128 Cosine Precision@3,Dim 128 Cosine Precision@5,Dim 128 Cosine Precision@10,Dim 128 Cosine Recall@1,Dim 128 Cosine Recall@3,Dim 128 Cosine Recall@5,Dim 128 Cosine Recall@10,Dim 128 Cosine Ndcg@10,Dim 128 Cosine Mrr@10,Dim 128 Cosine Map@100,Dim 64 Cosine Accuracy@1,Dim 64 Cosine Accuracy@3,Dim 64 Cosine Accuracy@5,Dim 64 Cosine Accuracy@10,Dim 64 Cosine Precision@1,Dim 64 Cosine Precision@3,Dim 64 Cosine Precision@5,Dim 64 Cosine Precision@10,Dim 64 Cosine Recall@1,Dim 64 Cosine Recall@3,Dim 64 Cosine Recall@5,Dim 64 Cosine Recall@10,Dim 64 Cosine Ndcg@10,Dim 64 Cosine Mrr@10,Dim 64 Cosine Map@100,Sequential Score
0,No log,No log,0.222222,0.37037,0.45679,0.592593,0.222222,0.123457,0.091358,0.059259,0.222222,0.37037,0.45679,0.592593,0.385446,0.321786,0.335582,0.222222,0.358025,0.432099,0.567901,0.222222,0.119342,0.08642,0.05679,0.222222,0.358025,0.432099,0.567901,0.379629,0.321669,0.336368,0.222222,0.345679,0.407407,0.580247,0.222222,0.115226,0.081481,0.058025,0.222222,0.345679,0.407407,0.580247,0.378679,0.317519,0.331103,0.185185,0.320988,0.407407,0.54321,0.185185,0.106996,0.081481,0.054321,0.185185,0.320988,0.407407,0.54321,0.33882,0.276705,0.289858,0.148148,0.259259,0.283951,0.469136,0.148148,0.08642,0.05679,0.046914,0.148148,0.259259,0.283951,0.469136,0.283907,0.228762,0.242168,0.283907
1,5.403100,No log,0.222222,0.419753,0.506173,0.679012,0.222222,0.139918,0.101235,0.067901,0.222222,0.419753,0.506173,0.679012,0.423446,0.345277,0.354727,0.222222,0.395062,0.506173,0.641975,0.222222,0.131687,0.101235,0.064198,0.222222,0.395062,0.506173,0.641975,0.415034,0.344459,0.356421,0.246914,0.358025,0.45679,0.641975,0.246914,0.119342,0.091358,0.064198,0.246914,0.358025,0.45679,0.641975,0.41069,0.341402,0.353561,0.197531,0.345679,0.45679,0.567901,0.197531,0.115226,0.091358,0.05679,0.197531,0.345679,0.45679,0.567901,0.365083,0.301778,0.317183,0.160494,0.296296,0.382716,0.518519,0.160494,0.098765,0.076543,0.051852,0.160494,0.296296,0.382716,0.518519,0.320024,0.259289,0.273385,0.320024
2,5.403100,No log,0.234568,0.444444,0.506173,0.703704,0.234568,0.148148,0.101235,0.07037,0.234568,0.444444,0.506173,0.703704,0.442395,0.362542,0.370762,0.222222,0.407407,0.518519,0.654321,0.222222,0.135802,0.103704,0.065432,0.222222,0.407407,0.518519,0.654321,0.423616,0.351778,0.363392,0.234568,0.382716,0.469136,0.679012,0.234568,0.127572,0.093827,0.067901,0.234568,0.382716,0.469136,0.679012,0.426578,0.350098,0.360351,0.222222,0.382716,0.469136,0.567901,0.222222,0.127572,0.093827,0.05679,0.222222,0.382716,0.469136,0.567901,0.385192,0.327871,0.344512,0.160494,0.308642,0.419753,0.555556,0.160494,0.102881,0.083951,0.055556,0.160494,0.308642,0.419753,0.555556,0.336319,0.269028,0.282068,0.336319
3,3.573800,No log,0.246914,0.444444,0.518519,0.716049,0.246914,0.148148,0.103704,0.071605,0.246914,0.444444,0.518519,0.716049,0.455199,0.375578,0.382917,0.222222,0.407407,0.518519,0.666667,0.222222,0.135802,0.103704,0.066667,0.222222,0.407407,0.518519,0.666667,0.428979,0.355242,0.365873,0.234568,0.382716,0.469136,0.691358,0.234568,0.127572,0.093827,0.069136,0.234568,0.382716,0.469136,0.691358,0.429834,0.351004,0.360267,0.222222,0.395062,0.481481,0.567901,0.222222,0.131687,0.096296,0.05679,0.222222,0.395062,0.481481,0.567901,0.386709,0.329605,0.347031,0.160494,0.308642,0.432099,0.555556,0.160494,0.102881,0.08642,0.055556,0.160494,0.308642,0.432099,0.555556,0.338634,0.271894,0.285165,0.338634


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

TrainOutput(global_step=20, training_loss=4.488420486450195, metrics={'train_runtime': 65.6946, 'train_samples_per_second': 44.022, 'train_steps_per_second': 0.304, 'total_flos': 0.0, 'train_loss': 4.488420486450195, 'epoch': 3.869565217391304})

In [37]:
trainer.save_model()

In [38]:
from sentence_transformers import SentenceTransformer

fine_tuned_model = SentenceTransformer(args.output_dir, device="cuda")

results = evaluator(fine_tuned_model)

for dim in matryoshka_dimensions:
    key = f"dim_{dim}_cosine_ndcg@10"
    print(f"{key}: {results[key]}")

dim_768_cosine_ndcg@10: 0.45347804212937676
dim_512_cosine_ndcg@10: 0.42930530454496996
dim_256_cosine_ndcg@10: 0.4295385601742736
dim_128_cosine_ndcg@10: 0.3863302926159187
dim_64_cosine_ndcg@10: 0.3387819186395353


In [40]:
'''
Dimension	Baseline	Fine-tuned	Improvement
768	        0.322	    0.453	   40.68 %
512	        0.299	    0.429	   43.47 %
256      	0.287	    0.429	   49.47 %
128	        0.282	    0.386	   36.88 %
64	        0.232	    0.338	   45.68 %
'''

'\nDimension\tBaseline\tFine-tuned\tImprovement\n768\t        0.322\t    0.453\t   40.68 %\n512\t        0.299\t    0.429\t   43.47 %\n256      \t0.287\t    0.429\t   49.47 %\n128\t        0.282\t    0.386\t   36.88 %\n64\t        0.232\t    0.338\t   45.68 %\n'