In [1]:
from Bio import Entrez
from keybert import KeyBERT
import time
import json
import os
import warnings
warnings.filterwarnings("ignore")
from tqdm import tqdm
import xml.etree.ElementTree as ET
import spacy
import re
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS
import random
import time
from datetime import datetime

# Load spaCy model only once
'''nlp = spacy.load("en_core_web_sm")

INTENT_TERMS = {
    "synonym", "definition", "define", "meaning",
    "cause", "treatment", "therapy",
    "effect", "role", "impact", "mechanism",
    "list"
}

GENERIC_MEDICAL_TERMS = {
    "disease", "disorder", "condition", "syndrome", "illness", "diagnosis"
}'''



training_data_file_path = "training13b.json"
base_pid_url = 'http://www.ncbi.nlm.nih.gov/pubmed/' 
Entrez.email = "kasapovic.m@hotmail.com"  # validan email 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_question_from_file(file_path):
    """
    Loads the BioASQ-style JSON and extracts the list of questions.
    
    Parameters:
        file_path (str): Path to the JSON file.
    
    Returns:
        List[dict]: A list of question dictionaries.
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        questions = data.get("questions", [])
        print(f"‚úÖ Loaded {len(questions)} questions from {file_path}")
        return questions
    except Exception as e:
        print(f"‚ùå Failed to load questions: {e}")
        return []

In [3]:
questions = load_question_from_file(file_path=training_data_file_path)
questions[0]

‚úÖ Loaded 5389 questions from training13b.json


{'body': 'Is Hirschsprung disease a mendelian or a multifactorial disorder?',
 'documents': ['http://www.ncbi.nlm.nih.gov/pubmed/15858239',
  'http://www.ncbi.nlm.nih.gov/pubmed/20598273',
  'http://www.ncbi.nlm.nih.gov/pubmed/6650562',
  'http://www.ncbi.nlm.nih.gov/pubmed/12239580',
  'http://www.ncbi.nlm.nih.gov/pubmed/21995290',
  'http://www.ncbi.nlm.nih.gov/pubmed/23001136',
  'http://www.ncbi.nlm.nih.gov/pubmed/15617541',
  'http://www.ncbi.nlm.nih.gov/pubmed/8896569',
  'http://www.ncbi.nlm.nih.gov/pubmed/15829955'],
 'ideal_answer': ["Coding sequence mutations in RET, GDNF, EDNRB, EDN3, and SOX10 are involved in the development of Hirschsprung disease. The majority of these genes was shown to be related to Mendelian syndromic forms of Hirschsprung's disease, whereas the non-Mendelian inheritance of sporadic non-syndromic Hirschsprung disease proved to be complex; involvement of multiple loci was demonstrated in a multiplicative model."],
 'concepts': ['http://www.disease-ontol

In [4]:
def parse_question(questions):
    """
    Parses the questions to extract relevant information.
    
    Parameters:
        questions (List[dict]): A list of question dictionaries.
    
    Returns:
        List[dict]: A list of parsed question dictionaries.
    """
    parsed_questions = []
    for question in questions:
        parsed_questions.append({
            "id": question["id"],
            "body": question["body"],
            "documents": [doc.replace(base_pid_url, '') for doc in question["documents"]],
            "snippets": question["snippets"],
        })
    return parsed_questions




In [5]:
def get_questions_text(questions):
    """
    Extracts the text of the questions from the parsed questions.
    
    Parameters:
        parsed_questions (List[dict]): A list of parsed question dictionaries.
    
    Returns:
        List[str]: A list of question texts.
    """
    return [{'id': question['id'],
              'body': question["body"]} for question in questions]

def get_ground_truth_documents(questions):
    """
    Extracts the ground truth from the parsed questions.
    
    Parameters:
        parsed_questions (List[dict]): A list of parsed question dictionaries.
    
    Returns:
        List[str]: A list of ground truth texts.
    """
    ground_truth = {}
    for question in questions:
        ground_truth[question['id']] = set(question["documents"])
    return ground_truth

In [6]:
parsed_questions = parse_question(questions=questions)
#parsed_questions[0:3]

In [7]:
ground_truth_documents = get_ground_truth_documents(questions=parsed_questions)
#ground_truth_documents

In [8]:
questions_texts = get_questions_text(questions=parsed_questions)

In [9]:
# compare the length of the questions and the ground truth documents
print(f"Number of questions: {len(questions_texts)}")
print(f"Number of ground truth documents: {len(ground_truth_documents)}")

Number of questions: 5389
Number of ground truth documents: 5389


In [10]:
def save_to_json(data, file_path):
    """
    Saves the data to a JSON file.
    
    Parameters:
        data (any): The data to save.
        file_path (str): Path to the JSON file.
    """
    try:
        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False, indent=4)
        print(f"‚úÖ Saved data to {file_path}")
    except Exception as e:
        print(f"‚ùå Failed to save data: {e}")

In [11]:
save_to_json(data=parsed_questions, file_path="parsed_questions.json")

‚úÖ Saved data to parsed_questions.json


In [12]:
def parse_pubmed_abstracts_from_xml(xml_string):
    """
    Parse PubMed XML and extract PMID, title, and abstract.

    Args:
        xml_string (str): XML data as string from Entrez.efetch(..., retmode="xml")

    Returns:
        List[Dict]: List of articles with keys: 'pid', 'title', 'abstract'
    """
    root = ET.fromstring(xml_string)
    articles = []

    for article in root.findall(".//PubmedArticle"):
        pid = article.findtext(".//PMID")
        title = article.findtext(".//ArticleTitle")

        # Handle multiple AbstractText parts (can have labels, e.g. "BACKGROUND", "METHODS", etc.)
        abstract_parts = article.findall(".//Abstract/AbstractText")
        abstract = " ".join([part.text for part in abstract_parts if part.text])

        articles.append({
            "pid": pid,
            "title": title,
            "abstract": abstract
        })

    return articles

In [13]:
# ‚öôÔ∏è Generate semantic queries
def generate_queries_BERT(question, model, keyphrase_ngram_range=(1, 3), top_n=7):
    keywords = model.extract_keywords(question, keyphrase_ngram_range=keyphrase_ngram_range, stop_words='english', top_n=top_n)
    return [kw for kw, _ in keywords]

'''
def generate_queries_BERT(question, model=KeyBERT('all-MiniLM-L6-v2'), keyphrase_ngram_range=(1, 3), top_n=7):
    keywords = model.extract_keywords(question, keyphrase_ngram_range=keyphrase_ngram_range, stop_words='english', top_n=top_n)

    cleaned_keywords = []

    for phrase, _ in keywords:
        doc = nlp(phrase)

        # Remove verbs, auxiliaries, stopwords, and short words
        filtered_tokens = [
            token.lemma_.lower()
            for token in doc
            if token.pos_ not in {"VERB", "AUX"}
            and token.lemma_.lower() not in ENGLISH_STOP_WORDS
            and len(token.lemma_) > 2
        ]

        # Reconstruct cleaned phrase
        cleaned_phrase = " ".join(filtered_tokens).strip()

        if cleaned_phrase:  # Only keep non-empty results
            cleaned_keywords.append(cleaned_phrase)

    return list(sorted(set(cleaned_keywords)))  # deduplicate + sort
'''

# üîé Search PubMed for each query
def search_pubmed(query, retmax=500):
    handle = Entrez.esearch(db="pubmed", term=query, retmax=retmax)
    record = Entrez.read(handle)
    handle.close()
    return record["IdList"]

# üì• Fetch details in JSON and decode
def fetch_details(id_list):
    ids = ",".join(id_list)
    handle = Entrez.efetch(db="pubmed", id=ids, rettype="medline", retmode="xml")
    xml_data = handle.read()  
    handle.close()
    records = parse_pubmed_abstracts_from_xml(xml_data)
    return records

In [14]:
'''questions = get_questions_text(questions=parsed_questions)[0:3]
kw_model = KeyBERT(model='all-MiniLM-L6-v2')

all_results = {}
results = []

for question in questions:
    print(f"‚ùì Question: {question['body']}")
    
    # === KeyBERT Phase ===
    queries_bert = generate_queries_BERT(question['body'], model=kw_model)
    print(f"üß† [KeyBERT] Queries: {queries_bert}")
    
    all_ids_bert = set()
    for q in queries_bert:
        ids = search_pubmed(q, retmax=200)
        all_ids_bert.update(ids)
        time.sleep(0.7)

    print(f"‚úÖ [KeyBERT] Found {len(all_ids_bert)} PMIDs.")
    ground_truth = ground_truth_documents[question['id']]
    found_bert = ground_truth.intersection(all_ids_bert)
    print(f"üéØ [KeyBERT] Found {len(found_bert)} / {len(ground_truth)} ground truth PMIDs.")
    
    # === Fallback Phase: Smart Keyword Extraction ===
    all_ids_final = all_ids_bert.copy()
    used_fallback = False
    queries_simple = []
    newly_found_ground_truth = set()

    if len(found_bert) < len(ground_truth) / 2:
        used_fallback = True
        query_simple = generate_queries_SIMPLE(question['body'])
        print(f"üîÑ [Fallback] SIMPLE Queries: {query_simple}")
        queries_simple = [query_simple]
        print(f"‚ö†Ô∏è [Fallback] Using SIMPLE extraction: {query_simple}")
        
        fallback_ids = search_pubmed(query_simple, retmax=200)
        print(f"‚úÖ [Fallback] Found {len(fallback_ids)} PMIDs.")
        time.sleep(0.7)
        found_fallback = ground_truth.intersection(fallback_ids)

        print(f"üéØ [Fallback] Found {len(found_fallback)} / {len(ground_truth)} ground truth PMIDs.")

        # ‚¨áÔ∏è Keep only NEW ground truth hits from fallback
        newly_found_ground_truth = found_fallback - found_bert
        if newly_found_ground_truth:
            print(f"üìå Adding {len(newly_found_ground_truth)} NEW ground truth PMIDs from fallback.")
            all_ids_final.update(newly_found_ground_truth)


    # === Fetch Details ===
    details = fetch_details(list(all_ids_final))

    # === Save Results per Question ===
    results.append({
        "qid": question['id'],
        "question": question['body'],
        "queries_bert": queries_bert,
        "queries_simple": queries_simple,
        "all_found_pids": list(all_ids_final),
        "number_of_found_pids": len(all_ids_final),
        "ground_truth_pids": list(ground_truth),
        "found_ground_truth": f"{len(ground_truth.intersection(all_ids_final))}/{len(ground_truth)}",
        "used_fallback": used_fallback,
        "details": details
    })

    print(f"‚úÖ Results for question '{question['body']}' saved.\n")

# === Final Save ===
all_results = {
    "results": results,
}

save_to_json(data=all_results, file_path="retreived_articles.json")
print("üíæ All results saved to 'retreived_articles.json'")'''


'questions = get_questions_text(questions=parsed_questions)[0:3]\nkw_model = KeyBERT(model=\'all-MiniLM-L6-v2\')\n\nall_results = {}\nresults = []\n\nfor question in questions:\n    print(f"‚ùì Question: {question[\'body\']}")\n    \n    # === KeyBERT Phase ===\n    queries_bert = generate_queries_BERT(question[\'body\'], model=kw_model)\n    print(f"üß† [KeyBERT] Queries: {queries_bert}")\n    \n    all_ids_bert = set()\n    for q in queries_bert:\n        ids = search_pubmed(q, retmax=200)\n        all_ids_bert.update(ids)\n        time.sleep(0.7)\n\n    print(f"‚úÖ [KeyBERT] Found {len(all_ids_bert)} PMIDs.")\n    ground_truth = ground_truth_documents[question[\'id\']]\n    found_bert = ground_truth.intersection(all_ids_bert)\n    print(f"üéØ [KeyBERT] Found {len(found_bert)} / {len(ground_truth)} ground truth PMIDs.")\n    \n    # === Fallback Phase: Smart Keyword Extraction ===\n    all_ids_final = all_ids_bert.copy()\n    used_fallback = False\n    queries_simple = []\n    newl

In [19]:
questions = get_questions_text(questions=parsed_questions)
for i in range(len(questions)):
    if questions[i]['id'] == "516c0ebc298dcd4e5100006e":
        print(i)
       

808


In [23]:
import random
import time
import json
import os
import shutil
from datetime import datetime

# üîß Parameters
BATCH_SIZE = 5
RESULTS_FILE = "retrieved_articles_sampled.json"
CHECKPOINT_FILE = "checkpoint.json"

# üìå Load checkpoint if it exists
if os.path.exists(CHECKPOINT_FILE):
    with open(CHECKPOINT_FILE, "r") as f:
        checkpoint = json.load(f)
        start_index = checkpoint.get("last_index", 0) + 1
else:
    start_index = 0

# üöÄ Initialize
questions = get_questions_text(questions=parsed_questions)
kw_model = KeyBERT(model='all-MiniLM-L6-v2')
results = []

# üìù Resume from saved results file if available
if os.path.exists(RESULTS_FILE):
    try:
        with open(RESULTS_FILE, "r") as f:
            results = json.load(f)["results"]
    except json.JSONDecodeError:
        print("‚ö° Warning: Results file is corrupted. Starting with empty results.")
        results = []

# üîÅ Process from start_index
for i in range(start_index, len(questions)):
    question = questions[i]
    qid = question['id']
    print(f"‚ùì[{i+1}/{len(questions)}] Question: {question['body']}")
    
    queries = generate_queries_BERT(question['body'], keyphrase_ngram_range=(1, 10), top_n=20, model=kw_model)
    print(f"üß† Generated Queries: {queries}")
    
    all_ids = set()
    for q in queries:
        try:
            ids = search_pubmed(q, retmax=20000)
            all_ids.update(ids)
            time.sleep(0.7)
        except Exception as e:
            print(f"‚ö†Ô∏è Failed to search PubMed with query '{q}': {e}")
    
    print(f"‚úÖ Found {len(all_ids)} unique PMIDs for this question.")
    
    ground_truth = ground_truth_documents[qid]
    found_gt = ground_truth.intersection(all_ids)
    print(f"üéØ Found {len(found_gt)} out of {len(ground_truth)} ground truth PMIDs.")
    
    non_gt = list(all_ids - found_gt)
    sampled_random = random.sample(non_gt, min(150, len(non_gt)))
    
    selected_ids = list(found_gt) + sampled_random
    print(f"üì¶ Fetching details for {len(selected_ids)} PMIDs (GT + Random Sample)")

    try:
        details = fetch_details(selected_ids)
    except Exception as e:
        print(f"‚ùå Error fetching details: {e}")
        details = []

    results.append({
        "qid": qid,
        "question": question['body'],
        "queries": queries,
        "ground_truth": list(ground_truth),
        "ground_truth_total": len(ground_truth),
        "ground_truth_found": list(found_gt),
        "ground_truth_found_count": len(found_gt),
        "random_sampled_count": len(sampled_random),
        "random_sampled_pmids": sampled_random,
        "selected_pmids": selected_ids,
        "error_rate": {
            "value": len(found_gt) / len(ground_truth) if len(ground_truth) > 0 else 0,
            "details": f"{len(found_gt)} found out of {len(ground_truth)}"
        },
        "details": details
    })

    print(f"‚úÖ Saved result for question \"{question['body']}\"\n")

    # üíæ Save batch and update checkpoint
    if (i + 1) % BATCH_SIZE == 0 or i == len(questions) - 1:
        print(f"üìÅ Saving batch at question {i+1}...")

        # Save to a temporary file first
        temp_file = RESULTS_FILE + ".tmp"
        with open(temp_file, "w") as f:
            json.dump({"results": results}, f, indent=2)

        # Move temp file to final results file (atomic save)
        shutil.move(temp_file, RESULTS_FILE)

        # Save checkpoint
        with open(CHECKPOINT_FILE, "w") as f:
            json.dump({"last_index": i}, f)

        print(f"‚úÖ Batch saved and checkpoint updated at index {i}\n")


‚ùì[4026/5389] Question: Is Lanabecestat effective for Alzheimer's disease?
üß† Generated Queries: ['lanabecestat effective alzheimer', 'lanabecestat effective alzheimer disease', 'effective alzheimer disease', 'effective alzheimer', 'alzheimer disease', 'lanabecestat effective', 'alzheimer', 'lanabecestat', 'disease', 'effective']
‚úÖ Found 33601 unique PMIDs for this question.
üéØ Found 2 out of 2 ground truth PMIDs.
üì¶ Fetching details for 152 PMIDs (GT + Random Sample)
‚úÖ Saved result for question "Is Lanabecestat effective for Alzheimer's disease?"

‚ùì[4027/5389] Question: How is the STING protein activated?
üß† Generated Queries: ['sting protein activated', 'sting protein', 'protein activated', 'sting', 'protein', 'activated']
‚úÖ Found 38035 unique PMIDs for this question.
üéØ Found 5 out of 5 ground truth PMIDs.
üì¶ Fetching details for 155 PMIDs (GT + Random Sample)
‚úÖ Saved result for question "How is the STING protein activated?"

‚ùì[4028/5389] Question: Explain t