# Baselines

In [2]:
from src.data import load_omnimed_dataset
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
from datasets import load_dataset
import pandas as pd
import tempfile
from transformers import DataCollatorForMultipleChoice
import evaluate
import numpy as np
from transformers import AutoModelForMultipleChoice, TrainingArguments, Trainer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
train_df, val_df, test_df = load_omnimed_dataset()

print("Train size:", len(train_df))
print("Validation size:", len(val_df))
print("Test size:", len(test_df))

# Check for image overlap
print("Overlap train-test:", len(set(train_df['image_path']) & set(test_df['image_path'])))
print("Overlap train-val:", len(set(train_df['image_path']) & set(val_df['image_path'])))


Train size: 42380
Validation size: 7472
Test size: 5535
Overlap train-test: 0
Overlap train-val: 0


In [4]:
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
print("Using device:", device)

Using device: mps


In [5]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract")
model = AutoModelForMultipleChoice.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract")
model = model.to(device)

Some weights of BertForMultipleChoice were not initialized from the model checkpoint at microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
def to_hf_dataset(df:pd.DataFrame):
    with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as tmp:
        df.to_csv(tmp.name, index=False)
        dataset = load_dataset('csv', data_files={'data': tmp.name}, split='data')
    return dataset

def preprocess_function(examples):
    option_cols = ["option_A", "option_B", "option_C", "option_D"]
    first_sentences = []
    second_sentences = []
    labels = []

    for i in range(len(examples["question"])):
        question = str(examples["question"][i])
        options = [str(examples[col][i]) for col in option_cols]
        first_sentences.extend([question] * 4)
        second_sentences.extend(options)
        label = option_cols.index(str(examples["gt_label"][i]))
        labels.append(label)

    tokenized = tokenizer(first_sentences, second_sentences, truncation=True)
    result = {k: [v[i:i+4] for i in range(0, len(v), 4)] for k, v in tokenized.items()}
    result["label"] = labels
    return result

In [7]:
train_df

Unnamed: 0,dataset,question_id,question_type,question,image_path,option_A,option_B,option_C,option_D,modality_type,gt_label
0,JSIEC,JSIEC_0046,Disease Diagnosis,What abnormality is present in this fundus image?,./data/OmniMedVQA/Images/JSIEC/0.0.Normal/1ffa...,Choroidal neovascularization,Central serous retinopathy,No Finding,Diabetic retinopathy,Fundus Photography,option_C
1,JSIEC,JSIEC_0047,Disease Diagnosis,What abnormality is present in this fundus image?,./data/OmniMedVQA/Images/JSIEC/0.0.Normal/1ffa...,Macular degeneration,Diabetic retinopathy,Cataracts,No Finding,Fundus Photography,option_D
2,JSIEC,JSIEC_0048,Disease Diagnosis,Is there any abnormality present in this fundu...,./data/OmniMedVQA/Images/JSIEC/0.0.Normal/1ffa...,No Finding,Retinal detachment,Macular degeneration,Diabetic retinopathy,Fundus Photography,option_A
3,JSIEC,JSIEC_0049,Disease Diagnosis,What is the finding in this fundus image?,./data/OmniMedVQA/Images/JSIEC/0.0.Normal/1ffa...,Choroidal neovascularization,Optic neuritis,Glaucoma,No Finding,Fundus Photography,option_D
4,JSIEC,JSIEC_0050,Disease Diagnosis,What abnormality is present in this fundus image?,./data/OmniMedVQA/Images/JSIEC/0.0.Normal/1ffa...,Macular hole,No Finding,Choroidal neovascularization,Diabetic retinopathy,Fundus Photography,option_B
...,...,...,...,...,...,...,...,...,...,...,...
42375,ISBI2016,ISBI2016_0675,Disease Diagnosis,Does this dermoscopic image suggest a conditio...,./data/OmniMedVQA/Images/ISBI2016/ISBI2016_ISI...,Congenital condition.,Autoimmune condition.,No Finding,Malignant condition.,Dermoscopy,option_D
42376,ISBI2016,ISBI2016_0677,Disease Diagnosis,Does this dermoscopic image indicate a benign ...,./data/OmniMedVQA/Images/ISBI2016/ISBI2016_ISI...,Traumatic condition.,Malignant condition.,No Finding,Indeterminate condition.,Dermoscopy,option_B
42377,ISBI2016,ISBI2016_0678,Disease Diagnosis,Does this dermoscopic image suggest a malignan...,./data/OmniMedVQA/Images/ISBI2016/ISBI2016_ISI...,Malignant condition.,No Finding,Pre-cancerous condition.,Inflammatory condition.,Dermoscopy,option_A
42378,ISBI2016,ISBI2016_0679,Disease Diagnosis,Does this dermoscopic image suggest the presen...,./data/OmniMedVQA/Images/ISBI2016/ISBI2016_ISI...,Autoimmune condition.,Pre-cancerous condition.,Malignant condition.,No Finding,Dermoscopy,option_C


In [8]:
hf_test = to_hf_dataset(test_df)
hf_train = to_hf_dataset(train_df)
hf_val = to_hf_dataset(val_df)

Generating data split: 5535 examples [00:00, 264112.32 examples/s]
Generating data split: 42380 examples [00:00, 550681.88 examples/s]
Generating data split: 7472 examples [00:00, 404698.34 examples/s]


In [9]:
remove_columns = ["dataset","question_id","modality_type","question_type","image_path","option_A","option_B","option_C","option_D","gt_label"]
tokenized_train = hf_train.map(preprocess_function, batched=True,remove_columns=remove_columns)
tokenized_val = hf_val.map(preprocess_function, batched=True,remove_columns=remove_columns)
tokenized_test = hf_test.map(preprocess_function, batched=True,remove_columns=remove_columns)

Map:   0%|          | 0/42380 [00:00<?, ? examples/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Map: 100%|██████████| 42380/42380 [00:02<00:00, 20148.67 examples/s]
Map: 100%|██████████| 7472/7472 [00:00<00:00, 20921.23 examples/s]
Map: 100%|██████████| 5535/5535 [00:00<00:00, 23414.71 examples/s]


In [10]:
collator = DataCollatorForMultipleChoice(tokenizer=tokenizer)

In [11]:
accuracy = evaluate.load("accuracy")

In [14]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [15]:
training_args = TrainingArguments(
    output_dir="bert_text_baseline",
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    processing_class=tokenizer,
    data_collator=collator,
    compute_metrics=compute_metrics,
    
)

trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.0361,0.03074,0.993442
2,0.0265,0.020184,0.993041
3,0.0181,0.020443,0.995048




TrainOutput(global_step=7947, training_loss=0.03715920427302467, metrics={'train_runtime': 1039.4732, 'train_samples_per_second': 122.312, 'train_steps_per_second': 7.645, 'total_flos': 6667363389790368.0, 'train_loss': 0.03715920427302467, 'epoch': 3.0})

Inference

In [None]:
tokenizer = AutoTokenizer.from_pretrained("model/bert_text_baseline/checkpoint-2649")
model = AutoModelForMultipleChoice.from_pretrained("model/bert_text_baseline/checkpoint-2649")
model = model.to(device)

In [None]:
def batch_infer_and_evaluate(test_df, model, tokenizer, device, batch_size=32):
    option_cols = ["option_A", "option_B", "option_C", "option_D"]
    questions = []
    options = []
    gt_labels = []

    # Prepare batched inputs
    for idx, row in test_df.iterrows():
        q = str(row["question"])
        opts = [str(row[col]) for col in option_cols]
        questions.append([q] * 4)
        options.append(opts)
        gt_labels.append(option_cols.index(str(row["gt_label"])))

    # Flatten for tokenizer batching
    flat_questions = [q for group in questions for q in group]
    flat_options = [o for group in options for o in group]
    all_inputs = tokenizer(flat_questions, flat_options, return_tensors="pt", padding=True, truncation=True)
    # Reshape to (num_examples, num_choices, seq_len)
    for k in all_inputs:
        all_inputs[k] = all_inputs[k].view(len(test_df), 4, -1).to(device)
    print(all_inputs['input_ids'].shape)
    preds = []
    for start in range(0, len(test_df), batch_size):
        end = start + batch_size
        batch = {k: v[start:end] for k, v in all_inputs.items()}
        with torch.no_grad():
            outputs = model(**batch)
            logits = outputs.logits
            batch_preds = logits.argmax(dim=1).cpu().numpy()
            preds.extend(batch_preds)

    gt_labels = np.array(gt_labels)
    accuracy = (preds == gt_labels).mean()
    return accuracy, preds

In [10]:
acc, preds = batch_infer_and_evaluate(test_df, model, tokenizer, device)
print(f"Test accuracy: {acc:.4f}")

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


torch.Size([5535, 4, 46])
{'input_ids': tensor([[[    2,  5521, 10281,  ...,     0,     0,     0],
         [    2,  5521, 10281,  ...,     0,     0,     0],
         [    2,  5521, 10281,  ...,     0,     0,     0],
         [    2,  5521, 10281,  ...,     0,     0,     0]],

        [[    2,  5521,  1744,  ...,     0,     0,     0],
         [    2,  5521,  1744,  ...,     0,     0,     0],
         [    2,  5521,  1744,  ...,     0,     0,     0],
         [    2,  5521,  1744,  ...,     0,     0,     0]],

        [[    2,  5521, 10281,  ...,     0,     0,     0],
         [    2,  5521, 10281,  ...,     0,     0,     0],
         [    2,  5521, 10281,  ...,     0,     0,     0],
         [    2,  5521, 10281,  ...,     0,     0,     0]],

        ...,

        [[    2,  5521,  2264,  ...,     0,     0,     0],
         [    2,  5521,  2264,  ...,     0,     0,     0],
         [    2,  5521,  2264,  ...,     0,     0,     0],
         [    2,  5521,  2264,  ...,     0,     0,     

Obtain a accuracy of 0.9926