In [1]:
# Install dependencies
!pip install datasets rank_bm25 transformers nltk scikit-learn sentence-transformers


Collecting datasets
  Using cached datasets-3.5.1-py3-none-any.whl.metadata (19 kB)
Collecting rank_bm25
  Using cached rank_bm25-0.2.2-py3-none-any.whl.metadata (3.2 kB)
Collecting transformers
  Using cached transformers-4.51.3-py3-none-any.whl.metadata (38 kB)
Collecting nltk
  Using cached nltk-3.9.1-py3-none-any.whl.metadata (2.9 kB)
Collecting scikit-learn
  Downloading scikit_learn-1.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (18 kB)
Collecting sentence-transformers
  Using cached sentence_transformers-4.1.0-py3-none-any.whl.metadata (13 kB)
Collecting filelock (from datasets)
  Downloading filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)
Collecting numpy>=1.17 (from datasets)
  Downloading numpy-2.2.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Using cached pyarrow-20.0.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)

In [None]:
from datasets import load_dataset
from rank_bm25 import BM25Okapi        # still used in preprocessing
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from nltk.tokenize import TreebankWordTokenizer
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
import torch

### NEW for semantic similarity
from sentence_transformers import SentenceTransformer, util

In [None]:
def evaluate_semantic_similarity(preds, refs):
    model_Sentence = SentenceTransformer("all-MiniLM-L6-v2")
    sim_scores = []
    for pred, ref in zip(preds, refs):
        sim = util.cos_sim(
            model_Sentence.encode(pred, convert_to_tensor=True),
            model_Sentence.encode(ref, convert_to_tensor=True)
        ).item()
        sim_scores.append(sim)
    avg_sim = sum(sim_scores) / len(sim_scores)
    print(f"Average Semantic Similarity: {avg_sim:.2f}")

In [None]:
# ——— load & split
dataset = load_dataset("pubmed_qa", "pqa_labeled")["train"].shuffle(seed=42)
split = dataset.train_test_split(test_size=0.2, seed=42)
train_dataset, test_dataset = split["train"], split["test"]


In [None]:
# ——— BM25 retriever (unchanged)
corpus = [' '.join(e['contexts']) for e in train_dataset["context"]]
tokenizer_bm25 = TreebankWordTokenizer()
bm25 = BM25Okapi([tokenizer_bm25.tokenize(doc) for doc in corpus])
def retrieve_with_bm25(q, k=1):
    tokens = tokenizer_bm25.tokenize(q)
    scores = bm25.get_scores(tokens)
    idxs = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:k]
    return [corpus[i] for i in idxs]

In [None]:
# ——— preprocessing
label2id = {'no': 0, 'yes': 1, 'maybe': 2}
id2label = {v:k for k,v in label2id.items()}
tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")

In [None]:
def preprocess(examples):
    in_ids, attn, labs = [], [], []
    for q, lbl in zip(examples['question'], examples['final_decision']):
        docs = retrieve_with_bm25(q)
        if not docs: continue
        enc = tokenizer(q, docs[0], truncation=True, padding='max_length', max_length=512)
        in_ids.append(enc['input_ids'])
        attn.append(enc['attention_mask'])
        labs.append(label2id[lbl.lower()])
    return {'input_ids':in_ids, 'attention_mask':attn, 'labels':labs}

train_enc = preprocess(train_dataset)
test_enc  = preprocess(test_dataset)

In [None]:
class PubMedQADataset(torch.utils.data.Dataset):
    def __init__(self, enc): self.enc = enc
    def __len__(self): return len(self.enc['labels'])
    def __getitem__(self, i): return {k:torch.tensor(v[i]) for k,v in self.enc.items()}

train_ds, test_ds = PubMedQADataset(train_enc), PubMedQADataset(test_enc)



In [None]:
# ——— classification metrics
def compute_metrics(pred):
    preds = np.argmax(pred.predictions, axis=1)
    acc   = accuracy_score(pred.label_ids, preds)
    f1    = f1_score(pred.label_ids, preds, average='macro')
    return {'accuracy':acc, 'f1':f1}


In [None]:
# — PART 1: pretrained evaluation
print("=== PART 1: Pretrained evaluation ===")
pre_model = AutoModelForSequenceClassification.from_pretrained(
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
    num_labels=3, id2label=id2label, label2id=label2id
)
pre_trainer = Trainer(model=pre_model, compute_metrics=compute_metrics, eval_dataset=test_ds)
pre_res = pre_trainer.evaluate()
print(pre_res)

# get raw preds & refs as texts
pred_out = pre_trainer.predict(test_ds)
pred_ids = np.argmax(pred_out.predictions, axis=1)
ref_ids  = pred_out.label_ids
pred_texts = [id2label[i] for i in pred_ids]
ref_texts  = [id2label[i] for i in ref_ids]

print("Semantic similarity (pretrained):")
evaluate_semantic_similarity(pred_texts, ref_texts)

In [None]:
# — PART 2: finetune & re-evaluate
print("\n=== PART 2: Finetune & evaluate ===")
ft_model = AutoModelForSequenceClassification.from_pretrained(
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
    num_labels=3, id2label=id2label, label2id=label2id
)
training_args = TrainingArguments(
    output_dir='./results', learning_rate=2e-5,
    per_device_train_batch_size=8, per_device_eval_batch_size=8,
    num_train_epochs=3, weight_decay=0.01, logging_steps=50
)
ft_trainer = Trainer(
    model=ft_model, args=training_args,
    train_dataset=train_ds, compute_metrics=compute_metrics
)
ft_trainer.train()
ft_res = ft_trainer.evaluate(eval_dataset=test_ds)
print(ft_res)

# get finetuned preds & refs
pred_out2 = ft_trainer.predict(test_ds)
pred_ids2 = np.argmax(pred_out2.predictions, axis=1)
ref_ids2  = pred_out2.label_ids
pred_texts2 = [id2label[i] for i in pred_ids2]
ref_texts2  = [id2label[i] for i in ref_ids2]

print("Semantic similarity (finetuned):")
evaluate_semantic_similarity(pred_texts2, ref_texts2)


In [None]:
# — FINAL comparison
print("\n=== Comparison on test set ===")
print(f"Accuracy → pretrained: {pre_res['eval_accuracy']:.4f}, finetuned: {ft_res['eval_accuracy']:.4f}")
print(f"   F1    → pretrained: {pre_res['eval_f1']:.4f}, finetuned: {ft_res['eval_f1']:.4f}")
