<a href="https://colab.research.google.com/github/QiaoLin22/MASTER-LLM-DL/blob/main/FT_Embedding_Models_on_Domain_Specific_Data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!pip install --upgrade sentence-transformers datasets transformers torch tensorboard

In [2]:
import torch

from sentence_transformers import SentenceTransformer, SentenceTransformerModelCardData, SentenceTransformerTrainingArguments, SentenceTransformerTrainer
from sentence_transformers.evaluation import InformationRetrievalEvaluator, SequentialEvaluator
from sentence_transformers.util import cos_sim
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers

from datasets import load_dataset, concatenate_datasets

In [3]:
from huggingface_hub import login
from google.colab import userdata

login(token=userdata.get('HF_TOKEN'), add_to_git_credential=True)

In [4]:
# Load dataset from the hub
dataset = load_dataset("AdamLucek/legal-rag-positives-synthetic", split="train")

README.md:   0%|          | 0.00/2.88k [00:00<?, ?B/s]

train.csv:   0%|          | 0.00/4.18M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/6469 [00:00<?, ? examples/s]

In [5]:
# Clean & Format Columns
dataset = dataset.rename_column("question", "anchor")
dataset = dataset.rename_column("text", "positive")
dataset = dataset.remove_columns(["chunk_id", "case_name", "date_filed", "court", "question_id", "answer_location"]) # keep global_chunk_id

# Add an id column to the dataset
dataset = dataset.add_column("id", range(len(dataset)))

In [6]:
# Shuffle Dataset
dataset = dataset.shuffle()

# Split Dataset Into a 90/10 Train/Test split
dataset = dataset.train_test_split(test_size=0.1)

# Save Datasets to Disk
dataset["train"].to_json("train_dataset.json", orient="records")
dataset["test"].to_json("test_dataset.json", orient="records")

Creating json from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

337356

In [7]:
# Hugging Face model ID
model_id = "nomic-ai/modernbert-embed-base"

# Loading via SentenceTransformer
model = SentenceTransformer(
    model_id, device="cuda" if torch.cuda.is_available() else "cpu"
)

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/210 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/445k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/54.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.26k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/596M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/20.8k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/3.58M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/694 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

In [8]:

# Load train and test datasets from their respective JSON files
# These contain pairs of questions (anchors) and text chunks (positives)
test_dataset = load_dataset("json", data_files="test_dataset.json", split="train")
train_dataset = load_dataset("json", data_files="train_dataset.json", split="train")

# Combine train and test datasets into a single corpus
# This ensures we have all possible text chunks available for retrieval evaluation
corpus_dataset = concatenate_datasets([train_dataset, test_dataset])

# Convert datasets into dictionary format required by the InformationRetrievalEvaluator
# corpus: maps corpus IDs to their text chunks (documents)
# Format: {corpus_id: text_chunk}
corpus = dict(
    zip(corpus_dataset["id"], corpus_dataset["positive"])
)

# queries: maps query IDs to their questions
# Format: {query_id: question_text}
queries = dict(
    zip(test_dataset["id"], test_dataset["anchor"])
)

# Create a mapping between queries and their relevant documents
# This tells the evaluator which documents are correct matches for each query
relevant_docs = {}
for q_id, global_chunk_id in zip(test_dataset["id"], test_dataset["global_chunk_id"]):
    # Initialize empty list for each query if not already present
    if q_id not in relevant_docs:
        relevant_docs[q_id] = []

    # Find all corpus entries that share the same global_chunk_id
    # This handles cases where multiple questions can refer to the same text chunk
    matching_corpus_ids = [
        cid for cid, chunk in zip(corpus_dataset["id"], corpus_dataset["global_chunk_id"])
        if chunk == global_chunk_id
    ]
    # Add the matching corpus IDs to the relevant documents for this query
    relevant_docs[q_id].extend(matching_corpus_ids)

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [10]:
# Dimensions of interest
matryoshka_dimensions = [768, 512, 256, 128, 64] # Important: large to small

# Create empty list to hold evaluators
matryoshka_evaluators = []

# Create an evaluator for each above dimension
for dim in matryoshka_dimensions:
    # Define the evaluator
    ir_evaluator = InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name=f"dim_{dim}",
        truncate_dim=dim,  # Truncate the embeddings to the respective dimension
        score_functions={"cosine": cos_sim},
    )
    # Add to list
    matryoshka_evaluators.append(ir_evaluator)

# Create a sequential evaluator
# Able to run all our dimension specific InformationRetrievalEvaluators sequentially.
evaluator = SequentialEvaluator(matryoshka_evaluators)

In [11]:
# Evaluate the model
base_results = evaluator(model)

# Print header
print("\nBase Model Evaluation Results")
print("-" * 85)
print(f"{'Metric':15} {'768d':>12} {'512d':>12} {'256d':>12} {'128d':>12} {'64d':>12}")
print("-" * 85)

# List of metrics to display
metrics = [
    'ndcg@10',
    'mrr@10',
    'map@100',
    'accuracy@1',
    'accuracy@3',
    'accuracy@5',
    'accuracy@10',
    'precision@1',
    'precision@3',
    'precision@5',
    'precision@10',
    'recall@1',
    'recall@3',
    'recall@5',
    'recall@10'
]

# Print each metric
for metric in metrics:
    values = []
    for dim in matryoshka_dimensions:
        key = f"dim_{dim}_cosine_{metric}"
        values.append(base_results[key])

    # Highlight NDCG@10
    metric_name = f"=={metric}==" if metric == "ndcg@10" else metric
    print(f"{metric_name:15}", end="  ")
    for val in values:
        print(f"{val:12.4f}", end=" ")
    print()

# Print sequential score
print("-" * 85)
print(f"{'seq_score:'} {base_results['sequential_score']:1f}")




Base Model Evaluation Results
-------------------------------------------------------------------------------------
Metric                  768d         512d         256d         128d          64d
-------------------------------------------------------------------------------------
==ndcg@10==            0.3884       0.3779       0.3628       0.3264       0.2512 
mrr@10                 0.3325       0.3221       0.3074       0.2713       0.2103 
map@100                0.3790       0.3704       0.3543       0.3151       0.2491 
accuracy@1             0.2875       0.2751       0.2628       0.2303       0.1762 
accuracy@3             0.3199       0.3199       0.3076       0.2658       0.2056 
accuracy@5             0.4096       0.4019       0.3849       0.3400       0.2689 
accuracy@10            0.5054       0.4915       0.4745       0.4281       0.3369 
precision@1            0.2875       0.2751       0.2628       0.2303       0.1762 
precision@3            0.2808       0.2715       0.2

In [12]:
# load model with SDPA for using Flash Attention 2
model = SentenceTransformer(
    model_id,
    model_kwargs={"attn_implementation": "sdpa"},
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name="ModernBERT Embed base Legal Matryoshka",
    ),
)

In [13]:
# Initial Loss
base_loss = MultipleNegativesRankingLoss(model)

# Matryoshka Loss Wrapper
train_loss = MatryoshkaLoss(
    model, base_loss, matryoshka_dims=matryoshka_dimensions
)

In [15]:
# Training Arguments
args = SentenceTransformerTrainingArguments(
    output_dir="modernbert-embed-base-legal-matryoshka-qlin", # output directory and hugging face model ID
    num_train_epochs=4,                                        # number of epochs
    per_device_train_batch_size=32,                            # train batch size
    gradient_accumulation_steps=16,                            # for a global batch size of 512
    per_device_eval_batch_size=16,                             # evaluation batch size
    warmup_ratio=0.1,                                          # warmup ratio
    learning_rate=2e-5,                                        # learning rate, 2e-5 is a good value
    lr_scheduler_type="cosine",                                # use cosine learning rate scheduler
    optim="adamw_torch_fused",                                 # use fused adamw optimizer
    tf32=True,                                                 # use tf32 precision
    bf16=True,                                                 # use bf16 precision
    batch_sampler=BatchSamplers.NO_DUPLICATES,                 # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    eval_strategy="epoch",                                     # evaluate after each epoch
    save_strategy="epoch",                                     # save after each epoch
    logging_steps=10,                                          # log every 10 steps
    save_total_limit=3,                                        # save only the last 3 models
    load_best_model_at_end=True,                               # load the best model when training ends
    metric_for_best_model="eval_dim_128_cosine_ndcg@10",       # Optimizing for the best ndcg@10 score for the 128 dimension
    report_to="none"                                           # Turning off training logging for now, input 'wandb' etc. if desired.
)

In [16]:
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset.select_columns(
        ["positive", "anchor"]
    ),  # training dataset
    loss=train_loss,
    evaluator=evaluator,
)

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

In [17]:
# Start training
trainer.train()

# Save the best model based on our eval_dim_128_cosine_ndcg@10 criteria
trainer.save_model()

dataset = dataset.select_columns(['anchor', 'positive', 'negative'])


Epoch,Training Loss,Validation Loss,Dim 768 Cosine Accuracy@1,Dim 768 Cosine Accuracy@3,Dim 768 Cosine Accuracy@5,Dim 768 Cosine Accuracy@10,Dim 768 Cosine Precision@1,Dim 768 Cosine Precision@3,Dim 768 Cosine Precision@5,Dim 768 Cosine Precision@10,Dim 768 Cosine Recall@1,Dim 768 Cosine Recall@3,Dim 768 Cosine Recall@5,Dim 768 Cosine Recall@10,Dim 768 Cosine Ndcg@10,Dim 768 Cosine Mrr@10,Dim 768 Cosine Map@100,Dim 512 Cosine Accuracy@1,Dim 512 Cosine Accuracy@3,Dim 512 Cosine Accuracy@5,Dim 512 Cosine Accuracy@10,Dim 512 Cosine Precision@1,Dim 512 Cosine Precision@3,Dim 512 Cosine Precision@5,Dim 512 Cosine Precision@10,Dim 512 Cosine Recall@1,Dim 512 Cosine Recall@3,Dim 512 Cosine Recall@5,Dim 512 Cosine Recall@10,Dim 512 Cosine Ndcg@10,Dim 512 Cosine Mrr@10,Dim 512 Cosine Map@100,Dim 256 Cosine Accuracy@1,Dim 256 Cosine Accuracy@3,Dim 256 Cosine Accuracy@5,Dim 256 Cosine Accuracy@10,Dim 256 Cosine Precision@1,Dim 256 Cosine Precision@3,Dim 256 Cosine Precision@5,Dim 256 Cosine Precision@10,Dim 256 Cosine Recall@1,Dim 256 Cosine Recall@3,Dim 256 Cosine Recall@5,Dim 256 Cosine Recall@10,Dim 256 Cosine Ndcg@10,Dim 256 Cosine Mrr@10,Dim 256 Cosine Map@100,Dim 128 Cosine Accuracy@1,Dim 128 Cosine Accuracy@3,Dim 128 Cosine Accuracy@5,Dim 128 Cosine Accuracy@10,Dim 128 Cosine Precision@1,Dim 128 Cosine Precision@3,Dim 128 Cosine Precision@5,Dim 128 Cosine Precision@10,Dim 128 Cosine Recall@1,Dim 128 Cosine Recall@3,Dim 128 Cosine Recall@5,Dim 128 Cosine Recall@10,Dim 128 Cosine Ndcg@10,Dim 128 Cosine Mrr@10,Dim 128 Cosine Map@100,Dim 64 Cosine Accuracy@1,Dim 64 Cosine Accuracy@3,Dim 64 Cosine Accuracy@5,Dim 64 Cosine Accuracy@10,Dim 64 Cosine Precision@1,Dim 64 Cosine Precision@3,Dim 64 Cosine Precision@5,Dim 64 Cosine Precision@10,Dim 64 Cosine Recall@1,Dim 64 Cosine Recall@3,Dim 64 Cosine Recall@5,Dim 64 Cosine Recall@10,Dim 64 Cosine Ndcg@10,Dim 64 Cosine Mrr@10,Dim 64 Cosine Map@100,Sequential Score
1,90.1248,No log,0.48068,0.510046,0.61051,0.695518,0.48068,0.452859,0.352396,0.21592,0.169243,0.441783,0.559505,0.68135,0.584133,0.525708,0.568138,0.463679,0.499227,0.595054,0.700155,0.463679,0.441525,0.345286,0.217465,0.162416,0.428259,0.546883,0.687146,0.578587,0.513868,0.556685,0.420402,0.460587,0.556414,0.644513,0.420402,0.40237,0.320556,0.198764,0.147347,0.39168,0.510819,0.625837,0.528261,0.469121,0.51429,0.352396,0.391036,0.462133,0.554869,0.352396,0.334364,0.263679,0.17187,0.126224,0.331273,0.4195,0.540572,0.450204,0.394557,0.440507,0.264297,0.295209,0.36476,0.454405,0.264297,0.252962,0.209274,0.137713,0.093251,0.247553,0.335523,0.439593,0.354084,0.303995,0.34398,0.354084
2,39.4375,No log,0.503864,0.544049,0.656878,0.727975,0.503864,0.477589,0.374652,0.227821,0.177743,0.469346,0.597501,0.718315,0.616957,0.554135,0.598794,0.50541,0.547141,0.642968,0.717156,0.50541,0.479134,0.372798,0.223029,0.178001,0.46883,0.593122,0.702602,0.609663,0.552323,0.59545,0.459042,0.499227,0.5966,0.676971,0.459042,0.438434,0.345595,0.209274,0.160355,0.427743,0.547527,0.662159,0.565286,0.507164,0.551563,0.409583,0.446677,0.516229,0.5966,0.409583,0.386914,0.301391,0.184853,0.146445,0.379186,0.474369,0.582561,0.500142,0.449534,0.494156,0.321484,0.35085,0.429675,0.499227,0.321484,0.303452,0.244513,0.153632,0.114245,0.297656,0.390649,0.485961,0.407152,0.359317,0.39958,0.407152
3,24.1188,No log,0.51932,0.55796,0.661515,0.738794,0.51932,0.493045,0.382689,0.230912,0.182251,0.480809,0.609093,0.730551,0.629337,0.567947,0.610524,0.514683,0.554869,0.655332,0.731066,0.514683,0.487893,0.379598,0.225966,0.181607,0.477331,0.605873,0.716383,0.6205,0.562632,0.604529,0.476043,0.513138,0.615147,0.678516,0.476043,0.451829,0.356414,0.210046,0.166924,0.440108,0.564786,0.664863,0.575198,0.521251,0.564764,0.417311,0.460587,0.53323,0.607419,0.417311,0.395672,0.310665,0.18779,0.14915,0.388846,0.491628,0.593509,0.510438,0.45922,0.504006,0.332303,0.36476,0.452859,0.510046,0.332303,0.315817,0.257187,0.158423,0.116821,0.308733,0.410355,0.500258,0.42059,0.371699,0.412323,0.42059


In [18]:
# Upload model to hub
trainer.model.push_to_hub("modernbert-embed-base-legal-matryoshka-qlin")

model.safetensors:   0%|          | 0.00/596M [00:00<?, ?B/s]

'https://huggingface.co/stardriver007/modernbert-embed-base-legal-matryoshka-qlin/commit/b260efc8e423e2df80597fb141e56dbe459a5ae3'

Evaluating Trained Model

In [19]:
fine_tuned_model = SentenceTransformer(
    args.output_dir, device="cuda" if torch.cuda.is_available() else "cpu"
)

# Evaluate the model
ft_results = evaluator(fine_tuned_model)

# Print header
print("Fine Tuned Model Evaluation Results")
print("-" * 85)
print(f"{'Metric':15} {'768d':>12} {'512d':>12} {'256d':>12} {'128d':>12} {'64d':>12}")
print("-" * 85)

# List of metrics to display
metrics = [
    'ndcg@10',
    'mrr@10',
    'map@100',
    'accuracy@1',
    'accuracy@3',
    'accuracy@5',
    'accuracy@10',
    'precision@1',
    'precision@3',
    'precision@5',
    'precision@10',
    'recall@1',
    'recall@3',
    'recall@5',
    'recall@10'
]

# Print each metric
for metric in metrics:
    values = []
    for dim in matryoshka_dimensions:
        key = f"dim_{dim}_cosine_{metric}"
        values.append(ft_results[key])

    # Highlight NDCG@10
    metric_name = f"=={metric}==" if metric == "ndcg@10" else metric
    print(f"{metric_name:15}", end="  ")
    for val in values:
        print(f"{val:12.4f}", end=" ")
    print()

# Print sequential score
print("-" * 85)
print(f"{'seq_score:'} {ft_results['sequential_score']:1f}")

Fine Tuned Model Evaluation Results
-------------------------------------------------------------------------------------
Metric                  768d         512d         256d         128d          64d
-------------------------------------------------------------------------------------
==ndcg@10==            0.6266       0.6191       0.5742       0.5105       0.4183 
mrr@10                 0.5655       0.5602       0.5198       0.4597       0.3711 
map@100                0.6090       0.6027       0.5633       0.5043       0.4110 
accuracy@1             0.5162       0.5116       0.4745       0.4189       0.3338 
accuracy@3             0.5611       0.5533       0.5116       0.4606       0.3632 
accuracy@5             0.6615       0.6538       0.6105       0.5301       0.4467 
accuracy@10            0.7326       0.7295       0.6785       0.6074       0.5070 
precision@1            0.5162       0.5116       0.4745       0.4189       0.3338 
precision@3            0.4920       0.4848     

In [20]:
%%capture
!pip install --upgrade sentence-transformers
!pip install git+https://github.com/huggingface/transformers

In [21]:
from sentence_transformers import SentenceTransformer

# Download from the 🤗 Hub
model = SentenceTransformer("stardriver007/modernbert-embed-base-legal-matryoshka-qlin", truncate_dim=256)

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/205 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/31.4k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/54.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.24k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/596M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/20.8k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/3.58M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/694 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

In [22]:
# Run inference
sentences = [
    'Which organization is Carmody Gaba Daman associated with?',
    'Assistant General Counsel, U.S. General Services Administration, Washington, D.C.; Carmody Gaba Daman, Assistant General Counsel, U.S. General Services Administration, Washington, D.C.; Michael Blumenthal, Trial Attorney, U.S. Small Business Administration, Office of General Counsel, Washington, D.C. MEMORANDUM AND ORDER', # Corresponding Positive
    'certain Solicitation requirements violate federal procurement statutes and agency regulations governing procurements involving small business offerors. See generally SHS MJAR at 14; VCH MJAR at 14. Having considered the parties’ arguments, applicable law, and the Administrative Record, this Court GRANTS in part and DENIES in part Plaintiffs’ Motions for Judgment on the', # Random Excerpt
]

embeddings = model.encode(sentences)
print(embeddings.shape)

(3, 256)


In [23]:
# Get the similarity scores for the embeddings
similarities = model.similarity(embeddings, embeddings)
print(similarities[0])

tensor([1.0000, 0.5927, 0.0143])
