In [1]:
# ================================
# Stage 3 — Hard Negative Mining
# ================================

import chromadb
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch

torch.cuda.empty_cache()

print("Device:", torch.device("cuda" if torch.cuda.is_available() else "cpu"))


Device: cuda


In [2]:
#2. Load Fine-Tuned Model (same one from Stage 1)
FINETUNED_MODEL_PATH = "models/pg16-minilm-mnrl"

model = SentenceTransformer(FINETUNED_MODEL_PATH)
model = model.to("cuda")

print("Loaded fine-tuned model:", FINETUNED_MODEL_PATH)


Loaded fine-tuned model: models/pg16-minilm-mnrl


In [3]:
#3. Connect to Chroma + Load All Documents
CHROMA_DIR = "chroma_pg16_minilm"
COLLECTION_NAME = "pg16_minilm"

client = chromadb.PersistentClient(path=CHROMA_DIR)
collection = client.get_collection(COLLECTION_NAME)

docs = collection.get(include=["documents", "metadatas"])
documents = docs["documents"]
metadatas = docs["metadatas"]
ids = docs["ids"]

print("Chroma loaded:")
print(" • Total passages:", len(documents))


Chroma loaded:
 • Total passages: 6865


In [4]:
#4. Load Stage 1 Training Pairs
TRAIN_CSV = "pg16_train_pairs.csv"
df = pd.read_csv(TRAIN_CSV)

print("Loaded training pairs:", len(df))
df.head()


Loaded training pairs: 100


Unnamed: 0,query,positive_passage
0,What does this passage describe?,Internals 59.3. Foreign Data Wrapper Helper Fu...
1,What does this passage describe?,sort order. Table 9.53. Array Operators Operat...
2,What does this passage describe?,SQL Syntax more expressions (separated by comm...
3,What does this passage describe?,SQL Key Words Key Word PostgreSQL SQL:2023 SQL...
4,What does this passage describe?,Monitoring Database Activity Whenever VACUUM i...


In [5]:
#5. Embed Queries in Batches
def embed_batch(texts, batch_size=64):
    all_embs = []
    for i in tqdm(range(0, len(texts), batch_size)):
        batch = texts[i : i + batch_size]
        emb = model.encode(
            batch,
            convert_to_numpy=True,
            batch_size=batch_size,
            show_progress_bar=False
        )
        all_embs.extend(emb)
    return np.array(all_embs)

queries = [f"query: {q}" for q in df["query"].tolist()]
query_embeddings = embed_batch(queries, batch_size=64)

print("Query embedding shape:", query_embeddings.shape)


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  3.90it/s]

Query embedding shape: (100, 384)





In [6]:
#6. Hard Negative Mining
hard_negatives = []

for i, row in tqdm(df.iterrows(), total=len(df)):
    q_emb = query_embeddings[i].tolist()
    positive = row["positive_passage"]

    # Query Chroma
    result = collection.query(
        query_embeddings=[q_emb],
        n_results=10,
        include=["documents"]
    )

    retrieved_docs = result["documents"][0]

    # Remove the positive passage from candidates
    candidates = [d for d in retrieved_docs if d.strip() != positive.strip()]

    if len(candidates) == 0:
        # In case the positive is ranked multiple times
        print("Warning: no negative candidate found for index:", i)
        continue

    # Hard negative = top incorrect retrieval
    hard_neg = candidates[0]

    hard_negatives.append({
        "query": row["query"],
        "positive_passage": positive,
        "hard_negative": hard_neg
    })


100%|████████████████████████████████████████| 100/100 [00:00<00:00, 465.53it/s]


In [7]:
#7. Save Hard Negatives for Stage 4 Triplet Training
df_hard = pd.DataFrame(hard_negatives)
df_hard.to_csv("pg16_hard_negatives.csv", index=False)

print("Saved:", len(df_hard), "hard negatives → pg16_hard_negatives.csv")
df_hard.head()


Saved: 100 hard negatives → pg16_hard_negatives.csv


Unnamed: 0,query,positive_passage,hard_negative
0,What does this passage describe?,Internals 59.3. Foreign Data Wrapper Helper Fu...,Part VI. Reference The entries in this Referen...
1,What does this passage describe?,sort order. Table 9.53. Array Operators Operat...,Part VI. Reference The entries in this Referen...
2,What does this passage describe?,SQL Syntax more expressions (separated by comm...,Part VI. Reference The entries in this Referen...
3,What does this passage describe?,SQL Key Words Key Word PostgreSQL SQL:2023 SQL...,Part VI. Reference The entries in this Referen...
4,What does this passage describe?,Monitoring Database Activity Whenever VACUUM i...,Part VI. Reference The entries in this Referen...


In [8]:
#8. Quick Sanity Check
for i in range(3):
    print("\n============== SAMPLE", i, "==============")
    print("Query:", df_hard.iloc[i]["query"])
    print("--- POSITIVE ---")
    print(df_hard.iloc[i]["positive_passage"][:300], "...")
    print("--- HARD NEGATIVE ---")
    print(df_hard.iloc[i]["hard_negative"][:300], "...")



Query: What does this passage describe?
--- POSITIVE ---
Internals 59.3. Foreign Data Wrapper Helper Functions .......................................................2424 59.4. Foreign Data Wrapper Query Planning .........................................................2425 59.5. Row Locking in Foreign Data Wrappers ....................................... ...
--- HARD NEGATIVE ---
Part VI. Reference The entries in this Reference are meant to provide in reasonable length an authoritative, complete, and formal summary about their respective subjects. More information about the use of PostgreSQL, in narrative, tutorial, or example form, can be found in other parts of this book.  ...

Query: What does this passage describe?
--- POSITIVE ---
sort order. Table 9.53. Array Operators Operator Description Example(s) anyarray @> anyarray → boolean Does the first array contain the second, that is, does each element appearing in the second array equal some element of the first array? (Duplicates 