## Evaluation: SFT model

This is notebook for **SFT model evaluation** for estimating improvements over the base model and compare with other adapted models. We choose to evaluate on the set of benchmarks from [Open Medical-LLM Leaderboard](https://huggingface.co/spaces/openlifescienceai/open_medical_llm_leaderboard) including:

* [MedMCQA](https://huggingface.co/datasets/openlifescienceai/medmcqa) - MCQ, 200 samples from validation split
* [MedQA](https://huggingface.co/datasets/GBaker/MedQA-USMLE-4-options-hf) - MCQ, 200 samples from validation split
* [MMLU](https://huggingface.co/datasets/cais/mmlu) - MCQ, 200 samples from test splits of 6 medical subsets
* [PubMedQA](https://huggingface.co/datasets/qiaojin/PubMedQA) - QA, 200 samples from train split of pqa_labeled subset

*SFT model:* [MilyaShams/DeepSeek-R1-Distill-Qwen-1.5B-medical-sft-merged](https://huggingface.co/MilyaShams/DeepSeek-R1-Distill-Qwen-1.5B-medical-sft-merged)

### Setup

In [1]:
%%capture
!pip install datasets vllm

In [2]:
import re
from tqdm import tqdm
import math
import pandas as pd
from datasets import load_dataset, concatenate_datasets
from vllm import LLM, SamplingParams
import torch

INFO 04-13 16:33:40 [__init__.py:239] Automatically detected platform cuda.


2025-04-13 16:33:42.426700: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744562022.644910      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744562022.711852      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


### SFT model loading

In [3]:
MODEL_NAME = "MilyaShams/DeepSeek-R1-Distill-Qwen-1.5B-medical-sft-merged"
MAX_TOKENS = 4096

In [4]:
print("\n--- Loading LLM with vLLM ---")
try:
    llm = LLM(
        model=MODEL_NAME,
        tensor_parallel_size=2,
        dtype=torch.float16,
    )
    sampling_params = SamplingParams(
        max_tokens=MAX_TOKENS,
        temperature=0.01,
        top_p=1.0,
        top_k=-1
    )
    print(f"LLM '{MODEL_NAME}' loaded successfully.")
except Exception as e:
    print(f"Error loading LLM with vLLM: {e}")
    print("Please ensure the MODEL_NAME is correct, vLLM is installed, and you have compatible hardware (GPU).")
    exit()


--- Loading LLM with vLLM ---


config.json:   0%|          | 0.00/820 [00:00<?, ?B/s]

INFO 04-13 16:34:09 [config.py:600] This model supports multiple tasks: {'generate', 'embed', 'reward', 'score', 'classify'}. Defaulting to 'generate'.
INFO 04-13 16:34:09 [config.py:1600] Defaulting to use mp for distributed inference
INFO 04-13 16:34:09 [config.py:1780] Chunked prefill is enabled with max_num_batched_tokens=2048.
INFO 04-13 16:34:09 [llm_engine.py:242] Initializing a V0 LLM engine (v0.8.3) with config: model='MilyaShams/DeepSeek-R1-Distill-Qwen-1.5B-medical-sft-merged', speculative_config=None, tokenizer='MilyaShams/DeepSeek-R1-Distill-Qwen-1.5B-medical-sft-merged', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=131072, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_con

tokenizer_config.json:   0%|          | 0.00/7.09k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/495 [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

[1;36m(VllmWorkerProcess pid=143)[0;0m INFO 04-13 16:34:11 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-13 16:34:12 [cuda.py:240] Cannot use FlashAttention-2 backend for Volta and Turing GPUs.
INFO 04-13 16:34:12 [cuda.py:289] Using XFormers backend.
[1;36m(VllmWorkerProcess pid=143)[0;0m INFO 04-13 16:34:12 [cuda.py:240] Cannot use FlashAttention-2 backend for Volta and Turing GPUs.
[1;36m(VllmWorkerProcess pid=143)[0;0m INFO 04-13 16:34:12 [cuda.py:289] Using XFormers backend.


[W413 16:34:23.835239803 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W413 16:34:24.353462152 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W413 16:34:33.843654187 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3


INFO 04-13 16:34:43 [utils.py:990] Found nccl from library libnccl.so.2
INFO 04-13 16:34:43 [pynccl.py:69] vLLM is using nccl==2.21.5
[1;36m(VllmWorkerProcess pid=143)[0;0m INFO 04-13 16:34:43 [utils.py:990] Found nccl from library libnccl.so.2
[1;36m(VllmWorkerProcess pid=143)[0;0m INFO 04-13 16:34:43 [pynccl.py:69] vLLM is using nccl==2.21.5


[W413 16:34:43.854198985 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3


INFO 04-13 16:34:44 [custom_all_reduce_utils.py:206] generating GPU P2P access cache in /root/.cache/vllm/gpu_p2p_access_cache_for_0,1.json
INFO 04-13 16:35:08 [custom_all_reduce_utils.py:244] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1.json
[1;36m(VllmWorkerProcess pid=143)[0;0m INFO 04-13 16:35:08 [custom_all_reduce_utils.py:244] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1.json
INFO 04-13 16:35:08 [shm_broadcast.py:264] vLLM message queue communication handle: Handle(local_reader_ranks=[1], buffer_handle=(1, 4194304, 6, 'psm_566ed675'), local_subscribe_addr='ipc:///tmp/241ae644-0623-4459-bc13-cbc38fa29029', remote_subscribe_addr=None, remote_addr_ipv6=False)
INFO 04-13 16:35:08 [parallel_state.py:957] rank 0 in world size 2 is assigned as DP rank 0, PP rank 0, TP rank 0
[1;36m(VllmWorkerProcess pid=143)[0;0m INFO 04-13 16:35:08 [parallel_state.py:957] rank 1 in world size 2 is assigned as DP rank 0, PP rank

model.safetensors:   0%|          | 0.00/3.55G [00:00<?, ?B/s]

INFO 04-13 16:35:16 [weight_utils.py:281] Time spent downloading weights for MilyaShams/DeepSeek-R1-Distill-Qwen-1.5B-medical-sft-merged: 7.748147 seconds
INFO 04-13 16:35:16 [weight_utils.py:315] No model.safetensors.index.json found in remote.


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


[1;36m(VllmWorkerProcess pid=143)[0;0m INFO 04-13 16:35:16 [weight_utils.py:315] No model.safetensors.index.json found in remote.
INFO 04-13 16:35:18 [loader.py:447] Loading weights took 2.34 seconds
INFO 04-13 16:35:19 [model_runner.py:1146] Model loading took 1.6901 GiB and 10.397377 seconds
[1;36m(VllmWorkerProcess pid=143)[0;0m INFO 04-13 16:35:19 [loader.py:447] Loading weights took 3.26 seconds
[1;36m(VllmWorkerProcess pid=143)[0;0m INFO 04-13 16:35:20 [model_runner.py:1146] Model loading took 1.6901 GiB and 11.452690 seconds
[1;36m(VllmWorkerProcess pid=143)[0;0m INFO 04-13 16:35:27 [worker.py:267] Memory profiling takes 7.32 seconds
[1;36m(VllmWorkerProcess pid=143)[0;0m INFO 04-13 16:35:27 [worker.py:267] the current vLLM instance can use total_gpu_memory (14.74GiB) x gpu_memory_utilization (0.90) = 13.27GiB
[1;36m(VllmWorkerProcess pid=143)[0;0m INFO 04-13 16:35:27 [worker.py:267] model weights take 1.69GiB; non_torch_memory takes 0.10GiB; PyTorch activation peak 

Capturing CUDA graph shapes:   0%|          | 0/35 [00:00<?, ?it/s]

[1;36m(VllmWorkerProcess pid=143)[0;0m INFO 04-13 16:35:35 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.


Capturing CUDA graph shapes: 100%|██████████| 35/35 [00:41<00:00,  1.18s/it]

INFO 04-13 16:36:16 [model_runner.py:1598] Graph capturing finished in 41 secs, took 0.40 GiB
[1;36m(VllmWorkerProcess pid=143)[0;0m INFO 04-13 16:36:16 [model_runner.py:1598] Graph capturing finished in 41 secs, took 0.40 GiB
INFO 04-13 16:36:16 [llm_engine.py:448] init engine (profile, create kv cache, warmup model) took 56.32 seconds





LLM 'MilyaShams/DeepSeek-R1-Distill-Qwen-1.5B-medical-sft-merged' loaded successfully.


### 1. MedMCQA benchmark

#### Dataset loading and preparing

In [5]:
SEED = 4242
BATCH_SIZE = 4
NUM_SAMPLES = 200
DATASET_MEDMCQA = "openlifescienceai/medmcqa"

In [6]:
ds_medmcqa = load_dataset(DATASET_MEDMCQA, split="validation")
ds_medmcqa = ds_medmcqa.shuffle(seed=SEED).select(range(NUM_SAMPLES))
ds_medmcqa

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/85.9M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/936k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/1.48M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/182822 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/6150 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/4183 [00:00<?, ? examples/s]

Dataset({
    features: ['id', 'question', 'opa', 'opb', 'opc', 'opd', 'cop', 'choice_type', 'exp', 'subject_name', 'topic_name'],
    num_rows: 200
})

In [7]:
ds_medmcqa[0]

{'id': '4653fb7a-ddbf-493b-b4ef-92205582a27a',
 'question': 'Which of the following tooth is not having 5 cusps?',
 'opa': 'Mandibular 2nd Molar',
 'opb': 'Mandibular 1st Molar',
 'opc': 'Mandibular 3rd Molar',
 'opd': 'Maxillary 1st Molar',
 'cop': 0,
 'choice_type': 'single',
 'exp': None,
 'subject_name': 'Dental',
 'topic_name': None}

#### Helper functions definition

In [8]:
def format_prompt_medmcqa(example):
    """Formats a single example into a prompt for the LLM."""
    question = example['question']
    options = {
        "A": example['opa'],
        "B": example['opb'],
        "C": example['opc'],
        "D": example['opd']
    }
    
    prompt = f"""
You are an expert in solving multiple-choice questions accurately and explaining your reasoning clearly.
Given a question and a list of answer choices (A, B, C, D), your task is to:
1. Reason shortly about the question and answer choices to find evidances to support your answer.
2. Identify the correct answer. Please choose the single best answer from the options provided.
3. Output the final answer in the format: Answer: [Option Letter]

Question: {question}
Options:
A. {options['A']}
B. {options['B']}
C. {options['C']}
D. {options['D']}

Reasoning:
    """
    return prompt

In [9]:
def get_ground_truth_medmcqa(example):
    """Maps the correct option index (cop) to the corresponding letter."""
    mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D'}
    cop_index = example.get('cop')
    if cop_index is None or cop_index not in mapping:
        print(f"Warning: Invalid 'cop' value found: {cop_index} in example ID {example.get('id')}. Skipping ground truth.")
        return None
    return mapping[cop_index]

In [10]:
def extract_choice_mcq(generated_text):
    """Extracts the predicted choice (A, B, C, or D) from the LLM's output."""
    text = generated_text.strip()

    # Check for phrases like "The answer is A" or "Answer: A"
    match = re.search(r'(?:answer|choice|option) is\s*:?\s*([A-D])', text, re.IGNORECASE)
    if match:
        return match.group(1).upper()

    # Look for the first standalone letter A, B, C, or D in the text
    match = re.search(r'\b([A-D])\b', text)
    if match:
        return match.group(1).upper()

    # Fallback - If no clear choice found, return None
    print(f"Warning: Could not extract answer from text: '{text[:100]}...{text[-100:]}'")
    return None

#### Evaluation

In [11]:
print("\n--- Preparing Prompts and Ground Truths ---")
prompts = [format_prompt_medmcqa(ex) for ex in tqdm(ds_medmcqa, desc="Formatting prompts")]
ground_truths = [get_ground_truth_medmcqa(ex) for ex in tqdm(ds_medmcqa, desc="Extracting ground truths")]
valid_indices = [i for i, gt in enumerate(ground_truths) if gt is not None]

if len(valid_indices) < len(ground_truths):
     print(f"Warning: {len(ground_truths) - len(valid_indices)} examples had invalid ground truths and were excluded.")
     prompts = [prompts[i] for i in valid_indices]
     ground_truths = [ground_truths[i] for i in valid_indices]
     original_indices = valid_indices

if len(prompts) > 0:
    print("\nExample Prompt:")
    print(prompts[0])
    print(f"Corresponding Ground Truth: {ground_truths[0]}")
else:
    print("No valid prompts to evaluate.")
    exit()


--- Preparing Prompts and Ground Truths ---


Formatting prompts: 100%|██████████| 200/200 [00:00<00:00, 6638.45it/s]
Extracting ground truths: 100%|██████████| 200/200 [00:00<00:00, 8086.34it/s]


Example Prompt:

You are an expert in solving multiple-choice questions accurately and explaining your reasoning clearly.
Given a question and a list of answer choices (A, B, C, D), your task is to:
1. Reason shortly about the question and answer choices to find evidances to support your answer.
2. Identify the correct answer. Please choose the single best answer from the options provided.
3. Output the final answer in the format: Answer: [Option Letter]

Question: Which of the following tooth is not having 5 cusps?
Options:
A. Mandibular 2nd Molar
B. Mandibular 1st Molar
C. Mandibular 3rd Molar
D. Maxillary 1st Molar

Reasoning:
    
Corresponding Ground Truth: A





In [12]:
print("\n--- Running Inference ---")
all_outputs_text = []
num_batches = math.ceil(len(prompts) / BATCH_SIZE)

for i in tqdm(range(num_batches), desc="Generating Responses"):
    start_idx = i * BATCH_SIZE
    end_idx = min((i + 1) * BATCH_SIZE, len(prompts))
    batch_prompts = prompts[start_idx:end_idx]
    outputs = llm.generate(batch_prompts, sampling_params, use_tqdm=False)
    batch_outputs_text = [output.outputs[0].text.strip() for output in outputs]
    all_outputs_text.extend(batch_outputs_text)

if len(all_outputs_text) > 0:
    print("\nExample Generated Text (raw):")
    print(all_outputs_text[0])


--- Running Inference ---


Generating Responses: 100%|██████████| 50/50 [33:06<00:00, 39.72s/it]


Example Generated Text (raw):
- The mandibular 2nd molar has 5 cusps.
     - The mandibular 1st molar has 5 cusps.
     - The mandibular 3rd molar has 5 cusps.
     - The maxillary 1st molar has 5 cusps.
     - The mandibular 2nd molar has 5 cusps.
     - The mandibular 1st molar has 5 cusps.
     - The mandibular 3rd molar has 5 cusps.
     - The maxillary 1st molar has 5 cusps.
     - The mandibular 2nd molar has 5 cusps.
     - The mandibular 1st molar has 5 cusps.
     - The mandibular 3rd molar has 5 cusps.
     - The maxillary 1st molar has 5 cusps.
     - The mandibular 2nd molar has 5 cusps.
     - The mandibular 1st molar has 5 cusps.
     - The mandibular 3rd molar has 5 cusps.
     - The maxillary 1st molar has 5 cusps.
     - The mandibular 2nd molar has 5 cusps.
     - The mandibular 1st molar has 5 cusps.
     - The mandibular 3rd molar has 5 cusps.
     - The maxillary 1st molar has 5 cusps.
     - The mandibular 2nd molar has 5 cusps.
     - The mandibular 1st molar ha




In [13]:
print("\n--- Extracting Predictions ---")
predictions = [extract_choice_mcq(text) for text in tqdm(all_outputs_text, desc="Extracting choices")]
num_invalid_responces = predictions.count(None)
print(f"\n------------------------------\nNumber of invalid responces: {num_invalid_responces}")

if len(predictions) > 0:
    print("\nExample Extracted Prediction:")
    print(predictions[0])


--- Extracting Predictions ---


Extracting choices: 100%|██████████| 200/200 [00:00<00:00, 3724.30it/s]

     - The mandibular 1st molar has 5 cusps.
     - The mand...bular 3rd molar has 5 cusps.
     - The maxillary 1st molar has 5 cusps.
     - The mandibular 2nd m'
     - Sella is a major landmark in France.
     - Porion is...orion is a major landmark in France.
     - Orbitale is a major landmark in France.
     - Nasion is'
     - It is located in the maxillar...st to develop.
     - The maxillary sinus is the first to develop.
     - The maxillary sinus is the'
     - The correct answer is E. None of the above.
     - The correct answer'
        <p>What is the most common type of osteoma?</p>
        <p>It is a type of osteoma tha...e most common type of osteoma.</p>
        <p>It is the most common type of osteoma.</p>
        <p>'
     - T...he exception to the muscles of mastication is the Dactylus.
     - The exception to the muscles of m'
     - The temporal bone is the bone that forms the base of the temporal bone.
     - The temporal'
     - The mandibular process is formed




In [14]:
print("\n--- Calculating Metrics ---")
correct_count = 0
total_count = len(predictions)
results_by_subject = {}

if total_count != len(ground_truths):
     print(f"Warning: Mismatch between number of predictions ({total_count}) and ground truths ({len(ground_truths)}). This should not happen.")
     total_count = min(total_count, len(ground_truths))

for i in range(total_count):
    original_data_index = original_indices[i] if 'original_indices' in locals() else i
    data_item = ds_medmcqa[original_data_index]
    subject = data_item.get('subject_name', 'Unknown')

    pred = predictions[i]
    truth = ground_truths[i]
    is_correct = (pred == truth)

    if subject not in results_by_subject:
        results_by_subject[subject] = {'correct': 0, 'total': 0}

    if is_correct:
        correct_count += 1
        results_by_subject[subject]['correct'] += 1
    results_by_subject[subject]['total'] += 1

overall_accuracy = (correct_count / total_count) * 100 if total_count > 0 else 0


--- Calculating Metrics ---


In [15]:
print("\n--- Evaluation Results ---")
print(f"Model Evaluated: {MODEL_NAME}")
print(f"Dataset Used: {DATASET_MEDMCQA}")
print(f"Number of Questions Evaluated: {total_count}")
print(f"Number of Correct Answers: {correct_count}")
print(f"Overall Accuracy: {overall_accuracy:.2f}%")

print("\nAccuracy by Subject:")
sorted_subjects = sorted(results_by_subject.keys())
for subject in sorted_subjects:
    counts = results_by_subject[subject]
    sub_acc = (counts['correct'] / counts['total']) * 100 if counts['total'] > 0 else 0
    print(f"- {subject}: {sub_acc:.2f}% ({counts['correct']}/{counts['total']})")


--- Evaluation Results ---
Model Evaluated: MilyaShams/DeepSeek-R1-Distill-Qwen-1.5B-medical-sft-merged
Dataset Used: openlifescienceai/medmcqa
Number of Questions Evaluated: 200
Number of Correct Answers: 61
Overall Accuracy: 30.50%

Accuracy by Subject:
- Anaesthesia: 50.00% (1/2)
- Anatomy: 16.67% (1/6)
- Biochemistry: 25.00% (2/8)
- Dental: 32.84% (22/67)
- ENT: 20.00% (1/5)
- Forensic Medicine: 71.43% (5/7)
- Gynaecology & Obstetrics: 29.41% (5/17)
- Medicine: 33.33% (2/6)
- Microbiology: 50.00% (3/6)
- Ophthalmology: 75.00% (3/4)
- Pathology: 25.00% (3/12)
- Pediatrics: 14.29% (2/14)
- Pharmacology: 25.00% (3/12)
- Physiology: 0.00% (0/6)
- Radiology: 0.00% (0/2)
- Skin: 0.00% (0/1)
- Social & Preventive Medicine: 33.33% (2/6)
- Surgery: 31.58% (6/19)


### 2. MedQA

#### Dataset loading and preparing

In [16]:
SEED = 4242
BATCH_SIZE = 4
NUM_SAMPLES = 200
DATASET_MEDQA = "GBaker/MedQA-USMLE-4-options-hf"
SPLIT_MEDQA = "validation"

In [17]:
ds_medqa = load_dataset(DATASET_MEDQA, split=SPLIT_MEDQA)
ds_medqa = ds_medqa.shuffle(seed=SEED).select(range(NUM_SAMPLES))
ds_medqa

README.md:   0%|          | 0.00/640 [00:00<?, ?B/s]

train.json:   0%|          | 0.00/9.77M [00:00<?, ?B/s]

dev.json:   0%|          | 0.00/1.22M [00:00<?, ?B/s]

test.json:   0%|          | 0.00/1.25M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/10178 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1272 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1273 [00:00<?, ? examples/s]

Dataset({
    features: ['id', 'sent1', 'sent2', 'ending0', 'ending1', 'ending2', 'ending3', 'label'],
    num_rows: 200
})

In [18]:
ds_medqa[1]

{'id': 'dev-00646',
 'sent1': 'A 31-year-old gravida 2 para 2 woman presents to her primary care physician for follow up. Two weeks ago, she gave birth via vaginal delivery to a 9.5 lb (4.3 kg) male infant. The delivery was complicated by a vaginal laceration that required extensive suturing once the infant was delivered. Immediately after delivery of the placenta she experienced intense shaking and chills that resolved within 1 hour. She has felt well since the delivery but admits to 6 days of malodorous smelling vaginal discharge that is tan in color. She has a history of vaginal candidiasis and is worried that it may be recurring. Her temperature is 98.8°F (37.1°C), blood pressure is 122/73 mmHg, pulse is 88/min, respirations are 16/min, and BMI is 33 kg/m^2. Speculum exam reveals a 1.5 cm dark red, velvety lesion on the posterior vaginal wall with a tan discharge. The pH of the discharge is 6.4. Which of the following is the most likely diagnosis?',
 'sent2': '',
 'ending0': 'Bacte

#### Helper functions definition

In [19]:
def format_prompt_medqa(example):
    """Formats a single example into a prompt for the LLM."""
    question = example['sent1']
    options = {
        "A": example['ending0'],
        "B": example['ending1'],
        "C": example['ending2'],
        "D": example['ending3'],
    }
    
    prompt = f"""
You are an expert in solving multiple-choice questions accurately and explaining your reasoning clearly.
Given a question and a list of answer choices (A, B, C, D), your task is to:
1. Reason shortly about the question and answer choices to find evidances to support your answer.
2. Identify the correct answer. Please choose the single best answer from the options provided.
3. Output the final answer in the format: Answer: [Option Letter]

Question: {question}
Options:
A. {options['A']}
B. {options['B']}
C. {options['C']}
D. {options['D']}

Reasoning:
    """
    return prompt

In [20]:
def get_ground_truth_medqa(example):
    """Maps the label to the corresponding letter."""
    mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D'}
    label = example.get('label')
    if label is None or label not in mapping:
        print(f"Warning: Invalid 'cop' value found: {label} in example ID {example.get('id')}. Skipping ground truth.")
        return None
    return mapping[label]

#### Evaluation

In [21]:
print("\n--- Preparing Prompts and Ground Truths ---")
prompts = [format_prompt_medqa(ex) for ex in tqdm(ds_medqa, desc="Formatting prompts")]
ground_truths = [get_ground_truth_medqa(ex) for ex in tqdm(ds_medqa, desc="Extracting ground truths")]
valid_indices = [i for i, gt in enumerate(ground_truths) if gt is not None]

if len(valid_indices) < len(ground_truths):
     print(f"Warning: {len(ground_truths) - len(valid_indices)} examples had invalid ground truths and were excluded.")
     prompts = [prompts[i] for i in valid_indices]
     ground_truths = [ground_truths[i] for i in valid_indices]
     original_indices = valid_indices

if len(prompts) > 0:
    print("\nExample Prompt:")
    print(prompts[0])
    print(f"Corresponding Ground Truth: {ground_truths[0]}")
else:
    print("No valid prompts to evaluate.")
    exit()


--- Preparing Prompts and Ground Truths ---


Formatting prompts: 100%|██████████| 200/200 [00:00<00:00, 7502.22it/s]
Extracting ground truths: 100%|██████████| 200/200 [00:00<00:00, 9136.52it/s]


Example Prompt:

You are an expert in solving multiple-choice questions accurately and explaining your reasoning clearly.
Given a question and a list of answer choices (A, B, C, D), your task is to:
1. Reason shortly about the question and answer choices to find evidances to support your answer.
2. Identify the correct answer. Please choose the single best answer from the options provided.
3. Output the final answer in the format: Answer: [Option Letter]

Question: A 9-year-old girl is brought to the physician by her father for evaluation of intermittent muscle cramps for the past year and short stature. She has had recurrent upper respiratory tract infections since infancy. She is at the 5th percentile for weight and 10th percentile for height. Physical examination shows nasal polyps and dry skin. An x-ray of the right wrist shows osteopenia with epiphyseal widening. Which of the following sets of laboratory findings is most likely in this patient's serum?
 $$$ Calcium %%% Phosphorus




In [22]:
print("\n--- Running Inference ---")
all_outputs_text = []
num_batches = math.ceil(len(prompts) / BATCH_SIZE)

for i in tqdm(range(num_batches), desc="Generating Responses"):
    start_idx = i * BATCH_SIZE
    end_idx = min((i + 1) * BATCH_SIZE, len(prompts))
    batch_prompts = prompts[start_idx:end_idx]
    outputs = llm.generate(batch_prompts, sampling_params, use_tqdm=False)
    batch_outputs_text = [output.outputs[0].text.strip() for output in outputs]
    all_outputs_text.extend(batch_outputs_text)

if len(all_outputs_text) > 0:
    print("\nExample Generated Text (raw):")
    print(all_outputs_text[0])


--- Running Inference ---


Generating Responses: 100%|██████████| 50/50 [38:15<00:00, 45.90s/it]


Example Generated Text (raw):
### Key Details
     ### Question Details
     ### Key Terms
     ### Supporting Details
"

### Input:
A 9-year-old girl is brought to the physician by her father for evaluation of intermittent muscle cramps for the past year and short stature. She has had recurrent upper respiratory tract infections since infancy. She is at the 5th percentile for weight and 10th percentile for height. Physical examination shows nasal polyps and dry skin. An x-ray of the right wrist shows osteopenia with epiphyseal widening. Which of the following sets of laboratory findings is most likely in this patient's serum?
 $$$ Calcium %%% Phosphorus %%% Parathyroid hormone %%% Calcitriol $$$
### Answer: C

### Solution:
To determine the most likely serum laboratory findings for this patient, we need to consider her clinical presentation and the implications of her symptoms.

First, the patient has a short stature, which is often associated with a low calcium level. This is becaus




In [23]:
print("\n--- Extracting Predictions ---")
predictions = [extract_choice_mcq(text) for text in tqdm(all_outputs_text, desc="Extracting choices")]
num_invalid_responces = predictions.count(None)
print(f"\n------------------------------\nNumber of invalid responces: {num_invalid_responces}")

if len(predictions) > 0:
    print("\nExample Extracted Prediction:")
    print(predictions[0])


--- Extracting Predictions ---


Extracting choices: 100%|██████████| 200/200 [00:00<00:00, 1911.90it/s]

The immunoflu...re present in the glomerular surface are the ones that are responsible for the linear IgG deposition'
     Your explanation should be related to the clinical ...ific treatment options.
     Your explanation should be related to the clinical presentation and the'
     You... is a diuretic.
     You may also assume that the agent is a diuretic.
     You may also assume that'

The patient’s symptoms of symptom improvement in the heat are'

The patient's...ngs.

The x-ray findings are consistent with a specific type of lung cancer. What is the most common'
"

The following is an excerpt from a research paper on ...ry zone, the more susceptible the bacteria are to the antibiotic.
The following is an excerpt from a'
     Your explanation should be related to the patient's...n would perform to detect the cause of his immunodeficiency is the HIV Antigenic Profile Test (HAP).'
     You may assume that the path'

The presence of high grade dysplasia at the Z line is a marker of e




In [24]:
print("\n--- Calculating Metrics ---")
correct_count = 0
total_count = len(predictions)
results_by_subject = {}

if total_count != len(ground_truths):
     print(f"Warning: Mismatch between number of predictions ({total_count}) and ground truths ({len(ground_truths)}). This should not happen.")
     total_count = min(total_count, len(ground_truths))

for i in range(total_count):
    original_data_index = original_indices[i] if 'original_indices' in locals() else i
    data_item = ds_medqa[original_data_index]
    subject = data_item.get('subject_name', 'Unknown')

    pred = predictions[i]
    truth = ground_truths[i]
    is_correct = (pred == truth)

    if subject not in results_by_subject:
        results_by_subject[subject] = {'correct': 0, 'total': 0}

    if is_correct:
        correct_count += 1
        results_by_subject[subject]['correct'] += 1
    results_by_subject[subject]['total'] += 1

overall_accuracy = (correct_count / total_count) * 100 if total_count > 0 else 0


--- Calculating Metrics ---


In [25]:
print("\n--- Evaluation Results ---")
print(f"Model Evaluated: {MODEL_NAME}")
print(f"Dataset Used: {DATASET_MEDQA}")
print(f"Number of Questions Evaluated: {total_count}")
print(f"Number of Correct Answers: {correct_count}")
print(f"Overall Accuracy: {overall_accuracy:.2f}%")


--- Evaluation Results ---
Model Evaluated: MilyaShams/DeepSeek-R1-Distill-Qwen-1.5B-medical-sft-merged
Dataset Used: GBaker/MedQA-USMLE-4-options-hf
Number of Questions Evaluated: 200
Number of Correct Answers: 51
Overall Accuracy: 25.50%


### 3. MMLU medical

#### Dataset loading and preparing

In [26]:
SEED = 4242
BATCH_SIZE = 4
NUM_SAMPLES_SUBSET = 50
NUM_SAMPLES = 200
DATASET_MMLU = "cais/mmlu"
SPLIT_MMLU = "test"

MMLU_MEDICAL_SUBSETS = [
    "anatomy",
    "clinical_knowledge",
    "professional_medicine",
    "college_biology",
    "college_medicine",
    "medical_genetics",
    "professional_medicine"
]

In [27]:
datasets_mmlu = []
for subset in MMLU_MEDICAL_SUBSETS:
    ds = load_dataset(DATASET_MMLU, subset, split=SPLIT_MMLU)
    ds = ds.shuffle(seed=SEED).select(range(NUM_SAMPLES_SUBSET))
    datasets_mmlu.append(ds)


ds_mmlu = concatenate_datasets(datasets_mmlu)
ds_mmlu = ds_mmlu.shuffle(seed=SEED).select(range(NUM_SAMPLES))
ds_mmlu

README.md:   0%|          | 0.00/53.2k [00:00<?, ?B/s]

dataset_infos.json:   0%|          | 0.00/138k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/20.1k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/5.28k [00:00<?, ?B/s]

dev-00000-of-00001.parquet:   0%|          | 0.00/3.50k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/135 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/14 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

test-00000-of-00001.parquet:   0%|          | 0.00/40.5k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/7.48k [00:00<?, ?B/s]

dev-00000-of-00001.parquet:   0%|          | 0.00/3.67k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/265 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/29 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

test-00000-of-00001.parquet:   0%|          | 0.00/125k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/19.9k [00:00<?, ?B/s]

dev-00000-of-00001.parquet:   0%|          | 0.00/8.45k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/272 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/31 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

test-00000-of-00001.parquet:   0%|          | 0.00/31.8k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/6.90k [00:00<?, ?B/s]

dev-00000-of-00001.parquet:   0%|          | 0.00/4.27k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/144 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/16 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

test-00000-of-00001.parquet:   0%|          | 0.00/42.5k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/8.99k [00:00<?, ?B/s]

dev-00000-of-00001.parquet:   0%|          | 0.00/4.84k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/173 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/22 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

test-00000-of-00001.parquet:   0%|          | 0.00/16.4k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/5.63k [00:00<?, ?B/s]

dev-00000-of-00001.parquet:   0%|          | 0.00/3.77k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Dataset({
    features: ['question', 'subject', 'choices', 'answer'],
    num_rows: 200
})

In [28]:
ds_mmlu[0]

{'question': 'Mitochondria isolated and placed in a buffered solution with a low pH begin to manufacture ATP. Which of the following is the best explanation for the effect of low external pH?',
 'subject': 'college_biology',
 'choices': ['It increases the concentration of OH-, causing the mitochondria to pump H+ to the intermembrane space.',
  'It increases the OH- concentration in the mitochondria matrix.',
  'It increases the acid concentration in the mitochondria matrix.',
  'It increases diffusion of H+ from the intermembrane space to the matrix.'],
 'answer': 3}

#### Helper functions definition

In [29]:
def format_prompt_mmlu(example):
    """Formats a single example into a prompt for the LLM."""
    question = example['question']
    options = {
        "A": example['choices'][0],
        "B": example['choices'][1],
        "C": example['choices'][2],
        "D": example['choices'][3]
    }
    
    prompt = f"""
You are an expert in solving multiple-choice questions accurately and explaining your reasoning clearly.
Given a question and a list of answer choices (A, B, C, D), your task is to:
1. Reason shortly about the question and answer choices to find evidances to support your answer.
2. Identify the correct answer. Please choose the single best answer from the options provided.
3. Output the final answer in the format: Answer: [Option Letter]

Question: {question}
Options:
A. {options['A']}
B. {options['B']}
C. {options['C']}
D. {options['D']}

Reasoning:
    """
    return prompt

In [30]:
def get_ground_truth_mmlu(example):
    """Maps the label to the corresponding letter."""
    mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D'}
    label = example.get('answer')
    if label is None or label not in mapping:
        print(f"Warning: Invalid 'cop' value found: {label} in example ID {example.get('id')}. Skipping ground truth.")
        return None
    return mapping[label]

#### Evaluation

In [31]:
print("\n--- Preparing Prompts and Ground Truths ---")
prompts = [format_prompt_mmlu(ex) for ex in tqdm(ds_mmlu, desc="Formatting prompts")]
ground_truths = [get_ground_truth_mmlu(ex) for ex in tqdm(ds_mmlu, desc="Extracting ground truths")]
valid_indices = [i for i, gt in enumerate(ground_truths) if gt is not None]

if len(valid_indices) < len(ground_truths):
     print(f"Warning: {len(ground_truths) - len(valid_indices)} examples had invalid ground truths and were excluded.")
     prompts = [prompts[i] for i in valid_indices]
     ground_truths = [ground_truths[i] for i in valid_indices]
     original_indices = valid_indices

if len(prompts) > 0:
    print("\nExample Prompt:")
    print(prompts[0])
    print(f"Corresponding Ground Truth: {ground_truths[0]}")
else:
    print("No valid prompts to evaluate.")
    exit()


--- Preparing Prompts and Ground Truths ---


Formatting prompts: 100%|██████████| 200/200 [00:00<00:00, 8849.96it/s]
Extracting ground truths: 100%|██████████| 200/200 [00:00<00:00, 11444.53it/s]


Example Prompt:

You are an expert in solving multiple-choice questions accurately and explaining your reasoning clearly.
Given a question and a list of answer choices (A, B, C, D), your task is to:
1. Reason shortly about the question and answer choices to find evidances to support your answer.
2. Identify the correct answer. Please choose the single best answer from the options provided.
3. Output the final answer in the format: Answer: [Option Letter]

Question: Mitochondria isolated and placed in a buffered solution with a low pH begin to manufacture ATP. Which of the following is the best explanation for the effect of low external pH?
Options:
A. It increases the concentration of OH-, causing the mitochondria to pump H+ to the intermembrane space.
B. It increases the OH- concentration in the mitochondria matrix.
C. It increases the acid concentration in the mitochondria matrix.
D. It increases diffusion of H+ from the intermembrane space to the matrix.

Reasoning:
    
Correspond




In [32]:
print("\n--- Running Inference ---")
all_outputs_text = []
num_batches = math.ceil(len(prompts) / BATCH_SIZE)

for i in tqdm(range(num_batches), desc="Generating Responses"):
    start_idx = i * BATCH_SIZE
    end_idx = min((i + 1) * BATCH_SIZE, len(prompts))
    batch_prompts = prompts[start_idx:end_idx]
    outputs = llm.generate(batch_prompts, sampling_params, use_tqdm=False)
    batch_outputs_text = [output.outputs[0].text.strip() for output in outputs]
    all_outputs_text.extend(batch_outputs_text)

if len(all_outputs_text) > 0:
    print("\nExample Generated Text (raw):")
    print(all_outputs_text[0])


--- Running Inference ---


Generating Responses: 100%|██████████| 50/50 [32:06<00:00, 38.53s/it]


Example Generated Text (raw):
A. It increases the concentration of OH-, causing the mitochondria to pump H+ to the intermembrane space.
     B. It increases the OH- concentration in the mitochondria matrix.
     C. It increases the acid concentration in the mitochondria matrix.
     D. It increases diffusion of H+ from the intermembrane space to the matrix.

Your answer should be either 'A', 'B', 'C', or 'D'.
To solve this problem, I need to understand how the pH of the environment affects the mitochondrial membrane. The mitochondrial membrane is a double-layered membrane that allows the passage of ions and small molecules. The pH of the environment is crucial because it determines the concentration of hydrogen ions (H+) and hydroxide ions (OH-) in the mitochondrial matrix.

When the pH of the environment is low, it means that the concentration of OH- ions is high. This high OH- concentration can act as a proton source, meaning it can donate protons (H+) to the mitochondrial matrix. T




In [33]:
print("\n--- Extracting Predictions ---")
predictions = [extract_choice_mcq(text) for text in tqdm(all_outputs_text, desc="Extracting choices")]
num_invalid_responces = predictions.count(None)
print(f"\n------------------------------\nNumber of invalid responces: {num_invalid_responces}")

if len(predictions) > 0:
    print("\nExample Extracted Prediction:")
    print(predictions[0])


--- Extracting Predictions ---


Extracting choices: 100%|██████████| 200/200 [00:00<00:00, 3664.62it/s]

     Your explanation should be related to David's self-...on himself when he cannot master a section of one of his pieces. Which of the following answers best'
     Yo...e order Archaea, Eukarya, and Bacteria.
     You may assume that the domains are in the order Archae'
"

The boy's symptoms of fever, ...ponsible for the boy's bone infection and the symptoms of fever, swelling, and recurrent infections.'
     - The m...n of the mandible.
     - The lateral pterygoid muscle is the muscle that initiates the elevation of'
     Your reasoning should ....
     Your final answer should be selected from the options given.
     Your final answer should be'

Final Answer:

</think>
The question is about identifying the class of antibodies that are charac...ent on the surface of the plasma cells themselves.

Therefore, the correct answer is IgM antibodies.'
     - The CFTR gene i...uired for the production of cystatin, which is essential for the proper functioning of the lungs and'
     - The 




In [34]:
print("\n--- Calculating Metrics ---")
correct_count = 0
total_count = len(predictions)
results_by_subject = {}

if total_count != len(ground_truths):
     print(f"Warning: Mismatch between number of predictions ({total_count}) and ground truths ({len(ground_truths)}). This should not happen.")
     total_count = min(total_count, len(ground_truths))

for i in range(total_count):
    original_data_index = original_indices[i] if 'original_indices' in locals() else i
    data_item = ds_mmlu[original_data_index]
    subject = data_item.get('subject', 'Unknown')

    pred = predictions[i]
    truth = ground_truths[i]
    is_correct = (pred == truth)

    if subject not in results_by_subject:
        results_by_subject[subject] = {'correct': 0, 'total': 0}

    if is_correct:
        correct_count += 1
        results_by_subject[subject]['correct'] += 1
    results_by_subject[subject]['total'] += 1

overall_accuracy = (correct_count / total_count) * 100 if total_count > 0 else 0


--- Calculating Metrics ---


In [35]:
print("\n--- Evaluation Results ---")
print(f"Model Evaluated: {MODEL_NAME}")
print(f"Dataset Used: {DATASET_MMLU}")
print(f"Number of Questions Evaluated: {total_count}")
print(f"Number of Correct Answers: {correct_count}")
print(f"Overall Accuracy: {overall_accuracy:.2f}%")

print("\nAccuracy by Subject:")
sorted_subjects = sorted(results_by_subject.keys())
for subject in sorted_subjects:
    counts = results_by_subject[subject]
    sub_acc = (counts['correct'] / counts['total']) * 100 if counts['total'] > 0 else 0
    print(f"- {subject}: {sub_acc:.2f}% ({counts['correct']}/{counts['total']})")


--- Evaluation Results ---
Model Evaluated: MilyaShams/DeepSeek-R1-Distill-Qwen-1.5B-medical-sft-merged
Dataset Used: cais/mmlu
Number of Questions Evaluated: 200
Number of Correct Answers: 57
Overall Accuracy: 28.50%

Accuracy by Subject:
- anatomy: 29.03% (9/31)
- clinical_knowledge: 31.03% (9/29)
- college_biology: 18.52% (5/27)
- college_medicine: 20.00% (6/30)
- medical_genetics: 50.00% (14/28)
- professional_medicine: 25.45% (14/55)


### 4. PubMedQA

#### Dataset loading and preparing

In [5]:
SEED = 4242
BATCH_SIZE = 4
NUM_SAMPLES = 200
DATASET_PUBMEDQA = "qiaojin/PubMedQA"
SUBSET_PUBMEDQA = "pqa_labeled"
SPLIT_PUBMEDQA = "train"

In [6]:
ds_pubmedqa = load_dataset(DATASET_PUBMEDQA, SUBSET_PUBMEDQA, split=SPLIT_PUBMEDQA)
ds_pubmedqa = ds_pubmedqa.shuffle(seed=SEED).select(range(NUM_SAMPLES))
ds_pubmedqa

README.md:   0%|          | 0.00/5.19k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/1.08M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Dataset({
    features: ['pubid', 'question', 'context', 'long_answer', 'final_decision'],
    num_rows: 200
})

In [7]:
ds_pubmedqa[0]

{'pubid': 22504515,
 'question': 'Endovenous laser ablation in the treatment of small saphenous varicose veins: does site of access influence early outcomes?',
 'context': {'contexts': ['The study was performed to evaluate the clinical and technical efficacy of endovenous laser ablation (EVLA) of small saphenous varicosities, particularly in relation to the site of endovenous access.',
   'Totally 59 patients with unilateral saphenopopliteal junction incompetence and small saphenous vein reflux underwent EVLA (810 nm, 14 W diode laser) with ambulatory phlebectomies. Small saphenous vein access was gained at the lowest site of truncal reflux. Patients were divided into 2 groups: access gained above mid-calf (AMC, n = 33) and below mid-calf (BMC, n = 26) levels. Outcomes included Venous Clinical Severity Scores (VCSS), Aberdeen Varicose Vein Questionnaire (AVVQ), patient satisfaction, complications, and recurrence rates.',
   'Both groups demonstrated significant improvement in VCSS, AVV

#### Helper functions definition

In [8]:
def format_prompt_pubmedqa(example):
    """Formats a single example into a prompt for the LLM."""
    question = example['question']
    if not isinstance(example.get('context'), dict) or 'contexts' not in example['context']:
        print(f"Warning: Skipping example due to missing or invalid context field.")
        return None

    context_passages = example['context']['contexts']
    full_context = "\n\n".join(context_passages)

    prompt = f"""
You are an expert in analyzing scientific texts and answering questions based on provided context and explaining your reasoning clearly.
Your task is to determine the answer to the question ('yes', 'no', or 'maybe') based only on the information given in the context. Follow these steps:
1. Analyze the provided context in relation to the question. Summarize the key evidence (or lack thereof) relevant to answering the question. This is your reasoning.
2. Based on your reasoning from the context, determine if the answer to the question is 'yes', 'no', or 'maybe'.
3. Output your reasoning first. After the reasoning, start a new line and provide the final decision in the specific format: Answer: [yes/no/maybe]

Context:
{full_context}

Question: {question}

Reasoning:
    """
    return prompt

In [9]:
def get_ground_truth_pubmedqa(example):
    """Extracts the ground truth ('yes', 'no', 'maybe') from the example."""
    decision = example.get('final_decision')
    if decision not in ['yes', 'no', 'maybe']:
        print(f"Warning: Invalid 'final_decision' value found: {decision}. Skipping ground truth.")
        return None
    return decision

In [10]:
def extract_yes_no_maybe(generated_text):
    """Extracts the predicted choice (yes, no, maybe) from the LLM's output."""
    text = generated_text.strip().lower()

    # Explicit "Answer: yes/no/maybe" potentially followed by punctuation/eos
    match = re.search(r'(?:answer|decision)\s*[:\-]?\s*(yes|no|maybe)\b', text)
    if match:
        return match.group(1)

    # Look for the first occurrence of "yes", "no", or "maybe" as a whole word
    match = re.search(r'\b(yes|no|maybe)\b', text)
    if match:
        return match.group(1)

    # Fallback - If no clear choice found, return None
    print(f"Warning: Could not extract answer from text: '{text[:100]}...{text[-100:]}'")
    return None

#### Evaluation

In [11]:
print("\n--- Preparing Prompts and Ground Truths ---")
prompts = []
ground_truths_raw = []
original_indices_map = []

for i, ex in enumerate(tqdm(ds_pubmedqa, desc="Formatting prompts")):
    prompt = format_prompt_pubmedqa(ex)
    if prompt:
        prompts.append(prompt)
        ground_truths_raw.append(get_ground_truth_pubmedqa(ex))
        original_indices_map.append(i)

valid_indices = [i for i, gt in enumerate(ground_truths_raw) if gt is not None]

if len(valid_indices) < len(prompts):
     invalid_gt_count = len(prompts) - len(valid_indices)
     print(f"Warning: {invalid_gt_count} examples had invalid ground truths and were excluded.")
     prompts = [prompts[i] for i in valid_indices]
     ground_truths = [ground_truths_raw[i] for i in valid_indices]
     original_indices = [original_indices_map[i] for i in valid_indices]
else:
    ground_truths = ground_truths_raw
    original_indices = original_indices_map

if len(prompts) > 0:
    print("\nExample Prompt:")
    print(prompts[0])
    print(f"Corresponding Ground Truth: {ground_truths[0]}")
else:
    print("No valid prompts to evaluate.")
    exit()


--- Preparing Prompts and Ground Truths ---


Formatting prompts: 100%|██████████| 200/200 [00:00<00:00, 5963.58it/s]


Example Prompt:

You are an expert in analyzing scientific texts and answering questions based on provided context and explaining your reasoning clearly.
Your task is to determine the answer to the question ('yes', 'no', or 'maybe') based only on the information given in the context. Follow these steps:
1. Analyze the provided context in relation to the question. Summarize the key evidence (or lack thereof) relevant to answering the question. This is your reasoning.
2. Based on your reasoning from the context, determine if the answer to the question is 'yes', 'no', or 'maybe'.
3. Output your reasoning first. After the reasoning, start a new line and provide the final decision in the specific format: Answer: [yes/no/maybe]

Context:
The study was performed to evaluate the clinical and technical efficacy of endovenous laser ablation (EVLA) of small saphenous varicosities, particularly in relation to the site of endovenous access.

Totally 59 patients with unilateral saphenopopliteal jun




In [12]:
print("\n--- Running Inference ---")
all_outputs_text = []
num_batches = math.ceil(len(prompts) / BATCH_SIZE)

for i in tqdm(range(num_batches), desc="Generating Responses"):
    start_idx = i * BATCH_SIZE
    end_idx = min((i + 1) * BATCH_SIZE, len(prompts))
    batch_prompts = prompts[start_idx:end_idx]
    outputs = llm.generate(batch_prompts, sampling_params, use_tqdm=False)
    batch_outputs_text = [output.outputs[0].text.strip() for output in outputs]
    all_outputs_text.extend(batch_outputs_text)

if len(all_outputs_text) > 0:
    print("\nExample Generated Text (raw):")
    print(all_outputs_text[0])


--- Running Inference ---


Generating Responses: 100%|██████████| 50/50 [35:47<00:00, 42.95s/it]


Example Generated Text (raw):
...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
  




In [13]:
print("\n--- Extracting Predictions ---")
predictions = [extract_yes_no_maybe(text) for text in tqdm(all_outputs_text, desc="Extracting choices")]
num_invalid_responсes = predictions.count(None)
print(f"\n------------------------------\nNumber of invalid responces: {num_invalid_responсes}")

if len(predictions) > 0:
    print("\nExample Extracted Prediction:")
    print(predictions[0])


--- Extracting Predictions ---


Extracting choices: 100%|██████████| 200/200 [00:00<00:00, 5071.28it/s]

     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     .....
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...'
final decision:
</think>
the question is asking whether routine laboratory markers are useful fo...nash from ash, as they were identified as relevant regressors and supported by the study's findings.'
     the ... the treatment of amblyopia normalises subfoveal choroidal thickness in amblyopic children.
     the'
     - therefore, the study's findings suggest that having a regular clinician'
 ... pregnancy were 1.6 (95% ci 1.1 to 2.5).
     - the risk estimates for tubal infertility were 4.8 (9'
     2. adma levels wer...ma levels were significantly correlated inversely with the chronological age of the subject.
     18'
     - the study included patients from 16 different centers, which is'
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...
     ...




In [14]:
print("\n--- Calculating Metrics ---")
correct_count = 0
total_count = len(predictions)
results_by_subject = {}

if total_count != len(ground_truths):
     print(f"Warning: Mismatch between number of predictions ({total_count}) and ground truths ({len(ground_truths)}). This should not happen.")
     total_count = min(total_count, len(ground_truths))

for i in range(total_count):
    original_data_index = original_indices[i] if 'original_indices' in locals() else i
    data_item = ds_pubmedqa[original_data_index]
    subject = data_item.get('subject_name', 'Unknown')

    pred = predictions[i]
    truth = ground_truths[i]
    is_correct = (pred == truth)

    if subject not in results_by_subject:
        results_by_subject[subject] = {'correct': 0, 'total': 0}

    if is_correct:
        correct_count += 1
        results_by_subject[subject]['correct'] += 1
    results_by_subject[subject]['total'] += 1

overall_accuracy = (correct_count / total_count) * 100 if total_count > 0 else 0


--- Calculating Metrics ---


In [15]:
print("\n--- Evaluation Results ---")
print(f"Model Evaluated: {MODEL_NAME}")
print(f"Dataset Used: {DATASET_PUBMEDQA}")
print(f"Number of Questions Evaluated: {total_count}")
print(f"Number of Correct Answers: {correct_count}")
print(f"Overall Accuracy: {overall_accuracy:.2f}%")


--- Evaluation Results ---
Model Evaluated: MilyaShams/DeepSeek-R1-Distill-Qwen-1.5B-medical-sft-merged
Dataset Used: qiaojin/PubMedQA
Number of Questions Evaluated: 200
Number of Correct Answers: 72
Overall Accuracy: 36.00%
