In [None]:
%%capture
!pip install --upgrade sentence-transformers datasets transformers torch tensorboard

In [None]:
import torch

from sentence_transformers import SentenceTransformer, SentenceTransformerModelCardData, SentenceTransformerTrainingArguments, SentenceTransformerTrainer
from sentence_transformers.evaluation import InformationRetrievalEvaluator, SequentialEvaluator
from sentence_transformers.util import cos_sim
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers

from datasets import load_dataset, concatenate_datasets

In [None]:
from huggingface_hub import login
from google.colab import userdata

login(token=userdata.get('HF_TOKEN'), add_to_git_credential=True)

In [None]:
dataset = load_dataset("AdamLucek/legal-rag-positives-synthetic", split="train")

README.md:   0%|          | 0.00/2.88k [00:00<?, ?B/s]

train.csv:   0%|          | 0.00/4.18M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/6469 [00:00<?, ? examples/s]

In [None]:
dataset = dataset.rename_column("question", "anchor")
dataset = dataset.rename_column("text", "positive")
dataset = dataset.remove_columns(["chunk_id", "case_name", "date_filed", "court", "question_id", "answer_location"])


dataset = dataset.add_column("id", range(len(dataset)))

In [None]:
dataset = dataset.shuffle()


dataset = dataset.train_test_split(test_size=0.1)


dataset["train"].to_json("train_dataset.json", orient="records")
dataset["test"].to_json("test_dataset.json", orient="records")

Creating json from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

339016

In [None]:
model_id = "nomic-ai/modernbert-embed-base"


model = SentenceTransformer(
    model_id, device="cuda" if torch.cuda.is_available() else "cpu"
)

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/210 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/445k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/54.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.26k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/596M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/20.8k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/3.58M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/694 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

In [None]:
test_dataset = load_dataset("json", data_files="test_dataset.json", split="train")
train_dataset = load_dataset("json", data_files="train_dataset.json", split="train")


corpus_dataset = concatenate_datasets([train_dataset, test_dataset])


corpus = dict(
    zip(corpus_dataset["id"], corpus_dataset["positive"])
)


queries = dict(
    zip(test_dataset["id"], test_dataset["anchor"])
)


relevant_docs = {}
for q_id, global_chunk_id in zip(test_dataset["id"], test_dataset["global_chunk_id"]):

    if q_id not in relevant_docs:
        relevant_docs[q_id] = []


    matching_corpus_ids = [
        cid for cid, chunk in zip(corpus_dataset["id"], corpus_dataset["global_chunk_id"])
        if chunk == global_chunk_id
    ]

    relevant_docs[q_id].extend(matching_corpus_ids)

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
matryoshka_dimensions = [768, 512, 256, 128, 64]


matryoshka_evaluators = []


for dim in matryoshka_dimensions:

    ir_evaluator = InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name=f"dim_{dim}",
        truncate_dim=dim,
        score_functions={"cosine": cos_sim},
    )

    matryoshka_evaluators.append(ir_evaluator)


evaluator = SequentialEvaluator(matryoshka_evaluators)

In [None]:
base_results = evaluator(model)


print("\nBase Model Evaluation Results")
print("-" * 85)
print(f"{'Metric':15} {'768d':>12} {'512d':>12} {'256d':>12} {'128d':>12} {'64d':>12}")
print("-" * 85)


metrics = [
    'ndcg@10',
    'mrr@10',
    'map@100',
    'accuracy@1',
    'accuracy@3',
    'accuracy@5',
    'accuracy@10',
    'precision@1',
    'precision@3',
    'precision@5',
    'precision@10',
    'recall@1',
    'recall@3',
    'recall@5',
    'recall@10'
]


for metric in metrics:
    values = []
    for dim in matryoshka_dimensions:
        key = f"dim_{dim}_cosine_{metric}"
        values.append(base_results[key])


    metric_name = f"=={metric}==" if metric == "ndcg@10" else metric
    print(f"{metric_name:15}", end="  ")
    for val in values:
        print(f"{val:12.4f}", end=" ")
    print()


print("-" * 85)
print(f"{'seq_score:'} {base_results['sequential_score']:1f}")

W0328 10:17:40.267000 605 torch/_inductor/utils.py:1137] [1/0] Not enough SMs to use max_autotune_gemm mode



Base Model Evaluation Results
-------------------------------------------------------------------------------------
Metric                  768d         512d         256d         128d          64d
-------------------------------------------------------------------------------------
==ndcg@10==            0.4482       0.4352       0.4161       0.3805       0.3025 
mrr@10                 0.3997       0.3865       0.3675       0.3344       0.2622 
map@100                0.4458       0.4318       0.4121       0.3746       0.3001 
accuracy@1             0.3586       0.3462       0.3277       0.2998       0.2303 
accuracy@3             0.3988       0.3849       0.3617       0.3184       0.2566 
accuracy@5             0.4652       0.4451       0.4405       0.3895       0.3199 
accuracy@10            0.5379       0.5348       0.5131       0.4760       0.3849 
precision@1            0.3586       0.3462       0.3277       0.2998       0.2303 
precision@3            0.3503       0.3354       0.3

In [None]:
model = SentenceTransformer(
    model_id,
    model_kwargs={"attn_implementation": "sdpa"},
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name="ModernBERT Embed base Legal Matryoshka",
    ),
)

In [None]:
base_loss = MultipleNegativesRankingLoss(model)


train_loss = MatryoshkaLoss(
    model, base_loss, matryoshka_dims=matryoshka_dimensions
)

In [None]:
args = SentenceTransformerTrainingArguments(
    output_dir="modernbert-embed-base-legal-matryoshka-lucek",
    num_train_epochs=4,
    per_device_train_batch_size=32,
    gradient_accumulation_steps=16,
    per_device_eval_batch_size=16,
    warmup_ratio=0.1,
    learning_rate=2e-5,
    lr_scheduler_type="cosine",
    optim="adamw_torch_fused",
    tf32=False,
    bf16=True,
    batch_sampler=BatchSamplers.NO_DUPLICATES,
    save_strategy="epoch",
    logging_steps=10,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_dim_128_cosine_ndcg@10",
    report_to="none"
)

In [None]:
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset.select_columns(
        ["positive", "anchor"]
    ),
    loss=train_loss,
    evaluator=evaluator,
)

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

In [None]:
trainer.train()


trainer.save_model()

dataset = dataset.select_columns(['anchor', 'positive', 'negative'])


Epoch,Training Loss,Validation Loss,Dim 768 Cosine Accuracy@1,Dim 768 Cosine Accuracy@3,Dim 768 Cosine Accuracy@5,Dim 768 Cosine Accuracy@10,Dim 768 Cosine Precision@1,Dim 768 Cosine Precision@3,Dim 768 Cosine Precision@5,Dim 768 Cosine Precision@10,Dim 768 Cosine Recall@1,Dim 768 Cosine Recall@3,Dim 768 Cosine Recall@5,Dim 768 Cosine Recall@10,Dim 768 Cosine Ndcg@10,Dim 768 Cosine Mrr@10,Dim 768 Cosine Map@100,Dim 512 Cosine Accuracy@1,Dim 512 Cosine Accuracy@3,Dim 512 Cosine Accuracy@5,Dim 512 Cosine Accuracy@10,Dim 512 Cosine Precision@1,Dim 512 Cosine Precision@3,Dim 512 Cosine Precision@5,Dim 512 Cosine Precision@10,Dim 512 Cosine Recall@1,Dim 512 Cosine Recall@3,Dim 512 Cosine Recall@5,Dim 512 Cosine Recall@10,Dim 512 Cosine Ndcg@10,Dim 512 Cosine Mrr@10,Dim 512 Cosine Map@100,Dim 256 Cosine Accuracy@1,Dim 256 Cosine Accuracy@3,Dim 256 Cosine Accuracy@5,Dim 256 Cosine Accuracy@10,Dim 256 Cosine Precision@1,Dim 256 Cosine Precision@3,Dim 256 Cosine Precision@5,Dim 256 Cosine Precision@10,Dim 256 Cosine Recall@1,Dim 256 Cosine Recall@3,Dim 256 Cosine Recall@5,Dim 256 Cosine Recall@10,Dim 256 Cosine Ndcg@10,Dim 256 Cosine Mrr@10,Dim 256 Cosine Map@100,Dim 128 Cosine Accuracy@1,Dim 128 Cosine Accuracy@3,Dim 128 Cosine Accuracy@5,Dim 128 Cosine Accuracy@10,Dim 128 Cosine Precision@1,Dim 128 Cosine Precision@3,Dim 128 Cosine Precision@5,Dim 128 Cosine Precision@10,Dim 128 Cosine Recall@1,Dim 128 Cosine Recall@3,Dim 128 Cosine Recall@5,Dim 128 Cosine Recall@10,Dim 128 Cosine Ndcg@10,Dim 128 Cosine Mrr@10,Dim 128 Cosine Map@100,Dim 64 Cosine Accuracy@1,Dim 64 Cosine Accuracy@3,Dim 64 Cosine Accuracy@5,Dim 64 Cosine Accuracy@10,Dim 64 Cosine Precision@1,Dim 64 Cosine Precision@3,Dim 64 Cosine Precision@5,Dim 64 Cosine Precision@10,Dim 64 Cosine Recall@1,Dim 64 Cosine Recall@3,Dim 64 Cosine Recall@5,Dim 64 Cosine Recall@10,Dim 64 Cosine Ndcg@10,Dim 64 Cosine Mrr@10,Dim 64 Cosine Map@100,Sequential Score
1,91.6964,No log,0.554869,0.598145,0.673879,0.748068,0.554869,0.537352,0.405255,0.234467,0.186631,0.515714,0.633308,0.732226,0.648303,0.597782,0.636885,0.550232,0.593509,0.670788,0.74034,0.550232,0.53323,0.404328,0.234003,0.184441,0.509145,0.629057,0.72849,0.64446,0.593166,0.632197,0.516229,0.539413,0.613601,0.698609,0.516229,0.489438,0.36847,0.218238,0.175554,0.47102,0.577924,0.679031,0.600399,0.55308,0.591522,0.426584,0.465224,0.556414,0.619784,0.426584,0.411128,0.325502,0.19459,0.145415,0.396574,0.512622,0.608063,0.523197,0.47003,0.513547,0.316847,0.346213,0.414219,0.489954,0.316847,0.306543,0.244513,0.152705,0.107161,0.294693,0.383823,0.475528,0.400133,0.354134,0.396381,0.400133
2,39.6429,No log,0.587326,0.629057,0.703246,0.772798,0.587326,0.562597,0.42442,0.242349,0.198867,0.539928,0.662416,0.755796,0.676391,0.627796,0.666361,0.584235,0.621329,0.690881,0.768161,0.584235,0.55899,0.418547,0.241577,0.197965,0.537996,0.653143,0.749614,0.67156,0.623446,0.661821,0.554869,0.587326,0.658423,0.723338,0.554869,0.533745,0.400618,0.228594,0.184441,0.50747,0.625322,0.711618,0.63608,0.591197,0.628391,0.477589,0.506955,0.602782,0.670788,0.477589,0.45492,0.354869,0.212828,0.162674,0.438305,0.553709,0.65662,0.573582,0.519917,0.561577,0.360124,0.392581,0.459042,0.522411,0.360124,0.347759,0.272643,0.162597,0.121458,0.333205,0.426584,0.506698,0.437446,0.395217,0.438222,0.437446
3,26.8879,No log,0.591963,0.63524,0.703246,0.766615,0.591963,0.568264,0.426275,0.240804,0.201185,0.547012,0.666409,0.750773,0.677417,0.63165,0.670654,0.585781,0.616692,0.690881,0.766615,0.585781,0.557445,0.41762,0.241731,0.198351,0.535291,0.651468,0.751803,0.672159,0.623629,0.662036,0.567233,0.587326,0.664606,0.731066,0.567233,0.538382,0.400927,0.230757,0.190623,0.515198,0.626352,0.720505,0.645374,0.600886,0.637734,0.499227,0.530139,0.613601,0.678516,0.499227,0.474498,0.365379,0.21592,0.1695,0.45814,0.570582,0.668341,0.589191,0.538561,0.578291,0.363215,0.401855,0.472952,0.527048,0.363215,0.351365,0.278207,0.16507,0.123648,0.339129,0.436373,0.514297,0.444439,0.400301,0.446199,0.444439


In [None]:
trainer.model.push_to_hub("modernbert-embed-base-legal-matryoshka-2")

model.safetensors:   0%|          | 0.00/596M [00:00<?, ?B/s]

'https://huggingface.co/manishh16/modernbert-embed-base-legal-matryoshka-2/commit/0c85e900c59e1ebdf68ce3cce59e478db6c52671'

In [None]:
fine_tuned_model = SentenceTransformer(
    args.output_dir, device="cuda" if torch.cuda.is_available() else "cpu"
)


ft_results = evaluator(fine_tuned_model)


print("Fine Tuned Model Evaluation Results")
print("-" * 85)
print(f"{'Metric':15} {'768d':>12} {'512d':>12} {'256d':>12} {'128d':>12} {'64d':>12}")
print("-" * 85)


metrics = [
    'ndcg@10',
    'mrr@10',
    'map@100',
    'accuracy@1',
    'accuracy@3',
    'accuracy@5',
    'accuracy@10',
    'precision@1',
    'precision@3',
    'precision@5',
    'precision@10',
    'recall@1',
    'recall@3',
    'recall@5',
    'recall@10'
]


for metric in metrics:
    values = []
    for dim in matryoshka_dimensions:
        key = f"dim_{dim}_cosine_{metric}"
        values.append(ft_results[key])


    metric_name = f"=={metric}==" if metric == "ndcg@10" else metric
    print(f"{metric_name:15}", end="  ")
    for val in values:
        print(f"{val:12.4f}", end=" ")
    print()


print("-" * 85)
print(f"{'seq_score:'} {ft_results['sequential_score']:1f}")

Fine Tuned Model Evaluation Results
-------------------------------------------------------------------------------------
Metric                  768d         512d         256d         128d          64d
-------------------------------------------------------------------------------------
==ndcg@10==            0.6774       0.6725       0.6451       0.5897       0.4442 
mrr@10                 0.6306       0.6231       0.6012       0.5382       0.4000 
map@100                0.6700       0.6616       0.6380       0.5784       0.4455 
accuracy@1             0.5904       0.5842       0.5672       0.4977       0.3632 
accuracy@3             0.6352       0.6167       0.5889       0.5301       0.4003 
accuracy@5             0.7048       0.6924       0.6662       0.6151       0.4714 
accuracy@10            0.7666       0.7682       0.7311       0.6801       0.5270 
precision@1            0.5904       0.5842       0.5672       0.4977       0.3632 
precision@3            0.5667       0.5569     

In [1]:
%%capture
!pip install --upgrade sentence-transformers
!pip install git+https://github.com/huggingface/transformers

In [3]:
from sentence_transformers import SentenceTransformer


model = SentenceTransformer("manishh16/modernbert-embed-base-legal-matryoshka-2", truncate_dim=256)

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/205 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/31.6k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/54.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.24k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/596M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/20.8k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/3.58M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/694 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

In [4]:
sentences = [
    'Which organization is Carmody Gaba Daman associated with?',
    'Assistant General Counsel, U.S. General Services Administration, Washington, D.C.; Carmody Gaba Daman, Assistant General Counsel, U.S. General Services Administration, Washington, D.C.; Michael Blumenthal, Trial Attorney, U.S. Small Business Administration, Office of General Counsel, Washington, D.C. MEMORANDUM AND ORDER', # Corresponding Positive
    'certain Solicitation requirements violate federal procurement statutes and agency regulations governing procurements involving small business offerors. See generally SHS MJAR at 14; VCH MJAR at 14. Having considered the parties’ arguments, applicable law, and the Administrative Record, this Court GRANTS in part and DENIES in part Plaintiffs’ Motions for Judgment on the', # Random Excerpt
]

embeddings = model.encode(sentences)
print(embeddings.shape)

W0328 12:02:12.413000 1886 torch/_inductor/utils.py:1137] [1/0] Not enough SMs to use max_autotune_gemm mode


(3, 256)


In [5]:
similarities = model.similarity(embeddings, embeddings)
print(similarities[0])

tensor([1.0000, 0.5956, 0.0074])
