In [55]:
import json
from pathlib import Path
from tqdm.auto import tqdm

import nltk
from nltk.tokenize import sent_tokenize
import re

from datasets import Dataset
from sentence_transformers import SentenceTransformer, InputExample, losses
from sentence_transformers.trainer import SentenceTransformerTrainer
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
from sklearn.metrics.pairwise import cosine_similarity
from accelerate import Accelerator
from rank_bm25 import BM25Okapi
import numpy as np


from data_manipulation import DataManipulator
from evaluation_metrices import Evaluator

In [None]:
BASE_DIR = Path('.')
DATA_ROOT = BASE_DIR / 'final_correct_datasets'
TRAIN_DIR = DATA_ROOT / 'training'
TEST_DIR = DATA_ROOT / 'test'
CANDIDATES_J = BASE_DIR / 'final_correct_datasets' / 'all_retrieved_articles_70k.json'

MODEL_NAME = 'all-MiniLM-L6-v2'
BATCH_SIZE = 16
EPOCHS = 1
WARMUP_STEPS = 100
TOP_K_ARTICLES = 10
TOP_K_SNIPPETS = 10

OUTPUT_COMBINED = BASE_DIR / 'predicted_phaseA_combined.json'

dm = DataManipulator()
all_docs = dm.get_all_articles(str(CANDIDATES_J))  
doc_lookup = {d['pid']: d for d in all_docs}

In [None]:
def tokenize(text: str):
    text = text.lower()
    text = re.sub(r'[^a-z0-9\s]', '', text)
    return text.split()

pid_list = []
doc_corpus = []
for art in all_docs:
    pid_list.append(art['pid'])
    tokens = tokenize(art['title'] + " " + art['abstract'])
    doc_corpus.append(tokens)


bm25 = BM25Okapi(doc_corpus)

In [63]:
train_questions = []
for fp in TRAIN_DIR.glob('*.json'):
    train_questions += json.loads(fp.read_text())['data']

test_batch_files = sorted(TEST_DIR.glob('*_test_batch_*.json'))

gt_train = {'data': dm.get_ground_truth_from_all_files(str(TRAIN_DIR))}
gt_test  = {'data': dm.get_ground_truth_from_all_files(str(TEST_DIR))} 

print(f"Train Qs: {len(train_questions)}, Test batches: {len(test_batch_files)}")

Train Qs: 5390, Test batches: 4


In [None]:
model = SentenceTransformer(MODEL_NAME)
doc_emb_cache = {}

In [None]:

train_pairs = []
for q in train_questions:
    q_text = q['question']
    for pid in q['ground_truth_documents_pid']:
        art = doc_lookup.get(pid)
        if art:
            train_pairs.append({
                'query': q_text,
                'doc': f"{art['title']}. {art['abstract']}"
            })


hf_train = Dataset.from_list(train_pairs)
hf_train = hf_train.map(lambda x: {'text_0': x['query'], 'text_1': x['doc']}, remove_columns=['query','doc'])


Map:   0%|          | 0/46339 [00:00<?, ? examples/s]

In [31]:
DEBUG = False

train_args = SentenceTransformerTrainingArguments(
    output_dir='sbert_finetuned',
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    warmup_steps=WARMUP_STEPS,
    learning_rate=2e-5,
    eval_strategy='no'
)

if DEBUG:
    N = min(100, len(hf_train))
    print(f"Debug mode: sampling first {N} training pairs")
    train_dataset_small = hf_train.select(range(N))
    
    debug_args = SentenceTransformerTrainingArguments(
        output_dir='sbert_debug',
        num_train_epochs=1,
        per_device_train_batch_size=min(8, BATCH_SIZE),
        warmup_steps=10,
        learning_rate=2e-5,
        eval_strategy='no'
    )
    train_dataset = train_dataset_small
else:
    debug_args    = train_args
    train_dataset = hf_train  


train_loss = losses.MultipleNegativesRankingLoss(model)

trainer = SentenceTransformerTrainer(
    model=model,
    args=train_args,
    train_dataset=train_dataset,
    loss=train_loss
)

trainer.train()

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.1596
1000,0.1318
1500,0.1224
2000,0.1053
2500,0.1119


TrainOutput(global_step=2897, training_loss=0.1222051871329042, metrics={'train_runtime': 6173.4553, 'train_samples_per_second': 7.506, 'train_steps_per_second': 0.469, 'total_flos': 0.0, 'train_loss': 0.1222051871329042, 'epoch': 1.0})

In [None]:
import re

def encode_query(text: str):
    return model.encode(text, convert_to_numpy=True)

def encode_document(pid: str):
    if pid in doc_emb_cache:
        return doc_emb_cache[pid]
    art = doc_lookup.get(pid, {'title':'','abstract':''})
    emb = model.encode(f"{art['title']}. {art['abstract']}", convert_to_numpy=True)
    doc_emb_cache[pid] = emb
    return emb


def extract_snippets(q_emb, abstract: str, pid: str):
    """
    Split the abstract into sentences, embed them in one batch, 
    rank by cosine similarity to q_emb, and return the top-K snippets.
    Falls back to a regex split if NLTK’s punkt isn’t available.
    """
    try:
        sentences = sent_tokenize(abstract)
    except LookupError:
        sentences = re.split(r'(?<=[\.!\?])\s+', abstract)

    if not sentences:
        return []

    sent_embs = model.encode(sentences, convert_to_numpy=True)
    sims = cosine_similarity([q_emb], sent_embs)[0]
    top_idxs = sims.argsort()[::-1][:TOP_K_SNIPPETS]

    snippets = []
    for i in top_idxs:
        txt = sentences[i]
        off = abstract.find(txt)
        snippets.append({
            'beginSection': 'abstract',
            'endSection': 'abstract',
            'text': txt,
            'document': pid,
            'offsetInBeginSection': off,
            'offsetInEndSection': off + len(txt)
        })
    return snippets


In [39]:
print(
    "Total preds:", len(train_preds),
    "| Unique qids:", len({p['id'] for p in train_preds}),
    "| Ground truth qids:", len(gt_train['data'])
)


print(train_preds)


Total preds: 15 | Unique qids: 15 | Ground truth qids: 5390
[{'id': '55031181e9bde69634000014', 'question': 'Is Hirschsprung disease a mendelian or a multifactorial disorder?', 'top_10_articles': [{'pid': 'http://www.ncbi.nlm.nih.gov/pubmed/13417627', 'title': "HIRSCHSPRUNG'S disease.", 'abstract': '', 'score': 0.7891457676887512}, {'pid': 'http://www.ncbi.nlm.nih.gov/pubmed/5934032', 'title': "The cone-shaped segment in the diagnosis of Hirschsprung's disease.", 'abstract': '', 'score': 0.7072323560714722}, {'pid': 'http://www.ncbi.nlm.nih.gov/pubmed/13002467', 'title': "The diagnosis and surgical treatment of Hirschsprung's disease.", 'abstract': '', 'score': 0.7060225009918213}, {'pid': 'http://www.ncbi.nlm.nih.gov/pubmed/23001136', 'title': "Chromosomal and related Mendelian syndromes associated with Hirschsprung's disease.", 'abstract': "Hirschsprung's disease (HSCR) is a fairly frequent cause of intestinal obstruction in children. It is characterized as a sex-linked heterogonous 

In [None]:
def rerank(q):
    q_emb = encode_query(q['question'])
    scored = []
    for art in q.get('all_retreived_articles', []):
        pid   = art['pid']
        d_emb = encode_document(pid)
        score = float(cosine_similarity([q_emb], [d_emb])[0][0])
        scored.append((score, art))
    scored.sort(key=lambda x: x[0], reverse=True)
    top = scored[:TOP_K_ARTICLES]

    top_articles = [{
        'pid': art['pid'],
        'title': art.get('title',''),
        'abstract': art.get('abstract',''),
        'score': score
    } for score, art in top]

    snippets = []
    for _, art in top:
        snippets += extract_snippets(q_emb, art.get('abstract',''), art['pid'])

    return {
        'id': q['qid'],
        'question': q['question'],
        'top_10_articles': top_articles,
        'snippets': snippets
    }

In [66]:
def bm25_rank(query: str, top_n: int = 100):
   
    q_tokens = query.lower().split()
    scores  = bm25.get_scores(q_tokens)
    top_idxs = np.argsort(scores)[::-1][:top_n]
    return [pid_list[i] for i in top_idxs if scores[i] > 0]

In [None]:


gt_test_all = dm.get_ground_truth_from_all_files(str(TEST_DIR))

for batch_fp in test_batch_files:
    raw   = json.loads(batch_fp.read_text())
    batch = raw.get('data', raw)
    print(f"\n=== Batch {batch_fp.name}: {len(batch)} questions ===")

    preds = []
    for q in tqdm(batch, desc="  reranking"):
        q_tokens = tokenize(q['question'])
        scores  = bm25.get_scores(q_tokens)
        top_idxs = np.argsort(scores)[::-1][:100]
        candidate_pids = [pid_list[i] for i in top_idxs if scores[i] > 0]

        q_rec = {
            'qid': q['qid'],
            'question': q['question'],
            'all_retreived_articles': [
                doc_lookup[pid] for pid in candidate_pids if pid in doc_lookup
            ]
        }
        preds.append(rerank(q_rec))

    batch_qids = {q['qid'] for q in batch}
    gt_entries = [e for e in gt_test_all if e['qid'] in batch_qids]
    gt_batch   = {'data': gt_entries}


    evaluator = Evaluator(ground_truth_data=gt_batch, predicted_data={'data': preds})
    art_m     = evaluator.evaluate_metrics_for_articles()
    snip_m    = evaluator.evaluate_metrics_for_snippets()

    print("  • Article‐level metrics:")
    evaluator.print_results(art_m)
    print("  • Snippet‐level metrics:")
    evaluator.print_results(snip_m)

# with open(OUTPUT_COMBINED, 'w') as f:
#     json.dump(combined, f, indent=2)
# print(f"Combined output saved to {OUTPUT_COMBINED}, total entries: {len(combined['data'])}")


=== Batch parsed_data_final_test_batch_1.json: 85 questions ===


  reranking:   0%|          | 0/85 [00:00<?, ?it/s]

Processing questions...: 85it [00:00, ?it/s]
Evaluating snippets...: 85it [00:00, 542.67it/s]


  • Article‐level metrics:
MRR: 70.25
MAP: 59.31
nDCG@10: 66.21
P_article: 16.73
R_article: 73.33
F1_article: 25.75
GMAP: 10.67
  • Snippet‐level metrics:
P_snip: 0.63
R_snip: 14.9
F1_snip: 1.18
MAP_snip: 3.29
GMAP_snip: 0.0

=== Batch parsed_data_final_test_batch_2.json: 85 questions ===


  reranking:   0%|          | 0/85 [00:00<?, ?it/s]

Processing questions...: 85it [00:00, 33127.28it/s]
Evaluating snippets...: 85it [00:00, 478.29it/s]


  • Article‐level metrics:
MRR: 64.59
MAP: 52.3
nDCG@10: 59.56
P_article: 16.03
R_article: 66.14
F1_article: 24.22
GMAP: 5.56
  • Snippet‐level metrics:
P_snip: 0.47
R_snip: 10.45
F1_snip: 0.9
MAP_snip: 2.56
GMAP_snip: 0.0

=== Batch parsed_data_final_test_batch_3.json: 85 questions ===


  reranking:   0%|          | 0/85 [00:00<?, ?it/s]

Processing questions...: 85it [00:00, 40139.14it/s]
Evaluating snippets...: 85it [00:00, 498.69it/s]


  • Article‐level metrics:
MRR: 68.32
MAP: 54.3
nDCG@10: 62.22
P_article: 18.71
R_article: 68.85
F1_article: 27.61
GMAP: 10.09
  • Snippet‐level metrics:
P_snip: 0.78
R_snip: 16.29
F1_snip: 1.47
MAP_snip: 4.8
GMAP_snip: 0.01

=== Batch parsed_data_final_test_batch_4.json: 85 questions ===


  reranking:   0%|          | 0/85 [00:00<?, ?it/s]

Processing questions...: 85it [00:00, 43998.01it/s]
Evaluating snippets...: 85it [00:00, 492.15it/s]

  • Article‐level metrics:
MRR: 63.62
MAP: 46.6
nDCG@10: 55.86
P_article: 19.45
R_article: 63.58
F1_article: 28.32
GMAP: 6.36
  • Snippet‐level metrics:
P_snip: 0.6
R_snip: 13.54
F1_snip: 1.14
MAP_snip: 3.43
GMAP_snip: 0.0





In [69]:
train_preds = []
for q in tqdm(train_questions, desc="Train Inference"):
    train_preds.append(rerank(q))


evaluator_train = Evaluator(
    ground_truth_data=gt_train,
    predicted_data   ={'data': train_preds}
)

art_m_train  = evaluator_train.evaluate_metrics_for_articles()
snip_m_train = evaluator_train.evaluate_metrics_for_snippets()

print("\n=== Training Article Retrieval Metrics ===")
evaluator_train.print_results(art_m_train)

print("\n=== Training Snippet Extraction Metrics ===")
evaluator_train.print_results(snip_m_train)

Train Inference:   0%|          | 0/5390 [00:00<?, ?it/s]

Processing questions...: 5390it [00:00, 22941.21it/s]
Evaluating snippets...: 5390it [00:06, 802.12it/s]


=== Training Article Retrieval Metrics ===
MRR: 94.38
MAP: 74.9
nDCG@10: 90.77
P_article: 54.61
R_article: 77.5
F1_article: 52.09
GMAP: 44.61

=== Training Snippet Extraction Metrics ===
P_snip: 4.32
R_snip: 31.7
F1_snip: 6.92
MAP_snip: 13.37
GMAP_snip: 0.81





dict_keys(['pid', 'title', 'abstract'])
dict_keys(['pid', 'title', 'abstract'])
dict_keys(['pid', 'title', 'abstract'])
dict_keys(['pid', 'title', 'abstract'])
dict_keys(['pid', 'title', 'abstract'])
dict_keys(['pid', 'title', 'abstract'])
dict_keys(['pid', 'title', 'abstract'])
dict_keys(['pid', 'title', 'abstract'])
dict_keys(['pid', 'title', 'abstract'])
dict_keys(['pid', 'title', 'abstract'])
5390
