In [1]:
NUM_PAIRS = int(1e6)

In [2]:
import torch
torch.cuda.is_available()

True

In [3]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

In [4]:
import sys
sys.path.insert(0, "../src")

from mlops_project.data import ArxivPapersDataset
from pathlib import Path
import random

data_dir = Path("../data")
train_dataset = ArxivPapersDataset(split="train", data_dir=data_dir).dataset
test_dataset = ArxivPapersDataset(split="test", data_dir=data_dir).dataset

print(f"Train samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Columns: {train_dataset.column_names}")

Train samples: 2039695
Test samples: 509924
Columns: ['primary_subject', 'subjects', 'abstract', 'title']


In [5]:
from collections import defaultdict
import numpy as np
from datasets import Dataset, load_from_disk

def create_positive_pairs(dataset, num_pairs: int = 100000, text_field: str = "abstract", seed: int = 42):
    """
    Create positive pairs for MultipleNegativesRankingLoss.
    
    Returns a dataset with columns: anchor, positive
    Each pair contains two abstracts from papers with the same primary_subject.
    MNRL will use in-batch negatives automatically.
    """
    random.seed(seed)
    np.random.seed(seed)
    
    subject_to_indices = defaultdict(list)
    for idx, subject in enumerate(dataset["primary_subject"]):
        subject_to_indices[subject].append(idx)
    
    subjects = [s for s in subject_to_indices.keys() if len(subject_to_indices[s]) >= 2]
    print(f"Found {len(subjects)} subjects with 2+ samples")
    
    pairs = {"anchor": [], "positive": []}
    
    print(f"Creating {num_pairs} positive pairs...")
    for _ in range(num_pairs):
        subject = random.choice(subjects)
        idx1, idx2 = random.sample(subject_to_indices[subject], 2)
        pairs["anchor"].append(dataset[idx1][text_field])
        pairs["positive"].append(dataset[idx2][text_field])
    
    return Dataset.from_dict(pairs)

def load_or_create_pairs(dataset, save_path: Path, num_pairs: int, text_field: str = "abstract", seed: int = 42):
    if save_path.exists():
        print(f"Loading cached pairs from {save_path}")
        return load_from_disk(str(save_path))
    
    print(f"Creating new pairs (will be cached at {save_path})")
    pairs = create_positive_pairs(dataset, num_pairs=num_pairs, text_field=text_field, seed=seed)
    pairs.save_to_disk(str(save_path))
    return pairs

train_pairs_path = data_dir / "train_pairs_mnrl"
train_pairs = load_or_create_pairs(train_dataset, train_pairs_path, num_pairs=NUM_PAIRS)
print(f"\nTraining pairs: {len(train_pairs)}")
print(train_pairs[0])

Loading cached pairs from ../data/train_pairs_mnrl

Training pairs: 100000
{'anchor': 'The paper describes the project, implementation and test of a C-band (5GHz) Low Noise Amplifier (LNA) using new low noise Pseudomorphic High Electron Mobility Transistors (pHEMTS) from Avago. The amplifier was developed to be used as a cost effective solution in a receiver chain for Galactic Emission Mapping (GEM-P) project in Portugal with the objective of finding affordable solutions not requiring strong cryogenic operation, as is the case of massive projects like the Square Kilometer Array (SKA), in Earth Sensing projects and other niches like microwave reflectometry. The particular application and amplifier requirements are first introduced. Several commercially available low noise devices were selected and the noise performance simulated. An ultra-low noise pHEMT was used for an implementation that achieved a Noise Figure of 0.6 dB with 13 dB gain at 5 GHz. The design, simulation and measured re

In [6]:
eval_pairs_path = data_dir / "eval_pairs_mnrl"
eval_pairs = load_or_create_pairs(test_dataset, eval_pairs_path, num_pairs=10000, seed=123)
print(f"Evaluation pairs: {len(eval_pairs)}")

Loading cached pairs from ../data/eval_pairs_mnrl
Evaluation pairs: 10000


In [7]:
from sentence_transformers.evaluation import InformationRetrievalEvaluator

def create_ir_evaluator(dataset, sample_size: int = 5000, name: str = "arxiv-retrieval"):
    np.random.seed(42)
    indices = np.random.choice(len(dataset), min(sample_size, len(dataset)), replace=False)
    
    queries = {}
    corpus = {}
    relevant_docs = {}
    subject_to_corpus_ids = defaultdict(set)
    
    for i, idx in enumerate(indices):
        idx = int(idx)
        corpus_id = f"doc_{i}"
        corpus[corpus_id] = dataset[idx]["abstract"]
        subject = dataset[idx]["primary_subject"]
        subject_to_corpus_ids[subject].add(corpus_id)
    
    query_indices = indices[:sample_size // 5]
    
    for i, idx in enumerate(query_indices):
        idx = int(idx)
        query_id = f"query_{i}"
        queries[query_id] = dataset[idx]["abstract"]
        subject = dataset[idx]["primary_subject"]
        doc_id = f"doc_{i}"
        relevant_docs[query_id] = subject_to_corpus_ids[subject] - {doc_id}
    
    queries = {qid: q for qid, q in queries.items() if len(relevant_docs.get(qid, set())) > 0}
    relevant_docs = {qid: docs for qid, docs in relevant_docs.items() if qid in queries}
    
    print(f"IR Evaluator: {len(queries)} queries, {len(corpus)} corpus docs")
    
    return InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name=name,
        precision_recall_at_k=[1, 5, 10],
        show_progress_bar=True,
    )

ir_evaluator = create_ir_evaluator(test_dataset, sample_size=5000)
print("IR Evaluator created for precision@k metrics")

IR Evaluator: 998 queries, 5000 corpus docs
IR Evaluator created for precision@k metrics


In [8]:
print("=== Baseline Precision@k (before fine-tuning) ===")
baseline_results = ir_evaluator(model)
for key, value in baseline_results.items():
    if "precision" in key or "recall" in key or "mrr" in key or "ndcg" in key:
        print(f"{key}: {value:.4f}")

=== Baseline Precision@k (before fine-tuning) ===


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/157 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.52s/it]

arxiv-retrieval_cosine_precision@1: 0.0000
arxiv-retrieval_cosine_precision@5: 0.3741
arxiv-retrieval_cosine_precision@10: 0.3905
arxiv-retrieval_cosine_recall@1: 0.0000
arxiv-retrieval_cosine_recall@5: 0.0323
arxiv-retrieval_cosine_recall@10: 0.0620
arxiv-retrieval_cosine_ndcg@10: 0.3473
arxiv-retrieval_cosine_mrr@10: 0.3461





In [None]:
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import MultipleNegativesRankingLoss

training_args = SentenceTransformerTrainingArguments(
    output_dir="../models/mnrl-minilm",
    num_train_epochs=1,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    learning_rate=3e-4,
    warmup_ratio=0.1,
    eval_strategy="steps",
    eval_steps=500,
    eval_on_start=True,
    save_strategy="steps",
    save_steps=500,
    logging_steps=100,
    torch_compile=torch.cuda.is_available(),
    fp16=torch.cuda.is_available(),
    tf32=torch.cuda.is_available()
)

# MultipleNegativesRankingLoss uses in-batch negatives
loss = MultipleNegativesRankingLoss(model)

trainer = SentenceTransformerTrainer(
    model=model,
    args=training_args,
    train_dataset=train_pairs,
    eval_dataset=eval_pairs,
    loss=loss,
    evaluator=ir_evaluator,
)

print("Trainer initialized with MNRL loss. Ready to train!")

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Trainer initialized with MNRL loss. Ready to train!


In [10]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mthorhojhus[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss,Arxiv-retrieval Cosine Accuracy@1,Arxiv-retrieval Cosine Accuracy@3,Arxiv-retrieval Cosine Accuracy@5,Arxiv-retrieval Cosine Accuracy@10,Arxiv-retrieval Cosine Precision@1,Arxiv-retrieval Cosine Precision@5,Arxiv-retrieval Cosine Precision@10,Arxiv-retrieval Cosine Recall@1,Arxiv-retrieval Cosine Recall@5,Arxiv-retrieval Cosine Recall@10,Arxiv-retrieval Cosine Ndcg@10,Arxiv-retrieval Cosine Mrr@10,Arxiv-retrieval Cosine Map@100
0,No log,3.641879,0.0,0.656313,0.795591,0.883768,0.0,0.373747,0.390281,0.0,0.032296,0.061957,0.347124,0.346205,0.18355
500,2.585600,2.600233,0.0,0.680361,0.780561,0.867735,0.0,0.402405,0.436774,0.0,0.038716,0.080088,0.387305,0.352894,0.278644


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/157 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.60s/it]
W0109 15:09:17.981000 23123 torch/fx/experimental/symbolic_shapes.py:6823] [0/2] _maybe_guard_rel() was called on non-relation expression Eq(s47, s8) | Eq(s8, 1)


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/157 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.64s/it]
W0109 15:10:30.290000 23123 torch/fx/experimental/symbolic_shapes.py:6823] [0/3] _maybe_guard_rel() was called on non-relation expression Eq(s47, s8) | Eq(s8, 1)


TrainOutput(global_step=782, training_loss=2.6874484479274896, metrics={'train_runtime': 237.0431, 'train_samples_per_second': 421.864, 'train_steps_per_second': 3.299, 'total_flos': 0.0, 'train_loss': 2.6874484479274896, 'epoch': 1.0})

In [11]:
print("=== Precision@k (after fine-tuning) ===")
final_results = ir_evaluator(model)
for key, value in final_results.items():
    if "precision" in key or "recall" in key or "mrr" in key or "ndcg" in key:
        print(f"{key}: {value:.4f}")

print("\n=== Improvement ===")
for key in baseline_results:
    if "precision" in key or "ndcg" in key:
        improvement = final_results[key] - baseline_results[key]
        print(f"{key}: {improvement:+.4f}")

=== Precision@k (after fine-tuning) ===


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/157 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.63s/it]

arxiv-retrieval_cosine_precision@1: 0.0000
arxiv-retrieval_cosine_precision@5: 0.4136
arxiv-retrieval_cosine_precision@10: 0.4480
arxiv-retrieval_cosine_recall@1: 0.0000
arxiv-retrieval_cosine_recall@5: 0.0417
arxiv-retrieval_cosine_recall@10: 0.0836
arxiv-retrieval_cosine_ndcg@10: 0.3978
arxiv-retrieval_cosine_mrr@10: 0.3576

=== Improvement ===
arxiv-retrieval_cosine_precision@1: +0.0000
arxiv-retrieval_cosine_precision@5: +0.0395
arxiv-retrieval_cosine_precision@10: +0.0575
arxiv-retrieval_cosine_ndcg@10: +0.0505





In [12]:
model.save("../models/mnrl-minilm-finetuned")
print("Model saved to ../models/mnrl-minilm-finetuned")

Model saved to ../models/mnrl-minilm-finetuned
