### Project: Question/Answering system using PubMedQA dataset and microsoft/BiomedNLP-PubMedBERT... pretrained model
## Anna jazayeri 2916723
## Data mining CIS 660

# 1. Setup and imports

In [16]:
# Setup and imports

!pip install -q datasets transformers accelerate

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from datasets import load_dataset
import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split

from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
import random
import math
import os
import re

In [2]:
# Device and seeds

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


device(type='cuda')

# 2. Data prep: build a big unlabeled (question, context) dataset
This uses labeled + unlabeled + artificial PubMedQA configs, but we ignore all labels.

In [3]:
# 2. Load PubMedQA splits and extract question/context pairs

def extract_qc_pairs(hf_dataset_split):
    questions = []
    contexts = []
    for ex in hf_dataset_split:
        q = ex["question"]
        ctx_list = ex["context"]["contexts"]
        ctx = " ".join(ctx_list)
        questions.append(q)
        contexts.append(ctx)
    return questions, contexts

# Load the three official configs; comment any out if you want
ds_labeled = load_dataset("pubmed_qa", "pqa_labeled")["train"]
ds_unlab = load_dataset("pubmed_qa", "pqa_unlabeled")["train"]
ds_artificial = load_dataset("pubmed_qa", "pqa_artificial")["train"]

q_l, c_l = extract_qc_pairs(ds_labeled)
q_u, c_u = extract_qc_pairs(ds_unlab)
q_a, c_a = extract_qc_pairs(ds_artificial)

all_questions = q_l + q_u + q_a
all_contexts = c_l + c_u + c_a

len(all_questions), len(all_contexts)


(273518, 273518)

In [4]:
# 2.1 Build a single DataFrame and deduplicate

df_all = pd.DataFrame({
    "question": all_questions,
    "context": all_contexts
})

# Optional: drop exact duplicate (question, context) pairs
df_all = df_all.drop_duplicates(subset=["question", "context"]).reset_index(drop=True)

# Add a global id for retrieval evaluation later
df_all["id"] = np.arange(len(df_all))

df_all.head(), len(df_all)


(                                            question  \
 0  Do mitochondria play a role in remodelling lac...   
 1  Landolt C and snellen e acuity: differences in...   
 2  Syncope during bathing in infants, a pediatric...   
 3  Are the long-term results of the transanal pul...   
 4  Can tailored interventions increase mammograph...   
 
                                              context  id  
 0  Programmed cell death (PCD) is the regulated d...   0  
 1  Assessment of visual acuity depends on the opt...   1  
 2  Apparent life-threatening events in infants ar...   2  
 3  The transanal endorectal pull-through (TERPT) ...   3  
 4  Telephone counseling and tailored print commun...   4  ,
 273467)

In [5]:
# 2.2 Train/val/test splits (80/10/10)

train_df, temp_df = train_test_split(
    df_all,
    test_size=0.2,
    random_state=42,
    shuffle=True
)

val_df, test_df = train_test_split(
    temp_df,
    test_size=0.5,
    random_state=42,
    shuffle=True
)

len(train_df), len(val_df), len(test_df)


(218773, 27347, 27347)

# 3. Dataset and DataLoader for dual encoder
We’ll have a dataset that returns question + context strings, and a collate_fn that tokenizes both.

In [6]:
# 3. Dataset and collate_fn

class QCDataset(Dataset):
    def __init__(self, df):
        self.questions = df["question"].tolist()
        self.contexts = df["context"].tolist()
    def __len__(self):
        return len(self.questions)
    def __getitem__(self, idx):
        return {
            "question": self.questions[idx],
            "context": self.contexts[idx]
        }

model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
tokenizer = AutoTokenizer.from_pretrained(model_name)

max_length = 128 # you can increase to 256 if you have GPU memory

def collate_fn(batch):
    qs = [b["question"] for b in batch]
    cs = [b["context"] for b in batch]

    q_enc = tokenizer(
        qs,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )

    c_enc = tokenizer(
        cs,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )

    return {
        "q_input_ids": q_enc["input_ids"],
        "q_attention_mask": q_enc["attention_mask"],
        "c_input_ids": c_enc["input_ids"],
        "c_attention_mask": c_enc["attention_mask"]
    }

train_dataset = QCDataset(train_df)
val_dataset = QCDataset(val_df)

batch_size = 8 # adjust based on GPU memory

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn
)


# 4. Dual encoder model + InfoNCE loss

In [7]:
# 4. Dual encoder using shared PubMedBERT

class DualEncoder(nn.Module):
    def __init__(self, encoder_name):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(encoder_name)
        self.hidden_size = self.encoder.config.hidden_size

    def encode(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        token_embs = outputs.last_hidden_state # [batch_size, seq_len, hidden_dim]
        cls_emb = token_embs[:, 0, :] # [batch_size, hidden_dim]
        cls_emb = cls_emb / cls_emb.norm(p=2, dim=1, keepdim=True).clamp(min=1e-8)
        return cls_emb

    def forward(self, q_input_ids, q_attention_mask, c_input_ids, c_attention_mask):
        q_emb = self.encode(q_input_ids, q_attention_mask)
        c_emb = self.encode(c_input_ids, c_attention_mask)
        return q_emb, c_emb

dual_encoder = DualEncoder(model_name).to(device)


In [8]:
# 4.1 InfoNCE / contrastive loss with in-batch negatives

def contrastive_loss(q_emb, c_emb, temperature=0.05):
    batch_size = q_emb.size(0)
    sim_matrix = q_emb @ c_emb.t()  # [B, B]
    sim_matrix = sim_matrix / temperature
    targets = torch.arange(batch_size, device=sim_matrix.device)

    loss_q2c = F.cross_entropy(sim_matrix, targets)
    loss_c2q = F.cross_entropy(sim_matrix.t(), targets)
    loss = (loss_q2c + loss_c2q) / 2.0
    return loss


# 5. Training loop (self-supervised)
We’ll train just to minimize the contrastive loss; you can limit number of batches/epochs to what your GPU can handle.

In [9]:
# 5. Training setup

num_epochs = 2 # start small, we can increase later
learning_rate = 2e-5

optimizer = torch.optim.AdamW(dual_encoder.parameters(), lr=learning_rate)

# Optional: linear warmup + decay scheduler
num_update_steps_per_epoch = math.ceil(len(train_loader))
max_train_steps = num_epochs * num_update_steps_per_epoch

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * max_train_steps),
    num_training_steps=max_train_steps
)


In [10]:
# 5.1 Training loop
# this part took me 4 hours to run so I already saved this part so we can just import it later easy peasy

dual_encoder.train()
set_seed(42)

for epoch in range(num_epochs):
    total_loss = 0.0
    dual_encoder.train()

    for step, batch in enumerate(train_loader):
        q_input_ids = batch["q_input_ids"].to(device)
        q_attention_mask = batch["q_attention_mask"].to(device)
        c_input_ids = batch["c_input_ids"].to(device)
        c_attention_mask = batch["c_attention_mask"].to(device)

        optimizer.zero_grad()

        q_emb, c_emb = dual_encoder(
            q_input_ids=q_input_ids,
            q_attention_mask=q_attention_mask,
            c_input_ids=c_input_ids,
            c_attention_mask=c_attention_mask
        )

        loss = contrastive_loss(q_emb, c_emb, temperature=0.05)
        loss.backward()
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()

        if (step + 1) % 50 == 0:
            avg_loss = total_loss / (step + 1)
            print(f"Epoch {epoch+1} | Step {step+1}/{len(train_loader)} | Avg loss: {avg_loss:.4f}")

    avg_epoch_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} finished. Average training loss: {avg_epoch_loss:.4f}")


Epoch 1 | Step 50/27347 | Avg loss: 1.6948
Epoch 1 | Step 100/27347 | Avg loss: 1.6557
Epoch 1 | Step 150/27347 | Avg loss: 1.5713
Epoch 1 | Step 200/27347 | Avg loss: 1.4328
Epoch 1 | Step 250/27347 | Avg loss: 1.2528
Epoch 1 | Step 300/27347 | Avg loss: 1.0843
Epoch 1 | Step 350/27347 | Avg loss: 0.9491
Epoch 1 | Step 400/27347 | Avg loss: 0.8409
Epoch 1 | Step 450/27347 | Avg loss: 0.7539
Epoch 1 | Step 500/27347 | Avg loss: 0.6827
Epoch 1 | Step 550/27347 | Avg loss: 0.6245
Epoch 1 | Step 600/27347 | Avg loss: 0.5757
Epoch 1 | Step 650/27347 | Avg loss: 0.5334
Epoch 1 | Step 700/27347 | Avg loss: 0.4972
Epoch 1 | Step 750/27347 | Avg loss: 0.4659
Epoch 1 | Step 800/27347 | Avg loss: 0.4388
Epoch 1 | Step 850/27347 | Avg loss: 0.4139
Epoch 1 | Step 900/27347 | Avg loss: 0.3917
Epoch 1 | Step 950/27347 | Avg loss: 0.3718
Epoch 1 | Step 1000/27347 | Avg loss: 0.3547
Epoch 1 | Step 1050/27347 | Avg loss: 0.3382
Epoch 1 | Step 1100/27347 | Avg loss: 0.3234
Epoch 1 | Step 1150/27347 | Av

# 6. Build context embedding index for retrieval
Now we encode all contexts with the trained dual encoder to build a retrieval index.

In [11]:
# 6. Build context embedding index for all contexts

class ContextDataset(Dataset):
    def __init__(self, df):
        self.contexts = df["context"].tolist()
        self.ids = df["id"].tolist()
    def __len__(self):
        return len(self.contexts)
    def __getitem__(self, idx):
        return {
            "context": self.contexts[idx],
            "id": self.ids[idx]
        }

context_dataset = ContextDataset(df_all)

def context_collate_fn(batch):
    cs = [b["context"] for b in batch]
    ids = [b["id"] for b in batch]
    enc = tokenizer(
        cs,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )
    return {
        "c_input_ids": enc["input_ids"],
        "c_attention_mask": enc["attention_mask"],
        "ids": torch.tensor(ids, dtype=torch.long)
    }

context_loader = DataLoader(
    context_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=context_collate_fn
)


In [12]:
# 6.1 Encode all contexts

dual_encoder.eval()

hidden_size = dual_encoder.hidden_size
num_contexts = len(df_all)
context_embs = np.zeros((num_contexts, hidden_size), dtype=np.float32)

with torch.no_grad():
    for batch in context_loader:
        c_input_ids = batch["c_input_ids"].to(device)
        c_attention_mask = batch["c_attention_mask"].to(device)
        ids = batch["ids"].numpy() # global ids

        c_emb = dual_encoder.encode(c_input_ids, c_attention_mask)  # [B, H]
        c_emb = c_emb.cpu().numpy()

        context_embs[ids] = c_emb

context_embs.shape


(273467, 768)

# 7. Self-retrieval evaluation on the test set
We check: for each test question, does the model retrieve its own context in top-k?

In [13]:
# 7. Self-retrieval evaluation: Recall@k and MRR

dual_encoder.eval()

test_questions = test_df["question"].tolist()
test_true_ids = test_df["id"].tolist()

ranks = []

with torch.no_grad():
    for q_text, true_id in zip(test_questions, test_true_ids):
        enc = tokenizer(
            q_text,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt"
        )
        q_input_ids = enc["input_ids"].to(device)
        q_attention_mask = enc["attention_mask"].to(device)

        q_emb = dual_encoder.encode(q_input_ids, q_attention_mask) # [1, H]
        q_emb = q_emb.cpu().numpy()[0]

        sims = context_embs @ q_emb  # [num_contexts]
        ranked_indices = np.argsort(-sims) # descending

        # rank of the true context (1-based)
        rank = np.where(ranked_indices == true_id)[0][0] + 1
        ranks.append(rank)

ranks = np.array(ranks)
len(ranks)


27347

In [14]:
# 7.1 Compute Recall@k and MRR

def recall_at_k(ranks, k):
    return np.mean(ranks <= k)

recall_1 = recall_at_k(ranks, 1)
recall_5 = recall_at_k(ranks, 5)
recall_10 = recall_at_k(ranks, 10)

mrr = np.mean(1.0 / ranks)

print(f"Recall@1: {recall_1:.4f}")
print(f"Recall@5: {recall_5:.4f}")
print(f"Recall@10: {recall_10:.4f}")
print(f"MRR: {mrr:.4f}")


Recall@1:  0.7945
Recall@5:  0.9273
Recall@10: 0.9531
MRR:       0.8545


# 8. Qualitative example
print a few example retrievals

In [15]:
# 8. Qualitative example: show top-3 retrieved contexts for one test question

idx = 0  # change this to look at different examples
q_text = test_questions[idx]
true_id = test_true_ids[idx]

print("Question:")
print(q_text)
print("\nTrue context:")
print(df_all.loc[df_all["id"] == true_id, "context"].values[0][:800], "...\n")

enc = tokenizer(
    q_text,
    padding=True,
    truncation=True,
    max_length=max_length,
    return_tensors="pt"
)

with torch.no_grad():
    q_emb = dual_encoder.encode(
        enc["input_ids"].to(device),
        enc["attention_mask"].to(device)
    ).cpu().numpy()[0]

sims = context_embs @ q_emb
ranked_indices = np.argsort(-sims)

print("Top-3 retrieved contexts:")
for rank_pos in range(3):
    cid = ranked_indices[rank_pos]
    score = sims[cid]
    print(f"\nRank {rank_pos+1} | id={cid} | sim={score:.4f}")
    print(df_all.loc[df_all["id"] == cid, "context"].values[0][:800], "...")


Question:
Does molecular Screen identify Cardiac Myosin-Binding Protein-C as a Protein Kinase G-Iα Substrate?

True context:
Pharmacological activation of cGMP-dependent protein kinase G I (PKGI) has emerged as a therapeutic strategy for humans with heart failure. However, PKG-activating drugs have been limited by hypotension arising from PKG-induced vasodilation. PKGIα antiremodeling substrates specific to the myocardium might provide targets to circumvent this limitation, but currently remain poorly understood. We performed a screen for myocardial proteins interacting with the PKGIα leucine zipper (LZ)-binding domain to identify myocardial-specific PKGI antiremodeling substrates. Our screen identified cardiac myosin-binding protein-C (cMyBP-C), a cardiac myocyte-specific protein, which has been demonstrated to inhibit cardiac remodeling in the phosphorylated state, and when mutated leads to hypertrophic cardio ...

Top-3 retrieved contexts:

Rank 1 | id=198815 | sim=0.8612
Cardiac co

# EXAMPLE 1. Simple QA function: retrieve + show evidence

This version:

    * Takes a question string
    * Finds the most similar abstract in your corpus
    * Shows a short “answer” snippet from that abstract + the full context

In [17]:
def qa_retrieve(question, top_k_docs=3, snippet_chars=400):
    dual_encoder.eval()

    # encode the question
    enc = tokenizer(
        question,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )

    with torch.no_grad():
        q_emb = dual_encoder.encode(
            enc["input_ids"].to(device),
            enc["attention_mask"].to(device)
        ).cpu().numpy()[0]  # [hidden_size]

    # similarity with all contexts
    sims = context_embs @ q_emb  # [num_contexts]
    ranked_indices = np.argsort(-sims)  # descending

    print("QUESTION:")
    print(question)
    print("\nTop retrieved documents:\n")

    for rank_pos in range(top_k_docs):
        cid = ranked_indices[rank_pos]
        score = sims[cid]
        ctx = df_all.loc[df_all["id"] == cid, "context"].values[0]

        # simple sentence split
        sentences = re.split(r'(?<=[.!?])\s+', ctx)
        sentences = [s.strip() for s in sentences if len(s.strip()) > 0]

        print(f"=== Document rank {rank_pos+1} | id={cid} | similarity={score:.3f} ===")
        print("Answer-like snippet:")
        # show first 2 sentences as "answer-ish" snippet
        print(" ".join(sentences[:2])[:snippet_chars], "...")
        print("\nContext (truncated):")
        print(ctx[:snippet_chars], "...")
        print("\n")


# 2. Demo cell using a real question from your dataset

You can demo with a question from test_df (or any split), or type your own.

In [18]:
# Example 1: use a question from our held-out test set
example_row = test_df.iloc[0]
demo_question = example_row["question"]

qa_retrieve(demo_question, top_k_docs=3, snippet_chars=400)


QUESTION:
Does molecular Screen identify Cardiac Myosin-Binding Protein-C as a Protein Kinase G-Iα Substrate?

Top retrieved documents:

=== Document rank 1 | id=198815 | similarity=0.861 ===
Answer-like snippet:
Cardiac contractility is regulated by dynamic phosphorylation of sarcomeric proteins by kinases such as cAMP-activated protein kinase A (PKA). Efficient phosphorylation requires that PKA be anchored close to its targets by A-kinase anchoring proteins (AKAPs). ...

Context (truncated):
Cardiac contractility is regulated by dynamic phosphorylation of sarcomeric proteins by kinases such as cAMP-activated protein kinase A (PKA). Efficient phosphorylation requires that PKA be anchored close to its targets by A-kinase anchoring proteins (AKAPs). Cardiac Myosin Binding Protein-C (cMyBPC) and cardiac troponin I (cTNI) are hypertrophic cardiomyopathy (HCM)-causing sarcomeric proteins wh ...


=== Document rank 2 | id=224497 | similarity=0.717 ===
Answer-like snippet:
Smooth muscle myos

# Then for a fully live demo, we can also type any question:

In [20]:
# Example 2: ad-hoc question during presentation
custom_question = "Does metformin help diabetes?"
qa_retrieve(custom_question, top_k_docs=3, snippet_chars=400)


QUESTION:
Does metformin help diabetes?

Top retrieved documents:

=== Document rank 1 | id=69408 | similarity=0.865 ===
Answer-like snippet:
To know whether metformin improves postprandial hyperglycaemia, we examined the effect of metformin on the glycated albumin (GA) to glycated haemoglobin (HbA1c) ratio (GA/HbA1c ratio) in patients with newly diagnosed type 2 diabetes. Metformin and lifestyle interventions were initiated in 18 patients with newly diagnosed type 2 diabetes. ...

Context (truncated):
To know whether metformin improves postprandial hyperglycaemia, we examined the effect of metformin on the glycated albumin (GA) to glycated haemoglobin (HbA1c) ratio (GA/HbA1c ratio) in patients with newly diagnosed type 2 diabetes. Metformin and lifestyle interventions were initiated in 18 patients with newly diagnosed type 2 diabetes. Metformin was titrated to 1500 mg/day or maximum-tolerated d ...


=== Document rank 2 | id=174263 | similarity=0.858 ===
Answer-like snippet:
"High dos

# a tiny “ask me anything” loop

In [22]:
def interactive_qa():
    print("Type a biomedical question (or just press Enter to quit).")
    while True:
        q = input("\nQ: ").strip()
        if q == "":
            print("Thank you and see you next time.")
            break
        qa_retrieve(q, top_k_docs=2, snippet_chars=350)

interactive_qa()


Type a biomedical question (or just press Enter to quit).



Q:  does smoking cause heart attack?


QUESTION:
does smoking cause heart attack?

Top retrieved documents:

=== Document rank 1 | id=221678 | similarity=0.763 ===
Answer-like snippet:
Smoking accounts for more than 5 million years of potential life lost per year in the US alone. Leading causes of smoking attributable mortality are acute atherothrombotic complications of coronary heart disease (CHD). ...

Context (truncated):
Smoking accounts for more than 5 million years of potential life lost per year in the US alone. Leading causes of smoking attributable mortality are acute atherothrombotic complications of coronary heart disease (CHD). Smoking cessation is a key issue in preventive medicine, but quantitative data on its benefit for the coronary arteries are sparse. ...


=== Document rank 2 | id=57021 | similarity=0.745 ===
Answer-like snippet:
To validate self-report about smoking cessation with biochemical markers of smoking activity amongst patients with ischaemic heart disease. Outpatients at the Division of Cardio


Q:  


Bye.
Gold article rank for this question: 1


# Save everything 

In [23]:
# 1) Save the big dataframe with question/context/id
# Parquet or pickle is nicer than CSV for large text
df_all.to_parquet("df_all_pubmedqa.parquet")  # or: df_all.to_pickle("df_all_pubmedqa.pkl")

# 2) Save the context embeddings
np.save("context_embs_pubmedqa.npy", context_embs)

# 3) Save the fine-tuned dual encoder weights
torch.save(dual_encoder.state_dict(), "dual_encoder_pubmedqa.pt")

print("Saved: df_all_pubmedqa.parquet, context_embs_pubmedqa.npy, dual_encoder_pubmedqa.pt")


Saved: df_all_pubmedqa.parquet, context_embs_pubmedqa.npy, dual_encoder_pubmedqa.pt
