# BioBERT-based model to classify biomedical QA pairs into yes/no/maybe.

Two-stage pipeline:
  1. Fine-tune BioBERT on labeled + (weighted) artificial data.
  2. Evaluate model.

## Importing Dataset

In [1]:
from datasets import load_dataset

# Load the labeled, unlabeled, and artificial subsets of PubMedQA
# The dataset is split into three subsets:

dataset_labeled = load_dataset("qiaojin/PubMedQA", 'pqa_labeled')
dataset_unlabeled = load_dataset("qiaojin/PubMedQA", 'pqa_unlabeled')
dataset_artificial = load_dataset('qiaojin/PubMedQA', 'pqa_artificial')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Convert the datasets to pandas DataFrames for easier manipulation
import pandas as pd

df_labeled_original = pd.DataFrame(dataset_labeled['train'])
df_unlabeled_original = pd.DataFrame(dataset_unlabeled['train'])
df_artificial_original = pd.DataFrame(dataset_artificial['train'])

In [3]:
# Make a copy of the original datasets to work on
df_labeled = df_labeled_original.copy()
df_unlabeled = df_unlabeled_original.copy()
df_artificial = df_artificial_original.copy()

## Preprocessing

In [4]:
# Function to merge Question and Contexts into Input
def merge_fields(row):
    context_text = " ".join(row['context']['contexts'])
    return f"{context_text}"

# Apply function to all datasets
for df in [df_labeled, df_artificial, df_unlabeled]:
    df['context_str'] = df.apply(merge_fields, axis=1)

In [5]:
label_map = {'yes': 0, 'no': 1, 'maybe': 2}
df_labeled['label'] = df_labeled['final_decision'].map(label_map)
df_artificial['label'] = df_artificial['final_decision'].map(label_map)

In [6]:
df_labeled_final = df_labeled[['question','context_str', 'label']]
df_artificial_final = df_artificial[['question','context_str', 'label']]

In [7]:
# Check class distribution
print(df_labeled_final['label'].value_counts())
print(df_artificial_final['label'].value_counts())

label
0    552
1    338
2    110
Name: count, dtype: int64
label
0    196144
1     15125
Name: count, dtype: int64


In [8]:
# Balance Artificial Dataset (Downsampling)
df_artificial_final = df_artificial_final.groupby('label').sample(n=min(df_artificial_final['label'].value_counts()), random_state=42)
# Shuffle the dataset
df_artificial_final = df_artificial_final.sample(frac=1, random_state=42)

In [9]:
len(df_artificial_final)

30250

In [10]:
from sklearn.model_selection import train_test_split

# Split df_labeled_final
labeled_train, labeled_test = train_test_split(
    df_labeled_final,
    test_size=0.2,      # 20% for testing
    random_state=42,    # for reproducibility
    stratify=df_labeled_final['label']  # optional: ensures class distribution is preserved
)

# Split df_artificial_final
artificial_train, artificial_test = train_test_split(
    df_artificial_final,
    test_size=0.2,
    random_state=42,
    stratify=df_artificial_final['label']
)


In [11]:
df_train = pd.concat([labeled_train, artificial_train], ignore_index=True, sort=False)
df_train.head()

Unnamed: 0,question,context_str,label
0,Increased neutrophil migratory activity after ...,Neutrophil infiltration of the lung is charact...,0
1,Are UK radiologists satisfied with the trainin...,A list of telephone numbers of UK hospitals wi...,1
2,Do patients with rheumatoid arthritis establis...,It is postulated that some aspects of methotre...,0
3,A short stay or 23-hour ward in a general and ...,We evaluated the usefulness of a short stay or...,0
4,Do family physicians know the costs of medical...,To determine the cost of 46 commonly used inve...,1


In [12]:
len(df_train)

25000

In [13]:
len(df_train['context_str'][0])

2109

In [14]:
df_test = pd.concat([labeled_test, artificial_test], ignore_index=True, sort=False)
df_test.head()

Unnamed: 0,question,context_str,label
0,Are home sampling kits for sexually transmitte...,There is an urgent need to increase opportunis...,2
1,Scrotal approach to both palpable and impalpab...,To determine the advantages of scrotal incisio...,0
2,Are polymorphisms in oestrogen receptors genes...,Polymorphisms in the oestrogen receptor 1 (ESR...,0
3,Do elderly patients benefit from surgery in ad...,Treatment of elderly cancer patients has gaine...,1
4,Does route of delivery affect maternal and per...,The route of delivery in eclampsia is controve...,1


## Extract most relevant text from context_str

In [None]:
## Extract most relevant text from context_str
from sentence_transformers import SentenceTransformer, util
from nltk.tokenize import sent_tokenize
import torch

# 1. Load sentence-transformers model (small, fast on CPU)
model = SentenceTransformer('all-MiniLM-L6-v2')  # Fast & accurate enough

# 2. Function to extract top-k relevant sentences from context based on question
def get_top_k_sentences(question, context, k=3):
    sentences = sent_tokenize(context)

    if not sentences:
        return ""  # Return empty if context is blank

    # Clamp k to number of available sentences
    k = min(k, len(sentences))

    # Encode question and context sentences
    question_embedding = model.encode(question, convert_to_tensor=True)
    sentence_embeddings = model.encode(sentences, convert_to_tensor=True)

    # Compute semantic similarity
    cosine_scores = util.pytorch_cos_sim(question_embedding, sentence_embeddings)[0]

    # Get top-k most relevant sentences
    top_k_indices = torch.topk(cosine_scores, k=k).indices
    selected_sentences = [sentences[i] for i in top_k_indices]

    return ' '.join(selected_sentences)


# 3. Apply to your dataset
# Make sure df_train has columns: 'question', 'context_str'
df_train['filtered_context'] = df_train.apply(
    lambda row: get_top_k_sentences(row['question'], row['context_str'], k=3), axis=1
)


In [None]:
df_train.to_csv('df_train.csv', index=False)  

In [None]:
print(df_train.head())
len(df_train['filtered_context'][0])

                                            question  \
0  Increased neutrophil migratory activity after ...   
1  Are UK radiologists satisfied with the trainin...   
2  Do patients with rheumatoid arthritis establis...   
3  A short stay or 23-hour ward in a general and ...   
4  Do family physicians know the costs of medical...   

                                         context_str  label  \
0  Neutrophil infiltration of the lung is charact...      0   
1  A list of telephone numbers of UK hospitals wi...      1   
2  It is postulated that some aspects of methotre...      0   
3  We evaluated the usefulness of a short stay or...      0   
4  To determine the cost of 46 commonly used inve...      1   

                                    filtered_context  
0  Neutrophils isolated from major trauma patient...  
1  Only 52% of departments had a dedicated paedia...  
2  To look at the effect of stopping FA supplemen...  
3  We evaluated the usefulness of a short stay or...  
4  Six hu

629

## Finetune BioBERT for QA Classification

Tokenize the data

In [None]:
import pandas as pd

df_train = pd.read_csv('df_train.csv')

In [None]:
from transformers import AutoTokenizer
import torch

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-v1.1')

def truncate_head_tail(text, tokenizer, max_length):
    """Truncate the middle of a text to preserve the beginning and end."""
    tokens = tokenizer.encode(text, add_special_tokens=False)
    
    if len(tokens) <= max_length:
        return tokens
    
    # Reserve space for special tokens [CLS] and [SEP]
    reserved = 2
    half = (max_length - reserved) // 2
    truncated = tokens[:half] + tokens[-half:]
    
    return [tokenizer.cls_token_id] + truncated + [tokenizer.sep_token_id]

def encode_data(tokenizer, texts, max_length):
    input_ids = []
    attention_masks = []

    for text in texts:
        ids = truncate_head_tail(text, tokenizer, max_length)
        mask = [1] * len(ids)

        # Pad to max_length
        padding_length = max_length - len(ids)
        ids += [tokenizer.pad_token_id] * padding_length
        mask += [0] * padding_length

        input_ids.append(ids)
        attention_masks.append(mask)

    return torch.tensor(input_ids), torch.tensor(attention_masks)

# Use this on your dataset
input_ids, attention_mask = encode_data(
    tokenizer,
    df_train['filtered_context'].tolist(),
    max_length=512
)

  from .autonotebook import tqdm as notebook_tqdm


Finetuning BioBERT


In [None]:
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments

# Load the pre-trained model
model = AutoModelForSequenceClassification.from_pretrained('dmis-lab/biobert-v1.1', num_labels=3)
# Define the training arguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=1,
    per_device_train_batch_size=2,  # lower to fit RAM
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,  # effectively 8 if batch size 2
    warmup_steps=100,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=100,
    save_strategy="no",  # Don't save checkpoints to avoid I/O overhead
    load_best_model_at_end=False,
    fp16=True,
)

from datasets import Dataset

data = {
    'input_ids': input_ids.tolist(),
    'attention_mask': attention_mask.tolist(),
    'labels': df_train['label'].tolist()
}

train_ds = Dataset.from_dict(data)

# Create the Trainer and start training
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
)
trainer.train()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at dmis-lab/biobert-v1.1 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss
100,0.8027
200,0.6548
300,0.6715
400,0.652
500,0.6395
600,0.6113
700,0.5964
800,0.5795
900,0.595
1000,0.6198


TrainOutput(global_step=3125, training_loss=0.5857315985107422, metrics={'train_runtime': 2019.8978, 'train_samples_per_second': 12.377, 'train_steps_per_second': 1.547, 'total_flos': 6577835443200000.0, 'train_loss': 0.5857315985107422, 'epoch': 1.0})