In [1]:
import torch

torch.cuda.is_available()

True

In [2]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

In [3]:
import sys

sys.path.insert(0, "../src")

from mlops_project.data import ArxivPapersDataset
from pathlib import Path
import random

# Load the dataset using the project's data module
data_dir = Path("../data")
train_dataset = ArxivPapersDataset(split="train", data_dir=data_dir).dataset
test_dataset = ArxivPapersDataset(split="test", data_dir=data_dir).dataset

print(f"Train samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Columns: {train_dataset.column_names}")

Train samples: 2039695
Test samples: 509924
Columns: ['primary_subject', 'subjects', 'abstract', 'title']


In [4]:
from collections import defaultdict
import numpy as np
from datasets import Dataset, load_from_disk


def create_contrastive_pairs(dataset, num_pairs: int = 100000, text_field: str = "abstract", seed: int = 42):
    """
    Create balanced positive and negative pairs for ContrastiveLoss.

    Returns a dataset with columns: sentence1, sentence2, label
    - label=1.0 for positive pairs (same subject)
    - label=0.0 for negative pairs (different subjects)
    """
    random.seed(seed)
    np.random.seed(seed)

    # Group indices by subject
    subject_to_indices = defaultdict(list)
    for idx, subject in enumerate(dataset["primary_subject"]):
        subject_to_indices[subject].append(idx)

    subjects = list(subject_to_indices.keys())
    print(f"Found {len(subjects)} unique subjects")

    pairs = {"sentence1": [], "sentence2": [], "label": []}
    num_positive = num_pairs // 2
    num_negative = num_pairs - num_positive

    # Create positive pairs (same subject)
    print(f"Creating {num_positive} positive pairs...")
    for _ in range(num_positive):
        subject = random.choice([s for s in subjects if len(subject_to_indices[s]) >= 2])
        idx1, idx2 = random.sample(subject_to_indices[subject], 2)
        pairs["sentence1"].append(dataset[idx1][text_field])
        pairs["sentence2"].append(dataset[idx2][text_field])
        pairs["label"].append(1.0)

    # Create negative pairs (different subjects)
    print(f"Creating {num_negative} negative pairs...")
    for _ in range(num_negative):
        subj1, subj2 = random.sample(subjects, 2)
        idx1 = random.choice(subject_to_indices[subj1])
        idx2 = random.choice(subject_to_indices[subj2])
        pairs["sentence1"].append(dataset[idx1][text_field])
        pairs["sentence2"].append(dataset[idx2][text_field])
        pairs["label"].append(0.0)

    return Dataset.from_dict(pairs)


def load_or_create_pairs(dataset, save_path: Path, num_pairs: int, text_field: str = "abstract", seed: int = 42):
    """Load pairs from disk if they exist, otherwise create and save them."""
    if save_path.exists():
        print(f"Loading cached pairs from {save_path}")
        return load_from_disk(str(save_path))

    print(f"Creating new pairs (will be cached at {save_path})")
    pairs = create_contrastive_pairs(dataset, num_pairs=num_pairs, text_field=text_field, seed=seed)
    pairs.save_to_disk(str(save_path))
    return pairs

In [5]:
# Create/load training pairs
train_pairs_path = data_dir / "train_pairs_cl_1000000"
train_pairs = load_or_create_pairs(train_dataset, train_pairs_path, num_pairs=1000000, seed=42)
print(f"\nTraining pairs: {len(train_pairs)}")
# print(train_pairs[0])

Creating new pairs (will be cached at ../data/train_pairs_cl_1000000)
Found 148 unique subjects
Creating 500000 positive pairs...
Creating 500000 negative pairs...


Saving the dataset (0/4 shards):   0%|          | 0/1000000 [00:00<?, ? examples/s]


Training pairs: 1000000


In [6]:
# Create/load evaluation pairs
eval_pairs_path = data_dir / "eval_pairs_cl"
eval_pairs = load_or_create_pairs(test_dataset, eval_pairs_path, num_pairs=10000, seed=42)
print(f"Evaluation pairs: {len(eval_pairs)}")

Loading cached pairs from ../data/eval_pairs_cl
Evaluation pairs: 10000


In [7]:
from sentence_transformers.evaluation import InformationRetrievalEvaluator

def create_ir_evaluator(dataset, sample_size: int = 5000, name: str = "arxiv-retrieval"):
    np.random.seed(42)
    indices = np.random.choice(len(dataset), min(sample_size, len(dataset)), replace=False)
    
    # Split indices: first 20% for queries, rest for corpus (no overlap)
    num_queries = sample_size // 5
    query_indices = indices[:num_queries]
    corpus_indices = indices[num_queries:]
    
    queries = {}
    corpus = {}
    relevant_docs = {}
    subject_to_corpus_ids = defaultdict(set)
    
    # Build corpus from corpus_indices only
    for i, idx in enumerate(corpus_indices):
        idx = int(idx)
        corpus_id = f"doc_{i}"
        corpus[corpus_id] = dataset[idx]["abstract"]
        subject = dataset[idx]["primary_subject"]
        subject_to_corpus_ids[subject].add(corpus_id)
    
    # Build queries from query_indices (no overlap with corpus)
    for i, idx in enumerate(query_indices):
        idx = int(idx)
        query_id = f"query_{i}"
        queries[query_id] = dataset[idx]["abstract"]
        subject = dataset[idx]["primary_subject"]
        # Relevant docs are corpus docs with same subject
        relevant_docs[query_id] = subject_to_corpus_ids[subject].copy()
    
    # Filter out queries with no relevant docs
    queries = {qid: q for qid, q in queries.items() if len(relevant_docs.get(qid, set())) > 0}
    relevant_docs = {qid: docs for qid, docs in relevant_docs.items() if qid in queries}
    
    print(f"IR Evaluator: {len(queries)} queries, {len(corpus)} corpus docs (no overlap)")
    
    return InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name=name,
        precision_recall_at_k=[1, 5, 10],
        show_progress_bar=True,
    )

ir_evaluator = create_ir_evaluator(test_dataset, sample_size=5000)
print("IR Evaluator created for precision@k metrics")

IR Evaluator: 996 queries, 4000 corpus docs (no overlap)
IR Evaluator created for precision@k metrics


In [8]:
# Evaluate baseline (before fine-tuning)
print("=== Baseline Precision@k (before fine-tuning) ===")
baseline_results = ir_evaluator(model)
for key, value in baseline_results.items():
    print(f"{key}: {value:.4f}")

=== Baseline Precision@k (before fine-tuning) ===


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.05s/it]

arxiv-retrieval_cosine_accuracy@1: 0.4900
arxiv-retrieval_cosine_accuracy@3: 0.7430
arxiv-retrieval_cosine_accuracy@5: 0.8153
arxiv-retrieval_cosine_accuracy@10: 0.8815
arxiv-retrieval_cosine_precision@1: 0.4900
arxiv-retrieval_cosine_precision@5: 0.4466
arxiv-retrieval_cosine_precision@10: 0.4149
arxiv-retrieval_cosine_recall@1: 0.0115
arxiv-retrieval_cosine_recall@5: 0.0453
arxiv-retrieval_cosine_recall@10: 0.0767
arxiv-retrieval_cosine_ndcg@10: 0.4343
arxiv-retrieval_cosine_mrr@10: 0.6277
arxiv-retrieval_cosine_map@100: 0.2006





In [9]:
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import ContrastiveLoss, OnlineContrastiveLoss

# Define training arguments
training_args = SentenceTransformerTrainingArguments(
    output_dir="../models/contrastive-minilm",
    num_train_epochs=1,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    learning_rate=3e-4,
    warmup_ratio=0.1,
    eval_strategy="steps",
    eval_steps=500,
    eval_on_start=True,
    save_strategy="steps",
    save_steps=500,
    logging_steps=100,
    torch_compile=torch.cuda.is_available(),
    fp16=torch.cuda.is_available(),
    tf32=torch.cuda.is_available())

# Initialize loss function
loss = OnlineContrastiveLoss(model)

# Create trainer with IR evaluator for precision@k
trainer = SentenceTransformerTrainer(
    model=model,
    args=training_args,
    train_dataset=train_pairs,
    eval_dataset=eval_pairs,
    loss=loss,
    evaluator=ir_evaluator,
)

print("Trainer initialized with precision@k evaluator. Ready to train!")

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Trainer initialized with precision@k evaluator. Ready to train!


In [10]:
# Start training
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mthorhojhus[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss,Arxiv-retrieval Cosine Accuracy@1,Arxiv-retrieval Cosine Accuracy@3,Arxiv-retrieval Cosine Accuracy@5,Arxiv-retrieval Cosine Accuracy@10,Arxiv-retrieval Cosine Precision@1,Arxiv-retrieval Cosine Precision@5,Arxiv-retrieval Cosine Precision@10,Arxiv-retrieval Cosine Recall@1,Arxiv-retrieval Cosine Recall@5,Arxiv-retrieval Cosine Recall@10,Arxiv-retrieval Cosine Ndcg@10,Arxiv-retrieval Cosine Mrr@10,Arxiv-retrieval Cosine Map@100
0,No log,27.913651,0.48996,0.742972,0.815261,0.881526,0.48996,0.446386,0.414759,0.011522,0.045248,0.076651,0.43423,0.627709,0.200542
500,3.466500,3.126542,0.424699,0.650602,0.737952,0.835341,0.424699,0.398795,0.379116,0.008699,0.040405,0.073222,0.391163,0.557256,0.196756
1000,3.335500,3.298453,0.383534,0.621486,0.710843,0.810241,0.383534,0.360843,0.34759,0.008205,0.036649,0.065868,0.358766,0.522611,0.171785
1500,3.207000,3.070521,0.423695,0.646586,0.721888,0.821285,0.423695,0.378916,0.364357,0.008599,0.036175,0.069692,0.377143,0.552732,0.200713
2000,2.970100,2.994114,0.406627,0.636546,0.723896,0.820281,0.406627,0.384337,0.373193,0.008295,0.037724,0.073362,0.383467,0.543791,0.199251
2500,2.883300,3.033103,0.410643,0.63253,0.714859,0.825301,0.410643,0.38012,0.366265,0.008147,0.039037,0.071679,0.377875,0.545239,0.197523
3000,2.858300,2.785812,0.430723,0.655622,0.74498,0.821285,0.430723,0.395582,0.376205,0.009346,0.041029,0.074624,0.388525,0.55685,0.215924


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.26s/it]
W0109 19:12:51.865000 113476 torch/fx/experimental/symbolic_shapes.py:6823] [0/2] _maybe_guard_rel() was called on non-relation expression Eq(s47, s8) | Eq(s8, 1)


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.11s/it]


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.10s/it]


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.13s/it]


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.12s/it]


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.12s/it]


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.13s/it]


KeyboardInterrupt: 

Error in callback <bound method _WandbInit._post_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x7bd354416060>> (for post_run_cell), with arguments args (<ExecutionResult object at 7bd3c1f4fe60, execution_count=10 error_before_exec=None error_in_exec= info=<ExecutionInfo object at 7bd3c1f4fdd0, raw_cell="# Start training
trainer.train()" transformed_cell="# Start training
trainer.train()
" store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://ssh-remote%2Bdesktop/home/thorh/MLOpsProject/notebooks/model_train_cl.ipynb#X12sdnNjb2RlLXJlbW90ZQ%3D%3D> result=None>,),kwargs {}:


ConnectionResetError: Connection lost

In [None]:
# Evaluate after fine-tuning
print("=== Precision@k (after fine-tuning) ===")
final_results = ir_evaluator(model)
for key, value in final_results.items():
    if "precision" in key or "recall" in key or "mrr" in key or "ndcg" in key:
        print(f"{key}: {value:.4f}")

# Compare improvement
print("\n=== Improvement ===")
for key in baseline_results:
    if "precision" in key or "ndcg" in key:
        improvement = final_results[key] - baseline_results[key]
        print(f"{key}: {improvement:+.4f}")

=== Precision@k (after fine-tuning) ===


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:02<00:00,  2.09s/it]

arxiv-retrieval_cosine_precision@1: 0.4538
arxiv-retrieval_cosine_precision@5: 0.4102
arxiv-retrieval_cosine_precision@10: 0.3912
arxiv-retrieval_cosine_recall@1: 0.0096
arxiv-retrieval_cosine_recall@5: 0.0404
arxiv-retrieval_cosine_recall@10: 0.0771
arxiv-retrieval_cosine_ndcg@10: 0.4065
arxiv-retrieval_cosine_mrr@10: 0.5780

=== Improvement ===
arxiv-retrieval_cosine_precision@1: -0.0361
arxiv-retrieval_cosine_precision@5: -0.0363
arxiv-retrieval_cosine_precision@10: -0.0237
arxiv-retrieval_cosine_ndcg@10: -0.0278



