In [10]:
import os
import shutil
import tempfile
from tqdm import tqdm
import json
import re
from collections import defaultdict
from rank_bm25 import BM25Okapi


In [11]:
# load json file
def load_data(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    return data['data']

# training file path
parsed_training_data = '../dataset/parsed_data_final.json'

data = load_data(file_path=parsed_training_data)

data[0]

{'qid': '55031181e9bde69634000014',
 'question': 'Is Hirschsprung disease a mendelian or a multifactorial disorder?',
 'ground_truth_documents_pid': ['http://www.ncbi.nlm.nih.gov/pubmed/6650562',
  'http://www.ncbi.nlm.nih.gov/pubmed/15829955',
  'http://www.ncbi.nlm.nih.gov/pubmed/12239580',
  'http://www.ncbi.nlm.nih.gov/pubmed/20598273',
  'http://www.ncbi.nlm.nih.gov/pubmed/23001136',
  'http://www.ncbi.nlm.nih.gov/pubmed/8896569',
  'http://www.ncbi.nlm.nih.gov/pubmed/15858239',
  'http://www.ncbi.nlm.nih.gov/pubmed/15617541',
  'http://www.ncbi.nlm.nih.gov/pubmed/21995290'],
 'error_rate': {'value': 1.0, 'details': '9 found out of 9'},
 'ground_truth_snippets': [{'offsetInBeginSection': 131,
   'offsetInEndSection': 358,
   'text': 'Hirschsprung disease (HSCR) is a multifactorial, non-mendelian disorder in which rare high-penetrance coding sequence mutations in the receptor tyrosine kinase RET contribute to risk in combination with mutations at other genes',
   'beginSection': 'a

In [12]:
def tokenize(text):
    text = text.lower()
    text = re.sub(r'[^a-z0-9\s]', '', text)
    return text.split()


In [13]:
def prepare_corpus(articles):
    """
    Prepares tokenized documents for BM25 from article title + abstract.
    Returns (tokenized_corpus, article_refs)
    """
    corpus = []
    article_refs = []

    for article in articles:
        title = article.get('title', '')
        abstract = article.get('abstract', '')
        text = f"{title} {abstract}".strip()

        if not text:
            continue

        tokens = tokenize(text)
        if tokens:
            corpus.append(tokens)
            article_refs.append(article)  # store original for scoring

    return corpus, article_refs

In [14]:
def rank_articles_bm25(question, articles):
    corpus, article_refs = prepare_corpus(articles)
    if not corpus:
        return []

    bm25 = BM25Okapi(corpus)
    query = tokenize(question)

    scores = bm25.get_scores(query)
    ranked_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)

    top_10 = []
    for i in ranked_indices[:10]:
        article = article_refs[i]
        top_10.append({
            'pid': article.get('pid', ''),
            'title': article.get('title', ''),
            'abstract': article.get('abstract', ''),
            'score': float(scores[i])
        })

    return top_10

In [15]:
def save_results(results, output_file='bm25_results_training.json'):
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)

In [16]:
def extract_snippets(question, top_articles):
    """
    Extract BioASQ-style snippets from the top-ranked articles.
    Returns a list of snippet dicts with section labels and character offsets.
    """
    query_terms = [re.sub(r'[^a-z0-9]', '', t.lower()) for t in re.findall(r'\w+', question)]
    snippets = []

    for article in top_articles:
        pid = article.get('pid', '')
        #doc_url = f"https://www.ncbi.nlm.nih.gov/pubmed/{pid}"

        for section in ['title', 'abstract']:
            field_text = article.get(section, '')
            text_lower = field_text.lower()
            text_norm = re.sub(r'[^a-z0-9\s]', '', text_lower)

            # Map normalized text back to original for offset tracking
            match_offsets = []
            for term in query_terms:
                for m in re.finditer(r'\b' + re.escape(term) + r'\b', text_norm):
                    start, end = m.start(), m.end()
                    match_offsets.append((start, end))

            if not match_offsets:
                continue

            snippet_start = min(offset[0] for offset in match_offsets)
            snippet_end = max(offset[1] for offset in match_offsets)

            # Use original text offsets — approximate based on normalized mapping
            snippet_text = field_text[snippet_start:snippet_end].strip()

            snippets.append({
                "beginSection": section,
                "endSection": section,
                "text": snippet_text,
                "document": pid,
                "offsetInBeginSection": snippet_start,
                "offsetInEndSection": snippet_end
            })

    return snippets

In [17]:
def rank_all_questions_bm25(data):
    results_by_question = []

    for entry in tqdm(data, desc="Processing questions..."):
        question = entry['question']
        articles = entry.get('all_retreived_articles', [])
        top_articles = rank_articles_bm25(question, articles)
        snippets = extract_snippets(question, top_articles)

        results_by_question.append({
            'question': question,
            'top_10_articles': top_articles, 
            'snippets': snippets
        })

    return results_by_question

In [20]:
results_bm25 = rank_all_questions_bm25(data=data)
save_results(results_bm25)

Processing questions...: 100%|██████████| 5390/5390 [02:14<00:00, 40.06it/s]
