# BioBERT-based model to classify biomedical QA pairs into yes/no/maybe.

Two-stage pipeline:
  1. Fine-tune BioBERT on labeled + (weighted) artificial data.
  2. Evaluate model.

## Importing Dataset

In [1]:
from datasets import load_dataset

# Load the labeled, unlabeled, and artificial subsets of PubMedQA
# The dataset is split into three subsets:

dataset_labeled = load_dataset("qiaojin/PubMedQA", 'pqa_labeled')
dataset_unlabeled = load_dataset("qiaojin/PubMedQA", 'pqa_unlabeled')
dataset_artificial = load_dataset('qiaojin/PubMedQA', 'pqa_artificial')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Convert the datasets to pandas DataFrames for easier manipulation
import pandas as pd

df_labeled_original = pd.DataFrame(dataset_labeled['train'])
df_unlabeled_original = pd.DataFrame(dataset_unlabeled['train'])
df_artificial_original = pd.DataFrame(dataset_artificial['train'])

In [3]:
# Make a copy of the original datasets to work on
df_labeled = df_labeled_original.copy()
df_unlabeled = df_unlabeled_original.copy()
df_artificial = df_artificial_original.copy()

## Preprocessing

In [4]:
# Function to merge Question and Contexts into Input
def merge_fields(row):
    context_text = " ".join(row['context']['contexts'])
    return f"{context_text}"

# Apply function to all datasets
for df in [df_labeled, df_artificial, df_unlabeled]:
    df['context_str'] = df.apply(merge_fields, axis=1)

In [7]:
df_labeled = df_labeled[df_labeled['final_decision'] != "maybe"]

label_map = {'yes': 0, 'no': 1}
df_labeled['label'] = df_labeled['final_decision'].map(label_map)
df_artificial['label'] = df_artificial['final_decision'].map(label_map)

In [8]:
df_labeled_final = df_labeled[['question','context_str', 'label']]
df_artificial_final = df_artificial[['question','context_str', 'label']]

In [9]:
# Check class distribution
print(df_labeled_final['label'].value_counts())
print(df_artificial_final['label'].value_counts())

label
0    552
1    338
Name: count, dtype: int64
label
0    196144
1     15125
Name: count, dtype: int64


In [10]:
# Balance Artificial Dataset (Downsampling)
df_artificial_final = df_artificial_final.groupby('label').sample(n=min(df_artificial_final['label'].value_counts()), random_state=42)
# Shuffle the dataset
df_artificial_final = df_artificial_final.sample(frac=1, random_state=42)

In [11]:
len(df_artificial_final)

30250

In [12]:
df_train = df_artificial_final
df_train.head()

Unnamed: 0,question,context_str,label
117741,Are common genetic variants in the microRNA bi...,Although the role of miRNA in cancer developme...,1
164514,Do y-SNPs indicate hybridisation between Europ...,Previous genetic studies of modern and ancient...,1
178755,Is genetic polymorphisms ofCYP2A6 andCYP2E1 wi...,To elucidate the association between genetic p...,1
51724,Does a background infusion of morphine enhance...,To compare the effects of patient-controlled a...,1
67278,Do neutrophils promote aerogenous spread of lu...,Adenocarcinoma with bronchioloalveolar carcino...,0


In [16]:
len(df_train['context_str'].iloc[0])

697

In [25]:
df_test = df_labeled_final
df_test.head()

Unnamed: 0,question,context_str,label
0,Do mitochondria play a role in remodelling lac...,Programmed cell death (PCD) is the regulated d...,0
1,Landolt C and snellen e acuity: differences in...,Assessment of visual acuity depends on the opt...,1
2,"Syncope during bathing in infants, a pediatric...",Apparent life-threatening events in infants ar...,0
3,Are the long-term results of the transanal pul...,The transanal endorectal pull-through (TERPT) ...,1
4,Can tailored interventions increase mammograph...,Telephone counseling and tailored print commun...,0


## Extract most relevant text from context_str

In [18]:
## Extract most relevant text from context_str
from sentence_transformers import SentenceTransformer, util
from nltk.tokenize import sent_tokenize
import torch

# 1. Load sentence-transformers model (small, fast on CPU)
model = SentenceTransformer('all-MiniLM-L6-v2')  # Fast & accurate enough

# 2. Function to extract top-k relevant sentences from context based on question
def get_top_k_sentences(question, context, k=3):
    sentences = sent_tokenize(context)

    if not sentences:
        return ""  # Return empty if context is blank

    # Clamp k to number of available sentences
    k = min(k, len(sentences))

    # Encode question and context sentences
    question_embedding = model.encode(question, convert_to_tensor=True)
    sentence_embeddings = model.encode(sentences, convert_to_tensor=True)

    # Compute semantic similarity
    cosine_scores = util.pytorch_cos_sim(question_embedding, sentence_embeddings)[0]

    # Get top-k most relevant sentences
    top_k_indices = torch.topk(cosine_scores, k=k).indices
    selected_sentences = [sentences[i] for i in top_k_indices]

    return ' '.join(selected_sentences)


# 3. Apply to your dataset
# Make sure df_train has columns: 'question', 'context_str'
df_train['filtered_context'] = df_train.apply(
    lambda row: get_top_k_sentences(row['question'], row['context_str'], k=3), axis=1
)


In [19]:
df_train.to_csv('df_train2.csv', index=False)  

In [20]:
print(df_train.head())
len(df_train['context_str'].iloc[0])

                                                 question  \
117741  Are common genetic variants in the microRNA bi...   
164514  Do y-SNPs indicate hybridisation between Europ...   
178755  Is genetic polymorphisms ofCYP2A6 andCYP2E1 wi...   
51724   Does a background infusion of morphine enhance...   
67278   Do neutrophils promote aerogenous spread of lu...   

                                              context_str  label  \
117741  Although the role of miRNA in cancer developme...      1   
164514  Previous genetic studies of modern and ancient...      1   
178755  To elucidate the association between genetic p...      1   
51724   To compare the effects of patient-controlled a...      1   
67278   Adenocarcinoma with bronchioloalveolar carcino...      0   

                                         filtered_context  
117741  Although the role of miRNA in cancer developme...  
164514  Strikingly, our results do not support the hyp...  
178755  To elucidate the association between

697

## Finetune BioBERT for QA Classification

Tokenize the data

In [21]:
import pandas as pd

df_train = pd.read_csv('df_train2.csv')

In [22]:
from transformers import AutoTokenizer
import torch

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-v1.1')

def truncate_head_tail(text, tokenizer, max_length):
    """Truncate the middle of a text to preserve the beginning and end."""
    tokens = tokenizer.encode(text, add_special_tokens=False)
    
    if len(tokens) <= max_length:
        return tokens
    
    # Reserve space for special tokens [CLS] and [SEP]
    reserved = 2
    half = (max_length - reserved) // 2
    truncated = tokens[:half] + tokens[-half:]
    
    return [tokenizer.cls_token_id] + truncated + [tokenizer.sep_token_id]

def encode_data(tokenizer, texts, max_length):
    input_ids = []
    attention_masks = []

    for text in texts:
        ids = truncate_head_tail(text, tokenizer, max_length)
        mask = [1] * len(ids)

        # Pad to max_length
        padding_length = max_length - len(ids)
        ids += [tokenizer.pad_token_id] * padding_length
        mask += [0] * padding_length

        input_ids.append(ids)
        attention_masks.append(mask)

    return torch.tensor(input_ids), torch.tensor(attention_masks)

# Use this on your dataset
input_ids, attention_mask = encode_data(
    tokenizer,
    df_train['filtered_context'].tolist(),
    max_length=512
)

Token indices sequence length is longer than the specified maximum sequence length for this model (544 > 512). Running this sequence through the model will result in indexing errors


Finetuning BioBERT


In [23]:
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments

# Load the pre-trained model
model = AutoModelForSequenceClassification.from_pretrained('dmis-lab/biobert-v1.1', num_labels=3)
# Define the training arguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=1,
    per_device_train_batch_size=2,  # lower to fit RAM
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,  # effectively 8 if batch size 2
    warmup_steps=100,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=100,
    save_strategy="no",  # Don't save checkpoints to avoid I/O overhead
    load_best_model_at_end=False,
    fp16=True,
)

from datasets import Dataset

data = {
    'input_ids': input_ids.tolist(),
    'attention_mask': attention_mask.tolist(),
    'labels': df_train['label'].tolist()
}

train_ds = Dataset.from_dict(data)

# Create the Trainer and start training
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
)
trainer.train()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at dmis-lab/biobert-v1.1 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss
100,0.7977
200,0.6613
300,0.6065
400,0.6072
500,0.5506
600,0.5904
700,0.566
800,0.6074
900,0.5796
1000,0.5823


TrainOutput(global_step=3781, training_loss=0.5520996703133007, metrics={'train_runtime': 2685.8895, 'train_samples_per_second': 11.263, 'train_steps_per_second': 1.408, 'total_flos': 7958654659436544.0, 'train_loss': 0.5520996703133007, 'epoch': 0.9999338842975206})

## Testing

df_test.to_csv('df_test2.csv', index=False) 

In [24]:
# Apply same preprocessing as training
input_ids_test, attention_mask_test = encode_data(
    tokenizer,
    df_test['filtered_context'].tolist(),
    max_length=512
)

KeyError: 'filtered_context'