In [1]:
NUM_PAIRS = int(1e6)

In [2]:
import torch
torch.cuda.is_available()

True

In [3]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

In [4]:
import sys
sys.path.insert(0, "../src")

from mlops_project.data import ArxivPapersDataset
from pathlib import Path
import random

data_dir = Path("../data")
train_dataset = ArxivPapersDataset(split="train", data_dir=data_dir).dataset
test_dataset = ArxivPapersDataset(split="test", data_dir=data_dir).dataset

print(f"Train samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Columns: {train_dataset.column_names}")

Train samples: 2039695
Test samples: 509924
Columns: ['primary_subject', 'subjects', 'abstract', 'title']


In [5]:
from collections import defaultdict
import numpy as np
from datasets import Dataset, load_from_disk

def create_positive_pairs(dataset, num_pairs: int = 100000, text_field: str = "abstract", seed: int = 42):
    """
    Create positive pairs for MultipleNegativesRankingLoss.
    
    Returns a dataset with columns: anchor, positive
    Each pair contains two abstracts from papers with the same primary_subject.
    MNRL will use in-batch negatives automatically.
    """
    random.seed(seed)
    np.random.seed(seed)
    
    subject_to_indices = defaultdict(list)
    for idx, subject in enumerate(dataset["primary_subject"]):
        subject_to_indices[subject].append(idx)
    
    subjects = [s for s in subject_to_indices.keys() if len(subject_to_indices[s]) >= 2]
    print(f"Found {len(subjects)} subjects with 2+ samples")
    
    pairs = {"anchor": [], "positive": []}
    
    print(f"Creating {num_pairs} positive pairs...")
    for _ in range(num_pairs):
        subject = random.choice(subjects)
        idx1, idx2 = random.sample(subject_to_indices[subject], 2)
        pairs["anchor"].append(dataset[idx1][text_field])
        pairs["positive"].append(dataset[idx2][text_field])
    
    return Dataset.from_dict(pairs)

def load_or_create_pairs(dataset, save_path: Path, num_pairs: int, text_field: str = "abstract", seed: int = 42):
    if save_path.exists():
        print(f"Loading cached pairs from {save_path}")
        return load_from_disk(str(save_path))
    
    print(f"Creating new pairs (will be cached at {save_path})")
    pairs = create_positive_pairs(dataset, num_pairs=num_pairs, text_field=text_field, seed=seed)
    pairs.save_to_disk(str(save_path))
    return pairs

train_pairs_path = data_dir / f"train_pairs_mnrl_{NUM_PAIRS}"
train_pairs = load_or_create_pairs(train_dataset, train_pairs_path, num_pairs=NUM_PAIRS)
print(f"\nTraining pairs: {len(train_pairs)}")
print(train_pairs[0])

Creating new pairs (will be cached at ../data/train_pairs_mnrl_1000000)
Found 148 subjects with 2+ samples
Creating 1000000 positive pairs...


Saving the dataset (0/4 shards):   0%|          | 0/1000000 [00:00<?, ? examples/s]


Training pairs: 1000000
{'anchor': 'The paper describes the project, implementation and test of a C-band (5GHz) Low Noise Amplifier (LNA) using new low noise Pseudomorphic High Electron Mobility Transistors (pHEMTS) from Avago. The amplifier was developed to be used as a cost effective solution in a receiver chain for Galactic Emission Mapping (GEM-P) project in Portugal with the objective of finding affordable solutions not requiring strong cryogenic operation, as is the case of massive projects like the Square Kilometer Array (SKA), in Earth Sensing projects and other niches like microwave reflectometry. The particular application and amplifier requirements are first introduced. Several commercially available low noise devices were selected and the noise performance simulated. An ultra-low noise pHEMT was used for an implementation that achieved a Noise Figure of 0.6 dB with 13 dB gain at 5 GHz. The design, simulation and measured results of the prototype are presented and discussed

In [6]:
eval_pairs_path = data_dir / "eval_pairs_mnrl"
eval_pairs = load_or_create_pairs(test_dataset, eval_pairs_path, num_pairs=10000, seed=123)
print(f"Evaluation pairs: {len(eval_pairs)}")

Loading cached pairs from ../data/eval_pairs_mnrl
Evaluation pairs: 10000


In [7]:
from sentence_transformers.evaluation import InformationRetrievalEvaluator

def create_ir_evaluator(dataset, sample_size: int = 5000, name: str = "arxiv-retrieval"):
    np.random.seed(42)
    indices = np.random.choice(len(dataset), min(sample_size, len(dataset)), replace=False)
    
    # Split indices: first 20% for queries, rest for corpus (no overlap)
    num_queries = sample_size // 5
    query_indices = indices[:num_queries]
    corpus_indices = indices[num_queries:]
    
    queries = {}
    corpus = {}
    relevant_docs = {}
    subject_to_corpus_ids = defaultdict(set)
    
    # Build corpus from corpus_indices only
    for i, idx in enumerate(corpus_indices):
        idx = int(idx)
        corpus_id = f"doc_{i}"
        corpus[corpus_id] = dataset[idx]["abstract"]
        subject = dataset[idx]["primary_subject"]
        subject_to_corpus_ids[subject].add(corpus_id)
    
    # Build queries from query_indices (no overlap with corpus)
    for i, idx in enumerate(query_indices):
        idx = int(idx)
        query_id = f"query_{i}"
        queries[query_id] = dataset[idx]["abstract"]
        subject = dataset[idx]["primary_subject"]
        # Relevant docs are corpus docs with same subject
        relevant_docs[query_id] = subject_to_corpus_ids[subject].copy()
    
    # Filter out queries with no relevant docs
    queries = {qid: q for qid, q in queries.items() if len(relevant_docs.get(qid, set())) > 0}
    relevant_docs = {qid: docs for qid, docs in relevant_docs.items() if qid in queries}
    
    print(f"IR Evaluator: {len(queries)} queries, {len(corpus)} corpus docs (no overlap)")
    
    return InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name=name,
        precision_recall_at_k=[1, 5, 10],
        show_progress_bar=True,
    )

ir_evaluator = create_ir_evaluator(test_dataset, sample_size=5000)
print("IR Evaluator created for precision@k metrics")

IR Evaluator: 996 queries, 4000 corpus docs (no overlap)
IR Evaluator created for precision@k metrics


In [8]:
print("=== Baseline Precision@k (before fine-tuning) ===")
baseline_results = ir_evaluator(model)
for key, value in baseline_results.items():
    if "precision" in key or "recall" in key or "mrr" in key or "ndcg" in key:
        print(f"{key}: {value:.4f}")

=== Baseline Precision@k (before fine-tuning) ===


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.05s/it]

arxiv-retrieval_cosine_precision@1: 0.4900
arxiv-retrieval_cosine_precision@5: 0.4466
arxiv-retrieval_cosine_precision@10: 0.4149
arxiv-retrieval_cosine_recall@1: 0.0115
arxiv-retrieval_cosine_recall@5: 0.0453
arxiv-retrieval_cosine_recall@10: 0.0767
arxiv-retrieval_cosine_ndcg@10: 0.4343
arxiv-retrieval_cosine_mrr@10: 0.6277





In [9]:
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import MultipleNegativesRankingLoss

training_args = SentenceTransformerTrainingArguments(
    output_dir="../models/mnrl-minilm",
    num_train_epochs=1,
    per_device_train_batch_size=256,
    per_device_eval_batch_size=256,
    learning_rate=5e-4,
    warmup_ratio=0.1,
    eval_strategy="steps",
    eval_steps=500,
    eval_on_start=True,
    save_strategy="steps",
    save_steps=500,
    logging_steps=100,
    torch_compile=torch.cuda.is_available(),
    fp16=torch.cuda.is_available(),
    tf32=torch.cuda.is_available()
)

# MultipleNegativesRankingLoss uses in-batch negatives
loss = MultipleNegativesRankingLoss(model)

trainer = SentenceTransformerTrainer(
    model=model,
    args=training_args,
    train_dataset=train_pairs,
    eval_dataset=eval_pairs,
    loss=loss,
    evaluator=ir_evaluator,
)

print("Trainer initialized with MNRL loss. Ready to train!")

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Trainer initialized with MNRL loss. Ready to train!


In [10]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mthorhojhus[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss,Arxiv-retrieval Cosine Accuracy@1,Arxiv-retrieval Cosine Accuracy@3,Arxiv-retrieval Cosine Accuracy@5,Arxiv-retrieval Cosine Accuracy@10,Arxiv-retrieval Cosine Precision@1,Arxiv-retrieval Cosine Precision@5,Arxiv-retrieval Cosine Precision@10,Arxiv-retrieval Cosine Recall@1,Arxiv-retrieval Cosine Recall@5,Arxiv-retrieval Cosine Recall@10,Arxiv-retrieval Cosine Ndcg@10,Arxiv-retrieval Cosine Mrr@10,Arxiv-retrieval Cosine Map@100
0,No log,4.406813,0.48996,0.742972,0.815261,0.881526,0.48996,0.446386,0.414759,0.011522,0.045248,0.076651,0.43423,0.627709,0.200542
500,3.282200,3.292161,0.528112,0.72992,0.806225,0.870482,0.528112,0.495582,0.46988,0.012958,0.056736,0.100652,0.486534,0.642937,0.303753
1000,3.057200,3.137934,0.558233,0.736948,0.807229,0.869478,0.558233,0.516466,0.500201,0.014051,0.058331,0.108223,0.516726,0.664268,0.346137
1500,2.904100,3.0699,0.559237,0.757028,0.815261,0.88253,0.559237,0.53253,0.512149,0.015271,0.063711,0.113336,0.528436,0.668968,0.367955
2000,2.812600,3.021035,0.573293,0.764056,0.828313,0.879518,0.573293,0.550803,0.53002,0.015501,0.065618,0.119627,0.545805,0.681281,0.391664
2500,2.691900,3.013168,0.592369,0.763052,0.825301,0.884538,0.592369,0.558635,0.545281,0.015379,0.067399,0.124925,0.559553,0.690444,0.409758
3000,2.638300,2.964381,0.594378,0.7751,0.830321,0.880522,0.594378,0.568273,0.555422,0.015767,0.068319,0.124997,0.570045,0.693858,0.416804
3500,2.546800,2.968331,0.608434,0.784137,0.833333,0.893574,0.608434,0.579518,0.568173,0.016196,0.069611,0.129839,0.581859,0.706495,0.434496


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:05<00:00,  5.03s/it]
W0109 15:38:53.148000 31725 torch/fx/experimental/symbolic_shapes.py:6823] [0/2] _maybe_guard_rel() was called on non-relation expression Eq(s47, s8) | Eq(s8, 1)


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.14s/it]


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.15s/it]


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.15s/it]


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.37s/it]


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.15s/it]


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.17s/it]


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.16s/it]
W0109 16:03:14.471000 31725 torch/fx/experimental/symbolic_shapes.py:6823] [0/3] _maybe_guard_rel() was called on non-relation expression Eq(s47, s8) | Eq(s8, 1)


TrainOutput(global_step=3907, training_loss=2.8652673794053416, metrics={'train_runtime': 1695.6604, 'train_samples_per_second': 589.741, 'train_steps_per_second': 2.304, 'total_flos': 0.0, 'train_loss': 2.8652673794053416, 'epoch': 1.0})

In [11]:
print("=== Precision@k (after fine-tuning) ===")
final_results = ir_evaluator(model)
for key, value in final_results.items():
    if "precision" in key or "recall" in key or "mrr" in key or "ndcg" in key:
        print(f"{key}: {value:.4f}")

print("\n=== Improvement ===")
for key in baseline_results:
    if "precision" in key or "ndcg" in key:
        improvement = final_results[key] - baseline_results[key]
        print(f"{key}: {improvement:+.4f}")

=== Precision@k (after fine-tuning) ===


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.16s/it]

arxiv-retrieval_cosine_precision@1: 0.5994
arxiv-retrieval_cosine_precision@5: 0.5817
arxiv-retrieval_cosine_precision@10: 0.5720
arxiv-retrieval_cosine_recall@1: 0.0158
arxiv-retrieval_cosine_recall@5: 0.0694
arxiv-retrieval_cosine_recall@10: 0.1306
arxiv-retrieval_cosine_ndcg@10: 0.5845
arxiv-retrieval_cosine_mrr@10: 0.7019

=== Improvement ===
arxiv-retrieval_cosine_precision@1: +0.1094
arxiv-retrieval_cosine_precision@5: +0.1351
arxiv-retrieval_cosine_precision@10: +0.1571
arxiv-retrieval_cosine_ndcg@10: +0.1502





In [12]:
model.save("../models/mnrl-minilm-finetuned")
print("Model saved to ../models/mnrl-minilm-finetuned")

Model saved to ../models/mnrl-minilm-finetuned
