In [18]:
import json
import random
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
from rank_bm25 import BM25Okapi
import nltk
nltk.download('punkt')
from nltk.tokenize import word_tokenize
from tqdm import tqdm
from collections import defaultdict

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\VIDUSHI\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [5]:
with open("../data/train.json", "r", encoding="utf-8") as f:
    raw = f.read()

try:
    dataset = json.loads(raw)
except json.JSONDecodeError:
    dataset = []
    for line in raw.splitlines():
        line = line.strip()
        if not line:
            continue
        try:
            dataset.append(json.loads(line))
        except json.JSONDecodeError:
            # Skip malformed lines; adjust if you need strict validation
            continue

# Handle if dataset is a dict with a list inside, or directly a list
if isinstance(dataset, dict):
    for key, value in dataset.items():
        if isinstance(value, list):
            dataset = value
            break

positives = []
for x in dataset:
    positives.append((x.get("question"), x.get("exp")))

In [6]:
len(positives)

182822

In [7]:
corpus = [x["exp"] for x in dataset]

In [8]:
# easy -ves 
def build_easy_negatives(dataset, gold_key_question="question", gold_key_passage="exp", max_attempts=10):
    all_passages = [x.get(gold_key_passage) for x in dataset if x.get(gold_key_passage)]
    easy_negatives = []
    for item in dataset:
        q = item.get(gold_key_question)
        gold = item.get(gold_key_passage)
        if not q or not gold or not all_passages:
            continue
        candidate = gold
        tries = 0
        while candidate == gold and tries < max_attempts:
            candidate = random.choice(all_passages)
            tries += 1
        if candidate != gold:
            easy_negatives.append((q, candidate))
    return easy_negatives

In [9]:
easy_negatives = build_easy_negatives(dataset)

In [10]:
def build_hard_negatives(dataset, gold_key_question="question", gold_key_passage="exp", sample_size=1000):
    all_passages = [x.get(gold_key_passage) for x in dataset if x.get(gold_key_passage)]
    
    # Sample to speed up BM25 indexing
    if len(all_passages) > sample_size:
        sampled_idx = random.sample(range(len(all_passages)), sample_size)
        all_passages = [all_passages[i] for i in sampled_idx]
    
    # Build BM25 index
    corpus_tokens = [p.lower().split() for p in all_passages]
    bm25 = BM25Okapi(corpus_tokens)
    
    hard_negatives = []
    for item in dataset:
        q = item.get(gold_key_question)
        gold = item.get(gold_key_passage)
        if not q or not gold:
            continue
        
        # BM25 rank
        q_tokens = q.lower().split()
        scores = bm25.get_scores(q_tokens)
        ranked = sorted(zip(all_passages, scores), key=lambda x: x[1], reverse=True)
        
        # Take top non-gold passage
        for passage, _ in ranked:
            if passage != gold:
                hard_negatives.append((q, passage))
                break
    
    return hard_negatives

In [11]:
hard_negatives = build_hard_negatives(dataset)

In [12]:
def has_contradiction(text1, text2):
    """Check if two passages contradict each other.
    Simple heuristic - improve with NLI model for production.
    """
    contradiction_pairs = [
        ('safe', 'contraindicated'),
        ('effective', 'ineffective'),
        ('recommended', 'not recommended'),
        ('increases', 'decreases'),
        ('approved', 'not approved'),
        ('use', 'avoid'),
        ('beneficial', 'harmful'),
        ('normal', 'abnormal')
    ]
    
    text1_lower = text1.lower()
    text2_lower = text2.lower()
    
    for pos, neg in contradiction_pairs:
        if (pos in text1_lower and neg in text2_lower) or \
           (neg in text1_lower and pos in text2_lower):
            return True
    return False

In [13]:
def extract_medical_topics(text):
    """Extract medical entities - simple keyword matching.
    TODO: Upgrade to scispacy or BioBERT NER for production.
    """
    medical_keywords = {
        'diabetes', 'hypertension', 'aspirin', 'metformin', 
        'surgery', 'pregnancy', 'cancer', 'antibiotics',
        'heart', 'blood pressure', 'cholesterol', 'infection',
        'pain', 'fever', 'asthma', 'copd', 'stroke', 'mi',
        'myocardial infarction', 'coronary', 'cardiovascular'
    }
    
    text_lower = text.lower()
    return [kw for kw in medical_keywords if kw in text_lower]

In [16]:
# Build contradiction negatives (FAST - with hard limits)
def build_contradiction_negatives(dataset, max_candidates_per_topic=50):    
    from collections import defaultdict
    
    # Step 1: Index passages by medical topics - O(n)
    topic_to_passages = defaultdict(list)
    
    for item in dataset:
        passage = item.get("exp")
        if not passage:
            continue
        
        topics = extract_medical_topics(passage)
        for topic in topics:
            topic_to_passages[topic].append(passage)
    
    print(f"Indexed {len(topic_to_passages)} topics for contradiction detection")
    
    # Step 2: For each question, check LIMITED candidates per topic
    contradiction_negatives = []
    
    for item in tqdm(dataset, desc="Building contradiction negatives"):
        q = item.get("question")
        gold = item.get("exp")
        
        if not q or not gold:
            continue
        
        # Get topics from gold passage
        gold_topics = extract_medical_topics(gold)
        if not gold_topics:
            continue
        
        # LIMIT: Only check first max_candidates_per_topic passages per topic
        candidates_to_check = []
        for topic in gold_topics:
            topic_passages = topic_to_passages[topic]
            # Sample randomly if too many
            if len(topic_passages) > max_candidates_per_topic:
                sampled = random.sample(topic_passages, max_candidates_per_topic)
                candidates_to_check.extend(sampled)
            else:
                candidates_to_check.extend(topic_passages)
        
        # Remove duplicates and gold passage
        candidates_to_check = list(set(candidates_to_check))
        candidates_to_check = [p for p in candidates_to_check if p != gold and p]
        
        # Check for contradiction (max 100 candidates per question)
        for candidate_passage in candidates_to_check[:100]:
            if has_contradiction(gold, candidate_passage):
                contradiction_negatives.append((q, candidate_passage))
                break  # Found one, move to next question
    
    print(f"Built {len(contradiction_negatives)} contradiction negatives")
    return contradiction_negatives

contradiction_negatives = build_contradiction_negatives(dataset, max_candidates_per_topic=50)

Indexed 21 topics for contradiction detection


Building contradiction negatives: 100%|██████████| 182822/182822 [00:21<00:00, 8576.20it/s] 

Built 39596 contradiction negatives





In [19]:
# Build medical hard negatives (most valuable for medical domain!)
def build_medical_hard_negatives(dataset, sample_size=1000):    
    # Group passages by medical topic/specialty
    topic_groups = defaultdict(list)
    for item in dataset:
        question = item.get("question", "")
        passage = item.get("exp", "")
        
        if not passage:
            continue
        
        # Extract medical topics from question
        topics = extract_medical_topics(question)
        
        for topic in topics:
            topic_groups[topic].append(passage)
    
    print(f"Grouped passages into {len(topic_groups)} medical topics")
    
    medical_hard_negatives = []
    for item in dataset:
        q = item.get("question")
        gold = item.get("exp")
        if not q or not gold:
            continue
        
        topics = extract_medical_topics(q)
        
        # Sample from SAME medical topic but different passage
        for topic in topics:
            candidates = [p for p in topic_groups.get(topic, []) if p != gold]
            if candidates:
                hard_neg = random.choice(candidates)
                medical_hard_negatives.append((q, hard_neg))
                break
    
    print(f"Built {len(medical_hard_negatives)} medical hard negatives")
    return medical_hard_negatives

medical_hard_negatives = build_medical_hard_negatives(dataset)

Grouped passages into 21 medical topics
Built 35069 medical hard negatives


In [20]:
# Combine negatives with strategic weighting
all_negatives = (
    easy_negatives * 1 +  # 1x easy (baseline)
    hard_negatives * 2 +  # 2x hard (BM25 challenging)
    medical_hard_negatives * 3 +  # 3x medical hard (domain-specific!)
    contradiction_negatives * 2  # 2x contradictions (hallucination prevention!)
)

print(f"Summary of weighted negatives:")
print(f"  Easy: {len(easy_negatives)} → {len(easy_negatives) * 1}")
print(f"  Hard (BM25): {len(hard_negatives)} → {len(hard_negatives) * 2}")
print(f"  Medical Hard: {len(medical_hard_negatives)} → {len(medical_hard_negatives) * 3}")
print(f"  Contradictions: {len(contradiction_negatives)} → {len(contradiction_negatives) * 2}")
print(f"  Total weighted: {len(all_negatives)}")

Summary of weighted negatives:
  Easy: 160869 → 160869
  Hard (BM25): 160869 → 321738
  Medical Hard: 35069 → 105207
  Contradictions: 39596 → 79192
  Total weighted: 667006


In [21]:
# (query, positive_passage, negative_passage)
positives_clean = [(q, p) for (q, p) in positives if q and p]

# Use all weighted negatives
negatives = all_negatives
negatives_clean = [(q, p) for (q, p) in negatives if q and p]

# Build a dict mapping query -> list of negative passages
from collections import defaultdict

neg_by_query = defaultdict(list)
for q, p in negatives_clean:
    neg_by_query[q].append(p)

print(f"Built negative index for {len(neg_by_query)} unique queries")

# samples: (query, positive_passage, negative_passage)
triplet_samples = []
for q, pos_p in tqdm(positives_clean, desc="Creating triplets"):
    # Fast lookup: get negatives for this query
    neg_for_q = neg_by_query.get(q, [])
    
    # Filter out the positive passage
    neg_for_q = [p for p in neg_for_q if p != pos_p]
    
    if neg_for_q:
        neg_p = random.choice(neg_for_q)
        triplet_samples.append((q, pos_p, neg_p))

random.shuffle(triplet_samples)
print(f"Created {len(triplet_samples)} triplet samples (query, positive, negative)")

Built negative index for 160869 unique queries


Creating triplets: 100%|██████████| 160869/160869 [00:00<00:00, 626474.88it/s]


Created 160869 triplet samples (query, positive, negative)


In [22]:
class DataAugmentation:    
    @staticmethod
    def augment_medical_query(query, num_augments=2):
        """Augment queries with medical synonyms and paraphrasing."""
        augmented = [query]
        
        # Medical synonym replacement
        medical_synonyms = {
            'heart attack': ['myocardial infarction', 'MI', 'cardiac event'],
            'high blood pressure': ['hypertension', 'elevated BP'],
            'diabetes': ['diabetes mellitus', 'DM'],
            'medicine': ['medication', 'drug', 'pharmaceutical'],
            'doctor': ['physician', 'clinician'],
            'symptoms': ['signs', 'manifestations'],
            'treatment': ['therapy', 'management'],
            'side effects': ['adverse effects', 'complications'],
        }
        
        query_lower = query.lower()
        for original, synonyms in medical_synonyms.items():
            if original in query_lower:
                for syn in synonyms[:num_augments]:
                    new_query = query_lower.replace(original, syn)
                    if new_query != query_lower:
                        augmented.append(new_query.capitalize())
        
        # Question reformulation
        reformulations = {
            'What is': ['What are', 'Can you explain', 'Define'],
            'How does': ['How do', 'What is the mechanism of'],
            'What are': ['What is', 'List the'],
        }
        
        for original, alternatives in reformulations.items():
            if original in query:
                for alt in alternatives[:num_augments]:
                    new_query = query.replace(original, alt, 1)
                    if new_query != query:
                        augmented.append(new_query)
        
        return list(set(augmented))[:num_augments + 1]
    
    @staticmethod
    def augment_passage(passage, drop_rate=0.1):
        """Augment passages by randomly dropping words (noise injection)."""
        words = passage.split()
        num_keep = int(len(words) * (1 - drop_rate))
        
        if num_keep < len(words) and num_keep > 0:
            indices = random.sample(range(len(words)), num_keep)
            indices.sort()
            return ' '.join([words[i] for i in indices])
        
        return passage

# Test augmentation
sample_query = "What is diabetes?"
augmented = DataAugmentation.augment_medical_query(sample_query)
print(f"Original: {sample_query}")
print(f"Augmented versions:")
for i, aug in enumerate(augmented[1:], 1):
    print(f"  {i}. {aug}")

Original: What is diabetes?
Augmented versions:
  1. What is dm?
  2. What is diabetes?


In [23]:
# Create triplets WITH augmentation (optional - set to True to use)
USE_AUGMENTATION = False  # Set to True for 3x more data

if USE_AUGMENTATION:
    print("Creating augmented triplets...")
    triplet_samples_aug = []
    
    for q, pos_p in tqdm(positives_clean, desc="Creating augmented triplets"):
        neg_for_q = neg_by_query.get(q, [])
        neg_for_q = [p for p in neg_for_q if p != pos_p]
        
        if not neg_for_q:
            continue
        
        # Original triplet
        neg_p = random.choice(neg_for_q)
        triplet_samples_aug.append((q, pos_p, neg_p))
        
        # Augmented versions
        aug_queries = DataAugmentation.augment_medical_query(q, num_augments=2)
        for aug_q in aug_queries[1:]:  # Skip first (original)
            aug_pos = DataAugmentation.augment_passage(pos_p)
            neg_p = random.choice(neg_for_q)
            triplet_samples_aug.append((aug_q, aug_pos, neg_p))
    
    random.shuffle(triplet_samples_aug)
    triplet_samples = triplet_samples_aug
    print(f"Created {len(triplet_samples)} augmented triplet samples")
else:
    print(f"Using {len(triplet_samples)} original triplets (set USE_AUGMENTATION=True for 3x more)")

Using 160869 original triplets (set USE_AUGMENTATION=True for 3x more)


In [24]:
# 90-10 train-val split
split = int(0.9 * len(triplet_samples))
train_data, val_data = triplet_samples[:split], triplet_samples[split:]

'''
# Use subset for faster iteration (10% of data) during prototyping
size = len(train_data) 
train_data_subset = train_data[:size//10]
val_data_subset = val_data[:len(val_data)//10]

print(f"Subset (10%) - train: {len(train_data_subset)}, val: {len(val_data_subset)}")
print("Use train_data_subset for fast prototyping, train_data for full training")
'''
print(f"Full dataset - train: {len(train_data)}, val: {len(val_data)}")

Full dataset - train: 144782, val: 16087


In [25]:
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
encoder = AutoModel.from_pretrained(model_name)

def collate_batch_triplet(batch, max_len=128):
    qs, pos_ps, neg_ps = zip(*batch)
    
    # Encode queries
    enc_q = tokenizer(
        list(qs),
        padding=True,
        truncation=True,
        max_length=max_len,
        return_tensors="pt"
    )
    
    # Encode positive passages
    enc_pos = tokenizer(
        list(pos_ps),
        padding=True,
        truncation=True,
        max_length=max_len,
        return_tensors="pt"
    )
    
    # Encode negative passages
    enc_neg = tokenizer(
        list(neg_ps),
        padding=True,
        truncation=True,
        max_length=max_len,
        return_tensors="pt"
    )
    
    return enc_q, enc_pos, enc_neg


In [26]:
class RankingScorer(nn.Module):
    def __init__(self, encoder, embedding_dim=128):
        super().__init__()
        self.encoder = encoder
        self.hidden_dim = encoder.config.hidden_size
        
        # Project to embedding space for ranking
        self.projection = nn.Sequential(
            nn.Linear(self.hidden_dim, embedding_dim),
            nn.ReLU(),
            nn.Linear(embedding_dim, embedding_dim)
        )
    
    def forward(self, enc_ips):
        out = self.encoder(**enc_ips).last_hidden_state
        
        # Mean pooling over tokens
        mask = enc_ips["attention_mask"].unsqueeze(-1)  # [B, L, 1]
        pooled = (out * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)  # [B, H]
        
        # Project to embedding space
        embedding = self.projection(pooled)  # [B, embedding_dim]
        
        return embedding

In [27]:
device = "cuda"
model = RankingScorer(encoder, embedding_dim=128).to(device)

In [28]:
opt = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)

# Ranking loss: Triplet Loss with margin
def ranking_loss(emb_pos, emb_neg, margin=0.5):
    """Triplet loss: max(0, margin - (pos_sim - neg_sim))
    Push negative embeddings away from positive embeddings.
    """
    # Cosine similarity
    pos_sim = F.cosine_similarity(emb_pos, emb_pos)  # [B], should be ~1.0
    neg_sim = F.cosine_similarity(emb_pos, emb_neg)  # [B], should be < pos_sim
    
    # Triplet loss: max(0, margin - (pos_sim - neg_sim))
    # Want pos_sim > neg_sim by at least margin
    loss = torch.clamp(margin - (pos_sim - neg_sim), min=0).mean()
    return loss


In [29]:
# Data loaders
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=collate_batch_triplet)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False, collate_fn=collate_batch_triplet)

print(f"Training batches per epoch: {len(train_loader)}")

Training batches per epoch: 4525


In [34]:
BATCH_SIZE = 16  # Reduced for 4-6GB VRAM
GRADIENT_ACCUMULATION_STEPS = 4  # Simulate batch_size=64
NUM_WORKERS = 0  # ⚠️ Set to 0 on Windows (fixes multiprocessing crash)
PIN_MEMORY = True  # Faster CPU->GPU transfer
FREEZE_ENCODER_EPOCHS = 2  # Train only projection layer first (faster!)
MAX_LENGTH = 96  # Reduce from 128 to save memory

# Learning rate schedule
from torch.optim.lr_scheduler import OneCycleLR, ReduceLROnPlateau

print(f"✓ Batch Size: {BATCH_SIZE} (effective: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS})")
print(f"✓ Max Length: {MAX_LENGTH} tokens")
print(f"✓ Workers: {NUM_WORKERS} (single-process for Windows stability)")
print(f"✓ Freeze Encoder: First {FREEZE_ENCODER_EPOCHS} epochs")

✓ Batch Size: 16 (effective: 64)
✓ Max Length: 96 tokens
✓ Workers: 0 (single-process for Windows stability)
✓ Freeze Encoder: First 2 epochs


In [35]:
# Optimized collate function with reduced max length
def collate_batch_optimized(batch, max_len=MAX_LENGTH):
    qs, pos_ps, neg_ps = zip(*batch)
    
    # Encode with shorter max_length to save memory
    enc_q = tokenizer(
        list(qs),
        padding=True,
        truncation=True,
        max_length=max_len,
        return_tensors="pt"
    )
    
    enc_pos = tokenizer(
        list(pos_ps),
        padding=True,
        truncation=True,
        max_length=max_len,
        return_tensors="pt"
    )
    
    enc_neg = tokenizer(
        list(neg_ps),
        padding=True,
        truncation=True,
        max_length=max_len,
        return_tensors="pt"
    )
    
    return enc_q, enc_pos, enc_neg

# Optimized data loaders - single process (Windows compatible)
train_loader = DataLoader(
    train_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_batch_optimized,
    num_workers=NUM_WORKERS,  # 0 for Windows
    pin_memory=PIN_MEMORY
)

val_loader = DataLoader(
    val_data,
    batch_size=BATCH_SIZE * 2,  # Larger batch for validation (no gradients)
    shuffle=False,
    collate_fn=collate_batch_optimized,
    num_workers=NUM_WORKERS,  # 0 for Windows
    pin_memory=PIN_MEMORY
)

print(f"✓ Training batches per epoch: {len(train_loader)}")
print(f"✓ Validation batches: {len(val_loader)}")
print(f"✓ Total training samples: {len(train_data)}")
print(f"✓ Single-process data loading (Windows compatible)")

✓ Training batches per epoch: 9049
✓ Validation batches: 503
✓ Total training samples: 144782
✓ Single-process data loading (Windows compatible)


In [36]:
EPOCHS = 5
EARLY_STOP_PATIENCE = 2  # Stop if no improvement
USE_MNR = True
# Optimizer with weight decay
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)

# Learning rate scheduler - warmup + cosine decay
total_steps = len(train_loader) * EPOCHS // GRADIENT_ACCUMULATION_STEPS
scheduler = OneCycleLR(
    optimizer,
    max_lr=5e-5,
    total_steps=total_steps,
    pct_start=0.1,  # 10% warmup
    anneal_strategy='cos'
)

scaler = torch.amp.GradScaler('cuda')

# Tracking
best_val_acc = 0
no_improve_count = 0

print(f"Starting efficient training:")
print(f"  Total steps: {total_steps}")
print(f"  Gradient accumulation: every {GRADIENT_ACCUMULATION_STEPS} batches")
print(f"  Early stopping patience: {EARLY_STOP_PATIENCE} epochs\n")

Starting efficient training:
  Total steps: 11311
  Gradient accumulation: every 4 batches
  Early stopping patience: 2 epochs



In [37]:
for epoch in range(EPOCHS):
    # Freeze encoder for first few epochs (train projection only = faster!)
    if epoch < FREEZE_ENCODER_EPOCHS:
        for param in model.encoder.parameters():
            param.requires_grad = False
        print(f"Epoch {epoch}: Encoder FROZEN (training projection only)")
    else:
        for param in model.encoder.parameters():
            param.requires_grad = True
        print(f"Epoch {epoch}: Encoder UNFROZEN (full training)")
    
    # Training
    model.train()
    total_loss = 0
    optimizer.zero_grad()
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch} {'[MNR]' if USE_MNR else '[Triplet]'}")
    
    for batch_idx, (enc_q, enc_pos, enc_neg) in enumerate(pbar):
        enc_q = {k: v.to(device) for k, v in enc_q.items()}
        enc_pos = {k: v.to(device) for k, v in enc_pos.items()}
        enc_neg = {k: v.to(device) for k, v in enc_neg.items()}
        
        batch_size = enc_q['input_ids'].shape[0]
        
        with torch.amp.autocast('cuda'):
            emb_q = model(enc_q)
            emb_pos = model(enc_pos)
            emb_neg = model(enc_neg)
            
            if USE_MNR:
                temperature = 0.05
                scores = torch.matmul(emb_q, emb_pos.T) / temperature
                labels = torch.arange(batch_size).to(device)
                loss = F.cross_entropy(scores, labels)
                
                # Add explicit negatives
                neg_scores = F.cosine_similarity(emb_q, emb_neg) / temperature
                pos_scores = scores.diagonal()
                margin_loss = F.relu(0.2 - (pos_scores - neg_scores)).mean()
                loss = loss + 0.5 * margin_loss
            else:
                pos_sim = F.cosine_similarity(emb_q, emb_pos)
                neg_sim = F.cosine_similarity(emb_q, emb_neg)
                loss = torch.clamp(0.5 - (pos_sim - neg_sim), min=0).mean()
        
        # Gradient accumulation
        loss = loss / GRADIENT_ACCUMULATION_STEPS
        scaler.scale(loss).backward()
        
        # Update every N steps
        if (batch_idx + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
            scaler.unscale_(optimizer)  # Unscale before step
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            scheduler.step()  # Scheduler AFTER optimizer (fixes warning)
        
        total_loss += loss.item() * GRADIENT_ACCUMULATION_STEPS
        pbar.set_postfix({
            "loss": f"{loss.item() * GRADIENT_ACCUMULATION_STEPS:.4f}",
            "lr": f"{scheduler.get_last_lr()[0]:.2e}"
        })
    
    avg_loss = total_loss / len(train_loader)
    
    # Validation
    model.eval()
    metrics = {'correct': 0, 'total': 0, 'mrr': []}
    
    with torch.no_grad():
        for enc_q, enc_pos, enc_neg in tqdm(val_loader, desc="Validation"):
            enc_q = {k: v.to(device) for k, v in enc_q.items()}
            enc_pos = {k: v.to(device) for k, v in enc_pos.items()}
            enc_neg = {k: v.to(device) for k, v in enc_neg.items()}
            
            with torch.amp.autocast('cuda'):
                emb_q = model(enc_q)
                emb_pos = model(enc_pos)
                emb_neg = model(enc_neg)
            
            batch_size = emb_q.shape[0]
            
            for i in range(batch_size):
                q_emb = emb_q[i:i+1]
                candidates = torch.cat([emb_pos[i:i+1], emb_neg[i:i+1]], dim=0)
                sims = F.cosine_similarity(q_emb, candidates, dim=1)
                
                rank = 2 - (sims[0] > sims[1]).long().item()
                
                if rank == 1:
                    metrics['correct'] += 1
                
                metrics['mrr'].append(1.0 / rank)
                metrics['total'] += 1
    
    val_acc = metrics['correct'] / metrics['total']
    val_mrr = sum(metrics['mrr']) / len(metrics['mrr'])
    
    print(f"\n Epoch {epoch} Summary:")
    print(f"   Train Loss: {avg_loss:.4f}")
    print(f"   Val Accuracy: {val_acc:.3f}")
    print(f"   Val MRR: {val_mrr:.3f}")
    
    # Early stopping check
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        no_improve_count = 0
        torch.save(model.state_dict(), "../checkpoints/best_EBM_scorer.ckpt")
        print(f"   ✓ New best model saved! (acc={val_acc:.3f})")
    else:
        no_improve_count += 1
        print(f"   ⚠ No improvement for {no_improve_count} epoch(s)")
        
        if no_improve_count >= EARLY_STOP_PATIENCE:
            print(f"\n⏹ Early stopping! No improvement for {EARLY_STOP_PATIENCE} epochs")
            break
    
    print()

print(f"\n✅ Training complete!")
print(f"   Best validation accuracy: {best_val_acc:.3f}")
print(f"   Best model saved at: ../checkpoints/best_EBM_scorer.ckpt")

Epoch 0: Encoder FROZEN (training projection only)


Epoch 0 [MNR]:   0%|          | 0/9049 [00:00<?, ?it/s]

Epoch 0 [MNR]: 100%|██████████| 9049/9049 [10:53<00:00, 13.85it/s, loss=0.8143, lr=4.85e-05]
Validation: 100%|██████████| 503/503 [00:58<00:00,  8.60it/s]



 Epoch 0 Summary:
   Train Loss: 1.4934
   Val Accuracy: 0.826
   Val MRR: 0.913
   ✓ New best model saved! (acc=0.826)

Epoch 1: Encoder FROZEN (training projection only)


Epoch 1 [MNR]: 100%|██████████| 9049/9049 [17:54<00:00,  8.42it/s, loss=0.5602, lr=3.75e-05]  
Validation: 100%|██████████| 503/503 [01:16<00:00,  6.54it/s]



 Epoch 1 Summary:
   Train Loss: 1.0127
   Val Accuracy: 0.838
   Val MRR: 0.919
   ✓ New best model saved! (acc=0.838)

Epoch 2: Encoder UNFROZEN (full training)


Epoch 2 [MNR]: 100%|██████████| 9049/9049 [33:23<00:00,  4.52it/s, loss=0.3618, lr=2.07e-05]
Validation: 100%|██████████| 503/503 [01:28<00:00,  5.67it/s]



 Epoch 2 Summary:
   Train Loss: 0.5051
   Val Accuracy: 0.942
   Val MRR: 0.971
   ✓ New best model saved! (acc=0.942)

Epoch 3: Encoder UNFROZEN (full training)


Epoch 3 [MNR]: 100%|██████████| 9049/9049 [53:58<00:00,  2.79it/s, loss=0.1594, lr=5.85e-06]     
Validation: 100%|██████████| 503/503 [01:44<00:00,  4.80it/s]



 Epoch 3 Summary:
   Train Loss: 0.2245
   Val Accuracy: 0.950
   Val MRR: 0.975
   ✓ New best model saved! (acc=0.950)

Epoch 4: Encoder UNFROZEN (full training)


Epoch 4 [MNR]: 100%|██████████| 9049/9049 [35:19<00:00,  4.27it/s, loss=0.2397, lr=2.00e-10]
Validation: 100%|██████████| 503/503 [01:20<00:00,  6.23it/s]



 Epoch 4 Summary:
   Train Loss: 0.1465
   Val Accuracy: 0.951
   Val MRR: 0.976
   ✓ New best model saved! (acc=0.951)


 Training complete!
   Best validation accuracy: 0.951
   Model saved to: ../checkpoints/EBM_scorer.ckpt


In [39]:
# Train with MNR Loss (recommended!)
USE_MNR = True  # Set to False for original triplet loss

scaler = torch.amp.GradScaler('cuda')
n = 5  # epochs

for epoch in range(n):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch} {'[MNR]' if USE_MNR else '[Triplet]'}")
    
    for enc_q, enc_pos, enc_neg in pbar:
        enc_q = {k: v.to(device) for k, v in enc_q.items()}
        enc_pos = {k: v.to(device) for k, v in enc_pos.items()}
        enc_neg = {k: v.to(device) for k, v in enc_neg.items()}
        
        batch_size = enc_q['input_ids'].shape[0]
        
        with torch.amp.autocast('cuda'):
            emb_q = model(enc_q)       # [B, D]
            emb_pos = model(enc_pos)   # [B, D]
            emb_neg = model(enc_neg)   # [B, D]
            
            if USE_MNR:
                # MNR Loss: use in-batch negatives
                temperature = 0.05
                
                # In-batch negatives: all positives as negatives for other queries
                scores = torch.matmul(emb_q, emb_pos.T) / temperature  # [B, B]
                
                # Labels: diagonal elements are the positives
                labels = torch.arange(batch_size).to(device)
                
                # Cross-entropy: model should rank correct positive highest
                loss = F.cross_entropy(scores, labels)
                
                # Add explicit hard negatives
                neg_scores = F.cosine_similarity(emb_q, emb_neg) / temperature
                pos_scores = scores.diagonal()
                
                # Margin loss between positive and explicit negative
                margin_loss = F.relu(0.2 - (pos_scores - neg_scores)).mean()
                loss = loss + 0.5 * margin_loss
            else:
                # Original Triplet Loss
                pos_sim = F.cosine_similarity(emb_q, emb_pos)
                neg_sim = F.cosine_similarity(emb_q, emb_neg)
                loss = torch.clamp(0.5 - (pos_sim - neg_sim), min=0).mean()
        
        opt.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})
    
    # Comprehensive validation
    model.eval()
    metrics = {'correct': 0, 'total': 0, 'mrr': [], 'recall@5': []}
    
    with torch.no_grad():
        for enc_q, enc_pos, enc_neg in tqdm(val_loader, desc="Validation"):
            enc_q = {k: v.to(device) for k, v in enc_q.items()}
            enc_pos = {k: v.to(device) for k, v in enc_pos.items()}
            enc_neg = {k: v.to(device) for k, v in enc_neg.items()}
            
            with torch.amp.autocast('cuda'):
                emb_q = model(enc_q)
                emb_pos = model(enc_pos)
                emb_neg = model(enc_neg)
            
            batch_size = emb_q.shape[0]
            
            for i in range(batch_size):
                q_emb = emb_q[i:i+1]
                candidates = torch.cat([emb_pos[i:i+1], emb_neg[i:i+1]], dim=0)
                sims = F.cosine_similarity(q_emb, candidates, dim=1)
                
                # Rank: 1 if positive > negative, 2 otherwise
                rank = 2 - (sims[0] > sims[1]).long().item()
                
                if rank == 1:
                    metrics['correct'] += 1
                
                metrics['mrr'].append(1.0 / rank)
                metrics['recall@5'].append(1.0 if rank <= 5 else 0.0)
                metrics['total'] += 1
    
    # Print comprehensive metrics
    accuracy = metrics['correct'] / metrics['total'] if metrics['total'] > 0 else 0
    mrr = sum(metrics['mrr']) / len(metrics['mrr']) if metrics['mrr'] else 0
    recall5 = sum(metrics['recall@5']) / len(metrics['recall@5']) if metrics['recall@5'] else 0
    
    print(f"Epoch {epoch}: accuracy={accuracy:.3f}, MRR={mrr:.3f}, recall@5={recall5:.3f}")

Epoch 0 [MNR]:   0%|          | 44/9049 [00:10<35:27,  4.23it/s, loss=0.2191] 


KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), "EBM_scorer.ckpt")

In [40]:
# Ranking inference: compute energy scores
def rank_passages_with_energy(model, tokenizer, query, passages, device="cuda", max_len=128, batch_size=32):
    """Rank passages and compute energy scores.
    
    Energy = 1 - similarity (lower energy = better match, closer to 0)
    Higher similarity = passage more relevant to query.
    """
    model.eval()
    similarity_scores = []
    
    autocast_device = "cuda" if str(device).startswith("cuda") and torch.cuda.is_available() else "cpu"
    with torch.no_grad(), torch.amp.autocast(autocast_device):
        # Encode query separately
        enc_query = tokenizer(
            query,
            padding=True,
            truncation=True,
            max_length=max_len,
            return_tensors="pt"
        )
        enc_query = {k: v.to(device) for k, v in enc_query.items()}
        query_embedding = model(enc_query)  # [1, embedding_dim]
        
        # Score each passage
        for i in range(0, len(passages), batch_size):
            batch = passages[i : i + batch_size]
            enc = tokenizer(
                batch,  # Encode passages only
                padding=True,
                truncation=True,
                max_length=max_len,
                return_tensors="pt"
            )
            enc = {k: v.to(device) for k, v in enc.items()}
            passage_embeddings = model(enc)  # [batch_size, embedding_dim]
            
            # Cosine similarity between query and each passage
            sims = F.cosine_similarity(query_embedding, passage_embeddings)  # [batch_size]
            similarity_scores.extend(sims.detach().cpu().tolist())
    
    # Convert similarity to energy: energy = 1 - similarity
    # Lower energy (closer to 0) = better match
    energy_scores = [1 - sim for sim in similarity_scores]
    
    # Sort by energy (ascending = best first, lowest energy = highest confidence)
    ranked = sorted(zip(passages, similarity_scores, energy_scores), key=lambda x: x[2])
    return ranked


# Test on a random positive example
source_pos = positives_clean if 'positives_clean' in globals() else [(q, p) for (q, p) in positives if q and p]
q, gold = random.choice(source_pos)

# Find easy/hard negatives for the same question
easy_p = None
hard_p = None

if 'easy_negatives' in globals() and easy_negatives:
    same_q_easy = [p for (qq, p) in easy_negatives if qq == q and p]
    if same_q_easy:
        easy_p = random.choice(same_q_easy)
    else:
        pool = [p for (_, p) in easy_negatives if p and p != gold]
        easy_p = random.choice(pool) if pool else None

if 'hard_negatives' in globals() and hard_negatives:
    same_q_hard = [p for (qq, p) in hard_negatives if qq == q and p]
    if same_q_hard:
        hard_p = random.choice(same_q_hard)
    else:
        pool = [p for (_, p) in hard_negatives if p and p != gold]
        hard_p = random.choice(pool) if pool else None

# Build passages to rank
test_passages = [gold]
if easy_p: test_passages.append(easy_p)
if hard_p: test_passages.append(hard_p)

# Rank them with energy scores
ranked = rank_passages_with_energy(model, tokenizer, q, test_passages, device=device, max_len=128, batch_size=32)

print("="*70)
print("Query:")
print(q[:300])
print("\n" + "="*70)
print("Ranked passages (ascending energy = best match):")
print("="*70)

for i, (passage, similarity, energy) in enumerate(ranked, 1):
    snippet = (passage or "").replace("\n", " ")[:220]
    is_gold = " ✓ GOLD" if passage == gold else ""
    
    # Color-code energy: green if low (good), red if high (bad)
    confidence = "✓ HIGH" if energy < 0.3 else "⚠ MEDIUM" if energy < 0.5 else "✗ LOW"
    
    print(f"\n{i}. {confidence}{is_gold}")
    print(f"   Similarity: {similarity:.4f} | Energy: {energy:.4f}")
    print(f"   {snippet}{'...' if passage and len(passage) > 220 else ''}")

print("\n" + "="*70)
print("Energy Score Guide:")
print("  Energy ≈ 0.0-0.3  → ✓ HIGH confidence (good match)")
print("  Energy ≈ 0.3-0.5  → ⚠ MEDIUM confidence (acceptable)")
print("  Energy ≈ 0.5+     → ✗ LOW confidence (poor match)")
print("="*70)


Query:
Which of the following is not a usual feature of right middle cerebral aery territory infarct?

Ranked passages (ascending energy = best match):

1. ✓ HIGH ✓ GOLD
   Similarity: 0.7577 | Energy: 0.2423
   The coical branches of the MCA supply the lateral surface of the hemisphere except for (1) the frontal pole and a strip along the superomedial border of the frontal and parietal lobes supplied by the ACA, and (2) the low...

2. ✓ HIGH
   Similarity: 0.7155 | Energy: 0.2845
   Basilar aery (Ref: B.D. Chaurasia, 3rd Ed, Vol lll/Pg 300) Basilar aery & its branches: The basilar aery is formed by the union of the right & left veebral aeries, at the lower border of pons. It ascends in the midline, ...

3. ⚠ MEDIUM
   Similarity: 0.5047 | Energy: 0.4953
   The symptoms in this question are suggestive of nihilistic delusions (intestine are rotting away) and pathological guilt (belief that patient is responsible for death). Both these symptoms are usually seen in patients wi...

Energy 