# Initial API retreival.
### Note: THIS CAN BE USED ONLY FOR TRAINING AND BATCH_1 TEST BECAUSE THEY CONTAIN GOLDEN ANSWERS AND THEREFORE THE STRUCTURE IS DIFFERENT.
### FOR BATCH_2, BATCH_3, BATCH_4 USE THE TEST_QUESTION_API_RETREVAL.IPYNB

In [None]:
from Bio import Entrez
from keybert import KeyBERT
import json
import warnings
warnings.filterwarnings("ignore")
from tqdm import tqdm
import xml.etree.ElementTree as ET
import re
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS
import random
import time
from datetime import datetime

# Load spaCy model only once
'''nlp = spacy.load("en_core_web_sm")

INTENT_TERMS = {
    "synonym", "definition", "define", "meaning",
    "cause", "treatment", "therapy",
    "effect", "role", "impact", "mechanism",
    "list"
}

GENERIC_MEDICAL_TERMS = {
    "disease", "disorder", "condition", "syndrome", "illness", "diagnosis"
}'''



data_file_path = "../datasets/test/batch_1/BioASQ-task13bPhaseA-testset1"
base_pid_url = 'http://www.ncbi.nlm.nih.gov/pubmed/' 
Entrez.email = "kasapovic.m@hotmail.com"  # validan email 

In [None]:
def load_question_from_file(file_path):
    """
    Loads the BioASQ-style JSON and extracts the list of questions.
    
    Parameters:
        file_path (str): Path to the JSON file.
    
    Returns:
        List[dict]: A list of question dictionaries.
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        questions = data.get("questions", [])
        print(f"Loaded {len(questions)} questions from {file_path}")
        return questions
    except Exception as e:
        print(f"Failed to load questions: {e}")
        return []

In [19]:
questions = load_question_from_file(file_path=data_file_path)
questions[0]

✅ Loaded 85 questions from ../datasets/test/batch_2/BioASQ-task13bPhaseA-testset2


{'id': '67dedd6818b1e36f2e000061',
 'type': 'factoid',
 'body': 'Which ensemble machine-learning framework has been developed harnessing UK biobank data?'}

In [20]:
def parse_question(questions):
    """
    Parses the questions to extract relevant information.
    
    Parameters:
        questions (List[dict]): A list of question dictionaries.
    
    Returns:
        List[dict]: A list of parsed question dictionaries.
    """
    parsed_questions = []
    for question in questions:
        parsed_questions.append({
            "id": question["id"],
            "body": question["body"],
            "documents": [doc.replace(base_pid_url, '') for doc in question["documents"]],
            "snippets": question["snippets"],
        })
    return parsed_questions




In [21]:
def get_questions_text(questions):
    """
    Extracts the text of the questions from the parsed questions.
    
    Parameters:
        parsed_questions (List[dict]): A list of parsed question dictionaries.
    
    Returns:
        List[str]: A list of question texts.
    """
    return [{'id': question['id'],
              'body': question["body"]} for question in questions]

def get_ground_truth_documents(questions):
    """
    Extracts the ground truth from the parsed questions.
    
    Parameters:
        parsed_questions (List[dict]): A list of parsed question dictionaries.
    
    Returns:
        List[str]: A list of ground truth texts.
    """
    ground_truth = {}
    for question in questions:
        ground_truth[question['id']] = set(question["documents"])
    return ground_truth

In [22]:
parsed_questions = parse_question(questions=questions)
#parsed_questions[0:3]

KeyError: 'documents'

In [None]:
ground_truth_documents = get_ground_truth_documents(questions=parsed_questions)
#ground_truth_documents

In [None]:
questions_texts = get_questions_text(questions=parsed_questions)

In [None]:
# compare the length of the questions and the ground truth documents
print(f"Number of questions: {len(questions_texts)}")
print(f"Number of ground truth documents: {len(ground_truth_documents)}")

Number of questions: 85
Number of ground truth documents: 85


In [None]:
def save_to_json(data, file_path):
    """
    Saves the data to a JSON file.
    
    Parameters:
        data (any): The data to save.
        file_path (str): Path to the JSON file.
    """
    try:
        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False, indent=4)
        print(f"Saved data to {file_path}")
    except Exception as e:
        print(f"Failed to save data: {e}")

In [None]:
save_to_json(data=parsed_questions, file_path="parsed_questions.json")

✅ Saved data to parsed_questions.json


In [None]:
def parse_pubmed_abstracts_from_xml(xml_string):
    """
    Parse PubMed XML and extract PMID, title, and abstract.

    Args:
        xml_string (str): XML data as string from Entrez.efetch(..., retmode="xml")

    Returns:
        List[Dict]: List of articles with keys: 'pid', 'title', 'abstract'
    """
    root = ET.fromstring(xml_string)
    articles = []

    for article in root.findall(".//PubmedArticle"):
        pid = article.findtext(".//PMID")
        title = article.findtext(".//ArticleTitle")

        # Handle multiple AbstractText parts (can have labels, e.g. "BACKGROUND", "METHODS", etc.)
        abstract_parts = article.findall(".//Abstract/AbstractText")
        abstract = " ".join([part.text for part in abstract_parts if part.text])

        articles.append({
            "pid": pid,
            "title": title,
            "abstract": abstract
        })

    return articles

In [None]:
# Generate semantic queries
def generate_queries_BERT(question, model, keyphrase_ngram_range=(1, 3), top_n=7):
    keywords = model.extract_keywords(question, keyphrase_ngram_range=keyphrase_ngram_range, stop_words='english', top_n=top_n)
    return [kw for kw, _ in keywords]

'''
def generate_queries_BERT(question, model=KeyBERT('all-MiniLM-L6-v2'), keyphrase_ngram_range=(1, 3), top_n=7):
    keywords = model.extract_keywords(question, keyphrase_ngram_range=keyphrase_ngram_range, stop_words='english', top_n=top_n)

    cleaned_keywords = []

    for phrase, _ in keywords:
        doc = nlp(phrase)

        # Remove verbs, auxiliaries, stopwords, and short words
        filtered_tokens = [
            token.lemma_.lower()
            for token in doc
            if token.pos_ not in {"VERB", "AUX"}
            and token.lemma_.lower() not in ENGLISH_STOP_WORDS
            and len(token.lemma_) > 2
        ]

        # Reconstruct cleaned phrase
        cleaned_phrase = " ".join(filtered_tokens).strip()

        if cleaned_phrase:  # Only keep non-empty results
            cleaned_keywords.append(cleaned_phrase)

    return list(sorted(set(cleaned_keywords)))  # deduplicate + sort
'''

# Search PubMed for each query
def search_pubmed(query, retmax=500):
    handle = Entrez.esearch(db="pubmed", term=query, retmax=retmax)
    record = Entrez.read(handle)
    handle.close()
    return record["IdList"]

# Fetch details in JSON and decode
def fetch_details(id_list):
    ids = ",".join(id_list)
    handle = Entrez.efetch(db="pubmed", id=ids, rettype="medline", retmode="xml")
    xml_data = handle.read()  
    handle.close()
    records = parse_pubmed_abstracts_from_xml(xml_data)
    return records

In [None]:
questions = get_questions_text(questions=parsed_questions)
for i in range(len(questions)):
    if questions[i]['id'] == "516c0ebc298dcd4e5100006e":
        print(i)
       

In [None]:
import random
import time
import json
import os
import shutil
from datetime import datetime

# 🔧 Parameters
BATCH_SIZE = 5
RESULTS_FILE = "retrieved_articles_sampled.json"
CHECKPOINT_FILE = "checkpoint.json"

# 📌 Load checkpoint if it exists
if os.path.exists(CHECKPOINT_FILE):
    with open(CHECKPOINT_FILE, "r") as f:
        checkpoint = json.load(f)
        start_index = checkpoint.get("last_index", 0) + 1
else:
    start_index = 0

# 🚀 Initialize
questions = get_questions_text(questions=parsed_questions)
kw_model = KeyBERT(model='all-MiniLM-L6-v2')
results = []

# 📝 Resume from saved results file if available
if os.path.exists(RESULTS_FILE):
    try:
        with open(RESULTS_FILE, "r") as f:
            results = json.load(f)["results"]
    except json.JSONDecodeError:
        print("⚡ Warning: Results file is corrupted. Starting with empty results.")
        results = []

# 🔁 Process from start_index
for i in range(start_index, len(questions)):
    question = questions[i]
    qid = question['id']
    print(f"❓[{i+1}/{len(questions)}] Question: {question['body']}")
    
    queries = generate_queries_BERT(question['body'], keyphrase_ngram_range=(1, 10), top_n=20, model=kw_model)
    print(f"🧠 Generated Queries: {queries}")
    
    all_ids = set()
    for q in queries:
        try:
            ids = search_pubmed(q, retmax=20000)
            all_ids.update(ids)
            time.sleep(0.7)
        except Exception as e:
            print(f"⚠️ Failed to search PubMed with query '{q}': {e}")
    
    print(f"✅ Found {len(all_ids)} unique PMIDs for this question.")
    
    ground_truth = ground_truth_documents[qid]
    found_gt = ground_truth.intersection(all_ids)
    print(f"🎯 Found {len(found_gt)} out of {len(ground_truth)} ground truth PMIDs.")
    
    non_gt = list(all_ids - found_gt)
    sampled_random = random.sample(non_gt, min(150, len(non_gt)))
    
    selected_ids = list(found_gt) + sampled_random
    print(f"📦 Fetching details for {len(selected_ids)} PMIDs (GT + Random Sample)")

    try:
        details = fetch_details(selected_ids)
    except Exception as e:
        print(f"❌ Error fetching details: {e}")
        details = []

    results.append({
        "qid": qid,
        "question": question['body'],
        "queries": queries,
        "ground_truth": list(ground_truth),
        "ground_truth_total": len(ground_truth),
        "ground_truth_found": list(found_gt),
        "ground_truth_found_count": len(found_gt),
        "random_sampled_count": len(sampled_random),
        "random_sampled_pmids": sampled_random,
        "selected_pmids": selected_ids,
        "error_rate": {
            "value": len(found_gt) / len(ground_truth) if len(ground_truth) > 0 else 0,
            "details": f"{len(found_gt)} found out of {len(ground_truth)}"
        },
        "details": details
    })

    print(f"✅ Saved result for question \"{question['body']}\"\n")

    # 💾 Save batch and update checkpoint
    if (i + 1) % BATCH_SIZE == 0 or i == len(questions) - 1:
        print(f"📁 Saving batch at question {i+1}...")

        # Save to a temporary file first
        temp_file = RESULTS_FILE + ".tmp"
        with open(temp_file, "w") as f:
            json.dump({"results": results}, f, indent=2)

        # Move temp file to final results file (atomic save)
        shutil.move(temp_file, RESULTS_FILE)

        # Save checkpoint
        with open(CHECKPOINT_FILE, "w") as f:
            json.dump({"last_index": i}, f)

        print(f"✅ Batch saved and checkpoint updated at index {i}\n")


❓[1/85] Question: Describe RankMHC
🧠 Generated Queries: ['rankmhc']
✅ Found 1 unique PMIDs for this question.
🎯 Found 1 out of 1 ground truth PMIDs.
📦 Fetching details for 1 PMIDs (GT + Random Sample)
✅ Saved result for question "Describe RankMHC"

❓[2/85] Question: What proportion of colorectal cancer cases are not assignable to any of the consensus molecular subtype (CMS) groups?
🧠 Generated Queries: ['proportion colorectal cancer cases assignable consensus molecular subtype cms groups', 'proportion colorectal cancer cases assignable consensus molecular subtype cms', 'proportion colorectal cancer cases assignable consensus molecular subtype', 'colorectal cancer cases assignable consensus molecular subtype cms groups', 'colorectal cancer cases assignable consensus molecular subtype cms', 'colorectal cancer cases assignable consensus molecular subtype', 'proportion colorectal cancer cases assignable consensus molecular', 'proportion colorectal cancer cases assignable consensus', 'cance