In [1]:
import torch
torch.cuda.is_available()

True

In [2]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

In [3]:
import sys
sys.path.insert(0, "../src")

from mlops_project.data import ArxivPapersDataset
from pathlib import Path
import random

# Load the dataset using the project's data module
data_dir = Path("../data")
train_dataset = ArxivPapersDataset(split="train", data_dir=data_dir).dataset
test_dataset = ArxivPapersDataset(split="test", data_dir=data_dir).dataset

print(f"Train samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Columns: {train_dataset.column_names}")

Train samples: 2039695
Test samples: 509924
Columns: ['primary_subject', 'subjects', 'abstract', 'title']


In [4]:
from collections import defaultdict
import numpy as np
from datasets import Dataset, load_from_disk

def create_contrastive_pairs(dataset, num_pairs: int = 100000, text_field: str = "abstract", seed: int = 42):
    """
    Create balanced positive and negative pairs for ContrastiveLoss.
    
    Returns a dataset with columns: sentence1, sentence2, label
    - label=1.0 for positive pairs (same subject)
    - label=0.0 for negative pairs (different subjects)
    """
    random.seed(seed)
    np.random.seed(seed)
    
    # Group indices by subject
    subject_to_indices = defaultdict(list)
    for idx, subject in enumerate(dataset["primary_subject"]):
        subject_to_indices[subject].append(idx)
    
    subjects = list(subject_to_indices.keys())
    print(f"Found {len(subjects)} unique subjects")
    
    pairs = {"sentence1": [], "sentence2": [], "label": []}
    num_positive = num_pairs // 2
    num_negative = num_pairs - num_positive
    
    # Create positive pairs (same subject)
    print(f"Creating {num_positive} positive pairs...")
    for _ in range(num_positive):
        subject = random.choice([s for s in subjects if len(subject_to_indices[s]) >= 2])
        idx1, idx2 = random.sample(subject_to_indices[subject], 2)
        pairs["sentence1"].append(dataset[idx1][text_field])
        pairs["sentence2"].append(dataset[idx2][text_field])
        pairs["label"].append(1.0)
    
    # Create negative pairs (different subjects)
    print(f"Creating {num_negative} negative pairs...")
    for _ in range(num_negative):
        subj1, subj2 = random.sample(subjects, 2)
        idx1 = random.choice(subject_to_indices[subj1])
        idx2 = random.choice(subject_to_indices[subj2])
        pairs["sentence1"].append(dataset[idx1][text_field])
        pairs["sentence2"].append(dataset[idx2][text_field])
        pairs["label"].append(0.0)
    
    return Dataset.from_dict(pairs)

def load_or_create_pairs(dataset, save_path: Path, num_pairs: int, text_field: str = "abstract", seed: int = 42):
    """Load pairs from disk if they exist, otherwise create and save them."""
    if save_path.exists():
        print(f"Loading cached pairs from {save_path}")
        return load_from_disk(str(save_path))
    
    print(f"Creating new pairs (will be cached at {save_path})")
    pairs = create_contrastive_pairs(dataset, num_pairs=num_pairs, text_field=text_field, seed=seed)
    pairs.save_to_disk(str(save_path))
    return pairs

# Create/load training pairs
train_pairs_path = data_dir / "train_pairs"
train_pairs = load_or_create_pairs(train_dataset, train_pairs_path, num_pairs=100000)
print(f"\nTraining pairs: {len(train_pairs)}")
print(train_pairs[0])

Creating new pairs (will be cached at ../data/train_pairs)
Found 148 unique subjects
Creating 50000 positive pairs...
Creating 50000 negative pairs...


Saving the dataset (0/1 shards):   0%|          | 0/100000 [00:00<?, ? examples/s]


Training pairs: 100000
{'sentence1': 'The paper describes the project, implementation and test of a C-band (5GHz) Low Noise Amplifier (LNA) using new low noise Pseudomorphic High Electron Mobility Transistors (pHEMTS) from Avago. The amplifier was developed to be used as a cost effective solution in a receiver chain for Galactic Emission Mapping (GEM-P) project in Portugal with the objective of finding affordable solutions not requiring strong cryogenic operation, as is the case of massive projects like the Square Kilometer Array (SKA), in Earth Sensing projects and other niches like microwave reflectometry. The particular application and amplifier requirements are first introduced. Several commercially available low noise devices were selected and the noise performance simulated. An ultra-low noise pHEMT was used for an implementation that achieved a Noise Figure of 0.6 dB with 13 dB gain at 5 GHz. The design, simulation and measured results of the prototype are presented and discuss

In [5]:
# Create/load evaluation pairs
eval_pairs_path = data_dir / "eval_pairs"
eval_pairs = load_or_create_pairs(test_dataset, eval_pairs_path, num_pairs=10000, seed=123)
print(f"Evaluation pairs: {len(eval_pairs)}")

Creating new pairs (will be cached at ../data/eval_pairs)
Found 148 unique subjects
Creating 5000 positive pairs...
Creating 5000 negative pairs...


Saving the dataset (0/1 shards):   0%|          | 0/10000 [00:00<?, ? examples/s]

Evaluation pairs: 10000


In [6]:
from sentence_transformers.evaluation import InformationRetrievalEvaluator
import numpy as np

def create_ir_evaluator(dataset, model, sample_size: int = 5000, name: str = "arxiv-retrieval"):
    """
    Create an InformationRetrievalEvaluator for precision@k evaluation.
    
    For each query, relevant documents are those with the same primary_subject.
    This measures how well the model retrieves papers from the same category.
    """
    np.random.seed(42)
    
    # Sample a subset for evaluation
    indices = np.random.choice(len(dataset), min(sample_size, len(dataset)), replace=False)
    
    queries = {}  # query_id -> query text
    corpus = {}   # corpus_id -> document text
    relevant_docs = {}  # query_id -> set of relevant corpus_ids
    
    # Build corpus and group by subject
    subject_to_corpus_ids = defaultdict(set)
    
    for i, idx in enumerate(indices):
        idx = int(idx)
        corpus_id = f"doc_{i}"
        corpus[corpus_id] = dataset[idx]["abstract"]
        subject = dataset[idx]["primary_subject"]
        subject_to_corpus_ids[subject].add(corpus_id)
    
    # Use a subset as queries, rest as corpus for retrieval
    query_indices = indices[:sample_size // 5]  # 20% as queries
    
    for i, idx in enumerate(query_indices):
        idx = int(idx)
        query_id = f"query_{i}"
        queries[query_id] = dataset[idx]["abstract"]
        subject = dataset[idx]["primary_subject"]
        # Relevant docs are all docs with same subject (excluding self)
        doc_id = f"doc_{i}"
        relevant_docs[query_id] = subject_to_corpus_ids[subject] - {doc_id}
    
    # Filter out queries with no relevant docs
    queries = {qid: q for qid, q in queries.items() if len(relevant_docs.get(qid, set())) > 0}
    relevant_docs = {qid: docs for qid, docs in relevant_docs.items() if qid in queries}
    
    print(f"IR Evaluator: {len(queries)} queries, {len(corpus)} corpus docs")
    
    return InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name=name,
        precision_recall_at_k=[1, 5, 10],
        show_progress_bar=True,
    )

# Create the IR evaluator for precision@k
ir_evaluator = create_ir_evaluator(test_dataset, model, sample_size=5000)
print("IR Evaluator created for precision@k metrics")

IR Evaluator: 998 queries, 5000 corpus docs
IR Evaluator created for precision@k metrics


In [7]:
# Evaluate baseline (before fine-tuning)
print("=== Baseline Precision@k (before fine-tuning) ===")
baseline_results = ir_evaluator(model)
for key, value in baseline_results.items():
    if "precision" in key or "recall" in key or "mrr" in key or "ndcg" in key:
        print(f"{key}: {value:.4f}")

=== Baseline Precision@k (before fine-tuning) ===


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/157 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.54s/it]

arxiv-retrieval_cosine_precision@1: 0.0000
arxiv-retrieval_cosine_precision@5: 0.3741
arxiv-retrieval_cosine_precision@10: 0.3905
arxiv-retrieval_cosine_recall@1: 0.0000
arxiv-retrieval_cosine_recall@5: 0.0323
arxiv-retrieval_cosine_recall@10: 0.0620
arxiv-retrieval_cosine_ndcg@10: 0.3473
arxiv-retrieval_cosine_mrr@10: 0.3461





In [8]:
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import ContrastiveLoss

# Define training arguments
training_args = SentenceTransformerTrainingArguments(
    output_dir="../models/contrastive-minilm",
    num_train_epochs=1,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    eval_strategy="steps",
    eval_steps=500,
    save_strategy="steps",
    save_steps=500,
    logging_steps=100,
    fp16=torch.cuda.is_available(),
)

# Initialize loss function
loss = ContrastiveLoss(model)

# Create trainer with IR evaluator for precision@k
trainer = SentenceTransformerTrainer(
    model=model,
    args=training_args,
    train_dataset=train_pairs,
    eval_dataset=eval_pairs,
    loss=loss,
    evaluator=ir_evaluator,
)

print("Trainer initialized with precision@k evaluator. Ready to train!")

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Trainer initialized with precision@k evaluator. Ready to train!


In [9]:
# Start training
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mthorhojhus[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss,Arxiv-retrieval Cosine Accuracy@1,Arxiv-retrieval Cosine Accuracy@3,Arxiv-retrieval Cosine Accuracy@5,Arxiv-retrieval Cosine Accuracy@10,Arxiv-retrieval Cosine Precision@1,Arxiv-retrieval Cosine Precision@5,Arxiv-retrieval Cosine Precision@10,Arxiv-retrieval Cosine Recall@1,Arxiv-retrieval Cosine Recall@5,Arxiv-retrieval Cosine Recall@10,Arxiv-retrieval Cosine Ndcg@10,Arxiv-retrieval Cosine Mrr@10,Arxiv-retrieval Cosine Map@100
500,0.0141,0.014209,0.0,0.587174,0.728457,0.845691,0.0,0.334469,0.35521,0.0,0.02752,0.056093,0.314754,0.31805,0.175042
1000,0.0137,0.013575,0.0,0.601202,0.729459,0.844689,0.0,0.336874,0.359619,0.0,0.028562,0.057467,0.318682,0.320458,0.182369
1500,0.0131,0.013441,0.0,0.59018,0.724449,0.840681,0.0,0.336273,0.359519,0.0,0.027743,0.057915,0.318382,0.317599,0.184057


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/157 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.61s/it]


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/157 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.61s/it]


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/157 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.61s/it]


TrainOutput(global_step=1563, training_loss=0.01710545784071021, metrics={'train_runtime': 259.6964, 'train_samples_per_second': 385.065, 'train_steps_per_second': 6.019, 'total_flos': 0.0, 'train_loss': 0.01710545784071021, 'epoch': 1.0})

In [10]:
# Evaluate after fine-tuning
print("=== Precision@k (after fine-tuning) ===")
final_results = ir_evaluator(model)
for key, value in final_results.items():
    if "precision" in key or "recall" in key or "mrr" in key or "ndcg" in key:
        print(f"{key}: {value:.4f}")

# Compare improvement
print("\n=== Improvement ===")
for key in baseline_results:
    if "precision" in key or "ndcg" in key:
        improvement = final_results[key] - baseline_results[key]
        print(f"{key}: {improvement:+.4f}")

=== Precision@k (after fine-tuning) ===


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/157 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.62s/it]

arxiv-retrieval_cosine_precision@1: 0.0000
arxiv-retrieval_cosine_precision@5: 0.3369
arxiv-retrieval_cosine_precision@10: 0.3602
arxiv-retrieval_cosine_recall@1: 0.0000
arxiv-retrieval_cosine_recall@5: 0.0280
arxiv-retrieval_cosine_recall@10: 0.0581
arxiv-retrieval_cosine_ndcg@10: 0.3190
arxiv-retrieval_cosine_mrr@10: 0.3183

=== Improvement ===
arxiv-retrieval_cosine_precision@1: +0.0000
arxiv-retrieval_cosine_precision@5: -0.0373
arxiv-retrieval_cosine_precision@10: -0.0303
arxiv-retrieval_cosine_ndcg@10: -0.0283



