## Evaluation: MSFT model

This is notebook for **MSFT model evaluation** for estimating improvements over the base model and compare with other adapted models. We choose to evaluate on the set of benchmarks from [Open Medical-LLM Leaderboard](https://huggingface.co/spaces/openlifescienceai/open_medical_llm_leaderboard) including:

* [MedMCQA](https://huggingface.co/datasets/openlifescienceai/medmcqa) - MCQ, 200 samples from validation split
* [MedQA](https://huggingface.co/datasets/GBaker/MedQA-USMLE-4-options-hf) - MCQ, 200 samples from validation split
* [MMLU](https://huggingface.co/datasets/cais/mmlu) - MCQ, 200 samples from test splits of 6 medical subsets
* [PubMedQA](https://huggingface.co/datasets/qiaojin/PubMedQA) - QA, 200 samples from train split of pqa_labeled subset

*MSFT model:* [MilyaShams/DeepSeek-R1-Distill-Qwen-1.5B-medical-msft-merged](https://huggingface.co/MilyaShams/DeepSeek-R1-Distill-Qwen-1.5B-medical-msft-merged)

### Setup

In [1]:
%%capture
!pip install datasets vllm

In [2]:
import re
from tqdm import tqdm
import math
import pandas as pd
from datasets import load_dataset, concatenate_datasets
from vllm import LLM, SamplingParams
import torch

INFO 04-13 22:07:24 [__init__.py:239] Automatically detected platform cuda.


2025-04-13 22:07:26.879644: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744582047.119015      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744582047.187703      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


### MSFT model loading

In [3]:
MODEL_NAME = "MilyaShams/DeepSeek-R1-Distill-Qwen-1.5B-medical-msft-merged"
MAX_TOKENS = 4096

In [4]:
print("\n--- Loading LLM with vLLM ---")
try:
    llm = LLM(
        model=MODEL_NAME,
        tensor_parallel_size=2,
        dtype=torch.float16,
    )
    sampling_params = SamplingParams(
        max_tokens=MAX_TOKENS,
        temperature=0.01,
        top_p=1.0,
        top_k=-1
    )
    print(f"LLM '{MODEL_NAME}' loaded successfully.")
except Exception as e:
    print(f"Error loading LLM with vLLM: {e}")
    print("Please ensure the MODEL_NAME is correct, vLLM is installed, and you have compatible hardware (GPU).")
    exit()


--- Loading LLM with vLLM ---


config.json:   0%|          | 0.00/820 [00:00<?, ?B/s]

INFO 04-13 22:07:54 [config.py:600] This model supports multiple tasks: {'generate', 'reward', 'score', 'embed', 'classify'}. Defaulting to 'generate'.
INFO 04-13 22:07:54 [config.py:1600] Defaulting to use mp for distributed inference
INFO 04-13 22:07:54 [config.py:1780] Chunked prefill is enabled with max_num_batched_tokens=2048.
INFO 04-13 22:07:54 [llm_engine.py:242] Initializing a V0 LLM engine (v0.8.3) with config: model='MilyaShams/DeepSeek-R1-Distill-Qwen-1.5B-medical-msft-merged', speculative_config=None, tokenizer='MilyaShams/DeepSeek-R1-Distill-Qwen-1.5B-medical-msft-merged', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=131072, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_c

tokenizer_config.json:   0%|          | 0.00/7.09k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/495 [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

[1;36m(VllmWorkerProcess pid=145)[0;0m INFO 04-13 22:07:56 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-13 22:07:58 [cuda.py:240] Cannot use FlashAttention-2 backend for Volta and Turing GPUs.
INFO 04-13 22:07:58 [cuda.py:289] Using XFormers backend.
[1;36m(VllmWorkerProcess pid=145)[0;0m INFO 04-13 22:07:58 [cuda.py:240] Cannot use FlashAttention-2 backend for Volta and Turing GPUs.
[1;36m(VllmWorkerProcess pid=145)[0;0m INFO 04-13 22:07:58 [cuda.py:289] Using XFormers backend.


[W413 22:08:09.948266599 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W413 22:08:10.409179607 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W413 22:08:19.958967107 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3


INFO 04-13 22:08:29 [utils.py:990] Found nccl from library libnccl.so.2
INFO 04-13 22:08:29 [pynccl.py:69] vLLM is using nccl==2.21.5
[1;36m(VllmWorkerProcess pid=145)[0;0m INFO 04-13 22:08:29 [utils.py:990] Found nccl from library libnccl.so.2
[1;36m(VllmWorkerProcess pid=145)[0;0m INFO 04-13 22:08:29 [pynccl.py:69] vLLM is using nccl==2.21.5


[W413 22:08:29.969502460 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3


INFO 04-13 22:08:30 [custom_all_reduce_utils.py:206] generating GPU P2P access cache in /root/.cache/vllm/gpu_p2p_access_cache_for_0,1.json
INFO 04-13 22:08:54 [custom_all_reduce_utils.py:244] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1.json
[1;36m(VllmWorkerProcess pid=145)[0;0m INFO 04-13 22:08:54 [custom_all_reduce_utils.py:244] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1.json
INFO 04-13 22:08:54 [shm_broadcast.py:264] vLLM message queue communication handle: Handle(local_reader_ranks=[1], buffer_handle=(1, 4194304, 6, 'psm_059d96f6'), local_subscribe_addr='ipc:///tmp/691c578d-0c95-45c7-9457-e56e736c8811', remote_subscribe_addr=None, remote_addr_ipv6=False)
INFO 04-13 22:08:54 [parallel_state.py:957] rank 0 in world size 2 is assigned as DP rank 0, PP rank 0, TP rank 0
[1;36m(VllmWorkerProcess pid=145)[0;0m INFO 04-13 22:08:54 [parallel_state.py:957] rank 1 in world size 2 is assigned as DP rank 0, PP rank

model.safetensors:   0%|          | 0.00/3.55G [00:00<?, ?B/s]

INFO 04-13 22:09:03 [weight_utils.py:281] Time spent downloading weights for MilyaShams/DeepSeek-R1-Distill-Qwen-1.5B-medical-msft-merged: 8.771985 seconds
INFO 04-13 22:09:03 [weight_utils.py:315] No model.safetensors.index.json found in remote.


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 04-13 22:09:05 [loader.py:447] Loading weights took 2.17 seconds
INFO 04-13 22:09:06 [model_runner.py:1146] Model loading took 1.6901 GiB and 11.254775 seconds
[1;36m(VllmWorkerProcess pid=145)[0;0m INFO 04-13 22:09:10 [weight_utils.py:281] Time spent downloading weights for MilyaShams/DeepSeek-R1-Distill-Qwen-1.5B-medical-msft-merged: 6.806367 seconds
[1;36m(VllmWorkerProcess pid=145)[0;0m INFO 04-13 22:09:10 [weight_utils.py:315] No model.safetensors.index.json found in remote.
[1;36m(VllmWorkerProcess pid=145)[0;0m INFO 04-13 22:09:13 [loader.py:447] Loading weights took 3.17 seconds
[1;36m(VllmWorkerProcess pid=145)[0;0m INFO 04-13 22:09:14 [model_runner.py:1146] Model loading took 1.6901 GiB and 19.162172 seconds
[1;36m(VllmWorkerProcess pid=145)[0;0m INFO 04-13 22:09:24 [worker.py:267] Memory profiling takes 9.39 seconds
[1;36m(VllmWorkerProcess pid=145)[0;0m INFO 04-13 22:09:24 [worker.py:267] the current vLLM instance can use total_gpu_memory (14.74GiB) x gpu_m

Capturing CUDA graph shapes:  97%|█████████▋| 34/35 [00:40<00:01,  1.07s/it]

[1;36m(VllmWorkerProcess pid=145)[0;0m INFO 04-13 22:10:13 [model_runner.py:1598] Graph capturing finished in 43 secs, took 0.40 GiB


Capturing CUDA graph shapes: 100%|██████████| 35/35 [00:42<00:00,  1.21s/it]

INFO 04-13 22:10:13 [model_runner.py:1598] Graph capturing finished in 42 secs, took 0.40 GiB
INFO 04-13 22:10:13 [llm_engine.py:448] init engine (profile, create kv cache, warmup model) took 59.67 seconds





LLM 'MilyaShams/DeepSeek-R1-Distill-Qwen-1.5B-medical-msft-merged' loaded successfully.


### 1. MedMCQA benchmark

#### Dataset loading and preparing

In [5]:
SEED = 4242
BATCH_SIZE = 4
NUM_SAMPLES = 200
DATASET_MEDMCQA = "openlifescienceai/medmcqa"

In [6]:
ds_medmcqa = load_dataset(DATASET_MEDMCQA, split="validation")
ds_medmcqa = ds_medmcqa.shuffle(seed=SEED).select(range(NUM_SAMPLES))
ds_medmcqa

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/85.9M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/936k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/1.48M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/182822 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/6150 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/4183 [00:00<?, ? examples/s]

Dataset({
    features: ['id', 'question', 'opa', 'opb', 'opc', 'opd', 'cop', 'choice_type', 'exp', 'subject_name', 'topic_name'],
    num_rows: 200
})

In [7]:
ds_medmcqa[0]

{'id': '4653fb7a-ddbf-493b-b4ef-92205582a27a',
 'question': 'Which of the following tooth is not having 5 cusps?',
 'opa': 'Mandibular 2nd Molar',
 'opb': 'Mandibular 1st Molar',
 'opc': 'Mandibular 3rd Molar',
 'opd': 'Maxillary 1st Molar',
 'cop': 0,
 'choice_type': 'single',
 'exp': None,
 'subject_name': 'Dental',
 'topic_name': None}

#### Helper functions definition

In [8]:
def format_prompt_medmcqa(example):
    """Formats a single example into a prompt for the LLM."""
    question = example['question']
    options = {
        "A": example['opa'],
        "B": example['opb'],
        "C": example['opc'],
        "D": example['opd']
    }
    
    prompt = f"""
You are an expert in solving multiple-choice questions accurately and explaining your reasoning clearly.
Given a question and a list of answer choices (A, B, C, D), your task is to:
1. Reason shortly about the question and answer choices to find evidances to support your answer.
2. Identify the correct answer. Please choose the single best answer from the options provided.
3. Output the final answer in the format: Answer: [Option Letter]

Question: {question}
Options:
A. {options['A']}
B. {options['B']}
C. {options['C']}
D. {options['D']}

Reasoning:
    """
    return prompt

In [9]:
def get_ground_truth_medmcqa(example):
    """Maps the correct option index (cop) to the corresponding letter."""
    mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D'}
    cop_index = example.get('cop')
    if cop_index is None or cop_index not in mapping:
        print(f"Warning: Invalid 'cop' value found: {cop_index} in example ID {example.get('id')}. Skipping ground truth.")
        return None
    return mapping[cop_index]

In [5]:
def extract_choice_mcq(generated_text):
    """Extracts the predicted choice (A, B, C, or D) from the LLM's output."""
    text = generated_text.strip()

    # Check for phrases like "The answer is A" or "Answer: A"
    match = re.search(r'(?:answer|choice|option) is\s*:?\s*([A-D])', text, re.IGNORECASE)
    if match:
        return match.group(1).upper()

    # Look for the first standalone letter A, B, C, or D in the text
    match = re.search(r'\b([A-D])\b', text)
    if match:
        return match.group(1).upper()

    # Fallback - If no clear choice found, return None
    print(f"Warning: Could not extract answer from text: '{text[:100]}...{text[-100:]}'")
    return None

#### Evaluation

In [11]:
print("\n--- Preparing Prompts and Ground Truths ---")
prompts = [format_prompt_medmcqa(ex) for ex in tqdm(ds_medmcqa, desc="Formatting prompts")]
ground_truths = [get_ground_truth_medmcqa(ex) for ex in tqdm(ds_medmcqa, desc="Extracting ground truths")]
valid_indices = [i for i, gt in enumerate(ground_truths) if gt is not None]

if len(valid_indices) < len(ground_truths):
     print(f"Warning: {len(ground_truths) - len(valid_indices)} examples had invalid ground truths and were excluded.")
     prompts = [prompts[i] for i in valid_indices]
     ground_truths = [ground_truths[i] for i in valid_indices]
     original_indices = valid_indices

if len(prompts) > 0:
    print("\nExample Prompt:")
    print(prompts[0])
    print(f"Corresponding Ground Truth: {ground_truths[0]}")
else:
    print("No valid prompts to evaluate.")
    exit()


--- Preparing Prompts and Ground Truths ---


Formatting prompts: 100%|██████████| 200/200 [00:00<00:00, 6592.43it/s]
Extracting ground truths: 100%|██████████| 200/200 [00:00<00:00, 7645.12it/s]


Example Prompt:

You are an expert in solving multiple-choice questions accurately and explaining your reasoning clearly.
Given a question and a list of answer choices (A, B, C, D), your task is to:
1. Reason shortly about the question and answer choices to find evidances to support your answer.
2. Identify the correct answer. Please choose the single best answer from the options provided.
3. Output the final answer in the format: Answer: [Option Letter]

Question: Which of the following tooth is not having 5 cusps?
Options:
A. Mandibular 2nd Molar
B. Mandibular 1st Molar
C. Mandibular 3rd Molar
D. Maxillary 1st Molar

Reasoning:
    
Corresponding Ground Truth: A





In [12]:
print("\n--- Running Inference ---")
all_outputs_text = []
num_batches = math.ceil(len(prompts) / BATCH_SIZE)

for i in tqdm(range(num_batches), desc="Generating Responses"):
    start_idx = i * BATCH_SIZE
    end_idx = min((i + 1) * BATCH_SIZE, len(prompts))
    batch_prompts = prompts[start_idx:end_idx]
    outputs = llm.generate(batch_prompts, sampling_params, use_tqdm=False)
    batch_outputs_text = [output.outputs[0].text.strip() for output in outputs]
    all_outputs_text.extend(batch_outputs_text)

if len(all_outputs_text) > 0:
    print("\nExample Generated Text (raw):")
    print(all_outputs_text[0])


--- Running Inference ---


Generating Responses: 100%|██████████| 50/50 [34:00<00:00, 40.80s/it]


Example Generated Text (raw):
- The mandibular 2nd molar is the most common type of molar in the lower jaw.
     - The mandibular 1st molar is the smallest molar in the lower jaw.
     - The mandibular 3rd molar is the largest molar in the lower jaw.
     - The maxillary 1st molar is the smallest molar in the upper jaw.

Options:
A. Mandibular 2nd Molar
B. Mandibular 1st Molar
C. Mandibular 3rd Molar
D. Maxillary 1st Molar

Options:
A. Mandibular 2nd Molar
B. Mandibular 1st Molar
C. Mandibular 3rd Molar
D. Maxillary 1st Molar

Options:
A. Mandibular 2nd Molar
B. Mandibular 1st Molar
C. Mandibular 3rd Molar
D. Maxillary 1st Molar

Options:
A. Mandibular 2nd Molar
B. Mandibular 1st Molar
C. Mandibular 3rd Molar
D. Maxillary 1st Molar

Options:
A. Mandibular 2nd Molar
B. Mandibular 1st Molar
C. Mandibular 3rd Molar
D. Maxillary 1st Molar

Options:
A. Mandibular 2nd Molar
B. Mandibular 1st Molar
C. Mandibular 3rd Molar
D. Maxillary 1st Molar

Options:
A. Mandibular 2nd Molar
B. Mandibular




In [13]:
print("\n--- Extracting Predictions ---")
predictions = [extract_choice_mcq(text) for text in tqdm(all_outputs_text, desc="Extracting choices")]
num_invalid_responces = predictions.count(None)
print(f"\n------------------------------\nNumber of invalid responces: {num_invalid_responces}")

if len(predictions) > 0:
    print("\nExample Extracted Prediction:")
    print(predictions[0])


--- Extracting Predictions ---


Extracting choices: 100%|██████████| 200/200 [00:00<00:00, 3733.45it/s]

     - Sella is a city in France.
     - Porion is a city in Fra...med after the horizontal plane of the earth.
     - The plane is used to create a horizontal surface'
     - The maxillary sinus is the first to develop in the face.
     - The ethmoidal sinus is the'
     - Interleukins are encoded by the IL-1 ...me 19.
     - The ODA gene is located on chromosome 19.
     - The IgG gene is located on chromosome'
     - They ... - The muscles of mastication are not responsible for the production of mucus.
     - The muscles of'
...both.
     - The study was conducted by Turku, so it's likely related to both conditions.
     - The'
     - In Goodpasture's...eart is enlarged and the lungs are enlarged.
     - In Goodpasture's syndrome, the heart is enlarged'
     - Candida...illus.
     - The most common fungal infection in the eye in an HIV positive patient is Aspergillus.'
     - The oculomotor nerve is respo...sed by a lesion of the oculomotor nerve.
     - Ptosis is caused by a lesio




In [14]:
print("\n--- Calculating Metrics ---")
correct_count = 0
total_count = len(predictions)
results_by_subject = {}

if total_count != len(ground_truths):
     print(f"Warning: Mismatch between number of predictions ({total_count}) and ground truths ({len(ground_truths)}). This should not happen.")
     total_count = min(total_count, len(ground_truths))

for i in range(total_count):
    original_data_index = original_indices[i] if 'original_indices' in locals() else i
    data_item = ds_medmcqa[original_data_index]
    subject = data_item.get('subject_name', 'Unknown')

    pred = predictions[i]
    truth = ground_truths[i]
    is_correct = (pred == truth)

    if subject not in results_by_subject:
        results_by_subject[subject] = {'correct': 0, 'total': 0}

    if is_correct:
        correct_count += 1
        results_by_subject[subject]['correct'] += 1
    results_by_subject[subject]['total'] += 1

overall_accuracy = (correct_count / total_count) * 100 if total_count > 0 else 0


--- Calculating Metrics ---


In [15]:
print("\n--- Evaluation Results ---")
print(f"Model Evaluated: {MODEL_NAME}")
print(f"Dataset Used: {DATASET_MEDMCQA}")
print(f"Number of Questions Evaluated: {total_count}")
print(f"Number of Correct Answers: {correct_count}")
print(f"Overall Accuracy: {overall_accuracy:.2f}%")

print("\nAccuracy by Subject:")
sorted_subjects = sorted(results_by_subject.keys())
for subject in sorted_subjects:
    counts = results_by_subject[subject]
    sub_acc = (counts['correct'] / counts['total']) * 100 if counts['total'] > 0 else 0
    print(f"- {subject}: {sub_acc:.2f}% ({counts['correct']}/{counts['total']})")


--- Evaluation Results ---
Model Evaluated: MilyaShams/DeepSeek-R1-Distill-Qwen-1.5B-medical-msft-merged
Dataset Used: openlifescienceai/medmcqa
Number of Questions Evaluated: 200
Number of Correct Answers: 46
Overall Accuracy: 23.00%

Accuracy by Subject:
- Anaesthesia: 0.00% (0/2)
- Anatomy: 0.00% (0/6)
- Biochemistry: 25.00% (2/8)
- Dental: 22.39% (15/67)
- ENT: 40.00% (2/5)
- Forensic Medicine: 57.14% (4/7)
- Gynaecology & Obstetrics: 47.06% (8/17)
- Medicine: 16.67% (1/6)
- Microbiology: 16.67% (1/6)
- Ophthalmology: 0.00% (0/4)
- Pathology: 25.00% (3/12)
- Pediatrics: 7.14% (1/14)
- Pharmacology: 16.67% (2/12)
- Physiology: 0.00% (0/6)
- Radiology: 50.00% (1/2)
- Skin: 0.00% (0/1)
- Social & Preventive Medicine: 33.33% (2/6)
- Surgery: 21.05% (4/19)


### 2. MedQA

#### Dataset loading and preparing

In [5]:
SEED = 4242
BATCH_SIZE = 4
NUM_SAMPLES = 200
DATASET_MEDQA = "GBaker/MedQA-USMLE-4-options-hf"
SPLIT_MEDQA = "validation"

In [6]:
ds_medqa = load_dataset(DATASET_MEDQA, split=SPLIT_MEDQA)
ds_medqa = ds_medqa.shuffle(seed=SEED).select(range(NUM_SAMPLES))
ds_medqa

README.md:   0%|          | 0.00/640 [00:00<?, ?B/s]

train.json:   0%|          | 0.00/9.77M [00:00<?, ?B/s]

dev.json:   0%|          | 0.00/1.22M [00:00<?, ?B/s]

test.json:   0%|          | 0.00/1.25M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/10178 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1272 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1273 [00:00<?, ? examples/s]

Dataset({
    features: ['id', 'sent1', 'sent2', 'ending0', 'ending1', 'ending2', 'ending3', 'label'],
    num_rows: 200
})

In [7]:
ds_medqa[1]

{'id': 'dev-00646',
 'sent1': 'A 31-year-old gravida 2 para 2 woman presents to her primary care physician for follow up. Two weeks ago, she gave birth via vaginal delivery to a 9.5 lb (4.3 kg) male infant. The delivery was complicated by a vaginal laceration that required extensive suturing once the infant was delivered. Immediately after delivery of the placenta she experienced intense shaking and chills that resolved within 1 hour. She has felt well since the delivery but admits to 6 days of malodorous smelling vaginal discharge that is tan in color. She has a history of vaginal candidiasis and is worried that it may be recurring. Her temperature is 98.8°F (37.1°C), blood pressure is 122/73 mmHg, pulse is 88/min, respirations are 16/min, and BMI is 33 kg/m^2. Speculum exam reveals a 1.5 cm dark red, velvety lesion on the posterior vaginal wall with a tan discharge. The pH of the discharge is 6.4. Which of the following is the most likely diagnosis?',
 'sent2': '',
 'ending0': 'Bacte

#### Helper functions definition

In [8]:
def format_prompt_medqa(example):
    """Formats a single example into a prompt for the LLM."""
    question = example['sent1']
    options = {
        "A": example['ending0'],
        "B": example['ending1'],
        "C": example['ending2'],
        "D": example['ending3'],
    }
    
    prompt = f"""
You are an expert in solving multiple-choice questions accurately and explaining your reasoning clearly.
Given a question and a list of answer choices (A, B, C, D), your task is to:
1. Reason shortly about the question and answer choices to find evidances to support your answer.
2. Identify the correct answer. Please choose the single best answer from the options provided.
3. Output the final answer in the format: Answer: [Option Letter]

Question: {question}
Options:
A. {options['A']}
B. {options['B']}
C. {options['C']}
D. {options['D']}

Reasoning:
    """
    return prompt

In [9]:
def get_ground_truth_medqa(example):
    """Maps the label to the corresponding letter."""
    mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D'}
    label = example.get('label')
    if label is None or label not in mapping:
        print(f"Warning: Invalid 'cop' value found: {label} in example ID {example.get('id')}. Skipping ground truth.")
        return None
    return mapping[label]

#### Evaluation

In [10]:
print("\n--- Preparing Prompts and Ground Truths ---")
prompts = [format_prompt_medqa(ex) for ex in tqdm(ds_medqa, desc="Formatting prompts")]
ground_truths = [get_ground_truth_medqa(ex) for ex in tqdm(ds_medqa, desc="Extracting ground truths")]
valid_indices = [i for i, gt in enumerate(ground_truths) if gt is not None]

if len(valid_indices) < len(ground_truths):
     print(f"Warning: {len(ground_truths) - len(valid_indices)} examples had invalid ground truths and were excluded.")
     prompts = [prompts[i] for i in valid_indices]
     ground_truths = [ground_truths[i] for i in valid_indices]
     original_indices = valid_indices

if len(prompts) > 0:
    print("\nExample Prompt:")
    print(prompts[0])
    print(f"Corresponding Ground Truth: {ground_truths[0]}")
else:
    print("No valid prompts to evaluate.")
    exit()


--- Preparing Prompts and Ground Truths ---


Formatting prompts: 100%|██████████| 200/200 [00:00<00:00, 8167.27it/s]
Extracting ground truths: 100%|██████████| 200/200 [00:00<00:00, 9634.33it/s]


Example Prompt:

You are an expert in solving multiple-choice questions accurately and explaining your reasoning clearly.
Given a question and a list of answer choices (A, B, C, D), your task is to:
1. Reason shortly about the question and answer choices to find evidances to support your answer.
2. Identify the correct answer. Please choose the single best answer from the options provided.
3. Output the final answer in the format: Answer: [Option Letter]

Question: A 9-year-old girl is brought to the physician by her father for evaluation of intermittent muscle cramps for the past year and short stature. She has had recurrent upper respiratory tract infections since infancy. She is at the 5th percentile for weight and 10th percentile for height. Physical examination shows nasal polyps and dry skin. An x-ray of the right wrist shows osteopenia with epiphyseal widening. Which of the following sets of laboratory findings is most likely in this patient's serum?
 $$$ Calcium %%% Phosphorus




In [11]:
print("\n--- Running Inference ---")
all_outputs_text = []
num_batches = math.ceil(len(prompts) / BATCH_SIZE)

for i in tqdm(range(num_batches), desc="Generating Responses"):
    start_idx = i * BATCH_SIZE
    end_idx = min((i + 1) * BATCH_SIZE, len(prompts))
    batch_prompts = prompts[start_idx:end_idx]
    outputs = llm.generate(batch_prompts, sampling_params, use_tqdm=False)
    batch_outputs_text = [output.outputs[0].text.strip() for output in outputs]
    all_outputs_text.extend(batch_outputs_text)

if len(all_outputs_text) > 0:
    print("\nExample Generated Text (raw):")
    print(all_outputs_text[0])


--- Running Inference ---


Generating Responses: 100%|██████████| 50/50 [38:55<00:00, 46.72s/it]


Example Generated Text (raw):
**Step 1:** The patient is a 9-year-old girl with recurrent muscle cramps, short stature, and a history of upper respiratory infections. These symptoms are suggestive of a condition known as hyperparathyroidism. Hyperparathyroidism is characterized by increased parathyroid hormone (PTH) levels, which leads to increased calcitonin production and subsequent hypercalcemia. This condition is often associated with conditions like hyperparathyroidism, which can also cause bone density issues, such as osteopenia.

The x-ray findings of the right wrist show osteopenia with epiphyseal widening. This suggests that the bone is not being replaced as it should be, which is consistent with hyperparathyroidism. In hyperparathyroidism, PTH is increased, leading to the accumulation of calcitriol, which is a calcium phosphate buffer. This buffer helps maintain calcium levels in the blood, but its accumulation can lead to hypercalcemia.

Now, let's consider the serum levels




In [14]:
print("\n--- Extracting Predictions ---")
predictions = [extract_choice_mcq(text) for text in tqdm(all_outputs_text, desc="Extracting choices")]
num_invalid_responces = predictions.count(None)
print(f"\n------------------------------\nNumber of invalid responces: {num_invalid_responces}")

if len(predictions) > 0:
    print("\nExample Extracted Prediction:")
    print(predictions[0])


--- Extracting Predictions ---



Extracting choices: 100%|██████████| 200/200 [00:00<00:00, 2246.60it/s]

     - This inhi...sm of action of amiodarone is to inhibit CYP2C9, which is responsible for the conversion of prothrom'
     <backing up>
     <backing up>
     <backing up>
     <backing up>
     <backing u...backing up>
     <backing up>
     <backing up>
     <backing up>
     <backing up>
     <backing up'
     - The thoracentesis'
     <p>As the patient's blood pressure increases, the heart's ability to pump blood effective...d through the ventricular contractility and the ventricular filling time.</p>
     <p>As the patient'
...nal pelvises.
     - The major and minor calyces are also responsible for the formation of the renal'
         <div>
             <div>
                 <div>
                     <div>
          ...                         <div>
                                 <div>
                         <div>'
     - Mesenchymal cells ar...ic epithelial cells.

The HLA-DP and HLA-DQ genes are expressed on Eosinophils.
The HLA-DRα and HLA-'
...entation and lab result




In [15]:
print("\n--- Calculating Metrics ---")
correct_count = 0
total_count = len(predictions)
results_by_subject = {}

if total_count != len(ground_truths):
     print(f"Warning: Mismatch between number of predictions ({total_count}) and ground truths ({len(ground_truths)}). This should not happen.")
     total_count = min(total_count, len(ground_truths))

for i in range(total_count):
    original_data_index = original_indices[i] if 'original_indices' in locals() else i
    data_item = ds_medqa[original_data_index]
    subject = data_item.get('subject_name', 'Unknown')

    pred = predictions[i]
    truth = ground_truths[i]
    is_correct = (pred == truth)

    if subject not in results_by_subject:
        results_by_subject[subject] = {'correct': 0, 'total': 0}

    if is_correct:
        correct_count += 1
        results_by_subject[subject]['correct'] += 1
    results_by_subject[subject]['total'] += 1

overall_accuracy = (correct_count / total_count) * 100 if total_count > 0 else 0


--- Calculating Metrics ---


In [16]:
print("\n--- Evaluation Results ---")
print(f"Model Evaluated: {MODEL_NAME}")
print(f"Dataset Used: {DATASET_MEDQA}")
print(f"Number of Questions Evaluated: {total_count}")
print(f"Number of Correct Answers: {correct_count}")
print(f"Overall Accuracy: {overall_accuracy:.2f}%")


--- Evaluation Results ---
Model Evaluated: MilyaShams/DeepSeek-R1-Distill-Qwen-1.5B-medical-msft-merged
Dataset Used: GBaker/MedQA-USMLE-4-options-hf
Number of Questions Evaluated: 200
Number of Correct Answers: 56
Overall Accuracy: 28.00%


### 3. MMLU medical

#### Dataset loading and preparing

In [6]:
SEED = 4242
BATCH_SIZE = 4
NUM_SAMPLES_SUBSET = 50
NUM_SAMPLES = 200
DATASET_MMLU = "cais/mmlu"
SPLIT_MMLU = "test"

MMLU_MEDICAL_SUBSETS = [
    "anatomy",
    "clinical_knowledge",
    "professional_medicine",
    "college_biology",
    "college_medicine",
    "medical_genetics",
    "professional_medicine"
]

In [7]:
datasets_mmlu = []
for subset in MMLU_MEDICAL_SUBSETS:
    ds = load_dataset(DATASET_MMLU, subset, split=SPLIT_MMLU)
    ds = ds.shuffle(seed=SEED).select(range(NUM_SAMPLES_SUBSET))
    datasets_mmlu.append(ds)


ds_mmlu = concatenate_datasets(datasets_mmlu)
ds_mmlu = ds_mmlu.shuffle(seed=SEED).select(range(NUM_SAMPLES))
ds_mmlu

README.md:   0%|          | 0.00/53.2k [00:00<?, ?B/s]

dataset_infos.json:   0%|          | 0.00/138k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/20.1k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/5.28k [00:00<?, ?B/s]

dev-00000-of-00001.parquet:   0%|          | 0.00/3.50k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/135 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/14 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

test-00000-of-00001.parquet:   0%|          | 0.00/40.5k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/7.48k [00:00<?, ?B/s]

dev-00000-of-00001.parquet:   0%|          | 0.00/3.67k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/265 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/29 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

test-00000-of-00001.parquet:   0%|          | 0.00/125k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/19.9k [00:00<?, ?B/s]

dev-00000-of-00001.parquet:   0%|          | 0.00/8.45k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/272 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/31 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

test-00000-of-00001.parquet:   0%|          | 0.00/31.8k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/6.90k [00:00<?, ?B/s]

dev-00000-of-00001.parquet:   0%|          | 0.00/4.27k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/144 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/16 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

test-00000-of-00001.parquet:   0%|          | 0.00/42.5k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/8.99k [00:00<?, ?B/s]

dev-00000-of-00001.parquet:   0%|          | 0.00/4.84k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/173 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/22 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

test-00000-of-00001.parquet:   0%|          | 0.00/16.4k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/5.63k [00:00<?, ?B/s]

dev-00000-of-00001.parquet:   0%|          | 0.00/3.77k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Dataset({
    features: ['question', 'subject', 'choices', 'answer'],
    num_rows: 200
})

In [8]:
ds_mmlu[0]

{'question': 'Mitochondria isolated and placed in a buffered solution with a low pH begin to manufacture ATP. Which of the following is the best explanation for the effect of low external pH?',
 'subject': 'college_biology',
 'choices': ['It increases the concentration of OH-, causing the mitochondria to pump H+ to the intermembrane space.',
  'It increases the OH- concentration in the mitochondria matrix.',
  'It increases the acid concentration in the mitochondria matrix.',
  'It increases diffusion of H+ from the intermembrane space to the matrix.'],
 'answer': 3}

#### Helper functions definition

In [9]:
def format_prompt_mmlu(example):
    """Formats a single example into a prompt for the LLM."""
    question = example['question']
    options = {
        "A": example['choices'][0],
        "B": example['choices'][1],
        "C": example['choices'][2],
        "D": example['choices'][3]
    }
    
    prompt = f"""
You are an expert in solving multiple-choice questions accurately and explaining your reasoning clearly.
Given a question and a list of answer choices (A, B, C, D), your task is to:
1. Reason shortly about the question and answer choices to find evidances to support your answer.
2. Identify the correct answer. Please choose the single best answer from the options provided.
3. Output the final answer in the format: Answer: [Option Letter]

Question: {question}
Options:
A. {options['A']}
B. {options['B']}
C. {options['C']}
D. {options['D']}

Reasoning:
    """
    return prompt

In [10]:
def get_ground_truth_mmlu(example):
    """Maps the label to the corresponding letter."""
    mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D'}
    label = example.get('answer')
    if label is None or label not in mapping:
        print(f"Warning: Invalid 'cop' value found: {label} in example ID {example.get('id')}. Skipping ground truth.")
        return None
    return mapping[label]

#### Evaluation

In [11]:
print("\n--- Preparing Prompts and Ground Truths ---")
prompts = [format_prompt_mmlu(ex) for ex in tqdm(ds_mmlu, desc="Formatting prompts")]
ground_truths = [get_ground_truth_mmlu(ex) for ex in tqdm(ds_mmlu, desc="Extracting ground truths")]
valid_indices = [i for i, gt in enumerate(ground_truths) if gt is not None]

if len(valid_indices) < len(ground_truths):
     print(f"Warning: {len(ground_truths) - len(valid_indices)} examples had invalid ground truths and were excluded.")
     prompts = [prompts[i] for i in valid_indices]
     ground_truths = [ground_truths[i] for i in valid_indices]
     original_indices = valid_indices

if len(prompts) > 0:
    print("\nExample Prompt:")
    print(prompts[0])
    print(f"Corresponding Ground Truth: {ground_truths[0]}")
else:
    print("No valid prompts to evaluate.")
    exit()


--- Preparing Prompts and Ground Truths ---


Formatting prompts: 100%|██████████| 200/200 [00:00<00:00, 8890.29it/s]
Extracting ground truths: 100%|██████████| 200/200 [00:00<00:00, 11075.24it/s]


Example Prompt:

You are an expert in solving multiple-choice questions accurately and explaining your reasoning clearly.
Given a question and a list of answer choices (A, B, C, D), your task is to:
1. Reason shortly about the question and answer choices to find evidances to support your answer.
2. Identify the correct answer. Please choose the single best answer from the options provided.
3. Output the final answer in the format: Answer: [Option Letter]

Question: Mitochondria isolated and placed in a buffered solution with a low pH begin to manufacture ATP. Which of the following is the best explanation for the effect of low external pH?
Options:
A. It increases the concentration of OH-, causing the mitochondria to pump H+ to the intermembrane space.
B. It increases the OH- concentration in the mitochondria matrix.
C. It increases the acid concentration in the mitochondria matrix.
D. It increases diffusion of H+ from the intermembrane space to the matrix.

Reasoning:
    
Correspond




In [12]:
print("\n--- Running Inference ---")
all_outputs_text = []
num_batches = math.ceil(len(prompts) / BATCH_SIZE)

for i in tqdm(range(num_batches), desc="Generating Responses"):
    start_idx = i * BATCH_SIZE
    end_idx = min((i + 1) * BATCH_SIZE, len(prompts))
    batch_prompts = prompts[start_idx:end_idx]
    outputs = llm.generate(batch_prompts, sampling_params, use_tqdm=False)
    batch_outputs_text = [output.outputs[0].text.strip() for output in outputs]
    all_outputs_text.extend(batch_outputs_text)

if len(all_outputs_text) > 0:
    print("\nExample Generated Text (raw):")
    print(all_outputs_text[0])


--- Running Inference ---


Generating Responses: 100%|██████████| 50/50 [41:07<00:00, 49.36s/it]


Example Generated Text (raw):
- The mitochondrial matrix is a protonic environment.
     - The mitochondrial matrix is a protonic environment.
     - The mitochondrial matrix is a protonic environment.
     - The mitochondrial matrix is a protonic environment.
     - The mitochondrial matrix is a protonic environment.
     - The mitochondrial matrix is a protonic environment.
     - The mitochondrial matrix is a protonic environment.
     - The mitochondrial matrix is a protonic environment.
     - The mitochondrial matrix is a protonic environment.
     - The mitochondrial matrix is a protonic environment.
     - The mitochondrial matrix is a protonic environment.
     - The mitochondrial matrix is a protonic environment.
     - The mitochondrial matrix is a protonic environment.
     - The mitochondrial matrix is a protonic environment.
     - The mitochondrial matrix is a protonic environment.
     - The mitochondrial matrix is a protonic environment.
     - The mitochondrial matri




In [13]:
print("\n--- Extracting Predictions ---")
predictions = [extract_choice_mcq(text) for text in tqdm(all_outputs_text, desc="Extracting choices")]
num_invalid_responces = predictions.count(None)
print(f"\n------------------------------\nNumber of invalid responces: {num_invalid_responces}")

if len(predictions) > 0:
    print("\nExample Extracted Prediction:")
    print(predictions[0])


--- Extracting Predictions ---


Extracting choices: 100%|██████████| 200/200 [00:00<00:00, 2561.63it/s]

     - The mitochondrial matrix is a protonic ...ironment.
     - The mitochondrial matrix is a protonic environment.
     - The mitochondrial matrix'
     - The ...a.
     - The left ventricle is more efficient at pumping blood, so it can pump more blood with less'
     - The DNA is the template for the translation process.
     - The'
     - Hurler's syndrome is cause...anine, not ornithine.
     - Galactose metabolism is impaired, leading to high blood ammonia.
     -'
     - He is very hard on ...as an internal locus of control.
     - He is a perfectionist, so he likely has an external locus of'
...ns, the answer is that none of the options raise suspicion of a chromosome abnormality. However, the'
     - The DNA is the template for the translation process.
     - The'
     - It is lo...or the bilateral contraction of the contralateral muscle.
     - It is responsible for the bilateral'
     - The prostate is a gl...ctum.
     - The seminal vesicle is a structure located in the ur




In [14]:
print("\n--- Calculating Metrics ---")
correct_count = 0
total_count = len(predictions)
results_by_subject = {}

if total_count != len(ground_truths):
     print(f"Warning: Mismatch between number of predictions ({total_count}) and ground truths ({len(ground_truths)}). This should not happen.")
     total_count = min(total_count, len(ground_truths))

for i in range(total_count):
    original_data_index = original_indices[i] if 'original_indices' in locals() else i
    data_item = ds_mmlu[original_data_index]
    subject = data_item.get('subject', 'Unknown')

    pred = predictions[i]
    truth = ground_truths[i]
    is_correct = (pred == truth)

    if subject not in results_by_subject:
        results_by_subject[subject] = {'correct': 0, 'total': 0}

    if is_correct:
        correct_count += 1
        results_by_subject[subject]['correct'] += 1
    results_by_subject[subject]['total'] += 1

overall_accuracy = (correct_count / total_count) * 100 if total_count > 0 else 0


--- Calculating Metrics ---


In [15]:
print("\n--- Evaluation Results ---")
print(f"Model Evaluated: {MODEL_NAME}")
print(f"Dataset Used: {DATASET_MMLU}")
print(f"Number of Questions Evaluated: {total_count}")
print(f"Number of Correct Answers: {correct_count}")
print(f"Overall Accuracy: {overall_accuracy:.2f}%")

print("\nAccuracy by Subject:")
sorted_subjects = sorted(results_by_subject.keys())
for subject in sorted_subjects:
    counts = results_by_subject[subject]
    sub_acc = (counts['correct'] / counts['total']) * 100 if counts['total'] > 0 else 0
    print(f"- {subject}: {sub_acc:.2f}% ({counts['correct']}/{counts['total']})")


--- Evaluation Results ---
Model Evaluated: MilyaShams/DeepSeek-R1-Distill-Qwen-1.5B-medical-msft-merged
Dataset Used: cais/mmlu
Number of Questions Evaluated: 200
Number of Correct Answers: 58
Overall Accuracy: 29.00%

Accuracy by Subject:
- anatomy: 25.81% (8/31)
- clinical_knowledge: 31.03% (9/29)
- college_biology: 29.63% (8/27)
- college_medicine: 36.67% (11/30)
- medical_genetics: 28.57% (8/28)
- professional_medicine: 25.45% (14/55)


### 4. PubMedQA

#### Dataset loading and preparing

In [16]:
SEED = 4242
BATCH_SIZE = 4
NUM_SAMPLES = 200
DATASET_PUBMEDQA = "qiaojin/PubMedQA"
SUBSET_PUBMEDQA = "pqa_labeled"
SPLIT_PUBMEDQA = "train"

In [17]:
ds_pubmedqa = load_dataset(DATASET_PUBMEDQA, SUBSET_PUBMEDQA, split=SPLIT_PUBMEDQA)
ds_pubmedqa = ds_pubmedqa.shuffle(seed=SEED).select(range(NUM_SAMPLES))
ds_pubmedqa

README.md:   0%|          | 0.00/5.19k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/1.08M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Dataset({
    features: ['pubid', 'question', 'context', 'long_answer', 'final_decision'],
    num_rows: 200
})

In [18]:
ds_pubmedqa[0]

{'pubid': 22504515,
 'question': 'Endovenous laser ablation in the treatment of small saphenous varicose veins: does site of access influence early outcomes?',
 'context': {'contexts': ['The study was performed to evaluate the clinical and technical efficacy of endovenous laser ablation (EVLA) of small saphenous varicosities, particularly in relation to the site of endovenous access.',
   'Totally 59 patients with unilateral saphenopopliteal junction incompetence and small saphenous vein reflux underwent EVLA (810 nm, 14 W diode laser) with ambulatory phlebectomies. Small saphenous vein access was gained at the lowest site of truncal reflux. Patients were divided into 2 groups: access gained above mid-calf (AMC, n = 33) and below mid-calf (BMC, n = 26) levels. Outcomes included Venous Clinical Severity Scores (VCSS), Aberdeen Varicose Vein Questionnaire (AVVQ), patient satisfaction, complications, and recurrence rates.',
   'Both groups demonstrated significant improvement in VCSS, AVV

#### Helper functions definition

In [19]:
def format_prompt_pubmedqa(example):
    """Formats a single example into a prompt for the LLM."""
    question = example['question']
    if not isinstance(example.get('context'), dict) or 'contexts' not in example['context']:
        print(f"Warning: Skipping example due to missing or invalid context field.")
        return None

    context_passages = example['context']['contexts']
    full_context = "\n\n".join(context_passages)

    prompt = f"""
You are an expert in analyzing scientific texts and answering questions based on provided context and explaining your reasoning clearly.
Your task is to determine the answer to the question ('yes', 'no', or 'maybe') based only on the information given in the context. Follow these steps:
1. Analyze the provided context in relation to the question. Summarize the key evidence (or lack thereof) relevant to answering the question. This is your reasoning.
2. Based on your reasoning from the context, determine if the answer to the question is 'yes', 'no', or 'maybe'.
3. Output your reasoning first. After the reasoning, start a new line and provide the final decision in the specific format: Answer: [yes/no/maybe]

Context:
{full_context}

Question: {question}

Reasoning:
    """
    return prompt

In [20]:
def get_ground_truth_pubmedqa(example):
    """Extracts the ground truth ('yes', 'no', 'maybe') from the example."""
    decision = example.get('final_decision')
    if decision not in ['yes', 'no', 'maybe']:
        print(f"Warning: Invalid 'final_decision' value found: {decision}. Skipping ground truth.")
        return None
    return decision

In [21]:
def extract_yes_no_maybe(generated_text):
    """Extracts the predicted choice (yes, no, maybe) from the LLM's output."""
    text = generated_text.strip().lower()

    # Explicit "Answer: yes/no/maybe" potentially followed by punctuation/eos
    match = re.search(r'(?:answer|decision)\s*[:\-]?\s*(yes|no|maybe)\b', text)
    if match:
        return match.group(1)

    # Look for the first occurrence of "yes", "no", or "maybe" as a whole word
    match = re.search(r'\b(yes|no|maybe)\b', text)
    if match:
        return match.group(1)

    # Fallback - If no clear choice found, return None
    print(f"Warning: Could not extract answer from text: '{text[:100]}...{text[-100:]}'")
    return None

#### Evaluation

In [22]:
print("\n--- Preparing Prompts and Ground Truths ---")
prompts = []
ground_truths_raw = []
original_indices_map = []

for i, ex in enumerate(tqdm(ds_pubmedqa, desc="Formatting prompts")):
    prompt = format_prompt_pubmedqa(ex)
    if prompt:
        prompts.append(prompt)
        ground_truths_raw.append(get_ground_truth_pubmedqa(ex))
        original_indices_map.append(i)

valid_indices = [i for i, gt in enumerate(ground_truths_raw) if gt is not None]

if len(valid_indices) < len(prompts):
     invalid_gt_count = len(prompts) - len(valid_indices)
     print(f"Warning: {invalid_gt_count} examples had invalid ground truths and were excluded.")
     prompts = [prompts[i] for i in valid_indices]
     ground_truths = [ground_truths_raw[i] for i in valid_indices]
     original_indices = [original_indices_map[i] for i in valid_indices]
else:
    ground_truths = ground_truths_raw
    original_indices = original_indices_map

if len(prompts) > 0:
    print("\nExample Prompt:")
    print(prompts[0])
    print(f"Corresponding Ground Truth: {ground_truths[0]}")
else:
    print("No valid prompts to evaluate.")
    exit()


--- Preparing Prompts and Ground Truths ---


Formatting prompts: 100%|██████████| 200/200 [00:00<00:00, 5497.01it/s]


Example Prompt:

You are an expert in analyzing scientific texts and answering questions based on provided context and explaining your reasoning clearly.
Your task is to determine the answer to the question ('yes', 'no', or 'maybe') based only on the information given in the context. Follow these steps:
1. Analyze the provided context in relation to the question. Summarize the key evidence (or lack thereof) relevant to answering the question. This is your reasoning.
2. Based on your reasoning from the context, determine if the answer to the question is 'yes', 'no', or 'maybe'.
3. Output your reasoning first. After the reasoning, start a new line and provide the final decision in the specific format: Answer: [yes/no/maybe]

Context:
The study was performed to evaluate the clinical and technical efficacy of endovenous laser ablation (EVLA) of small saphenous varicosities, particularly in relation to the site of endovenous access.

Totally 59 patients with unilateral saphenopopliteal jun




In [23]:
print("\n--- Running Inference ---")
all_outputs_text = []
num_batches = math.ceil(len(prompts) / BATCH_SIZE)

for i in tqdm(range(num_batches), desc="Generating Responses"):
    start_idx = i * BATCH_SIZE
    end_idx = min((i + 1) * BATCH_SIZE, len(prompts))
    batch_prompts = prompts[start_idx:end_idx]
    outputs = llm.generate(batch_prompts, sampling_params, use_tqdm=False)
    batch_outputs_text = [output.outputs[0].text.strip() for output in outputs]
    all_outputs_text.extend(batch_outputs_text)

if len(all_outputs_text) > 0:
    print("\nExample Generated Text (raw):")
    print(all_outputs_text[0])


--- Running Inference ---


Generating Responses: 100%|██████████| 50/50 [15:10<00:00, 18.20s/it]


Example Generated Text (raw):
The study examined the effect of endovenous laser ablation (EVLA) on small saphenous varicose veins, particularly focusing on the site of endovenous access. The study involved 59 patients with unilateral saphenopopliteal junction incompetence and small saphenous vein reflux, who underwent EVLA with ambulatory phlebectomies. The patients were divided into two groups based on the level of endovenous access: one group gained access above the mid-calf (AMC, n = 33) and the other below the mid-calf (BMC, n = 26). The study evaluated several outcomes, including clinical severity scores, the Aberdeen Varicose Vein Questionnaire, patient satisfaction, complications, and recurrence rates.

The study found that both groups showed significant improvement in VCSS, AVVQ, generic quality of life Short Form 36, and EuroQol scores, with no differences between the groups for complications and recurrence rates. This suggests that the site of endovenous access did not signi




In [24]:
print("\n--- Extracting Predictions ---")
predictions = [extract_yes_no_maybe(text) for text in tqdm(all_outputs_text, desc="Extracting choices")]
num_invalid_responсes = predictions.count(None)
print(f"\n------------------------------\nNumber of invalid responces: {num_invalid_responсes}")

if len(predictions) > 0:
    print("\nExample Extracted Prediction:")
    print(predictions[0])


--- Extracting Predictions ---


Extracting choices: 100%|██████████| 200/200 [00:00<00:00, 13327.94it/s]

 ...ence of a direct correlation between adma levels and cardiovascular events or cardiac death.
     15'

     the study also'
     - the study used caro...relevant to the question.
     - the study's findings are relevant to the question.
     - the study'
     b. the use ... aaa. the use of radiographic images is not a new approach.
     bbb. the use of radiographic images'
     - the context does not mention any evidence of a specific serovar'
     b. perimount: 0.95/2.12 = 0.45
     c. mosaic: 2.37/0.95 = 2.50
   ...2/0.95 = 2.23
     q. mosaic: 0.78/2.12 = 0.37
     r. perimount: 0.95/2.37 = 0.40
     s. mosaic: 2'
     - the study concluded that'
     - the...t surgical management is not necessary for the vanishing testes syndrome.
     - the study concluded'
final answer:
</think>
the question is asking whether the introduction of elec...study does not provide evidence of behavioral adaptation or significant changes in driving behavior.'

     the question is'
answer:

</think>




In [25]:
print("\n--- Calculating Metrics ---")
correct_count = 0
total_count = len(predictions)
results_by_subject = {}

if total_count != len(ground_truths):
     print(f"Warning: Mismatch between number of predictions ({total_count}) and ground truths ({len(ground_truths)}). This should not happen.")
     total_count = min(total_count, len(ground_truths))

for i in range(total_count):
    original_data_index = original_indices[i] if 'original_indices' in locals() else i
    data_item = ds_pubmedqa[original_data_index]
    subject = data_item.get('subject_name', 'Unknown')

    pred = predictions[i]
    truth = ground_truths[i]
    is_correct = (pred == truth)

    if subject not in results_by_subject:
        results_by_subject[subject] = {'correct': 0, 'total': 0}

    if is_correct:
        correct_count += 1
        results_by_subject[subject]['correct'] += 1
    results_by_subject[subject]['total'] += 1

overall_accuracy = (correct_count / total_count) * 100 if total_count > 0 else 0


--- Calculating Metrics ---


In [26]:
print("\n--- Evaluation Results ---")
print(f"Model Evaluated: {MODEL_NAME}")
print(f"Dataset Used: {DATASET_PUBMEDQA}")
print(f"Number of Questions Evaluated: {total_count}")
print(f"Number of Correct Answers: {correct_count}")
print(f"Overall Accuracy: {overall_accuracy:.2f}%")


--- Evaluation Results ---
Model Evaluated: MilyaShams/DeepSeek-R1-Distill-Qwen-1.5B-medical-msft-merged
Dataset Used: qiaojin/PubMedQA
Number of Questions Evaluated: 200
Number of Correct Answers: 83
Overall Accuracy: 41.50%
