In [None]:
from datasets import load_dataset

dataset = load_dataset("pubmed_qa", "pqa_artificial")["train"].select(range(2000))

# print(dataset)

# Collect question → list of valid contexts
multi_positive_pairs = []

all_passages = set()  # For global pool of negatives

for item in dataset:
    question = item["question"]
    contexts = item["context"]["contexts"]
    if not contexts:
        continue

    for passage in contexts:
        multi_positive_pairs.append((question, passage))
        all_passages.add(passage)


In [6]:
import random

triplets = []
all_passages = list(all_passages)

for question, pos in multi_positive_pairs:
    neg = random.choice([c for c in all_passages if c != pos])
    triplets.append((question, pos, neg))


In [None]:
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader

model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

train_samples = []
for q, pos, neg in triplets:
    train_samples.append(InputExample(texts=[q, pos], label=1.0))
    train_samples.append(InputExample(texts=[q, neg], label=0.0))

train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=16)
train_loss = losses.CosineSimilarityLoss(model=model)

model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=25, warmup_steps=100)


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss


In [56]:
import faiss
import numpy as np

unique_contexts = list(set(p for _, p, _ in triplets))
context_embeddings = model.encode(unique_contexts, convert_to_numpy=True, show_progress_bar=True)

index = faiss.IndexFlatIP(context_embeddings.shape[1])
index.add(context_embeddings)


Batches:   0%|          | 0/97 [00:00<?, ?it/s]

In [57]:
query = "What is the role of mitochondria in plant programmed cell death?"
query_embedding = model.encode([query])

D, I = index.search(query_embedding, k=5)
retrieved = [unique_contexts[i] for i in I[0]]

for i, ctx in enumerate(retrieved):
    print(f"Context {i+1}:\n{ctx[:300]}...\n")


Context 1:
Recently, pluripotency of induced pluripotent stem (iPS) cells has been displayed after producing adult mice, in tetraploid complementation assays. These studies lead us to the last piece of the puzzle for reprogramming somatic cells into fully pluripotent cells which function as embryonic stem cell...

Context 2:
We succeeded to the birth of viable and fertile adult mice derived entirely from reprogrammed ASC, indicating cell types other than fibroblasts can also be restored to the embryonic level of pluripotency....

Context 3:
A library of potential oxymatrine binding peptides was generated. Ubiquinol-cytochrome c reductase binding protein (UQCRB) was one of the candidate binding proteins of oxymatrine. UQCRB-displaying T7 phage binding numbers in the oxymatrine group were significantly higher than that in the control gro...

Context 4:
A T7 phage cDNA library of human CHB was biopanned by affinity selection using oxymatrine as bait. The interaction of oxymatrine with its

In [58]:
from collections import defaultdict
import numpy as np

# Map each question to its list of valid ground-truth contexts
gold_contexts_by_question = defaultdict(list)
for item in dataset:
    q = item["question"]
    gold_contexts_by_question[q] = item["context"]["contexts"]

top_k = 5
correct = 0
total = 0

for question, gold_contexts in gold_contexts_by_question.items():
    if not gold_contexts:
        continue

    query_embedding = model.encode([question], convert_to_numpy=True)
    D, I = index.search(query_embedding, top_k)

    retrieved = [unique_contexts[i] for i in I[0]]

    # Evaluate Recall@k
    match_count = sum(any(retr.strip() == gold.strip() for gold in gold_contexts) for retr in retrieved)

    if match_count > 0:
        correct += 1
    total += 1

recall_at_k = correct / total
print(f"Recall@{top_k}: {recall_at_k:.3f}")


Recall@5: 0.999


In [30]:
from sentence_transformers import SentenceTransformer

# Load baseline retriever model (untuned)
baseline_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")


In [31]:
import faiss
import numpy as np

# Get unique corpus of passages from the dataset
all_passages = list({p for item in dataset for p in item["context"]["contexts"]})
baseline_passage_embeddings = baseline_model.encode(all_passages, convert_to_numpy=True, show_progress_bar=True)

# Build FAISS index
index = faiss.IndexFlatIP(baseline_passage_embeddings.shape[1])
index.add(baseline_passage_embeddings)


Batches:   0%|          | 0/105 [00:00<?, ?it/s]

In [32]:
from collections import defaultdict

# Map each question to its list of gold contexts
gold_contexts_by_question = defaultdict(list)
for item in dataset:
    q = item["question"]
    gold_contexts_by_question[q] = item["context"]["contexts"]

top_k = 5
correct = 0
total = 0

for question, gold_contexts in gold_contexts_by_question.items():
    if not gold_contexts:
        continue

    query_embedding = baseline_model.encode([question], convert_to_numpy=True)
    D, I = index.search(query_embedding, top_k)

    retrieved = [all_passages[i] for i in I[0]]

    match_count = sum(any(retr.strip() == gold.strip() for gold in gold_contexts) for retr in retrieved)

    if match_count > 0:
        correct += 1
    total += 1

recall_at_k = correct / total
print(f"[Baseline Retriever] Recall@{top_k}: {recall_at_k:.3f}")


[Baseline Retriever] Recall@5: 0.991
