In [1]:
from sentence_transformers import SentenceTransformer, models

word_model = models.Transformer('answerdotai/ModernBERT-base', max_seq_length=256, model_args={'attn_implementation': 'sdpa', 'dtype': 'bfloat16', 'device_map': 'auto'})
pooling = models.Pooling(word_model.get_word_embedding_dimension())
model = SentenceTransformer(modules=[word_model, pooling])
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False, 'architecture': 'ModernBertModel'})
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
)

In [2]:
NUM_PAIRS = int(1e6)

In [3]:
import sys
sys.path.insert(0, "../src")

from mlops_project.data import ArxivPapersDataset
from pathlib import Path
import random

data_dir = Path("../data")
train_dataset = ArxivPapersDataset(split="train", data_dir=data_dir).dataset
test_dataset = ArxivPapersDataset(split="test", data_dir=data_dir).dataset

print(f"Train samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Columns: {train_dataset.column_names}")

Train samples: 2039695
Test samples: 509924
Columns: ['primary_subject', 'subjects', 'abstract', 'title']


In [4]:
from collections import defaultdict
import numpy as np
from datasets import Dataset, load_from_disk

def create_positive_pairs(dataset, num_pairs: int = 100000, text_field: str = "abstract", seed: int = 42):
    """
    Create positive pairs for MultipleNegativesRankingLoss.
    
    Returns a dataset with columns: anchor, positive
    Each pair contains two abstracts from papers with the same primary_subject.
    MNRL will use in-batch negatives automatically.
    """
    random.seed(seed)
    np.random.seed(seed)
    
    subject_to_indices = defaultdict(list)
    for idx, subject in enumerate(dataset["primary_subject"]):
        subject_to_indices[subject].append(idx)
    
    subjects = [s for s in subject_to_indices.keys() if len(subject_to_indices[s]) >= 2]
    print(f"Found {len(subjects)} subjects with 2+ samples")
    
    pairs = {"anchor": [], "positive": []}
    
    print(f"Creating {num_pairs} positive pairs...")
    for _ in range(num_pairs):
        subject = random.choice(subjects)
        idx1, idx2 = random.sample(subject_to_indices[subject], 2)
        pairs["anchor"].append(dataset[idx1][text_field])
        pairs["positive"].append(dataset[idx2][text_field])
    
    return Dataset.from_dict(pairs)

def load_or_create_pairs(dataset, save_path: Path, num_pairs: int, text_field: str = "abstract", seed: int = 42):
    if save_path.exists():
        print(f"Loading cached pairs from {save_path}")
        return load_from_disk(str(save_path))
    
    print(f"Creating new pairs (will be cached at {save_path})")
    pairs = create_positive_pairs(dataset, num_pairs=num_pairs, text_field=text_field, seed=seed)
    pairs.save_to_disk(str(save_path))
    return pairs

In [5]:
train_pairs_path = data_dir / f"train_pairs_mnrl_{NUM_PAIRS}"
train_pairs = load_or_create_pairs(train_dataset, train_pairs_path, num_pairs=NUM_PAIRS)
print(f"\nTraining pairs: {len(train_pairs)}")
print(train_pairs[0])

Loading cached pairs from ../data/train_pairs_mnrl_1000000

Training pairs: 1000000
{'anchor': 'The paper describes the project, implementation and test of a C-band (5GHz) Low Noise Amplifier (LNA) using new low noise Pseudomorphic High Electron Mobility Transistors (pHEMTS) from Avago. The amplifier was developed to be used as a cost effective solution in a receiver chain for Galactic Emission Mapping (GEM-P) project in Portugal with the objective of finding affordable solutions not requiring strong cryogenic operation, as is the case of massive projects like the Square Kilometer Array (SKA), in Earth Sensing projects and other niches like microwave reflectometry. The particular application and amplifier requirements are first introduced. Several commercially available low noise devices were selected and the noise performance simulated. An ultra-low noise pHEMT was used for an implementation that achieved a Noise Figure of 0.6 dB with 13 dB gain at 5 GHz. The design, simulation and me

In [6]:
eval_pairs_path = data_dir / "eval_pairs_mnrl_1000"
eval_pairs = load_or_create_pairs(test_dataset, eval_pairs_path, num_pairs=1000, seed=42)
print(f"Evaluation pairs: {len(eval_pairs)}")

Loading cached pairs from ../data/eval_pairs_mnrl_1000
Evaluation pairs: 1000


In [None]:
from sentence_transformers.evaluation import InformationRetrievalEvaluator

def create_ir_evaluator(dataset, sample_size: int = 5000, name: str = "arxiv-retrieval"):
    np.random.seed(42)
    indices = np.random.choice(len(dataset), min(sample_size, len(dataset)), replace=False)
    
    # Split indices: first 20% for queries, rest for corpus (no overlap)
    num_queries = sample_size // 5
    query_indices = indices[:num_queries]
    corpus_indices = indices[num_queries:]
    
    queries = {}
    corpus = {}
    relevant_docs = {}
    subject_to_corpus_ids = defaultdict(set)
    
    # Build corpus from corpus_indices only
    for i, idx in enumerate(corpus_indices):
        idx = int(idx)
        corpus_id = f"doc_{i}"
        corpus[corpus_id] = dataset[idx]["abstract"]
        subject = dataset[idx]["primary_subject"]
        subject_to_corpus_ids[subject].add(corpus_id)
    
    # Build queries from query_indices (no overlap with corpus)
    for i, idx in enumerate(query_indices):
        idx = int(idx)
        query_id = f"query_{i}"
        queries[query_id] = dataset[idx]["abstract"]
        subject = dataset[idx]["primary_subject"]
        # Relevant docs are corpus docs with same subject
        relevant_docs[query_id] = subject_to_corpus_ids[subject].copy()
    
    # Filter out queries with no relevant docs
    queries = {qid: q for qid, q in queries.items() if len(relevant_docs.get(qid, set())) > 0}
    relevant_docs = {qid: docs for qid, docs in relevant_docs.items() if qid in queries}
    
    print(f"IR Evaluator: {len(queries)} queries, {len(corpus)} corpus docs (no overlap)")
    
    return InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name=name,
        precision_recall_at_k=[1, 5, 10],
    )

ir_evaluator = create_ir_evaluator(test_dataset, sample_size=500)
print("IR Evaluator created for precision@k metrics")

IR Evaluator: 93 queries, 400 corpus docs (no overlap)
IR Evaluator created for precision@k metrics


In [8]:
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import MultipleNegativesRankingLoss
import torch

training_args = SentenceTransformerTrainingArguments(
    output_dir="../models/mnrl-modernbert",
    num_train_epochs=1,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    learning_rate=1e-4,
    warmup_ratio=0.1,
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=500,
    logging_steps=100,
    torch_compile=torch.cuda.is_available(),
    bf16=torch.cuda.is_available(),
    tf32=torch.cuda.is_available()
)

# MultipleNegativesRankingLoss uses in-batch negatives
loss = MultipleNegativesRankingLoss(model)

trainer = SentenceTransformerTrainer(
    model=model,
    args=training_args,
    train_dataset=train_pairs,
    eval_dataset=eval_pairs,
    loss=loss,
    evaluator=ir_evaluator,
)

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

In [None]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mthorhojhus[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss,Arxiv-retrieval Cosine Accuracy@1,Arxiv-retrieval Cosine Accuracy@3,Arxiv-retrieval Cosine Accuracy@5,Arxiv-retrieval Cosine Accuracy@10,Arxiv-retrieval Cosine Precision@1,Arxiv-retrieval Cosine Precision@5,Arxiv-retrieval Cosine Precision@10,Arxiv-retrieval Cosine Recall@1,Arxiv-retrieval Cosine Recall@5,Arxiv-retrieval Cosine Recall@10,Arxiv-retrieval Cosine Ndcg@10,Arxiv-retrieval Cosine Mrr@10,Arxiv-retrieval Cosine Map@100
100,3.8078,3.73609,0.258065,0.430108,0.548387,0.677419,0.258065,0.227957,0.201075,0.022654,0.116692,0.227933,0.248828,0.380581,0.196127
200,3.4296,2.94351,0.301075,0.473118,0.55914,0.731183,0.301075,0.249462,0.216129,0.048292,0.147536,0.266573,0.289983,0.422854,0.257794
300,2.7604,2.66291,0.344086,0.569892,0.698925,0.784946,0.344086,0.311828,0.260215,0.066331,0.208292,0.317767,0.35215,0.476596,0.317388
400,2.6172,2.550896,0.376344,0.623656,0.709677,0.806452,0.376344,0.337634,0.292473,0.067974,0.231361,0.366987,0.393764,0.508227,0.344903
500,2.513,2.463694,0.376344,0.612903,0.72043,0.806452,0.376344,0.35914,0.296774,0.060508,0.23929,0.363079,0.402414,0.51548,0.364719
600,2.4449,2.398434,0.387097,0.666667,0.731183,0.827957,0.387097,0.369892,0.316129,0.077112,0.250668,0.380438,0.425262,0.532493,0.385334
700,2.4257,2.348606,0.44086,0.645161,0.752688,0.827957,0.44086,0.367742,0.336559,0.090704,0.253546,0.402798,0.450808,0.562468,0.402991
800,2.3504,2.317515,0.397849,0.645161,0.731183,0.83871,0.397849,0.372043,0.327957,0.087342,0.251666,0.401662,0.446446,0.542631,0.407853
900,2.2864,2.264133,0.408602,0.688172,0.741935,0.849462,0.408602,0.378495,0.339785,0.085847,0.253227,0.426422,0.462368,0.561521,0.417466
1000,2.2536,2.229873,0.430108,0.666667,0.763441,0.83871,0.430108,0.378495,0.33871,0.083098,0.250467,0.419673,0.456014,0.565288,0.414297


W0109 22:03:41.526000 202793 torch/fx/experimental/symbolic_shapes.py:6823] [0/3] _maybe_guard_rel() was called on non-relation expression Eq(s47, s8) | Eq(s8, 1)


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  1.29it/s]


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  1.21it/s]


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  1.19it/s]


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  1.19it/s]


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:04<00:00,  4.43s/it]


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  1.19it/s]


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  1.53it/s]


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  1.68it/s]


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  1.68it/s]


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:04<00:00,  4.42s/it]


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  1.18it/s]


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:04<00:00,  4.43s/it]


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  1.19it/s]


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  1.19it/s]


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  1.19it/s]


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  1.19it/s]


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  1.15it/s]


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  1.17it/s]


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  1.08it/s]


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  1.08it/s]


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  1.08it/s]


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  1.16it/s]


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  1.07it/s]


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  1.07it/s]


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  1.09it/s]


In [None]:
print("=== Precision@k (after fine-tuning) ===")
final_results = ir_evaluator(model)
for key, value in final_results.items():
    print(f"{key}: {value:.4f}")

=== Precision@k (after fine-tuning) ===


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.16s/it]

arxiv-retrieval_cosine_precision@1: 0.5994
arxiv-retrieval_cosine_precision@5: 0.5817
arxiv-retrieval_cosine_precision@10: 0.5720
arxiv-retrieval_cosine_recall@1: 0.0158
arxiv-retrieval_cosine_recall@5: 0.0694
arxiv-retrieval_cosine_recall@10: 0.1306
arxiv-retrieval_cosine_ndcg@10: 0.5845
arxiv-retrieval_cosine_mrr@10: 0.7019

=== Improvement ===
arxiv-retrieval_cosine_precision@1: +0.1094
arxiv-retrieval_cosine_precision@5: +0.1351
arxiv-retrieval_cosine_precision@10: +0.1571
arxiv-retrieval_cosine_ndcg@10: +0.1502





In [None]:
model.save("../models/mnrl-minilm-finetuned")
print("Model saved to ../models/mnrl-minilm-finetuned")

Model saved to ../models/mnrl-minilm-finetuned
