In [19]:
import json
import random
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
from rank_bm25 import BM25Okapi
import nltk
nltk.download('punkt')
from nltk.tokenize import word_tokenize
from tqdm import tqdm

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\VIDUSHI\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [20]:
with open("data/train.json", "r", encoding="utf-8") as f:
    raw = f.read()

try:
    dataset = json.loads(raw)
except json.JSONDecodeError:
    dataset = []
    for line in raw.splitlines():
        line = line.strip()
        if not line:
            continue
        try:
            dataset.append(json.loads(line))
        except json.JSONDecodeError:
            # Skip malformed lines; adjust if you need strict validation
            continue

# Handle if dataset is a dict with a list inside, or directly a list
if isinstance(dataset, dict):
    for key, value in dataset.items():
        if isinstance(value, list):
            dataset = value
            break

positives = []
for x in dataset:
    positives.append((x.get("question"), x.get("exp")))

In [21]:
len(positives)

182822

In [22]:
corpus = [x["exp"] for x in dataset]

In [23]:
# easy -ves 
def build_easy_negatives(dataset, gold_key_question="question", gold_key_passage="exp", max_attempts=10):
    all_passages = [x.get(gold_key_passage) for x in dataset if x.get(gold_key_passage)]
    easy_negatives = []
    for item in dataset:
        q = item.get(gold_key_question)
        gold = item.get(gold_key_passage)
        if not q or not gold or not all_passages:
            continue
        candidate = gold
        tries = 0
        while candidate == gold and tries < max_attempts:
            candidate = random.choice(all_passages)
            tries += 1
        if candidate != gold:
            easy_negatives.append((q, candidate))
    return easy_negatives

In [24]:
easy_negatives = build_easy_negatives(dataset)

In [25]:
def build_hard_negatives(dataset, gold_key_question="question", gold_key_passage="exp", sample_size=1000):
    all_passages = [x.get(gold_key_passage) for x in dataset if x.get(gold_key_passage)]
    
    # Sample to speed up BM25 indexing
    if len(all_passages) > sample_size:
        sampled_idx = random.sample(range(len(all_passages)), sample_size)
        all_passages = [all_passages[i] for i in sampled_idx]
    
    # Build BM25 index
    corpus_tokens = [p.lower().split() for p in all_passages]
    bm25 = BM25Okapi(corpus_tokens)
    
    hard_negatives = []
    for item in dataset:
        q = item.get(gold_key_question)
        gold = item.get(gold_key_passage)
        if not q or not gold:
            continue
        
        # BM25 rank
        q_tokens = q.lower().split()
        scores = bm25.get_scores(q_tokens)
        ranked = sorted(zip(all_passages, scores), key=lambda x: x[1], reverse=True)
        
        # Take top non-gold passage
        for passage, _ in ranked:
            if passage != gold:
                hard_negatives.append((q, passage))
                break
    
    return hard_negatives

In [26]:
hard_negatives = build_hard_negatives(dataset)

In [27]:
print("Summary of generated pairs:")
print(f"positives: {len(positives)} | easy_negatives: {len(easy_negatives)} | hard_negatives: {len(hard_negatives)}")
print("sample easy negative:", easy_negatives[0] if easy_negatives else None)
print("sample hard negative:", hard_negatives[0] if hard_negatives else None)

Summary of generated pairs:
positives: 182822 | easy_negatives: 160869 | hard_negatives: 160869
sample easy negative: ('Chronic urethral obstruction due to benign prismatic hyperplasia can lead to the following change in kidney parenchyma', "Ans. is 'b' i.e., Preformed toxin . Vomiting within 6 hours of ingestion of rice suggests the diagnosis of emetic type of food poisoning caused by B. cereus. . It caused by preformed heat stable enterotoxin. . In emetic form, B. cereus is not found in large numbers in fecal specimens. Therefore food sample is more useful.")
sample hard negative: ('Chronic urethral obstruction due to benign prismatic hyperplasia can lead to the following change in kidney parenchyma', "Ans. a (MCU). (Ref. Sutton Radiology 7th ed. 1017, 1061)- MCU is IOC for PU valve and VUR.Posterior urethral valves (PUV)# Varying degree of chronic urethral obstruction due to fusion and prominence of plicae colliculi, normal concentric folds of urethra.# Usually located in posterior 

In [28]:
# (query, positive_passage, negative_passage)
positives_clean = [(q, p) for (q, p) in positives if q and p]

negatives = easy_negatives + hard_negatives
negatives_clean = [(q, p) for (q, p) in negatives if q and p]

# Build a dict mapping query -> list of negative passages
from collections import defaultdict

neg_by_query = defaultdict(list)
for q, p in negatives_clean:
    neg_by_query[q].append(p)

print(f"Built negative index for {len(neg_by_query)} unique queries")

# samples: (query, positive_passage, negative_passage)
triplet_samples = []
for q, pos_p in tqdm(positives_clean, desc="Creating triplets"):
    # Fast lookup: get negatives for this query
    neg_for_q = neg_by_query.get(q, [])
    
    # Filter out the positive passage
    neg_for_q = [p for p in neg_for_q if p != pos_p]
    
    if neg_for_q:
        neg_p = random.choice(neg_for_q)
        triplet_samples.append((q, pos_p, neg_p))

random.shuffle(triplet_samples)
print(f"Created {len(triplet_samples)} triplet samples (query, positive, negative)")

Built negative index for 160869 unique queries


Creating triplets: 100%|██████████| 160869/160869 [00:00<00:00, 331504.75it/s]


Created 160869 triplet samples (query, positive, negative)


In [29]:
# 90-10 train-val split
split = int(0.9 * len(triplet_samples))
train_data, val_data = triplet_samples[:split], triplet_samples[split:]

# Use subset for faster iteration (10% of data) during prototyping
subset_size = len(train_data) // 10
train_data_subset = train_data[:subset_size]
val_data_subset = val_data[:len(val_data)//10]

print(f"Full dataset - train: {len(train_data)}, val: {len(val_data)}")
print(f"Subset (10%) - train: {len(train_data_subset)}, val: {len(val_data_subset)}")
print("Use train_data_subset for fast prototyping, train_data for full training")


Full dataset - train: 144782, val: 16087
Subset (10%) - train: 14478, val: 1608
Use train_data_subset for fast prototyping, train_data for full training


In [30]:
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
encoder = AutoModel.from_pretrained(model_name)

def collate_batch_triplet(batch, max_len=128):
    qs, pos_ps, neg_ps = zip(*batch)
    
    # Encode queries
    enc_q = tokenizer(
        list(qs),
        padding=True,
        truncation=True,
        max_length=max_len,
        return_tensors="pt"
    )
    
    # Encode positive passages
    enc_pos = tokenizer(
        list(pos_ps),
        padding=True,
        truncation=True,
        max_length=max_len,
        return_tensors="pt"
    )
    
    # Encode negative passages
    enc_neg = tokenizer(
        list(neg_ps),
        padding=True,
        truncation=True,
        max_length=max_len,
        return_tensors="pt"
    )
    
    return enc_q, enc_pos, enc_neg


In [31]:
class RankingScorer(nn.Module):
    def __init__(self, encoder, embedding_dim=128):
        super().__init__()
        self.encoder = encoder
        self.hidden_dim = encoder.config.hidden_size
        
        # Project to embedding space for ranking
        self.projection = nn.Sequential(
            nn.Linear(self.hidden_dim, embedding_dim),
            nn.ReLU(),
            nn.Linear(embedding_dim, embedding_dim)
        )
    
    def forward(self, enc_ips):
        out = self.encoder(**enc_ips).last_hidden_state
        
        # Mean pooling over tokens
        mask = enc_ips["attention_mask"].unsqueeze(-1)  # [B, L, 1]
        pooled = (out * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)  # [B, H]
        
        # Project to embedding space
        embedding = self.projection(pooled)  # [B, embedding_dim]
        
        return embedding

In [32]:
device = "cuda"
model = RankingScorer(encoder, embedding_dim=128).to(device)

In [33]:
opt = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)

# Ranking loss: Triplet Loss with margin
def ranking_loss(emb_pos, emb_neg, margin=0.5):
    """Triplet loss: max(0, margin - (pos_sim - neg_sim))
    Push negative embeddings away from positive embeddings.
    """
    # Cosine similarity
    pos_sim = F.cosine_similarity(emb_pos, emb_pos)  # [B], should be ~1.0
    neg_sim = F.cosine_similarity(emb_pos, emb_neg)  # [B], should be < pos_sim
    
    # Triplet loss: max(0, margin - (pos_sim - neg_sim))
    # Want pos_sim > neg_sim by at least margin
    loss = torch.clamp(margin - (pos_sim - neg_sim), min=0).mean()
    return loss


In [34]:
# Data loaders
train_loader = DataLoader(train_data_subset, batch_size=32, shuffle=True, collate_fn=collate_batch_triplet)
val_loader = DataLoader(val_data_subset, batch_size=32, shuffle=False, collate_fn=collate_batch_triplet)

print(f"Training batches per epoch: {len(train_loader)}")

Training batches per epoch: 453


In [35]:
scaler = torch.amp.GradScaler('cuda')

n = 5 # epochs
for epoch in range(n):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    
    for enc_q, enc_pos, enc_neg in pbar:
        enc_q = {k: v.to(device) for k, v in enc_q.items()}
        enc_pos = {k: v.to(device) for k, v in enc_pos.items()}
        enc_neg = {k: v.to(device) for k, v in enc_neg.items()}
        
        with torch.amp.autocast('cuda'):
            emb_q = model(enc_q)       # Query embeddings
            emb_pos = model(enc_pos)   # Positive passage embeddings
            emb_neg = model(enc_neg)   # Negative passage embeddings
            
            # Cosine similarity: query vs positive, query vs negative
            pos_sim = F.cosine_similarity(emb_q, emb_pos)  # [B]
            neg_sim = F.cosine_similarity(emb_q, emb_neg)  # [B]
            
            # Triplet loss: max(0, margin - (pos_sim - neg_sim))
            loss = torch.clamp(0.5 - (pos_sim - neg_sim), min=0).mean()
        
        opt.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})
    
    # Validation: check if positive is ranked higher than negative
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for enc_q, enc_pos, enc_neg in tqdm(val_loader, desc="Validation"):
            enc_q = {k: v.to(device) for k, v in enc_q.items()}
            enc_pos = {k: v.to(device) for k, v in enc_pos.items()}
            enc_neg = {k: v.to(device) for k, v in enc_neg.items()}
            
            with torch.amp.autocast('cuda'):
                emb_q = model(enc_q)
                emb_pos = model(enc_pos)
                emb_neg = model(enc_neg)
            
            # Check: query similarity to positive > query similarity to negative
            pos_sim = F.cosine_similarity(emb_q, emb_pos)
            neg_sim = F.cosine_similarity(emb_q, emb_neg)
            correct += (pos_sim > neg_sim).sum().item()
            total += emb_q.shape[0]
    
    print(f"epoch {epoch}: val_ranking_acc={correct/total:.3f}")


Epoch 0:   0%|          | 0/453 [00:00<?, ?it/s]

Epoch 0: 100%|██████████| 453/453 [03:16<00:00,  2.31it/s, loss=0.1093]
Validation: 100%|██████████| 51/51 [00:09<00:00,  5.50it/s]


epoch 0: val_ranking_acc=0.889


Epoch 1: 100%|██████████| 453/453 [03:35<00:00,  2.10it/s, loss=0.0971]
Validation: 100%|██████████| 51/51 [00:09<00:00,  5.49it/s]


epoch 1: val_ranking_acc=0.926


Epoch 2: 100%|██████████| 453/453 [03:11<00:00,  2.37it/s, loss=0.0335]
Validation: 100%|██████████| 51/51 [00:06<00:00,  8.06it/s]


epoch 2: val_ranking_acc=0.935


Epoch 3: 100%|██████████| 453/453 [03:11<00:00,  2.37it/s, loss=0.0447]
Validation: 100%|██████████| 51/51 [00:06<00:00,  7.77it/s]


epoch 3: val_ranking_acc=0.935


Epoch 4: 100%|██████████| 453/453 [03:23<00:00,  2.22it/s, loss=0.0230]
Validation: 100%|██████████| 51/51 [00:09<00:00,  5.61it/s]

epoch 4: val_ranking_acc=0.942





In [36]:
torch.save(model.state_dict(), "ranking_scorer.ckpt")

In [38]:
# Ranking inference: compute energy scores
def rank_passages_with_energy(model, tokenizer, query, passages, device="cuda", max_len=128, batch_size=32):
    """Rank passages and compute energy scores.
    
    Energy = 1 - similarity (lower energy = better match, closer to 0)
    Higher similarity = passage more relevant to query.
    """
    model.eval()
    similarity_scores = []
    
    autocast_device = "cuda" if str(device).startswith("cuda") and torch.cuda.is_available() else "cpu"
    with torch.no_grad(), torch.amp.autocast(autocast_device):
        # Encode query separately
        enc_query = tokenizer(
            query,
            padding=True,
            truncation=True,
            max_length=max_len,
            return_tensors="pt"
        )
        enc_query = {k: v.to(device) for k, v in enc_query.items()}
        query_embedding = model(enc_query)  # [1, embedding_dim]
        
        # Score each passage
        for i in range(0, len(passages), batch_size):
            batch = passages[i : i + batch_size]
            enc = tokenizer(
                batch,  # Encode passages only
                padding=True,
                truncation=True,
                max_length=max_len,
                return_tensors="pt"
            )
            enc = {k: v.to(device) for k, v in enc.items()}
            passage_embeddings = model(enc)  # [batch_size, embedding_dim]
            
            # Cosine similarity between query and each passage
            sims = F.cosine_similarity(query_embedding, passage_embeddings)  # [batch_size]
            similarity_scores.extend(sims.detach().cpu().tolist())
    
    # Convert similarity to energy: energy = 1 - similarity
    # Lower energy (closer to 0) = better match
    energy_scores = [1 - sim for sim in similarity_scores]
    
    # Sort by energy (ascending = best first, lowest energy = highest confidence)
    ranked = sorted(zip(passages, similarity_scores, energy_scores), key=lambda x: x[2])
    return ranked


# Test on a random positive example
source_pos = positives_clean if 'positives_clean' in globals() else [(q, p) for (q, p) in positives if q and p]
q, gold = random.choice(source_pos)

# Find easy/hard negatives for the same question
easy_p = None
hard_p = None

if 'easy_negatives' in globals() and easy_negatives:
    same_q_easy = [p for (qq, p) in easy_negatives if qq == q and p]
    if same_q_easy:
        easy_p = random.choice(same_q_easy)
    else:
        pool = [p for (_, p) in easy_negatives if p and p != gold]
        easy_p = random.choice(pool) if pool else None

if 'hard_negatives' in globals() and hard_negatives:
    same_q_hard = [p for (qq, p) in hard_negatives if qq == q and p]
    if same_q_hard:
        hard_p = random.choice(same_q_hard)
    else:
        pool = [p for (_, p) in hard_negatives if p and p != gold]
        hard_p = random.choice(pool) if pool else None

# Build passages to rank
test_passages = [gold]
if easy_p: test_passages.append(easy_p)
if hard_p: test_passages.append(hard_p)

# Rank them with energy scores
ranked = rank_passages_with_energy(model, tokenizer, q, test_passages, device=device, max_len=128, batch_size=32)

print("="*70)
print("Query:")
print(q[:300])
print("\n" + "="*70)
print("Ranked passages (ascending energy = best match):")
print("="*70)

for i, (passage, similarity, energy) in enumerate(ranked, 1):
    snippet = (passage or "").replace("\n", " ")[:220]
    is_gold = " ✓ GOLD" if passage == gold else ""
    
    # Color-code energy: green if low (good), red if high (bad)
    confidence = "✓ HIGH" if energy < 0.3 else "⚠ MEDIUM" if energy < 0.5 else "✗ LOW"
    
    print(f"\n{i}. {confidence}{is_gold}")
    print(f"   Similarity: {similarity:.4f} | Energy: {energy:.4f}")
    print(f"   {snippet}{'...' if passage and len(passage) > 220 else ''}")

print("\n" + "="*70)
print("Energy Score Guide:")
print("  Energy ≈ 0.0-0.3  → ✓ HIGH confidence (good match)")
print("  Energy ≈ 0.3-0.5  → ⚠ MEDIUM confidence (acceptable)")
print("  Energy ≈ 0.5+     → ✗ LOW confidence (poor match)")
print("="*70)


Query:
A policemen foo..a a person ln ing unconscious in iglu lateral position on the road with superficial injury to the face, bruises on the right arm, and injury to the lateral aspect of right knee. Nerve most probably injured:

Ranked passages (ascending energy = best match):

1. ✗ LOW ✓ GOLD
   Similarity: 0.4146 | Energy: 0.5854
   Ans. c. Common peroneal nerve Common peroneal nerve (L4, L5, Sl, S2) is the smaller terminal branch of sciatic nerve. The larger terminal branch of sciatic nerve is the tibial nerve. The common peroneal nerve is relative...

2. ✗ LOW
   Similarity: 0.3954 | Energy: 0.6046
   curare notch ref : willey 10th ed

3. ✗ LOW
   Similarity: -0.0587 | Energy: 1.0587
   The area posterior to the sternum is occupied by the right ventricle and hence is most likely to be injured in this case. The convex anterosuperior surface of the right ventricle makes up a large pa of the sternocoastal ...

Energy Score Guide:
  Energy ≈ 0.0-0.3  → ✓ HIGH confidence (good match)