In [2]:
from datasets import load_dataset
from rank_bm25 import BM25Okapi
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline
from nltk.tokenize import TreebankWordTokenizer
import nltk
import torch
import numpy as np
from sklearn.metrics import f1_score
import re

  from .autonotebook import tqdm as notebook_tqdm


<h2>Loading the dataset</h2>

The PubMedQA dataset is loaded, specifically the pqa_labeled dataset that consists of 1000 samples. The dataset is shuffled and split into train and test set. 

The train set consists of 800 samples and the test set consists of 200 samples. 

In [9]:

dataset = load_dataset("pubmed_qa", "pqa_labeled")["train"].shuffle(seed=42)
split_dataset = dataset.train_test_split(test_size=0.2, seed=42)

#Three subsets of data for model training and evaluation
train_dataset = split_dataset["train"]
test_dataset = split_dataset["test"]


This dataset is structured to support question-answering (QA) tasks related to biomedical literature. It contains five columns that provide essential information for each entry.

1. pubid: A unique identifier for each record.
2. question: The medical or scientific question posed.
3. context: Background information related to the question.
4. long_answer: A detailed response to the question based on the provided context.
5. final_decision: yes/no/maybe

In [11]:
split_dataset

DatasetDict({
    train: Dataset({
        features: ['pubid', 'question', 'context', 'long_answer', 'final_decision'],
        num_rows: 800
    })
    test: Dataset({
        features: ['pubid', 'question', 'context', 'long_answer', 'final_decision'],
        num_rows: 200
    })
})

<h2>BM25 for document retrieval</h2>

Preparing a corpus from training set to be used with BM25, an efficient (alternative to tf-idf) ranking algorithm for information retrieval. BM25 helps rank documents by their relevance to a given query, making it useful for QA tasks.

The following code achieves the following:
- Retrieves the context field from the training dataset (train_dataset).
- Converts the extracted texts into a list of strings (corpus).
- Joins multiple contexts within each entry into a single string.
- Initializes the Treebank Word Tokenizer, which is optimized for English text.
- Tokenizes each document in the corpus into a list of words
- Creates a BM25 index using the tokenized corpus.
- BM25 ranks documents based on the frequency and importance of words in a query.




In [12]:
# Prepare corpus for BM25 using training set


# Extract just the 'context' field from the train set
corpus_data = train_dataset["context"]
corpus = [' '.join(entry['contexts']) for entry in corpus_data]

# Tokenize the corpus
tokenizer = TreebankWordTokenizer()
tokenized_corpus = [tokenizer.tokenize(doc) for doc in corpus]

# Initialize BM25
bm25 = BM25Okapi(tokenized_corpus)



In [13]:
def retrieve_with_bm25(query, k=3):
    tokenized_query = tokenizer.tokenize(query)
    doc_scores = bm25.get_scores(tokenized_query)
    top_k_indices = sorted(range(len(doc_scores)), key=lambda i: doc_scores[i], reverse=True)[:k]
    return [corpus[i] for i in top_k_indices]


In [14]:
query = "Does aspirin reduce the risk of stroke?"
top_docs = retrieve_with_bm25(query)

for i, doc in enumerate(top_docs, 1):
    print(f"Doc {i}:\n{doc[:500]}\n")


Doc 1:
In primary and secondary prevention trials, statins have been shown to reduce the risk of stroke. In addition to lipid lowering, statins have a number of antiatherothrombotic and neuroprotective properties. In a preliminary observational study, we explored whether clinical outcome is improved in patients who are on treatment with statins when stroke occurs. We conducted a population-based case-referent study of 25- to 74-year-old stroke patients with, for each case of a patient who was on statin

Doc 2:
Type 2 diabetes may be present for several years before diagnosis, by which time many patients have already developed diabetic complications. Earlier detection and treatment may reduce this burden, but evidence to support this approach is lacking. Glycemic control and clinical and surrogate outcomes were compared for 5,088 of 5,102 U.K. Diabetes Prospective Study participants according to whether they had low (<140 mg/dl [<7.8 mmol/l]), intermediate (140 to<180 mg/dl [7.8 to<10.0

<h2>Load BioMedBERT for Biomedical Question Answering</h2>

Load BioMedBERT, a specialized BERT model designed for biomedical text processing, and sets up a question-answering (QA) pipeline.


In [15]:
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline

# Load model and tokenizer
model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)

# Create QA pipeline
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)



Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Device set to use cpu


In [16]:
# Try using top retrieved doc
context_doc = top_docs[0]
result = qa_pipeline({
    "question": query,
    "context": context_doc
})
print(f"Answer: {result['answer']}")




Answer: referent study


In [17]:
def map_answer_to_label(answer: str) -> str:
    answer = answer.lower()
    if any(phrase in answer for phrase in ["no", "not", "does not", "none", "negative"]):
        return "no"
    elif any(phrase in answer for phrase in ["yes", "does", "can", "reduce", "associated with", "positive", "increased", "decreased"]):
        return "yes"
    else:
        return "maybe"


In [18]:
import random
def evaluate(dataset, k=3, max_samples=100):
    correct = 0
    total = 0
    examples = []

    for example in random.sample(list(dataset), min(max_samples, len(dataset))):
        question = example["question"]
        gold_label = example["final_decision"]

        # Retrieve top documents
        retrieved_docs = retrieve_with_bm25(question, k=k)
        if not retrieved_docs:
            continue

        # Use top-1 doc for answer extraction
        context = retrieved_docs[0]

        try:
            result = qa_pipeline({"question": question, "context": context})
            predicted_answer = result["answer"]
            predicted_label = map_answer_to_label(predicted_answer)
        except:
            continue

        total += 1
        if predicted_label == gold_label:
            correct += 1

        examples.append({
            "question": question,
            "gold_label": gold_label,
            "predicted_label": predicted_label,
            "predicted_answer": predicted_answer,
            "context_snippet": context[:300]
        })

    accuracy = correct / total if total > 0 else 0
    return accuracy, examples


In [19]:
test_acc, test_examples = evaluate(test_dataset, max_samples=100)
print(f"Test Accuracy: {test_acc:.2f}")


Test Accuracy: 0.17
