In [1]:
# %%
"""
Contrastive fine-tuning notebook for multi-qa-mpnet on Mahabharata dataset.
Inputs required:
 - train_pairs.jsonl : lines of {"anchor","positive","negatives","meta"}
 - passages.jsonl (optional) : passages with "text" fields (for indexing/eval)
Outputs:
 - saved model under out_dir (default ./mpnet_ft)
 - passage embeddings: out_dir/passage_embeddings.npy
 - faiss index: out_dir/faiss_index.idx
"""
# %% imports
import os
import json
import math
from pathlib import Path
from tqdm import tqdm
import random
import numpy as np

# huggingface / sentence-transformers
from sentence_transformers import SentenceTransformer, InputExample, losses, util, evaluation
from torch.utils.data import DataLoader

# optional faiss (install faiss-cpu or faiss-gpu)
try:
    import faiss
except Exception:
    faiss = None
    print("faiss not available. Hard-negative mining will be skipped unless you install faiss-cpu or faiss-gpu.")

# %% user params - adjust as needed
train_pairs_path = "./dataset_out_simple/train_pairs_clean.jsonl"  # path to your train pairs
passages_path = "./dataset_out_simple/passages_clean.jsonl"        # optional, used for indexing
out_dir = "./mpnet_ft"
base_model = "sentence-transformers/multi-qa-mpnet-base-dot-v1"

# training hyperparams
batch_size = 64           # try 64, increase if GPU allows
num_epochs = 2               # 1-3 is usually fine for contrastive
learning_rate = 2e-5
max_pairs_to_use = None      # set int to subsample, or None to use all
warmup_ratio = 0.1           # proportion of steps for warmup

# hard-negative mining params
do_hard_negative_mining = True if faiss is not None else False
mine_top_k = 64
hard_negatives_per_anchor = 2
hard_retrain_epochs = 1

# device/mixed precision handled by sentence-transformers automatically if installed with pytorch+cuda

# %% helper functions
def read_jsonl(path):
    items = []
    with open(path, 'r', encoding='utf8') as fh:
        for line in fh:
            line = line.strip()
            if not line:
                continue
            items.append(json.loads(line))
    return items

def save_model_and_embeddings(model, passages_texts, out_dir):
    os.makedirs(out_dir, exist_ok=True)
    model.save(out_dir)
    print("Saved model to", out_dir)
    if passages_texts is not None and len(passages_texts) > 0:
        emb = model.encode(passages_texts, batch_size=256, convert_to_numpy=True, show_progress_bar=True)
        np.save(os.path.join(out_dir, "passage_embeddings.npy"), emb)
        print("Saved passage embeddings (shape {})".format(emb.shape))
        if faiss is not None:
            # normalize for inner product search
            faiss.normalize_L2(emb)
            d = emb.shape[1]
            index = faiss.IndexFlatIP(d)
            index.add(emb)
            faiss.write_index(index, os.path.join(out_dir, "faiss_index.idx"))
            print("Built & saved FAISS index at", os.path.join(out_dir, "faiss_index.idx"))
    else:
        print("No passages provided; skipping embedding/index save.")

# %% Load training pairs
print("Loading train pairs from:", train_pairs_path)
pairs = read_jsonl(train_pairs_path)
print("Total raw pairs:", len(pairs))
if max_pairs_to_use is not None and max_pairs_to_use < len(pairs):
    pairs = random.sample(pairs, max_pairs_to_use)
    print("Subsampled pairs to:", len(pairs))

# Create InputExamples (anchor, positive) for MultipleNegativesRankingLoss
examples = []
for p in pairs:
    anchor = p.get("anchor")
    positive = p.get("positive")
    if not anchor or not positive:
        continue
    examples.append(InputExample(texts=[anchor, positive]))

print("Prepared {} InputExample pairs (anchor, positive)".format(len(examples)))

# %% split small dev set for monitoring (sample some pairs)
random.shuffle(examples)
dev_size = min(512, max(32, int(0.02 * len(examples))))  # 2% or at most 512
dev_examples = examples[:dev_size]
train_examples = examples[dev_size:]
print("Train examples:", len(train_examples), "Dev examples:", len(dev_examples))

# %% prepare dataloaders & model
model = SentenceTransformer(base_model)
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size)
train_loss = losses.MultipleNegativesRankingLoss(model=model)

# evaluator: use EmbeddingSimilarityEvaluator-like approach with dev set
# we'll build index of positives and check recall by searching anchor embedding & checking its positive
# For convenience, create dev mapping: dev_anchor_text -> dev_positive_text
dev_anchor_texts = [ex.texts[0] for ex in dev_examples]
dev_positive_texts = [ex.texts[1] for ex in dev_examples]

# %%

# compute training steps & warmup
num_train_steps = int(len(train_dataloader) * num_epochs)
warmup_steps = max(1, int(warmup_ratio * num_train_steps))
print(f"Training {num_epochs} epochs, {len(train_dataloader)} steps/epoch, total steps {num_train_steps}, warmup {warmup_steps}")

# Train
os.makedirs(out_dir, exist_ok=True)
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=num_epochs,
    warmup_steps=warmup_steps,
    optimizer_params={"lr": learning_rate},
    output_path=os.path.join(out_dir, "mpnet_stage1"),
    show_progress_bar=True
)

print("Stage 1 training complete. Model saved at:", os.path.join(out_dir, "mpnet_stage1"))

# %% Evaluation: simple Recall@k on dev set using model + FAISS
def recall_at_k_for_dev(model, anchor_texts, positive_texts, k=5):
    # Encode dev positive texts as "passage corpus"
    corpus_emb = model.encode(positive_texts, convert_to_numpy=True, batch_size=128, show_progress_bar=False)
    query_emb = model.encode(anchor_texts, convert_to_numpy=True, batch_size=128, show_progress_bar=False)
    # normalize for IP
    faiss.normalize_L2(corpus_emb)
    faiss.normalize_L2(query_emb)
    d = corpus_emb.shape[1]
    index = faiss.IndexFlatIP(d)
    index.add(corpus_emb)
    D, I = index.search(query_emb, k)
    hits = 0
    for qi, inds in enumerate(I):
        # check if the gold positive (index qi) is among retrieved (note we built corpus of positives in same order)
        # this simplistic check requires that the corpus ordering matches dev positives; since we encoded positives in that order, gold index is qi
        if qi in inds:
            hits += 1
    return hits / len(anchor_texts)

if faiss is not None and len(dev_anchor_texts) > 0:
    r1 = recall_at_k_for_dev(model, dev_anchor_texts, dev_positive_texts, k=1)
    r5 = recall_at_k_for_dev(model, dev_anchor_texts, dev_positive_texts, k=5)
    print(f"Dev Recall@1 = {r1:.4f}, Recall@5 = {r5:.4f}")
else:
    print("Skipping dev Recall@k (faiss not available or no dev examples).")

# %% Save initial model & (optional) encode and save passages embeddings if provided
passage_texts = []
if os.path.exists(passages_path):
    passages = read_jsonl(passages_path)
    passage_texts = [p.get("text") for p in passages if p.get("text")]
else:
    print("No passages file found at", passages_path)
save_model_and_embeddings(model, passage_texts, out_dir=os.path.join(out_dir, "mpnet_stage1"))

# %% Optional: Hard-negative mining & retrain with BatchHardTripletLoss
if do_hard_negative_mining and faiss is not None:
    print("Starting hard negative mining using FAISS (top_k =", mine_top_k, ") ...")
    # Build index over passages_texts (if not available, we mine from positives in train pairs)
    if len(passage_texts) == 0:
        # fallback: build pool from positive texts in pairs
        pool_texts = [p['positive'] for p in pairs if p.get('positive')]
    else:
        pool_texts = passage_texts

    pool_emb = model.encode(pool_texts, convert_to_numpy=True, batch_size=256, show_progress_bar=True)
    faiss.normalize_L2(pool_emb)
    d = pool_emb.shape[1]
    index = faiss.IndexFlatIP(d)
    index.add(pool_emb)

    # For each anchor, retrieve top_k and pick negatives (skip identical/overlapping by exact text match)
    hard_triplets = []
    anchors = [p['anchor'] for p in pairs if p.get('anchor') and p.get('positive')]
    positives = [p['positive'] for p in pairs if p.get('anchor') and p.get('positive')]

    # For efficiency, sample a subset for mining if dataset large
    max_mine = min(len(anchors), 20000)  # cap mining to 20k anchors for speed; adjust if you like
    sample_idx = random.sample(range(len(anchors)), max_mine)

    for idx in tqdm(sample_idx):
        a_text = anchors[idx]
        pos_text = positives[idx]
        q_emb = model.encode(a_text, convert_to_numpy=True)
        faiss.normalize_L2(q_emb)
        D, I = index.search(q_emb.reshape(1,-1), mine_top_k)
        # iterate candidates and pick hard negatives that are not the positive and not equal to anchor
        picked = 0
        for cand_i in I[0]:
            cand_text = pool_texts[cand_i]
            if cand_text == pos_text or cand_text == a_text:
                continue
            # optionally ensure negative from other chapter by checking meta — skipped here for simplicity
            hard_triplets.append((a_text, pos_text, cand_text))
            picked += 1
            if picked >= hard_negatives_per_anchor:
                break

    print("Mined hard triplets:", len(hard_triplets))

    # Convert to InputExamples for BatchHardTripletLoss (texts=[anchor, pos, neg])
    triplet_examples = [InputExample(texts=[a,p,n]) for (a,p,n) in hard_triplets]

    # dataloader and loss
    triplet_dataloader = DataLoader(triplet_examples, shuffle=True, batch_size=max(16, batch_size//4))
    triplet_loss = losses.BatchHardTripletLoss(model=model, distance_metric=losses.BatchHardTripletLoss.distance_cosine, margin=0.2)

    # retrain
    num_steps = int(len(triplet_dataloader) * hard_retrain_epochs)
    warmup_steps = max(1, int(0.1 * num_steps))
    model.fit(
        train_objectives=[(triplet_dataloader, triplet_loss)],
        epochs=hard_retrain_epochs,
        warmup_steps=warmup_steps,
        optimizer_params={"lr": learning_rate},
        output_path=os.path.join(out_dir, "mpnet_hardneg"),
        show_progress_bar=True,
        num_items_in_batch=None
    )
    print("Hard-negative retrain complete. Saved at", os.path.join(out_dir, "mpnet_hardneg"))
    # save embeddings again
    save_model_and_embeddings(model, passage_texts, out_dir=os.path.join(out_dir, "mpnet_hardneg"))

else:
    print("Skipping hard-negative mining (faiss not available or disabled).")

# %% Done
print("All done. Models stored under:", out_dir)


  from tqdm.autonotebook import tqdm, trange


faiss not available. Hard-negative mining will be skipped unless you install faiss-cpu or faiss-gpu.
Loading train pairs from: ./dataset_out_simple/train_pairs_clean.jsonl
Total raw pairs: 29489
Prepared 29489 InputExample pairs (anchor, positive)
Train examples: 28977 Dev examples: 512
Training 2 epochs, 453 steps/epoch, total steps 906, warmup 90


Currently using DataParallel (DP) for multi-gpu training, while DistributedDataParallel (DDP) is recommended for faster training. See https://sbert.net/docs/sentence_transformer/training/distributed.html for more information.
  super().__init__(


TypeError: compute_loss() got an unexpected keyword argument 'num_items_in_batch'

In [None]:
# %%
"""
Contrastive fine-tuning notebook for multi-qa-mpnet on Mahabharata dataset.
Inputs required:
 - train_pairs.jsonl : lines of {"anchor","positive","negatives","meta"}
 - passages.jsonl (optional) : passages with "text" fields (for indexing/eval)
Outputs:
 - saved model under out_dir (default ./mpnet_ft)
 - passage embeddings: out_dir/passage_embeddings.npy
 - faiss index: out_dir/faiss_index.idx
"""
# %% imports
import os
import json
import math
from pathlib import Path
from tqdm import tqdm
import random
import numpy as np

# huggingface / sentence-transformers
from sentence_transformers import SentenceTransformer, InputExample, losses, util, evaluation
from torch.utils.data import DataLoader

# optional faiss (install faiss-cpu or faiss-gpu)
try:
    import faiss
except Exception:
    faiss = None
    print("faiss not available. Hard-negative mining will be skipped unless you install faiss-cpu or faiss-gpu.")

# %% user params - adjust as needed
train_pairs_path = "./dataset_out_simple/train_pairs_clean.jsonl"  # path to your train pairs
passages_path = "./dataset_out_simple/passages_clean.jsonl"        # optional, used for indexing
out_dir = "./mpnet_ft"
base_model = "sentence-transformers/multi-qa-mpnet-base-dot-v1"

# training hyperparams
batch_size = 64           # try 64, increase if GPU allows
num_epochs = 2               # 1-3 is usually fine for contrastive
learning_rate = 2e-5
max_pairs_to_use = None      # set int to subsample, or None to use all
warmup_ratio = 0.1           # proportion of steps for warmup

# hard-negative mining params
do_hard_negative_mining = True if faiss is not None else False
mine_top_k = 64
hard_negatives_per_anchor = 2
hard_retrain_epochs = 1

# device/mixed precision handled by sentence-transformers automatically if installed with pytorch+cuda

# %% helper functions
def read_jsonl(path):
    items = []
    with open(path, 'r', encoding='utf8') as fh:
        for line in fh:
            line = line.strip()
            if not line:
                continue
            items.append(json.loads(line))
    return items

def save_model_and_embeddings(model, passages_texts, out_dir):
    os.makedirs(out_dir, exist_ok=True)
    model.save(out_dir)
    print("Saved model to", out_dir)
    if passages_texts is not None and len(passages_texts) > 0:
        emb = model.encode(passages_texts, batch_size=256, convert_to_numpy=True, show_progress_bar=True)
        np.save(os.path.join(out_dir, "passage_embeddings.npy"), emb)
        print("Saved passage embeddings (shape {})".format(emb.shape))
        if faiss is not None:
            # normalize for inner product search
            faiss.normalize_L2(emb)
            d = emb.shape[1]
            index = faiss.IndexFlatIP(d)
            index.add(emb)
            faiss.write_index(index, os.path.join(out_dir, "faiss_index.idx"))
            print("Built & saved FAISS index at", os.path.join(out_dir, "faiss_index.idx"))
    else:
        print("No passages provided; skipping embedding/index save.")

# %% Load training pairs
print("Loading train pairs from:", train_pairs_path)
pairs = read_jsonl(train_pairs_path)
print("Total raw pairs:", len(pairs))
if max_pairs_to_use is not None and max_pairs_to_use < len(pairs):
    pairs = random.sample(pairs, max_pairs_to_use)
    print("Subsampled pairs to:", len(pairs))

# Create InputExamples (anchor, positive) for MultipleNegativesRankingLoss
examples = []
for p in pairs:
    anchor = p.get("anchor")
    positive = p.get("positive")
    if not anchor or not positive:
        continue
    examples.append(InputExample(texts=[anchor, positive]))

print("Prepared {} InputExample pairs (anchor, positive)".format(len(examples)))

# %% split small dev set for monitoring (sample some pairs)
random.shuffle(examples)
dev_size = min(512, max(32, int(0.02 * len(examples))))  # 2% or at most 512
dev_examples = examples[:dev_size]
train_examples = examples[dev_size:]
print("Train examples:", len(train_examples), "Dev examples:", len(dev_examples))

# %% prepare dataloaders & model
model = SentenceTransformer(base_model)
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size)
train_loss = losses.MultipleNegativesRankingLoss(model=model)

# evaluator: use EmbeddingSimilarityEvaluator-like approach with dev set
# we'll build index of positives and check recall by searching anchor embedding & checking its positive
# For convenience, create dev mapping: dev_anchor_text -> dev_positive_text
dev_anchor_texts = [ex.texts[0] for ex in dev_examples]
dev_positive_texts = [ex.texts[1] for ex in dev_examples]

# %%

# compute training steps & warmup
num_train_steps = int(len(train_dataloader) * num_epochs)
warmup_steps = max(1, int(warmup_ratio * num_train_steps))
print(f"Training {num_epochs} epochs, {len(train_dataloader)} steps/epoch, total steps {num_train_steps}, warmup {warmup_steps}")

# Train
os.makedirs(out_dir, exist_ok=True)
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=num_epochs,
    warmup_steps=warmup_steps,
    optimizer_params={"lr": learning_rate},
    output_path=os.path.join(out_dir, "mpnet_stage1"),
    show_progress_bar=True
)

print("Stage 1 training complete. Model saved at:", os.path.join(out_dir, "mpnet_stage1"))

# %% Evaluation: simple Recall@k on dev set using model + FAISS
def recall_at_k_for_dev(model, anchor_texts, positive_texts, k=5):
    # Encode dev positive texts as "passage corpus"
    corpus_emb = model.encode(positive_texts, convert_to_numpy=True, batch_size=128, show_progress_bar=False)
    query_emb = model.encode(anchor_texts, convert_to_numpy=True, batch_size=128, show_progress_bar=False)
    # normalize for IP
    faiss.normalize_L2(corpus_emb)
    faiss.normalize_L2(query_emb)
    d = corpus_emb.shape[1]
    index = faiss.IndexFlatIP(d)
    index.add(corpus_emb)
    D, I = index.search(query_emb, k)
    hits = 0
    for qi, inds in enumerate(I):
        # check if the gold positive (index qi) is among retrieved (note we built corpus of positives in same order)
        # this simplistic check requires that the corpus ordering matches dev positives; since we encoded positives in that order, gold index is qi
        if qi in inds:
            hits += 1
    return hits / len(anchor_texts)

if faiss is not None and len(dev_anchor_texts) > 0:
    r1 = recall_at_k_for_dev(model, dev_anchor_texts, dev_positive_texts, k=1)
    r5 = recall_at_k_for_dev(model, dev_anchor_texts, dev_positive_texts, k=5)
    print(f"Dev Recall@1 = {r1:.4f}, Recall@5 = {r5:.4f}")
else:
    print("Skipping dev Recall@k (faiss not available or no dev examples).")

# %% Save initial model & (optional) encode and save passages embeddings if provided
passage_texts = []
if os.path.exists(passages_path):
    passages = read_jsonl(passages_path)
    passage_texts = [p.get("text") for p in passages if p.get("text")]
else:
    print("No passages file found at", passages_path)
save_model_and_embeddings(model, passage_texts, out_dir=os.path.join(out_dir, "mpnet_stage1"))

# %% Optional: Hard-negative mining & retrain with BatchHardTripletLoss
if do_hard_negative_mining and faiss is not None:
    print("Starting hard negative mining using FAISS (top_k =", mine_top_k, ") ...")
    # Build index over passages_texts (if not available, we mine from positives in train pairs)
    if len(passage_texts) == 0:
        # fallback: build pool from positive texts in pairs
        pool_texts = [p['positive'] for p in pairs if p.get('positive')]
    else:
        pool_texts = passage_texts

    pool_emb = model.encode(pool_texts, convert_to_numpy=True, batch_size=256, show_progress_bar=True)
    faiss.normalize_L2(pool_emb)
    d = pool_emb.shape[1]
    index = faiss.IndexFlatIP(d)
    index.add(pool_emb)

    # For each anchor, retrieve top_k and pick negatives (skip identical/overlapping by exact text match)
    hard_triplets = []
    anchors = [p['anchor'] for p in pairs if p.get('anchor') and p.get('positive')]
    positives = [p['positive'] for p in pairs if p.get('anchor') and p.get('positive')]

    # For efficiency, sample a subset for mining if dataset large
    max_mine = min(len(anchors), 20000)  # cap mining to 20k anchors for speed; adjust if you like
    sample_idx = random.sample(range(len(anchors)), max_mine)

    for idx in tqdm(sample_idx):
        a_text = anchors[idx]
        pos_text = positives[idx]
        q_emb = model.encode(a_text, convert_to_numpy=True)
        faiss.normalize_L2(q_emb)
        D, I = index.search(q_emb.reshape(1,-1), mine_top_k)
        # iterate candidates and pick hard negatives that are not the positive and not equal to anchor
        picked = 0
        for cand_i in I[0]:
            cand_text = pool_texts[cand_i]
            if cand_text == pos_text or cand_text == a_text:
                continue
            # optionally ensure negative from other chapter by checking meta — skipped here for simplicity
            hard_triplets.append((a_text, pos_text, cand_text))
            picked += 1
            if picked >= hard_negatives_per_anchor:
                break

    print("Mined hard triplets:", len(hard_triplets))

    # Convert to InputExamples for BatchHardTripletLoss (texts=[anchor, pos, neg])
    triplet_examples = [InputExample(texts=[a,p,n]) for (a,p,n) in hard_triplets]

    # dataloader and loss
    triplet_dataloader = DataLoader(triplet_examples, shuffle=True, batch_size=max(16, batch_size//4))
    triplet_loss = losses.BatchHardTripletLoss(model=model, distance_metric=losses.BatchHardTripletLoss.distance_cosine, margin=0.2)

    # retrain
    num_steps = int(len(triplet_dataloader) * hard_retrain_epochs)
    warmup_steps = max(1, int(0.1 * num_steps))
    model.fit(
        train_objectives=[(triplet_dataloader, triplet_loss)],
        epochs=hard_retrain_epochs,
        warmup_steps=warmup_steps,
        optimizer_params={"lr": learning_rate},
        output_path=os.path.join(out_dir, "mpnet_hardneg"),
        show_progress_bar=True,
        num_items_in_batch=None
    )
    print("Hard-negative retrain complete. Saved at", os.path.join(out_dir, "mpnet_hardneg"))
    # save embeddings again
    save_model_and_embeddings(model, passage_texts, out_dir=os.path.join(out_dir, "mpnet_hardneg"))

else:
    print("Skipping hard-negative mining (faiss not available or disabled).")

# %% Done
print("All done. Models stored under:", out_dir)


In [2]:
# Manual contrastive training loop (safe & robust) - paste and run
import os, random, gc, torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel
from tqdm.auto import tqdm

# CONFIG
base_model = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 16          # lower if OOM (try 2 or 1)
accumulation_steps = 1
num_epochs = 2
learning_rate = 2e-5
max_length = 256          # reduce if inputs are long
temperature = .05
out_dir = "./mpnet_manual_ft"
use_amp = True if device.type == "cuda" else False

print("Device:", device, "batch_size:", batch_size, "use_amp:", use_amp)

# Prepare tokenizer + model
tokenizer = AutoTokenizer.from_pretrained(base_model)
encoder = AutoModel.from_pretrained(base_model).to(device)
encoder.train()

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output.last_hidden_state
    mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    summed = torch.sum(token_embeddings * mask, dim=1)
    counts = torch.clamp(mask.sum(dim=1), min=1e-9)
    return summed / counts

class PairDataset(Dataset):
    def __init__(self, pairs): self.pairs = pairs
    def __len__(self): return len(self.pairs)
    def __getitem__(self, idx):
        p = self.pairs[idx]; return p['anchor'], p['positive']

# shuffle and create dataloader
random.shuffle(pairs)
dataset = PairDataset(pairs)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=0)

optimizer = optim.AdamW(encoder.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler() if (use_amp and device.type == "cuda") else None

os.makedirs(out_dir, exist_ok=True)
torch.cuda.empty_cache()
gc.collect()

for epoch in range(num_epochs):
    encoder.train()
    total_loss = 0.0
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
    optimizer.zero_grad()
    for step, (anchors, positives) in enumerate(pbar):
        encoded_a = tokenizer(list(anchors), padding=True, truncation=True, return_tensors="pt", max_length=max_length).to(device)
        encoded_p = tokenizer(list(positives), padding=True, truncation=True, return_tensors="pt", max_length=max_length).to(device)

        if scaler is not None:
            with torch.cuda.amp.autocast():
                out_a = encoder(**encoded_a)
                out_p = encoder(**encoded_p)
                emb_a = mean_pooling(out_a, encoded_a['attention_mask'])
                emb_p = mean_pooling(out_p, encoded_p['attention_mask'])
                emb_a = nn.functional.normalize(emb_a, p=2, dim=1)
                emb_p = nn.functional.normalize(emb_p, p=2, dim=1)
                logits = torch.matmul(emb_a, emb_p.t()) / temperature
                labels = torch.arange(logits.size(0), device=logits.device)
                loss = criterion(logits, labels) / accumulation_steps
            scaler.scale(loss).backward()
        else:
            out_a = encoder(**encoded_a)
            out_p = encoder(**encoded_p)
            emb_a = mean_pooling(out_a, encoded_a['attention_mask'])
            emb_p = mean_pooling(out_p, encoded_p['attention_mask'])
            emb_a = nn.functional.normalize(emb_a, p=2, dim=1)
            emb_p = nn.functional.normalize(emb_p, p=2, dim=1)
            logits = torch.matmul(emb_a, emb_p.t()) / temperature
            labels = torch.arange(logits.size(0), device=logits.device)
            loss = criterion(logits, labels) / accumulation_steps
            loss.backward()

        if (step + 1) % accumulation_steps == 0:
            if scaler is not None:
                scaler.step(optimizer); scaler.update()
            else:
                optimizer.step()
            optimizer.zero_grad()

        total_loss += (loss.item() * accumulation_steps)
        pbar.set_postfix({'loss': total_loss / (step+1)})

    avg_loss = total_loss / (step+1)
    print(f"Epoch {epoch+1} done — avg loss: {avg_loss:.4f}")

    # Save checkpoint
    ck = os.path.join(out_dir, f"epoch_{epoch+1}")
    os.makedirs(ck, exist_ok=True)
    encoder.save_pretrained(ck); tokenizer.save_pretrained(ck)
    torch.cuda.empty_cache()
    gc.collect()
print("Training finished. Checkpoints in", out_dir)


Device: cuda batch_size: 16 use_amp: True


  scaler = torch.cuda.amp.GradScaler() if (use_amp and device.type == "cuda") else None
  with torch.cuda.amp.autocast():
Epoch 1/2: 100%|██████████| 1843/1843 [27:19<00:00,  1.12it/s, loss=0.464]


Epoch 1 done — avg loss: 0.4642


Epoch 2/2: 100%|██████████| 1843/1843 [27:24<00:00,  1.12it/s, loss=0.138]


Epoch 2 done — avg loss: 0.1384
Training finished. Checkpoints in ./mpnet_manual_ft


In [3]:
# ===== Evaluate dev Recall@K & check grads =====
import numpy as np, torch, gc
from tqdm.auto import tqdm
from torch.nn.functional import normalize

# 1) Quick eval function (uses in-memory dev_pairs)
def encode_texts(texts, tokenizer, model, device='cuda', batch_size=32, max_length=256):
    model.eval()
    embs = []
    with torch.no_grad():
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i+batch_size]
            encoded = tokenizer(batch, padding=True, truncation=True, return_tensors="pt", max_length=max_length).to(device)
            out = model(**encoded)
            emb = out.last_hidden_state.mean(dim=1)
            emb = normalize(emb, p=2, dim=1)
            embs.append(emb.cpu().numpy())
    return np.vstack(embs)

# prepare dev set (if not already)
dev_pairs = dev_pairs if 'dev_pairs' in globals() else pairs[:min(1000, max(200, int(0.05*len(pairs))))]
dev_anchors = [p['anchor'] for p in dev_pairs]
dev_positives = [p['positive'] for p in dev_pairs]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Eval device:", device)

# encode dev anchors & positives (corpus == positives here for simplicity)
q_embs = encode_texts(dev_anchors, tokenizer, encoder, device=device, batch_size=32, max_length=256)
p_embs = encode_texts(dev_positives, tokenizer, encoder, device=device, batch_size=64, max_length=256)

# simple CPU retrieval
sims = q_embs @ p_embs.T   # (Q, N)
topk = np.argsort(-sims, axis=1)[:, :10]

# compute Recall@K and MRR
K_list = [1,5,10]
Q = len(q_embs)
recalls = {k:0 for k in K_list}
mrr = 0.0
for i in range(Q):
    gold = i   # since dev_positives are aligned
    ranks = np.where(topk[i]==gold)[0]
    for k in K_list:
        if gold in topk[i,:k]:
            recalls[k]+=1
    if ranks.size>0:
        mrr += 1.0/(ranks[0]+1)
mrr /= Q
recalls = {k: recalls[k]/Q for k in K_list}
print("Dev Recall@K:", recalls, "MRR:", mrr)

# 2) Check gradient norms (run after a backward step or on last step)
grads = [p.grad.norm().item() for p in encoder.parameters() if p.grad is not None]
if len(grads)>0:
    print("Grad stats: mean {:.3e}, max {:.3e}, min {:.3e}".format(np.mean(grads), np.max(grads), np.min(grads)))
else:
    print("No gradients present right now. (Run one backward or check training loop)")


Eval device: cuda
Dev Recall@K: {1: 0.533, 5: 0.912, 10: 0.973} MRR: 0.696156349206349
No gradients present right now. (Run one backward or check training loop)
