In [1]:
# =======================================
# STAGE 4 — Triplet Training with Hard Negatives
# =======================================

import torch
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, InputExample, losses
import pandas as pd
from tqdm import tqdm

torch.cuda.empty_cache()
print("CUDA available:", torch.cuda.is_available())


CUDA available: True


In [2]:
# Path to Stage 1 fine-tuned model
BASE_MODEL_PATH = "models/pg16-minilm-mnrl"

model = SentenceTransformer(BASE_MODEL_PATH)
model = model.to("cuda")

# Optional but helpful for VRAM
model._first_module().auto_model.gradient_checkpointing_enable()

print("Loaded base model for triplet training from:", BASE_MODEL_PATH)


Loaded base model for triplet training from: models/pg16-minilm-mnrl


In [3]:
# Load Hard Negative Triplets
TRIPLET_CSV = "pg16_hard_negatives.csv"

df = pd.read_csv(TRIPLET_CSV)
print("Loaded hard negatives:", len(df))
df.head()


Loaded hard negatives: 100


Unnamed: 0,query,positive_passage,hard_negative
0,What does this passage describe?,Internals 59.3. Foreign Data Wrapper Helper Fu...,Part VI. Reference The entries in this Referen...
1,What does this passage describe?,sort order. Table 9.53. Array Operators Operat...,Part VI. Reference The entries in this Referen...
2,What does this passage describe?,SQL Syntax more expressions (separated by comm...,Part VI. Reference The entries in this Referen...
3,What does this passage describe?,SQL Key Words Key Word PostgreSQL SQL:2023 SQL...,Part VI. Reference The entries in this Referen...
4,What does this passage describe?,Monitoring Database Activity Whenever VACUUM i...,Part VI. Reference The entries in this Referen...


In [4]:
# Build InputExample triplets: (anchor, positive, negative)

train_examples = []

for _, row in df.iterrows():
    anchor   = f"query: {row['query']}"
    positive = f"passage: {row['positive_passage']}"
    negative = f"passage: {row['hard_negative']}"

    ex = InputExample(texts=[anchor, positive, negative])
    train_examples.append(ex)

print("Total triplet examples:", len(train_examples))


Total triplet examples: 100


In [5]:
#DataLoader + Triplet Loss
BATCH_SIZE = 8  # adjust if your GPU is comfy

train_dataloader = DataLoader(
    train_examples,
    shuffle=True,
    batch_size=BATCH_SIZE
)

# TripletLoss uses (anchor, positive, negative) embeddings
train_loss = losses.TripletLoss(
    model=model,
    distance_metric=losses.TripletDistanceMetric.COSINE,
    triplet_margin=0.3
)

print("Dataloader ready with batch size:", BATCH_SIZE)


Dataloader ready with batch size: 8


In [6]:
#Free VRAM Before Training
torch.cuda.empty_cache()
print("GPU cache cleared. Ready to train.")


GPU cache cleared. Ready to train.


In [7]:
#Train with Triplet Loss (Stage 4)
OUTPUT_PATH = "models/pg16-minilm-triplet"

num_epochs = 1  # you can raise to 2–3 if time/VRAM allows

model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=num_epochs,
    warmup_steps=int(0.1 * len(train_dataloader)),
    show_progress_bar=True,
    use_amp=True,             # mixed precision, saves VRAM
    output_path=OUTPUT_PATH
)

print("✅ Stage 4 complete! Final model saved to:", OUTPUT_PATH)


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss


✅ Stage 4 complete! Final model saved to: models/pg16-minilm-triplet


In [8]:
#Quick Sanity Check: Similarities
from sentence_transformers import util

# Reload the final model
final_model = SentenceTransformer(OUTPUT_PATH).to("cuda")

sample = df.sample(3, random_state=42)

for idx, row in sample.iterrows():
    anchor   = f"query: {row['query']}"
    positive = f"passage: {row['positive_passage']}"
    negative = f"passage: {row['hard_negative']}"

    emb_a = final_model.encode(anchor, convert_to_tensor=True)
    emb_p = final_model.encode(positive, convert_to_tensor=True)
    emb_n = final_model.encode(negative, convert_to_tensor=True)

    sim_pos = util.cos_sim(emb_a, emb_p).item()
    sim_neg = util.cos_sim(emb_a, emb_n).item()

    print("\n==============================")
    print("Query:", row['query'])
    print(f"cos(query, positive) = {sim_pos:.3f}")
    print(f"cos(query, negative) = {sim_neg:.3f}")



Query: What does this passage describe?
cos(query, positive) = 0.104
cos(query, negative) = -0.333

Query: What does this passage describe?
cos(query, positive) = 0.297
cos(query, negative) = -0.333

Query: What does this passage describe?
cos(query, positive) = 0.180
cos(query, negative) = -0.333
