In [53]:
import json
import random
import json, random, torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
from rank_bm25 import BM25Okapi
import nltk
nltk.download('punkt')
from nltk.tokenize import word_tokenize
from tqdm import tqdm

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\VIDUSHI\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [3]:
# +ve pairs: create (query, passage) pairs for training
# Robustly load JSON (array) or JSONL (one object per line)
with open("data/train.json", "r", encoding="utf-8") as f:
    raw = f.read()

try:
    dataset = json.loads(raw)
except json.JSONDecodeError:
    dataset = []
    for line in raw.splitlines():
        line = line.strip()
        if not line:
            continue
        try:
            dataset.append(json.loads(line))
        except json.JSONDecodeError:
            # Skip malformed lines; adjust if you need strict validation
            continue

# Handle if dataset is a dict with a list inside, or directly a list
if isinstance(dataset, dict):
    for key, value in dataset.items():
        if isinstance(value, list):
            dataset = value
            break

positives = []
for x in dataset:
    positives.append((x.get("question"), x.get("exp")))

In [4]:
positives

[('Chronic urethral obstruction due to benign prismatic hyperplasia can lead to the following change in kidney parenchyma',
  'Chronic urethral obstruction because of urinary calculi, prostatic hyperophy, tumors, normal pregnancy, tumors, uterine prolapse or functional disorders cause hydronephrosis which by definition is used to describe dilatation of renal pelvis and calculus associated with progressive atrophy of the kidney due to obstruction to the outflow of urine Refer Robbins 7yh/9,1012,9/e. P950'),
 ('Which vitamin is supplied from only animal source:',
  "Ans. (c) Vitamin B12 Ref: Harrison's 19th ed. P 640* Vitamin B12 (Cobalamin) is synthesized solely by microorganisms.* In humans, the only source for humans is food of animal origin, e.g., meat, fish, and dairy products.* Vegetables, fruits, and other foods of nonanimal origin doesn't contain Vitamin B12 .* Daily requirements of vitamin Bp is about 1-3 pg. Body stores are of the order of 2-3 mg, sufficient for 3-4 years if su

In [6]:
corpus = [x["exp"] for x in dataset]

In [9]:
# easy -ves 
def build_easy_negatives(dataset, gold_key_question="question", gold_key_passage="exp", max_attempts=10):
    all_passages = [x.get(gold_key_passage) for x in dataset if x.get(gold_key_passage)]
    easy_negatives = []
    for item in dataset:
        q = item.get(gold_key_question)
        gold = item.get(gold_key_passage)
        if not q or not gold or not all_passages:
            continue
        candidate = gold
        tries = 0
        while candidate == gold and tries < max_attempts:
            candidate = random.choice(all_passages)
            tries += 1
        if candidate != gold:
            easy_negatives.append((q, candidate))
    return easy_negatives

In [10]:
easy_negatives = build_easy_negatives(dataset)

In [21]:
def build_hard_negatives(dataset, gold_key_question="question", gold_key_passage="exp", sample_size=1000):
    all_passages = [x.get(gold_key_passage) for x in dataset if x.get(gold_key_passage)]
    
    # Sample to speed up BM25 indexing
    if len(all_passages) > sample_size:
        sampled_idx = random.sample(range(len(all_passages)), sample_size)
        all_passages = [all_passages[i] for i in sampled_idx]
    
    # Build BM25 index
    corpus_tokens = [p.lower().split() for p in all_passages]
    bm25 = BM25Okapi(corpus_tokens)
    
    hard_negatives = []
    for item in dataset:
        q = item.get(gold_key_question)
        gold = item.get(gold_key_passage)
        if not q or not gold:
            continue
        
        # BM25 rank
        q_tokens = q.lower().split()
        scores = bm25.get_scores(q_tokens)
        ranked = sorted(zip(all_passages, scores), key=lambda x: x[1], reverse=True)
        
        # Take top non-gold passage
        for passage, _ in ranked:
            if passage != gold:
                hard_negatives.append((q, passage))
                break
    
    return hard_negatives

In [22]:
hard_negatives = build_hard_negatives(dataset)

In [23]:
print("Summary of generated pairs:")
print(f"positives: {len(positives)} | easy_negatives: {len(easy_negatives)} | hard_negatives: {len(hard_negatives)}")
print("sample easy negative:", easy_negatives[0] if easy_negatives else None)
print("sample hard negative:", hard_negatives[0] if hard_negatives else None)

Summary of generated pairs:
positives: 182822 | easy_negatives: 160869 | hard_negatives: 160869
sample easy negative: ('Chronic urethral obstruction due to benign prismatic hyperplasia can lead to the following change in kidney parenchyma', "Ans. is 'b' i.e., 3 years Medical Council recommends that the medical records of the indoor patient to be aintained for atleast a period of 3 years from the date of commencement of treatment.")
sample hard negative: ('Chronic urethral obstruction due to benign prismatic hyperplasia can lead to the following change in kidney parenchyma', "(Lymphatic obstruction) (837-LB) (839-B & L 25th)Phenomena resulting from lymphatic obstruction In advanced breast cancer(i) Peaud's orange (ii) Cancer-en-cuirasse (iii) LymphangiosarcomaPeaud' orange - is due to cutaneous lymphatic oedema, where the infiltrate skin is fethered by the sweat ducts, it cannot swell leading to an appearance like orange skin. Occasionally the same phenomenon is seen over a chronic absc

In [None]:
# split and label the data
# tokenize pairs
# scorer head
# train with bce loss

In [35]:
# Filter out None values before creating samples
positives_clean = [(q, p) for (q, p) in positives if q and p]

negatives = easy_negatives + hard_negatives
negatives_clean = [(q, p) for (q, p) in negatives if q and p]
samples = [(q, p, 0) for (q, p) in positives_clean] + [(q, p, 1) for (q, p) in negatives_clean]
random.shuffle(samples)
print(f"Created {len(samples)} valid samples (filtered out None values)")

Created 482607 valid samples (filtered out None values)


In [57]:
split= int(0.9 * len(samples))
train_data, val_data = samples[:split], samples[split:]

# Use subset for faster iteration (10% of data)
subset_size = len(train_data) // 10
train_data_subset = train_data[:subset_size]
val_data_subset = val_data[:len(val_data)//10]

print(f"Full dataset - train: {len(train_data)}, val: {len(val_data)}")
print(f"Subset (10%) - train: {len(train_data_subset)}, val: {len(val_data_subset)}")
print("Use train_data_subset for fast prototyping, train_data for full training")

Full dataset - train: 434346, val: 48261
Subset (10%) - train: 43434, val: 4826
Use train_data_subset for fast prototyping, train_data for full training


In [58]:
# encode 
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
encoder = AutoModel.from_pretrained(model_name)

def collate_batch(batch, max_len = 128):  # Reduced from 256 for speed
    qs, ps, ys = zip(*batch)
    enc = tokenizer(
        list(qs),
        list(ps),
        padding = True,
        truncation = True,
        max_length= max_len, 
        return_tensors="pt"
    )
    labels = torch.tensor(ys, dtype=torch.float32)
    return enc, labels

In [59]:
class TransformerScorer(nn.Module):
    def __init__(self, encoder, hidden = 256):
        super().__init__()
        self.encoder = encoder 
        dim = encoder.config.hidden_size
        self.classifier = nn.Sequential(
            nn.Linear(dim*2, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1)
        )
    
    def forward(self, enc_ips):
        out = self.encoder(**enc_ips).last_hidden_state
        # Mean pooling over tokens
        mask = enc_ips["attention_mask"].unsqueeze(-1)  # [B, L, 1]
        pooled = (out * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)  # [B, H]

        # Split back into q/p representations: we encoded pairs, so duplicate pooled
        # Alternatively, encode separately. Here we encode concatenated pair: use pooled as joint repr.
        joint = pooled
        # If you prefer separate encodes, run encoder twice (for q and p) and concat. This is faster (single pass).
        x = torch.cat([joint, joint], dim=1)  # simple pass-through for joint repr
        return self.classifier(x).squeeze(1)

In [60]:
device = "cuda" 
model = TransformerScorer(encoder).to(device)

In [61]:
opt = torch.optim.AdamW(model.parameters(), lr=2e-5)
bce = nn.BCEWithLogitsLoss()

In [62]:
# data loader - use subset for fast iteration
train_loader = DataLoader(train_data_subset, batch_size=32, shuffle=True, collate_fn=collate_batch)
val_loader = DataLoader(val_data_subset, batch_size=32, shuffle=False, collate_fn=collate_batch)

print(f"Training batches per epoch: {len(train_loader)}")

Training batches per epoch: 1358


In [63]:
# Train with mixed precision and progress tracking
scaler = torch.amp.GradScaler('cuda')

for epoch in range(3):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    for enc, y in pbar:
        enc = {k: v.to(device) for k, v in enc.items()}
        y = y.to(device)
        
        with torch.amp.autocast('cuda'):
            logits = model(enc)
            loss = bce(logits, y)
        
        opt.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})
    
    model.eval()
    total = correct = 0
    with torch.no_grad():
        for enc, y in tqdm(val_loader, desc="Validation"):
            enc = {k: v.to(device) for k, v in enc.items()}
            y = y.to(device)
            
            with torch.amp.autocast('cuda'):
                logits = model(enc)
            
            preds = (torch.sigmoid(logits) > 0.5).float()
            correct += (preds == y).sum().item()
            total += y.numel()
    
    print(f"epoch {epoch}: val_acc={correct/total:.3f}")

Epoch 0: 100%|██████████| 1358/1358 [04:28<00:00,  5.06it/s, loss=0.5228]
Validation: 100%|██████████| 151/151 [00:07<00:00, 20.24it/s]


epoch 0: val_acc=0.863


Epoch 1: 100%|██████████| 1358/1358 [04:21<00:00,  5.20it/s, loss=0.1521]
Validation: 100%|██████████| 151/151 [00:07<00:00, 20.71it/s]


epoch 1: val_acc=0.898


Epoch 2: 100%|██████████| 1358/1358 [04:22<00:00,  5.17it/s, loss=0.0241]
Validation: 100%|██████████| 151/151 [00:07<00:00, 20.66it/s]

epoch 2: val_acc=0.903





In [64]:
torch.save(model.state_dict(), "scorer.ckpt")

In [67]:
# EBM scorer sanity check: single energy score per candidate (lower = better)
import random, torch


def ebm_energy_batch(model, tokenizer, query, passages, device="cuda", max_len=128, batch_size=64):
    """Return energy scores in [0,1] where lower = better.
    Energy = sigmoid(logit), since model trained with 0=positive, 1=negative.
    """
    model.eval()
    energies = []
    autocast_device = "cuda" if str(device).startswith("cuda") and torch.cuda.is_available() else "cpu"
    with torch.no_grad(), torch.amp.autocast(autocast_device):
        for i in range(0, len(passages), batch_size):
            batch = passages[i : i + batch_size]
            enc = tokenizer([query] * len(batch), batch,
                            padding=True, truncation=True, max_length=max_len,
                            return_tensors="pt")
            enc = {k: v.to(device) for k, v in enc.items()}
            logits = model(enc)                     # higher logit => more negative
            energy = torch.sigmoid(logits)          # energy in [0,1]; lower = better
            energies.extend(energy.detach().cpu().tolist())
    return energies


# Pick a random positive (q, gold passage)
source_pos = positives_clean if 'positives_clean' in globals() else [(q, p) for (q, p) in positives if q and p]
q, gold = random.choice(source_pos)

# Try to find easy/hard negatives for the SAME question
easy_p = None
hard_p = None

if 'easy_negatives' in globals() and easy_negatives:
    same_q_easy = [p for (qq, p) in easy_negatives if qq == q and p]
    if same_q_easy:
        easy_p = random.choice(same_q_easy)
    else:
        pool = [p for (_, p) in easy_negatives if p and p != gold]
        easy_p = random.choice(pool) if pool else None

if 'hard_negatives' in globals() and hard_negatives:
    same_q_hard = [p for (qq, p) in hard_negatives if qq == q and p]
    if same_q_hard:
        hard_p = random.choice(same_q_hard)
    else:
        pool = [p for (_, p) in hard_negatives if p and p != gold]
        hard_p = random.choice(pool) if pool else None

# Build set to score
pairs = [("positive", gold)]
if easy_p: pairs.append(("easy_negative", easy_p))
if hard_p: pairs.append(("hard_negative", hard_p))

labels, passages = zip(*pairs)
energies = ebm_energy_batch(model, tokenizer, q, list(passages), device=device, max_len=128, batch_size=32)

print("Query:\n", q)
print("\nEnergy scores (lower = better; near 0 = likely correct):")
for lbl, txt, e in zip(labels, passages, energies):
    snippet = (txt or "").replace("\n", " ")[:220]
    print(f"- {lbl:14s} | energy={e:.3f} | {snippet}{'...' if txt and len(txt)>220 else ''}")

# Gate using a single threshold on energy
energy_threshold = 0.30  # accept if energy <= threshold
accepted = [(lbl, p, e) for (lbl, p, e) in zip(labels, passages, energies) if e <= energy_threshold]
print(f"\nAccepted (energy <= {energy_threshold}): {len(accepted)}")
for lbl, p, e in accepted:
    tag = " (GOLD)" if lbl == 'positive' else ""
    print(f"  - {lbl}{tag}: energy={e:.3f}")

Query:
 Carrier is defined as -

Energy scores (lower = better; near 0 = likely correct):
- positive       | energy=0.724 | Ans. is 'c' i.e., Infected person harbouring infectious agent without clinical features and acts as source of infection Sources and reservoirs* Source is 'the person, animal, object or substance from which infectious age...
- easy_negative  | energy=0.995 | To prevent NTDs, it is recommended that all women of child-bearing age, who are capable of becoming pregnant should take 0.4 mg (400 mg) of folic acid daily. If a pregnancy is planned in high risk women (with previously ...
- hard_negative  | energy=0.996 | Ans. is 'b' i.e., INH and Rifampicin o Multi-Drug Resistant Tuberculosis is defined by resistance to INH and Rifampicin.o 'Multi-Drug Resistant Tuberculosis (MDR-TB) is defined as disease caused by strain of M Tuberculos...

Accepted (energy <= 0.3): 0
