In [1]:
%pip install transformers
%pip install torch
%pip install datasets



Note: you may need to restart the kernel to use updated packages.


In [2]:
from datasets import load_dataset

ds = load_dataset("qiaojin/PubMedQA", "pqa_labeled")


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
import torch

In [6]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset

# Load BioBERT tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-base-cased-v1.1")
model = AutoModelForSequenceClassification.from_pretrained("dmis-lab/biobert-base-cased-v1.1", num_labels=3)  # 3 for yes/no/maybe classification

# Move model to the available device (cuda or cpu)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Load the dataset
ds = load_dataset("qiaojin/PubMedQA", "pqa_labeled")

# Preprocess the dataset
def preprocess_function(examples):
    inputs = examples['question']
    # Use context for more accurate predictions
    context = examples['context']
    # Extract the true answers
    targets = examples['final_decision']
    
    # Combine question and context
    combined_input = [f"Context: {c} Question: {q}" for q, c in zip(inputs, context)]
    
    # Tokenize the inputs and labels
    model_inputs = tokenizer(combined_input, max_length=512, truncation=True, padding="max_length")
    
    # Map the final_decision (yes/no/maybe) to labels
    label_map = {"yes": 0, "no": 1, "maybe": 2}
    labels = [label_map[ans] for ans in targets]
    
    model_inputs["labels"] = labels

    return model_inputs

# Select a small subset of the dataset for demonstration
small_ds = ds['train'].select(range(50))
tokenized_ds = small_ds.map(preprocess_function, batched=True)

# Split the dataset into training, validation, and test sets
train_size = 30
val_size = 10
test_size = 10

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    tokenized_ds, [train_size, val_size, test_size]
)

# Define training arguments
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=3,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    weight_decay=0.01,
)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
)

# Train the model
trainer.train()

# Evaluate the model
metrics = trainer.evaluate(eval_dataset=test_dataset)
print("Test set evaluation:", metrics)

# Generate predictions
def prediction_answer(question, context):
    # Combine the question and context into one input
    combined_input = f"Context: {context} Question: {question}"
    
    # Tokenize the input
    inputs = tokenizer(combined_input, return_tensors="pt", max_length=512, truncation=True, padding="max_length")
    inputs = {key: val.to(device) for key, val in inputs.items()}  # Move inputs to the correct device
    
    # Get the model's output
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        predicted_label = torch.argmax(logits, dim=1).item()
    
    # Map numeric labels to "yes", "no", or "maybe"
    label_map = {0: "yes", 1: "no", 2: "maybe"}
    return label_map[predicted_label]

num_correct = 0

# Evaluate the Q&A performance
for example in test_dataset:
    question = example['question']
    # Include the context field in the prediction
    context = example['context']
    # Extract true answer from 'final_decision' column
    true_answer = example['final_decision']
    # Use model to predict the answer
    predicted_answer = prediction_answer(question, context)
    
    print(f"Question: {question}")
    print(f"Context: {context}")
    print(f"True Answer: {true_answer}")
    print(f"Predicted Answer: {predicted_answer}")
    print("="*80)
    
    # Calculate accuracy
    if true_answer == predicted_answer:
        num_correct += 1

# Print accuracy
accuracy = num_correct / len(test_dataset)
print(f"Accuracy: {accuracy}")


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at dmis-lab/biobert-base-cased-v1.1 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

Map: 100%|██████████| 10/10 [00:00<00:00, 209.86 examples/s]
  4%|▍         | 7/180 [10:52<4:28:48, 93.23s/it]
 33%|███▎      | 6/18 [00:11<00:21,  1.78s/it]
[A
[A

[A[A                                       
                                               
 33%|███▎      | 6/18 [00:12<00:21,  1.78s/it]
[A

{'eval_loss': 1.2338898181915283, 'eval_runtime': 0.683, 'eval_samples_per_second': 2.928, 'eval_steps_per_second': 2.928, 'epoch': 1.0}


 67%|██████▋   | 12/18 [00:22<00:10,  1.68s/it]
[A
[A

[A[A                                       
                                               
 67%|██████▋   | 12/18 [00:22<00:10,  1.68s/it]
[A

{'eval_loss': 1.3380873203277588, 'eval_runtime': 0.6236, 'eval_samples_per_second': 3.207, 'eval_steps_per_second': 3.207, 'epoch': 2.0}


100%|██████████| 18/18 [00:32<00:00,  1.62s/it]
[A
[A

[A[A                                       
                                               
100%|██████████| 18/18 [00:34<00:00,  1.62s/it]
[A
100%|██████████| 18/18 [00:34<00:00,  1.91s/it]


{'eval_loss': 1.3831374645233154, 'eval_runtime': 0.6, 'eval_samples_per_second': 3.333, 'eval_steps_per_second': 3.333, 'epoch': 3.0}
{'train_runtime': 34.3542, 'train_samples_per_second': 0.524, 'train_steps_per_second': 0.524, 'train_loss': 1.1432906256781683, 'epoch': 3.0}


100%|██████████| 2/2 [00:00<00:00,  6.37it/s]


Test set evaluation: {'eval_loss': 1.0763216018676758, 'eval_runtime': 0.6505, 'eval_samples_per_second': 3.074, 'eval_steps_per_second': 3.074, 'epoch': 3.0}
Question: Syncope during bathing in infants, a pediatric form of water-induced urticaria?
Context: {'contexts': ['Apparent life-threatening events in infants are a difficult and frequent problem in pediatric practice. The prognosis is uncertain because of risk of sudden infant death syndrome.', 'Eight infants aged 2 to 15 months were admitted during a period of 6 years; they suffered from similar maladies in the bath: on immersion, they became pale, hypotonic, still and unreactive; recovery took a few seconds after withdrawal from the bath and stimulation. Two diagnoses were initially considered: seizure or gastroesophageal reflux but this was doubtful. The hypothesis of an equivalent of aquagenic urticaria was then considered; as for patients with this disease, each infant\'s family contained members suffering from dermographism