In [1]:
NUM_PAIRS = int(1e6)

In [2]:
import torch

torch.cuda.is_available()

True

In [3]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

In [4]:
import sys

sys.path.insert(0, "../src")

from mlops_project.data import ArxivPapersDataset
from pathlib import Path
import random

data_dir = Path("../data")
train_dataset = ArxivPapersDataset(split="train", data_dir=data_dir).dataset
test_dataset = ArxivPapersDataset(split="test", data_dir=data_dir).dataset

print(f"Train samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Columns: {train_dataset.column_names}")

Train samples: 2039695
Test samples: 509924
Columns: ['primary_subject', 'subjects', 'abstract', 'title']


In [5]:
from collections import defaultdict
import numpy as np
from datasets import Dataset, load_from_disk


def create_positive_pairs(dataset, num_pairs: int = 100000, text_field: str = "abstract", seed: int = 42):
    """
    Create positive pairs for MultipleNegativesRankingLoss.

    Returns a dataset with columns: anchor, positive
    Each pair contains two abstracts from papers with the same primary_subject.
    MNRL will use in-batch negatives automatically.
    """
    random.seed(seed)
    np.random.seed(seed)

    subject_to_indices = defaultdict(list)
    for idx, subject in enumerate(dataset["primary_subject"]):
        subject_to_indices[subject].append(idx)

    subjects = [s for s in subject_to_indices.keys() if len(subject_to_indices[s]) >= 2]
    print(f"Found {len(subjects)} subjects with 2+ samples")

    pairs = {"anchor": [], "positive": []}

    print(f"Creating {num_pairs} positive pairs...")
    for _ in range(num_pairs):
        subject = random.choice(subjects)
        idx1, idx2 = random.sample(subject_to_indices[subject], 2)
        pairs["anchor"].append(dataset[idx1][text_field])
        pairs["positive"].append(dataset[idx2][text_field])

    return Dataset.from_dict(pairs)


def load_or_create_pairs(dataset, save_path: Path, num_pairs: int, text_field: str = "abstract", seed: int = 42):
    if save_path.exists():
        print(f"Loading cached pairs from {save_path}")
        return load_from_disk(str(save_path))

    print(f"Creating new pairs (will be cached at {save_path})")
    pairs = create_positive_pairs(dataset, num_pairs=num_pairs, text_field=text_field, seed=seed)
    pairs.save_to_disk(str(save_path))
    return pairs

In [None]:
train_pairs_path = data_dir / f"train_pairs_mnrl_{NUM_PAIRS}"
train_pairs = load_or_create_pairs(train_dataset, train_pairs_path, num_pairs=NUM_PAIRS)
print(f"\nTraining pairs: {len(train_pairs)}")

Loading cached pairs from ../data/train_pairs_mnrl_1000000

Training pairs: 1000000
{'anchor': 'The paper describes the project, implementation and test of a C-band (5GHz) Low Noise Amplifier (LNA) using new low noise Pseudomorphic High Electron Mobility Transistors (pHEMTS) from Avago. The amplifier was developed to be used as a cost effective solution in a receiver chain for Galactic Emission Mapping (GEM-P) project in Portugal with the objective of finding affordable solutions not requiring strong cryogenic operation, as is the case of massive projects like the Square Kilometer Array (SKA), in Earth Sensing projects and other niches like microwave reflectometry. The particular application and amplifier requirements are first introduced. Several commercially available low noise devices were selected and the noise performance simulated. An ultra-low noise pHEMT was used for an implementation that achieved a Noise Figure of 0.6 dB with 13 dB gain at 5 GHz. The design, simulation and me

In [None]:
# from torch.utils.data import Dataset

# class DynamicPairDataset(Dataset):
#     def __init__(self, dataset, text_field="abstract", size=100000):
#         self.dataset = dataset
#         self.text_field = text_field
#         self.size = size
        
#         # Build category index once
#         self.subject_to_indices = defaultdict(list)
#         for idx, subject in enumerate(dataset["primary_subject"]):
#             self.subject_to_indices[subject].append(idx)
        
#         self.subjects = [s for s in self.subject_to_indices 
#                         if len(self.subject_to_indices[s]) >= 2]
    
#     def __len__(self):
#         return self.size  # Controls "epoch" length
    
#     def __getitem__(self, idx):
#         # Ignore idx — sample fresh every time
#         subject = random.choice(self.subjects)
#         i, j = random.sample(self.subject_to_indices[subject], 2)
#         return {
#             "anchor": self.dataset[i][self.text_field],
#             "positive": self.dataset[j][self.text_field]
#         }

# train_pairs_dynamic = DynamicPairDataset(train_dataset, size=NUM_PAIRS)

In [None]:
# from datasets import IterableDataset, Features, Value

# def pair_generator(dataset, subject_to_indices, subjects, text_field):
#     while True:  # Infinite generator
#         subject = random.choice(subjects)
#         i, j = random.sample(subject_to_indices[subject], 2)
#         yield {
#             "anchor": dataset[i][text_field],
#             "positive": dataset[j][text_field]
#         }

# # Build index
# subject_to_indices = defaultdict(list)
# for idx, subject in enumerate(train_dataset["primary_subject"]):
#     subject_to_indices[subject].append(idx)
# subjects = [s for s in subject_to_indices if len(subject_to_indices[s]) >= 2]


In [None]:
# train_pairs_dynamic = IterableDataset.from_generator(
#     pair_generator,
#     features=Features({
#         "anchor": Value("string"),
#         "positive": Value("string")
#     }),
#     gen_kwargs={
#         "dataset": train_dataset,
#         "subject_to_indices": subject_to_indices,
#         "subjects": subjects,
#         "text_field": "abstract"
#     }
# )

In [7]:
eval_pairs_path = data_dir / "eval_pairs_mnrl"
eval_pairs = load_or_create_pairs(test_dataset, eval_pairs_path, num_pairs=10000, seed=123)
print(f"Evaluation pairs: {len(eval_pairs)}")

Loading cached pairs from ../data/eval_pairs_mnrl
Evaluation pairs: 10000


In [None]:
from sentence_transformers.evaluation import InformationRetrievalEvaluator


def create_ir_evaluator(dataset, sample_size: int = 5000, name: str = "arxiv-retrieval"):
    np.random.seed(42)
    indices = np.random.choice(len(dataset), min(sample_size, len(dataset)), replace=False)

    # Split indices: first 20% for queries, rest for corpus (no overlap)
    num_queries = sample_size // 5
    query_indices = indices[:num_queries]
    corpus_indices = indices[num_queries:]

    queries = {}
    corpus = {}
    relevant_docs = {}
    subject_to_corpus_ids = defaultdict(set)

    # Build corpus from corpus_indices only
    for i, idx in enumerate(corpus_indices):
        idx = int(idx)
        corpus_id = f"doc_{i}"
        corpus[corpus_id] = dataset[idx]["abstract"]
        subject = dataset[idx]["primary_subject"]
        subject_to_corpus_ids[subject].add(corpus_id)

    # Build queries from query_indices (no overlap with corpus)
    for i, idx in enumerate(query_indices):
        idx = int(idx)
        query_id = f"query_{i}"
        queries[query_id] = dataset[idx]["abstract"]
        subject = dataset[idx]["primary_subject"]
        # Relevant docs are corpus docs with same subject
        relevant_docs[query_id] = subject_to_corpus_ids[subject].copy()

    # Filter out queries with no relevant docs
    queries = {qid: q for qid, q in queries.items() if len(relevant_docs.get(qid, set())) > 0}
    relevant_docs = {qid: docs for qid, docs in relevant_docs.items() if qid in queries}

    print(f"IR Evaluator: {len(queries)} queries, {len(corpus)} corpus docs (no overlap)")

    return InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name=name,
        precision_recall_at_k=[1, 5, 10],
    )


ir_evaluator = create_ir_evaluator(test_dataset, sample_size=5000)
print("IR Evaluator created for precision@k metrics")

IR Evaluator: 996 queries, 4000 corpus docs (no overlap)
IR Evaluator created for precision@k metrics


In [9]:
print("=== Baseline Precision@k (before fine-tuning) ===")
baseline_results = ir_evaluator(model)
for key, value in baseline_results.items():
    if "precision" in key or "recall" in key or "mrr" in key or "ndcg" in key:
        print(f"{key}: {value:.4f}")

=== Baseline Precision@k (before fine-tuning) ===


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:03<00:00,  3.06s/it]

arxiv-retrieval_cosine_precision@1: 0.4900
arxiv-retrieval_cosine_precision@5: 0.4466
arxiv-retrieval_cosine_precision@10: 0.4149
arxiv-retrieval_cosine_recall@1: 0.0115
arxiv-retrieval_cosine_recall@5: 0.0453
arxiv-retrieval_cosine_recall@10: 0.0767
arxiv-retrieval_cosine_ndcg@10: 0.4343
arxiv-retrieval_cosine_mrr@10: 0.6277





In [10]:
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import MultipleNegativesRankingLoss

training_args = SentenceTransformerTrainingArguments(
    output_dir="../models/mnrl-minilm",
    num_train_epochs=1,
    per_device_train_batch_size=256,
    per_device_eval_batch_size=256,
    learning_rate=3e-4,
    warmup_ratio=0.1,
    eval_strategy="steps",
    eval_steps=500,
    eval_on_start=True,
    save_strategy="steps",
    save_steps=500,
    logging_steps=100,
    torch_compile=torch.cuda.is_available(),
    fp16=torch.cuda.is_available(),
    tf32=torch.cuda.is_available(),
)

# MultipleNegativesRankingLoss uses in-batch negatives
loss = MultipleNegativesRankingLoss(model)

trainer = SentenceTransformerTrainer(
    model=model,
    args=training_args,
    train_dataset=train_pairs,
    eval_dataset=eval_pairs,
    loss=loss,
    evaluator=ir_evaluator,
)

print("Trainer initialized with MNRL loss. Ready to train!")

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Trainer initialized with MNRL loss. Ready to train!


In [11]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mthorhojhus[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss,Arxiv-retrieval Cosine Accuracy@1,Arxiv-retrieval Cosine Accuracy@3,Arxiv-retrieval Cosine Accuracy@5,Arxiv-retrieval Cosine Accuracy@10,Arxiv-retrieval Cosine Precision@1,Arxiv-retrieval Cosine Precision@5,Arxiv-retrieval Cosine Precision@10,Arxiv-retrieval Cosine Recall@1,Arxiv-retrieval Cosine Recall@5,Arxiv-retrieval Cosine Recall@10,Arxiv-retrieval Cosine Ndcg@10,Arxiv-retrieval Cosine Mrr@10,Arxiv-retrieval Cosine Map@100
0,No log,4.406813,0.48996,0.742972,0.815261,0.881526,0.48996,0.446386,0.414759,0.011522,0.045248,0.076651,0.43423,0.627709,0.200542
500,3.252900,3.227887,0.538153,0.728916,0.797189,0.870482,0.538153,0.499398,0.479418,0.012977,0.055967,0.102437,0.496657,0.648392,0.316256
1000,3.024300,3.11339,0.574297,0.772088,0.827309,0.875502,0.574297,0.534137,0.508434,0.015347,0.062078,0.111604,0.52813,0.679397,0.36186
1500,2.873900,3.035457,0.59739,0.761044,0.817269,0.868474,0.59739,0.545181,0.526908,0.015089,0.065725,0.117867,0.545232,0.689752,0.380117
2000,2.786500,2.991063,0.594378,0.792169,0.833333,0.893574,0.594378,0.562851,0.538454,0.01583,0.067467,0.120198,0.556081,0.699645,0.398593
2500,2.679100,2.980253,0.598394,0.788153,0.849398,0.889558,0.598394,0.576305,0.5501,0.015469,0.069449,0.124311,0.567519,0.704266,0.413639
3000,2.631100,2.946059,0.614458,0.804217,0.842369,0.881526,0.614458,0.581124,0.563554,0.016309,0.069954,0.129002,0.581056,0.713853,0.425566
3500,2.567200,2.947726,0.61747,0.803213,0.850402,0.89257,0.61747,0.590763,0.569378,0.016457,0.071804,0.130638,0.586713,0.71805,0.43386


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.08s/it]
W0109 20:22:03.239000 139213 torch/fx/experimental/symbolic_shapes.py:6823] [0/2] _maybe_guard_rel() was called on non-relation expression Eq(s47, s8) | Eq(s8, 1)


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.15s/it]


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.13s/it]


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.28s/it]


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:05<00:00,  5.08s/it]


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.13s/it]


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.14s/it]


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.14s/it]
W0109 20:55:30.649000 139213 torch/fx/experimental/symbolic_shapes.py:6823] [0/3] _maybe_guard_rel() was called on non-relation expression Eq(s47, s8) | Eq(s8, 1)


TrainOutput(global_step=3907, training_loss=2.8535287982519857, metrics={'train_runtime': 2304.459, 'train_samples_per_second': 433.941, 'train_steps_per_second': 1.695, 'total_flos': 0.0, 'train_loss': 2.8535287982519857, 'epoch': 1.0})

In [12]:
print("=== Precision@k (after fine-tuning) ===")
final_results = ir_evaluator(model)
for key, value in final_results.items():
    if "precision" in key or "recall" in key or "mrr" in key or "ndcg" in key:
        print(f"{key}: {value:.4f}")

print("\n=== Improvement ===")
for key in baseline_results:
    if "precision" in key or "ndcg" in key:
        improvement = final_results[key] - baseline_results[key]
        print(f"{key}: {improvement:+.4f}")

=== Precision@k (after fine-tuning) ===


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.12s/it]

arxiv-retrieval_cosine_precision@1: 0.6084
arxiv-retrieval_cosine_precision@5: 0.5878
arxiv-retrieval_cosine_precision@10: 0.5700
arxiv-retrieval_cosine_recall@1: 0.0168
arxiv-retrieval_cosine_recall@5: 0.0718
arxiv-retrieval_cosine_recall@10: 0.1304
arxiv-retrieval_cosine_ndcg@10: 0.5869
arxiv-retrieval_cosine_mrr@10: 0.7142

=== Improvement ===
arxiv-retrieval_cosine_precision@1: +0.1185
arxiv-retrieval_cosine_precision@5: +0.1412
arxiv-retrieval_cosine_precision@10: +0.1551
arxiv-retrieval_cosine_ndcg@10: +0.1526





In [13]:
model.save("../models/mnrl-minilm-finetuned")
print("Model saved to ../models/mnrl-minilm-finetuned")

Model saved to ../models/mnrl-minilm-finetuned


In [18]:
import numpy as np
import random
from collections import defaultdict
from sentence_transformers.evaluation import InformationRetrievalEvaluator

def create_stratified_ir_evaluator(dataset, samples_per_subject: int = 10, name: str = "arxiv-stratified"):
    """
    Creates an IR Evaluator that guarantees coverage of all subjects.
    
    Args:
        dataset: The Hugging Face dataset.
        samples_per_subject: Max number of papers to use per subject (split between query/corpus).
                             Rare subjects with fewer papers will use all available.
    """
    # 1. Group all indices by subject
    subject_to_indices = defaultdict(list)
    for idx, subject in enumerate(dataset["primary_subject"]):
        subject_to_indices[subject].append(idx)
    
    queries = {}
    corpus = {}
    relevant_docs = {}
    
    # Track which corpus IDs belong to which subject for relevance matching
    subject_to_corpus_ids = defaultdict(set)
    
    print(f"Stratifying across {len(subject_to_indices)} subjects...")
    
    # 2. Iterate through every subject to select Query/Corpus pairs
    for subject, indices in subject_to_indices.items():
        if len(indices) < 2:
            continue  # Need at least 1 query and 1 corpus doc
            
        # Shuffle indices for this subject
        random.shuffle(indices)
        
        # Select subset for this subject
        # If subject is huge (cs.CV), cap at samples_per_subject to prevent dominance
        # If subject is tiny, take all of them
        n_samples = min(len(indices), samples_per_subject)
        selected_indices = indices[:n_samples]
        
        # Split: 20% queries (min 1), rest corpus
        n_queries = max(1, int(n_samples * 0.2))
        
        query_idxs = selected_indices[:n_queries]
        corpus_idxs = selected_indices[n_queries:]
        
        # Add to global dictionaries
        for idx in corpus_idxs:
            corpus_id = f"doc_{idx}"
            corpus[corpus_id] = dataset[int(idx)]["abstract"]
            subject_to_corpus_ids[subject].add(corpus_id)
            
        for idx in query_idxs:
            query_id = f"query_{idx}"
            queries[query_id] = dataset[int(idx)]["abstract"]
            # Store subject temporarily to link relevance later
            queries[query_id + "_subj"] = subject

    # 3. Link Queries to Relevant Docs
    # A query is relevant to ANY doc in the corpus with the same subject
    final_queries = {}
    
    for q_key, q_text in queries.items():
        if "_subj" in q_key: continue
        
        subject = queries[q_key + "_subj"]
        rel_docs = subject_to_corpus_ids[subject]
        
        if len(rel_docs) > 0:
            final_queries[q_key] = q_text
            relevant_docs[q_key] = rel_docs
            
    print(f"Stratified Evaluator: {len(final_queries)} queries, {len(corpus)} corpus docs")
    print(f"Coverage: {len(subject_to_corpus_ids)} subjects represented")

    return InformationRetrievalEvaluator(
        queries=final_queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name=name,
        precision_recall_at_k=[1, 5, 10],
        show_progress_bar=True
    )

# Usage
stratified_evaluator = create_stratified_ir_evaluator(test_dataset, samples_per_subject=1000)

Stratifying across 148 subjects...
Stratified Evaluator: 24163 queries, 96763 corpus docs
Coverage: 148 subjects represented


In [19]:
baseline_results = stratified_evaluator(model)
for key, value in baseline_results.items():
    print(f"{key}: {value:.4f}")

Batches:   0%|          | 0/756 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/1563 [00:00<?, ?it/s]

Corpus Chunks:  50%|█████     | 1/2 [00:29<00:29, 29.99s/it]

Batches:   0%|          | 0/1462 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 2/2 [00:57<00:00, 28.97s/it]


arxiv-stratified_cosine_accuracy@1: 0.5356
arxiv-stratified_cosine_accuracy@3: 0.7338
arxiv-stratified_cosine_accuracy@5: 0.7972
arxiv-stratified_cosine_accuracy@10: 0.8579
arxiv-stratified_cosine_precision@1: 0.5356
arxiv-stratified_cosine_precision@5: 0.5260
arxiv-stratified_cosine_precision@10: 0.5220
arxiv-stratified_cosine_recall@1: 0.0008
arxiv-stratified_cosine_recall@5: 0.0038
arxiv-stratified_cosine_recall@10: 0.0076
arxiv-stratified_cosine_ndcg@10: 0.5246
arxiv-stratified_cosine_mrr@10: 0.6470
arxiv-stratified_cosine_map@100: 0.3816


In [32]:
import pandas as pd
import numpy as np
from collections import defaultdict
from sentence_transformers import util
import torch
import random

def run_full_stratified_eval(model, dataset, samples_per_subject=10):
    """
    Comprehensive IR evaluation with 10 samples per subject.
    Calculates Precision, Recall, MRR, and NDCG.
    """
    # 1. Stratified Sampling Logic
    subject_to_indices = defaultdict(list)
    for idx, subject in enumerate(dataset["primary_subject"]):
        subject_to_indices[subject].append(idx)
    
    valid_subjects = [s for s, idxs in subject_to_indices.items() if len(idxs) >= 2]
    
    queries, query_subjects = [], []
    corpus_texts, corpus_ids = [], []
    subject_to_corpus_ids = defaultdict(set)
    current_corpus_idx = 0

    for subject in valid_subjects:
        indices = subject_to_indices[subject]
        n_to_pick = min(len(indices), samples_per_subject)
        selected_idxs = random.sample(indices, n_to_pick)
        
        n_queries = max(1, int(n_to_pick * 0.2))
        q_idxs, c_idxs = selected_idxs[:n_queries], selected_idxs[n_queries:]
        
        for idx in c_idxs:
            corpus_texts.append(dataset[int(idx)]["abstract"])
            corpus_ids.append(current_corpus_idx)
            subject_to_corpus_ids[subject].add(current_corpus_idx)
            current_corpus_idx += 1
            
        for idx in q_idxs:
            queries.append(dataset[int(idx)]["abstract"])
            query_subjects.append(subject)

    # 2. Vectorized Search
    corpus_embeddings = model.encode(corpus_texts, convert_to_tensor=True, show_progress_bar=True)
    query_embeddings = model.encode(queries, convert_to_tensor=True, show_progress_bar=True)
    results = util.semantic_search(query_embeddings, corpus_embeddings, top_k=10)
    
    # 3. Metric Calculation
    subject_metrics = defaultdict(lambda: {"p10": [], "r10": [], "mrr": [], "ndcg": []})
    
    for i, hits in enumerate(results):
        subj = query_subjects[i]
        rel_set = subject_to_corpus_ids[subj]
        num_rel_in_corpus = len(rel_set)
        
        is_rel = [hit['corpus_id'] in rel_set for hit in hits]
        
        # Precision @ 10
        p10 = sum(is_rel) / 10
        # Recall @ 10 (How many of the available relevant docs were found)
        r10 = sum(is_rel) / num_rel_in_corpus if num_rel_in_corpus > 0 else 0
        
        # MRR (Reciprocal of the rank of the FIRST relevant document)
        mrr = 0
        for rank, found in enumerate(is_rel):
            if found:
                mrr = 1 / (rank + 1)
                break
        
        # NDCG @ 10 (Normalized Discounted Cumulative Gain)
        dcg = sum([int(found) / np.log2(rank + 2) for rank, found in enumerate(is_rel)])
        idcg = sum([1.0 / np.log2(rank + 2) for rank in range(min(num_rel_in_corpus, 10))])
        ndcg = dcg / idcg if idcg > 0 else 0

        subject_metrics[subj]["p10"].append(p10)
        subject_metrics[subj]["r10"].append(r10)
        subject_metrics[subj]["mrr"].append(mrr)
        subject_metrics[subj]["ndcg"].append(ndcg)

    # 4. Aggregation & Formatting
    data = []
    for subj, metrics in subject_metrics.items():
        data.append({
            "Subject": subj,
            "Count": len(metrics["p10"]),
            "P@10": np.mean(metrics["p10"]),
            "R@10": np.mean(metrics["r10"]),
            "MRR": np.mean(metrics["mrr"]),
            "NDCG@10": np.mean(metrics["ndcg"])
        })
        
    df = pd.DataFrame(data).sort_values("NDCG@10", ascending=False).reset_index(drop=True)
    return df

# Run full evaluation
full_stats_df = run_full_stratified_eval(model, test_dataset, samples_per_subject=50)
full_stats_df

Batches:   0%|          | 0/185 [00:00<?, ?it/s]

Batches:   0%|          | 0/47 [00:00<?, ?it/s]

Unnamed: 0,Subject,Count,P@10,R@10,MRR,NDCG@10
0,Exactly Solvable and Integrable Systems (nlin.SI),10,0.95,0.2375,0.950000,0.942379
1,High Energy Physics - Lattice (hep-lat),10,0.93,0.2325,0.933333,0.926135
2,Digital Libraries (cs.DL),10,0.91,0.2275,0.933333,0.911005
3,Robotics (cs.RO),10,0.89,0.2225,1.000000,0.908445
4,Earth and Planetary Astrophysics (astro-ph.EP),10,0.86,0.2150,0.870000,0.856385
...,...,...,...,...,...,...
143,Applications (stat.AP),10,0.11,0.0275,0.208333,0.116422
144,Other Computer Science (cs.OH),10,0.10,0.0250,0.190278,0.102946
145,"Computational Engineering, Finance, and Scienc...",10,0.06,0.0150,0.171111,0.066202
146,Quantitative Methods (q-bio.QM),10,0.04,0.0100,0.114286,0.045768


In [33]:
pd.set_option('display.max_rows', None)  # Ensure all 148 subjects are printed
pd.set_option('display.expand_frame_repr', False)

In [34]:
full_stats_df

Unnamed: 0,Subject,Count,P@10,R@10,MRR,NDCG@10
0,Exactly Solvable and Integrable Systems (nlin.SI),10,0.95,0.2375,0.95,0.942379
1,High Energy Physics - Lattice (hep-lat),10,0.93,0.2325,0.933333,0.926135
2,Digital Libraries (cs.DL),10,0.91,0.2275,0.933333,0.911005
3,Robotics (cs.RO),10,0.89,0.2225,1.0,0.908445
4,Earth and Planetary Astrophysics (astro-ph.EP),10,0.86,0.215,0.87,0.856385
5,Differential Geometry (math.DG),10,0.81,0.2025,1.0,0.834632
6,Superconductivity (cond-mat.supr-con),10,0.82,0.205,0.861111,0.815075
7,Accelerator Physics (physics.acc-ph),10,0.79,0.1975,0.861111,0.801827
8,Information Retrieval (cs.IR),10,0.78,0.195,0.9,0.794021
9,Plasma Physics (physics.plasm-ph),10,0.77,0.1925,0.85,0.778866
