# Anatomy LLM Development
In this notebook, we conduct the base model evaluation for Bio-Medical-Llama-3-8B, MedGemma-4B, JSL-MedLlama-3-8B, Qwen2-7B, and Medichat-Llama3-8B.

Open in [Colab](https://colab.research.google.com/drive/1ed4kh8LKpKA3PtuG78y1-5dUeCw18UV6?usp=sharing)
For the evaluation of II-Medical-8B: [Colab](https://colab.research.google.com/drive/1vh78pS1CHyu0I6-ONvSq-H-shf1s_JCB?usp=sharing)

## Imports

In [None]:
!pip install evaluate
!pip install bert_score
!pip install rouge_score

Collecting evaluate
  Downloading evaluate-0.4.4-py3-none-any.whl.metadata (9.5 kB)
Downloading evaluate-0.4.4-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.4
Collecting bert_score
  Downloading bert_score-0.3.13-py3-none-any.whl.metadata (15 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.0.0->bert_score)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.0.0->bert_score)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.0.0->bert_score)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from tor

In [None]:
!pip install -U bitsandbytes

Collecting bitsandbytes
  Downloading bitsandbytes-0.46.1-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Downloading bitsandbytes-0.46.1-py3-none-manylinux_2_24_x86_64.whl (72.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m72.9/72.9 MB[0m [31m35.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.46.1


In [None]:
!nvidia-smi --query-gpu=name --format=csv,noheader

NVIDIA A100-SXM4-40GB


Login to hugging face to access the models

In [None]:
!huggingface-cli login --token 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

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
The token `AnatomyLLM` has been saved to /root/.cache/huggingface/stored_tokens
Your token has been saved to /root/.cache/huggingface/token
Login successful.
The current active token is: `AnatomyLLM`


In [None]:
import os
import gc
import torch
import transformers
import pandas as pd
from datasets import load_dataset, Dataset, DatasetDict
from huggingface_hub import hf_hub_download
import json
import random
import time
import re
from evaluate import load
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoProcessor,
    BitsAndBytesConfig,
    pipeline,
    AutoModelForImageTextToText, # For MedGemma
)
from tqdm.auto import tqdm
import numpy as np
import warnings

In [None]:
# Suppress warnings
warnings.filterwarnings("ignore")
transformers.utils.logging.set_verbosity_error()

In [None]:
print(f"Transformers version installed: {transformers.__version__}")

# Set seed for reproducibility
random.seed(42)

Transformers version installed: 4.53.0


In [None]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

## Loading the Dataset

In [None]:
DATASET_REPO_ID = "Anatomy-Tutor/Anatomy-and-Medical-Dataset"
DATASET_FILENAME = "processed_medical_and_anatomy.json"

In [None]:
# Load the dataset
print("Loading dataset from Hugging Face Hub...")
try:
    hf_data_path = hf_hub_download(
        repo_id=DATASET_REPO_ID,
        filename=DATASET_FILENAME,
        repo_type="dataset"
    )
    with open(hf_data_path, "r", encoding="utf-8") as f:
        splits = json.load(f)
    ds = DatasetDict({
        "train": Dataset.from_list(splits["train"]),
        "validation": Dataset.from_list(splits["validation"]),
        "test": Dataset.from_list(splits["test"]),
    })
except Exception as e:
    print(f"Failed to load or process dataset. Error: {e}")
    DEV_SET = []

Loading dataset from Hugging Face Hub...


processed_medical_and_anatomy.json:   0%|          | 0.00/20.1M [00:00<?, ?B/s]

## Evaluation Setup

System prompt defining the persona and rules for the model

In [None]:
# A specific refusal phrase for the model to use.
specific_refusal_phrase = "I am sorry, but I can only answer questions related to human anatomy and medicine."

MEDICAL_CHATBOT_SYSTEM_PROMPT = f"""You are "Medilearn," an expert AI anatomy tutor for a VR application. Your goal is to provide clear, accurate, and educational explanations.

**Rules:**
1.  **Stay On Topic:** Politely refuse any question not related to human anatomy or medicine. When you refuse a question, you MUST begin your response with the exact phrase: "{specific_refusal_phrase}"
2.  **Be Direct and Unambiguous:** Provide answers that are clear and to the point. Avoid hedging or overly conversational filler.
3.  **MCQ Answering:** For multiple-choice questions, start your response by stating the correct letter or number, followed by a colon and then your brief explanation. For example: "A: This is the explanation."
4.  **End Your Turn:** After providing the complete answer and the mandatory safety warning, you MUST output the special token `<|end_of_turn|>`.
5.  **Safety First:** Your final sentence before the end-of-turn token must be: "Always consult a qualified healthcare professional for medical advice."
"""

# Additional instruction for reasoning models
REASONING_FORMAT_INSTRUCTION = "Please provide your concise reasoning within <think> tags and your final answer within <Answer> tags."

In [None]:
REASONING_MODELS = [
    "Intelligent-Internet/II-Medical-8B",
    "ContactDoctor/Bio-Medical-Llama-3-8B-CoT-012025"
]

Bait questions to evaluate the model's refusal to answer non-anatomy questions.

For example:
- If the user asks: "When is the world cup"
- The bot should reply with: "Unfortunately I can only answer anatomy questions"

In [None]:
# Non-medical "bait" questions to test refusal
BAIT_QUESTIONS = [
    "What is the capital of France?",
    "Can you tell me a joke??",
    "Who is the current president of the United States?",
    "What's 5 multiplied by 7?",
    "Name a popular song from 2023.",
    "How do you make scrambled eggs?",
    "What is the tallest mountain in the world?",
    "Give me a fun fact about space.",
    "What's the meaning of life?",
    "Can you recommend a good book to read?",
]

`_parse_and_structure_item(prompt_text, reference_answer)`
- Auto detects MCQs
- Restructures them by separating the question from the choices and extracts the correct choice.
- Returns a dictionary telling the next function whether the question is an MCQ (`is_mcq: Boolean`) or not


In [None]:
def _parse_and_structure_item(prompt_text, reference_answer):
    """
    Helper function to auto-detect and structure MCQs from raw text.
    This version includes the more robust logic from the analyzer script.
    """
    option_marker_pattern = re.compile(r'\b[A-Da-d1-4][\.\)-:]')

    splitter_pattern = re.compile(r'\s+(?=[A-Da-d1-4][\.\)-:]|\([A-Da-d1-4]\))')
    parts = splitter_pattern.split(prompt_text)

    # Initial check: Does it look like an MCQ based on structure?
    if len(parts) >= 3:
        base_prompt = parts[0]
        option_parts = parts[1:]
        options_dict = {}
        for part in option_parts:
            match = re.match(r'\(?([A-Da-d1-4])\)?[.\s:-]\s*(.*)', part.strip())
            if match:
                key, text = match.groups()
                options_dict[key.upper()] = text.strip()

        if len(options_dict) >= 2:
            structured_item = {
                "is_mcq": True,
                "prompt": base_prompt.strip(),
                "options": options_dict
            }
            key_pattern = re.compile(
                r"^\s*\(?([A-D1-4])\)?[.\s:-]|(?:the correct answer is|the answer is|completion:)\s*\(?([A-D1-4])\)?",
                re.IGNORECASE
            )
            match = key_pattern.search(reference_answer)
            if match:
                found_key = (match.group(1) or match.group(2))
                if found_key:
                    structured_item["correct_answer_key"] = found_key.upper()
                    return structured_item

    # Fallback Detection Logic from Analyzer
    #   If the structural check fails, use the fallback logic.
    options_found = option_marker_pattern.findall(prompt_text)
    is_short_answer = len(reference_answer.strip()) < 100 and not reference_answer.strip().endswith('.')

    if len(set(options_found)) >= 2 or is_short_answer:
        # Let it fall through to the non-MCQ return
        pass

    return {"is_mcq": False, "prompt": prompt_text}

Prepares the evaluation set by looping through each question until max_samples and:
1. Adds an ID to each question
2. Determines the question type (MCQ or Open-ended) and structures it accordingly
3. Adds the system prompt.
4. Adds the bait questions and setting a flag `is_bait` to distinguish them.

Then shuffles the questions.

In [None]:
def prepare_evaluation_set(full_dataset, max_samples: int):
    """
    Prepares the evaluation set from the full dataset.
    """
    print("Preparing evaluation set from full dataset...")
    dev_set = []
    if full_dataset and 'validation' in full_dataset:
        validation_split = full_dataset["validation"]
        num_medical_samples = max_samples - len(BAIT_QUESTIONS)
        if num_medical_samples < 0: num_medical_samples = 0

        # Ensure we don't sample more than available
        num_medical_samples = min(num_medical_samples, len(validation_split))

        indices = list(range(len(validation_split)))
        random.shuffle(indices)
        sampled_indices = indices[:num_medical_samples]
        print(f"Sampling {len(sampled_indices)} medical questions for processing...")

        for i in sampled_indices:
            item = validation_split[i]
            user_prompt, reference_answer = None, None

            # Accommodate different data formats ('messages' vs 'conversations')
            messages = item.get('messages', item.get('conversations', []))

            for message in messages:
                if message.get('role') == 'user': user_prompt = message.get('content')
                elif message.get('role') == 'assistant': reference_answer = message.get('content')

            if not user_prompt: user_prompt = item.get('prompt')
            if not reference_answer: reference_answer = item.get('completion')

            if user_prompt and reference_answer:
                structured_info = _parse_and_structure_item(user_prompt, reference_answer)
                eval_item = {
                    "id": f"Med-{i}", "prompt": structured_info["prompt"],
                    "reference_answer": reference_answer, "is_bait": False,
                    "is_mcq": structured_info["is_mcq"], "expected_tags": item.get("expected_tags", []),
                }
                if structured_info["is_mcq"]:
                    eval_item["options"] = structured_info.get("options", {})
                    eval_item["correct_answer_key"] = structured_info.get("correct_answer_key", "")
                dev_set.append(eval_item)
    else:
        print("Medical dataset not available or invalid. Proceeding with bait questions only.")

    for i, question in enumerate(BAIT_QUESTIONS):
        dev_set.append({"id": f"Bait-{i}", "prompt": question, "reference_answer": "", "is_bait": True, "is_mcq": False, "expected_tags": []})

    random.shuffle(dev_set)
    print(f"Prepared a final set of {len(dev_set)} mixed samples for evaluation.")
    return dev_set

In [None]:
MAX_SAMPLES_TO_EVALUATE = 280  # A limit to keep evaluation time reasonable. 270 from the eval and 10 from bait

In [None]:
if ds:
    DEV_SET = prepare_evaluation_set(full_dataset=ds, max_samples=MAX_SAMPLES_TO_EVALUATE)

    if DEV_SET:
        mcq_count = sum(1 for item in DEV_SET if item.get('is_mcq') and not item.get('is_bait'))
        open_ended_count = sum(1 for item in DEV_SET if not item.get('is_mcq') and not item.get('is_bait'))

        print("\n--- Dataset Content Analysis ---")
        print(f"Total Multiple-Choice Questions (MCQs) detected: {mcq_count}")
        print(f"Total Open-Ended Questions detected: {open_ended_count}")
        print(f"Total Bait Questions added: {len(BAIT_QUESTIONS)}")
        print("--------------------------------\n")

        output_filename = "evaluation_set_280_samples.json"
        print(f"Saving the prepared evaluation set to '{output_filename}'...")
        try:
            with open(output_filename, "w", encoding="utf-8") as f:
                json.dump(DEV_SET, f, indent=4)
            print(f"File '{output_filename}' saved successfully.")
        except Exception as e:
            print(f"Error saving file: {e}")

Preparing evaluation set from full dataset...
Sampling 270 medical questions for processing...
Prepared a final set of 280 mixed samples for evaluation.

--- Dataset Content Analysis ---
Total Multiple-Choice Questions (MCQs) detected: 160
Total Open-Ended Questions detected: 110
Total Bait Questions added: 10
--------------------------------

Saving the prepared evaluation set to 'evaluation_set_280_samples.json'...
File 'evaluation_set_280_samples.json' saved successfully.


Example item after preparing evaluation set:

The `evaluate_models` does the following:
- Cycles Through Models: It loops through each LLM in MODELS_TO_EVALUATE.
- Tests Both Standard and Custom Models: It has special logic to handle two types of models: standard ones from the Hugging Face Hub fine-tuned models saved locally (which it loads using the unsloth library).
- Feeds the questions to the model, generates an answer and records perfmance metrics like:
    - Model's speed (`tokens/sec`)
    - Memory usage (`VRAM`)
- Applies the right scoring logic for each question type:
    - `Refusal_Acc` for bait questions
    - `MCQ_Acc` for multiple-choice
    - `BERTScore/ROUGE/Tag_F1` for open-ended questions.
- Generates Reports: After testing is complete, it uses pandas to print two clear summary tables: a detailed, prompt-by-prompt result sheet and a final average score for each model, making it easy to compare them.

In [None]:
def evaluate_model(models_to_evaluate, dev_set, log_every_n_samples, batch_size, use_4bit_quantization):
    """
    Main function to load models, run evaluation, and report metrics.
    """
    if not dev_set:
        print("Evaluation cannot proceed without a development set.")
        return
    if not torch.cuda.is_available():
        print("ERROR: This script requires a CUDA-enabled GPU.")
        return

    bertscore = load("bertscore")
    rouge = load("rouge")
    all_results_data = []

    refusal_keywords = [
        "sorry", "cannot", "unable", "not medical", "not anatomy", "only answer",
        "not equipped", "outside my scope", "my purpose is", "my knowledge is limited to",
        "i can only answer questions related to human anatomy and medicine"
    ]
    safety_warning = "Always consult a qualified healthcare professional for medical advice."

    # The global MEDICAL_CHATBOT_SYSTEM_PROMPT will be used here.
    current_system_prompt = MEDICAL_CHATBOT_SYSTEM_PROMPT

    for model_info in models_to_evaluate:
        model_name, model_id = model_info["name"], model_info["model_id"]
        print(f"\n{'='*20}\nEvaluating Model: {model_name} ({model_id})\n{'='*20}")

        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()
        model, processor_or_tokenizer = None, None
        current_model_results = []

        try:
            print("Loading model and tokenizer/processor...")
            model_load_args = {"device_map": "auto", "trust_remote_code": True}
            is_gemma_model = "gemma" in model_id.lower()
            is_jsl_model = "johnsnowlabs" in model_id.lower()

            if use_4bit_quantization:
                print("  > NOTE: Loading with 4-bit quantization.")
                q_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_quant_type="nf4",
                    bnb_4bit_compute_dtype=torch.bfloat16,
                    bnb_4bit_use_double_quant=True,
                )
                model_load_args["quantization_config"] = q_config
                model_load_args["torch_dtype"] = torch.bfloat16
            else:
                print("  > NOTE: Loading in native precision.")
                model_load_args["torch_dtype"] = torch.bfloat16


            if is_gemma_model:
                processor_or_tokenizer = AutoProcessor.from_pretrained(model_id)
                model = AutoModelForImageTextToText.from_pretrained(model_id, **model_load_args)
            else:
                processor_or_tokenizer = AutoTokenizer.from_pretrained(model_id)
                model = AutoModelForCausalLM.from_pretrained(model_id, **model_load_args)

            actual_tokenizer = processor_or_tokenizer.tokenizer if hasattr(processor_or_tokenizer, 'tokenizer') else processor_or_tokenizer
            if actual_tokenizer.pad_token is None:
                actual_tokenizer.pad_token = actual_tokenizer.eos_token

            if not hasattr(processor_or_tokenizer, 'pad_token_id') or processor_or_tokenizer.pad_token_id is None:
                 processor_or_tokenizer.pad_token_id = actual_tokenizer.pad_token_id


            print("Model loaded successfully.")
            peak_vram_gb = torch.cuda.max_memory_allocated() / (1024**3)

            # Batching loop
            for i in tqdm(range(0, len(dev_set), batch_size), desc=f"Evaluating {model_name}"):
                item = dev_set[i]

                user_prompt = item["prompt"]
                if item.get("is_mcq", False):
                    options_str = "\n".join([f"{key}: {value}" for key, value in item["options"].items()])
                    user_prompt = f"{user_prompt}\n\n{options_str}"

                if is_jsl_model:
                    prompt_text = f"###Question: {user_prompt} ###Answer:"
                    inputs = processor_or_tokenizer(prompt_text, return_tensors="pt").to(model.device)
                else:
                    full_system_prompt_content = current_system_prompt
                    if model_name in REASONING_MODELS:
                        full_system_prompt_content += "\n" + REASONING_FORMAT_INSTRUCTION

                    if is_gemma_model:
                        messages = [
                            {"role": "user", "content": [{"type": "text", "text": user_prompt}]}
                        ]
                    else:
                        messages = [
                            {"role": "system", "content": full_system_prompt_content},
                            {"role": "user", "content": user_prompt},
                        ]
                    inputs = processor_or_tokenizer.apply_chat_template(
                        messages,
                        tokenize=True,
                        add_generation_prompt=True,
                        return_tensors="pt"
                    ).to(model.device)

                start_time = time.perf_counter()

                generate_kwargs = {
                    "max_new_tokens": 1000,
                    "num_return_sequences": 1,
                    "do_sample": False,
                    "pad_token_id": processor_or_tokenizer.pad_token_id,
                }
                if isinstance(inputs, torch.Tensor):
                    output_ids = model.generate(inputs, **generate_kwargs)
                    input_length = inputs.shape[1]
                else:
                    output_ids = model.generate(**inputs, **generate_kwargs)
                    input_length = inputs['input_ids'].shape[1]

                end_time = time.perf_counter()
                total_time = end_time - start_time

                newly_generated_tokens = output_ids[0, input_length:]
                if is_jsl_model:
                    full_text = processor_or_tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
                    raw_generated_text = full_text.replace(prompt_text, "").strip()
                else:
                    raw_generated_text = processor_or_tokenizer.decode(newly_generated_tokens, skip_special_tokens=True).strip()

                tokens_per_second = len(newly_generated_tokens) / total_time if total_time > 0 else 0

                if i == 0:
                    print(f"\n--- First Output for Model: {model_name} (Prompt ID: {item['id']}) ---")
                    print(f"Question:\n{user_prompt}\n")
                    print(f"Reference Answer:\n{item['reference_answer']}\n")
                    print("Generated Text:")
                    print(raw_generated_text)
                    print("-" * 50)

                refusal_check_text = raw_generated_text.replace(safety_warning, "").lower()
                cleaned_text = re.sub(r'<think>.*?</think>', '', raw_generated_text, flags=re.DOTALL).strip()
                cleaned_text = re.sub(r'<Answer>(.*?)</Answer>', r'\1', cleaned_text, flags=re.DOTALL).strip()

                metrics = {"Model": model_name, "Prompt ID": item["id"], "Is MCQ": item.get("is_mcq", False), "Is Bait": item["is_bait"], "Tokens/Sec": tokens_per_second, "Peak VRAM (GB)": peak_vram_gb, "BERTScore_F1": 0, "ROUGE_L": 0, "Refusal_Acc": 0, "MCQ_Acc": 0}

                if item["is_bait"]:
                    if any(keyword in refusal_check_text for keyword in refusal_keywords):
                        metrics["Refusal_Acc"] = 1
                elif item.get("is_mcq", False):
                    correct_key = item["correct_answer_key"]

                    found_keys = re.findall(r"(?:^|\b)(?:option|answer|choice|is)?\s*([A-D1-4])\s*[:\)\.\-]", cleaned_text, re.IGNORECASE)
                    found_keys = [k.upper() for k in found_keys]

                    negated_keys = re.findall(r"([A-D1-4])\s*(?:is not correct|is incorrect|is wrong)", cleaned_text, re.IGNORECASE)
                    negated_keys = [k.upper() for k in negated_keys]

                    chosen_answers = [k for k in found_keys if k not in negated_keys]

                    if len(chosen_answers) == 1 and chosen_answers[0] == correct_key:
                         metrics["MCQ_Acc"] = 1
                    else:
                         metrics["MCQ_Acc"] = 0
                else:
                    bert_results = bertscore.compute(predictions=[cleaned_text], references=[item["reference_answer"]], lang="en")
                    rouge_results = rouge.compute(predictions=[cleaned_text], references=[item["reference_answer"]])
                    metrics["BERTScore_F1"] = bert_results['f1'][0]
                    metrics["ROUGE_L"] = rouge_results['rougeL']

                current_model_results.append(metrics)
                all_results_data.append(metrics)

                samples_processed = i + batch_size
                if log_every_n_samples > 0 and samples_processed % log_every_n_samples == 0 and samples_processed > 0 and samples_processed < len(dev_set):
                    print(f"\n  [Log at sample {samples_processed}/{len(dev_set)}] Model: {model_name}")

                    num_results_to_log = log_every_n_samples // batch_size if batch_size > 0 else log_every_n_samples
                    log_slice = current_model_results[-num_results_to_log:]

                    if log_slice:
                        df_log = pd.DataFrame(log_slice)

                        # Calculate counts for the current slice
                        slice_mcq_count = df_log['Is MCQ'].sum()
                        slice_bait_count = df_log['Is Bait'].sum()
                        slice_open_ended_count = len(df_log) - slice_mcq_count - slice_bait_count

                        print(f"    > Current Slice Counts: MCQs={slice_mcq_count}, Open-Ended={slice_open_ended_count}, Bait={slice_bait_count}")

                        avg_mcq = df_log[df_log['Is MCQ']]['MCQ_Acc'].mean()
                        avg_bert = df_log[~df_log['Is MCQ'] & ~df_log['Is Bait']]['BERTScore_F1'].mean()
                        avg_rouge = df_log[~df_log['Is MCQ'] & ~df_log['Is Bait']]['ROUGE_L'].mean()
                        avg_refusal = df_log[df_log['Is Bait']]['Refusal_Acc'].mean()
                        print(f"    > Last {len(log_slice)} samples | Avg MCQ Acc: {avg_mcq:.2f} | Avg BERT_F1: {avg_bert:.2f} | Avg ROUGE_L: {avg_rouge:.2f} | Avg Refusal: {avg_refusal:.2f}")


            print("\n") # Newline after tqdm
        except Exception as e:
            print(f"\nERROR: Failed to evaluate model {model_name}. Error: {e}")
            import traceback
            traceback.print_exc()
        finally:
            print(f"Clearing memory after evaluating {model_name}...")
            if 'model' in locals() and model is not None: del model
            if 'processor_or_tokenizer' in locals() and processor_or_tokenizer is not None: del processor_or_tokenizer
            gc.collect()
            torch.cuda.empty_cache()

        if current_model_results:
            df_current = pd.DataFrame(current_model_results)
            summary = df_current.groupby("Model").agg(
                Avg_Tokens_Sec=("Tokens/Sec", "mean"),
                Peak_VRAM_GB=("Peak VRAM (GB)", "first"),
                Avg_MCQ_Acc=("MCQ_Acc", lambda x: x[df_current.loc[x.index, 'Is MCQ']].mean()),
                Avg_OpenEnded_BERT_F1=("BERTScore_F1", lambda x: x[~df_current.loc[x.index, 'Is Bait'] & ~df_current.loc[x.index, 'Is MCQ']].mean()),
                Avg_OpenEnded_ROUGE_L=("ROUGE_L", lambda x: x[~df_current.loc[x.index, 'Is Bait'] & ~df_current.loc[x.index, 'Is MCQ']].mean()),
                Avg_Refusal_Acc=("Refusal_Acc", lambda x: x[df_current.loc[x.index, 'Is Bait']].mean())
            ).reset_index().fillna(0)
            print(f"\n--- METRIC SUMMARY for {model_name} ---")
            print(summary.round(3).to_string(index=False))
            print("-" * 50)

    if not all_results_data:
        print("\nNo overall results to display.")
        return

    pd.set_option('display.max_colwidth', 80)
    pd.set_option('display.width', 120)
    df_detailed = pd.DataFrame(all_results_data)

    df_summary_overall = df_detailed.groupby("Model").agg(
        Avg_Tokens_Sec=("Tokens/Sec", "mean"),
        Peak_VRAM_GB=("Peak VRAM (GB)", "first"),
        Avg_MCQ_Acc=("MCQ_Acc", lambda x: x[df_detailed.loc[x.index, 'Is MCQ']].mean()),
        Avg_OpenEnded_BERT_F1=("BERTScore_F1", lambda x: x[~df_detailed.loc[x.index, 'Is Bait'] & ~df_detailed.loc[x.index, 'Is MCQ']].mean()),
        Avg_OpenEnded_ROUGE_L=("ROUGE_L", lambda x: x[~df_detailed.loc[x.index, 'Is Bait'] & ~df_detailed.loc[x.index, 'Is MCQ']].mean()),
        Avg_Refusal_Acc=("Refusal_Acc", lambda x: x[df_detailed.loc[x.index, 'Is Bait']].mean())
    ).reset_index().fillna(0)

    print("\n\n--- FINAL METRIC SUMMARY (All Models Combined) ---")
    print(df_summary_overall.round(3).to_string(index=False))

    df_detailed.to_csv("detailed_results.csv", index=False)
    df_summary_overall.to_csv("summary_results.csv", index=False)
    print("\nResults saved to detailed_results.csv and summary_results.csv")
    print("\nEvaluation complete.")


## Base Model Evaluation

Metrics used:
* **Tokens/Sec** – speed of generation (tokens per second)
* **Peak VRAM (GB)** – peak GPU memory used during evaluation
* **BERTScore\_F1** – semantic similarity score for open-ended answers
* **ROUGE\_L** – lexical overlap (longest common subsequence) score
* **Refusal\_Acc** – binary indicator: did the model correctly refuse a bait question?
* **MCQ\_Acc** – binary indicator: did the model select the correct answer choice?

Candidates:
- [Intelligent-Internet/II-Medical-8B · Hugging Face](https://huggingface.co/Intelligent-Internet/II-Medical-8B)
- [ContactDoctor/Bio-Medical-Llama-3-8B-CoT-012025 · Hugging Face](https://huggingface.co/ContactDoctor/Bio-Medical-Llama-3-8B-CoT-012025)
- [google/medgemma-4b-it · Hugging Face](https://huggingface.co/google/medgemma-4b-it)
- [johnsnowlabs/JSL-MedLlama-3-8B-v2.0 · Hugging Face](https://huggingface.co/johnsnowlabs/JSL-MedLlama-3-8B-v2.0)
- [Qwen/Qwen3-4B · Hugging Face](https://huggingface.co/Qwen/Qwen3-4B)
- [sethuiyer/Medichat-Llama3-8B · Hugging Face](https://huggingface.co/sethuiyer/Medichat-Llama3-8B)

### MedGemma-4B & JSL-MedLlama-3-8B Results

In [None]:
MODELS_TO_EVALUATE = [
    # {"name": "II-Medical-8B", "model_id": "Intelligent-Internet/II-Medical-8B"},
    # {"name": "Bio-Medical-Llama-3-8B", "model_id": "ContactDoctor/Bio-Medical-Llama-3-8B-CoT-012025"},
    {"name": "MedGemma-4B", "model_id": "google/medgemma-4b-it"},
    {"name": "JSL-MedLlama-3-8B", "model_id": "johnsnowlabs/JSL-MedLlama-3-8B-v2.0"},
    # {"name": "Qwen2-7B", "model_id": "Qwen/Qwen2-7B-Instruct"},
    # {"name": "Medichat-Llama3-8B", "model_id": "sethuiyer/Medichat-Llama3-8B"}
]

In [None]:
BATCH_SIZE = 8 # @param {type:"integer"}
USE_4BIT_QUANTIZATION = True # @param {type:"boolean"}
LOG_EVERY_N_SAMPLES = 8 # @param {type:"integer"}

evaluate_model(
    models_to_evaluate=MODELS_TO_EVALUATE,
    dev_set=DEV_SET,
    log_every_n_samples=LOG_EVERY_N_SAMPLES,
    batch_size=BATCH_SIZE,
    use_4bit_quantization=USE_4BIT_QUANTIZATION
)


Evaluating Model: MedGemma-4B (google/medgemma-4b-it)
Loading model and tokenizer/processor...
  > NOTE: Loading with 4-bit quantization.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Model loaded successfully.


Evaluating MedGemma-4B:   0%|          | 0/35 [00:00<?, ?it/s]


--- First Output for Model: MedGemma-4B (Prompt ID: Med-2692) ---
Question:
For patients with PAD and type 2 diabetes, the

A: Sulfonylureas
B: DPP-4 inhibitors
C: SGLT-2 inhibitors
D: Meglitinides

Reference Answer:
c- SGLT-2 inhibitors

Generated Text:
The correct answer is **C: SGLT-2 inhibitors**. Here's why:

*   **SGLT-2 inhibitors (Sodium-Glucose Co-transporter 2 inhibitors):** These medications work by blocking the reabsorption of glucose in the kidneys, leading to increased glucose excretion in the urine. They have been shown to be beneficial in patients with both PAD and type 2 diabetes, potentially improving cardiovascular outcomes.

Here's why the other options are less suitable:

*   **Sulfonylureas:** These medications stimulate insulin release from the pancreas. While they can effectively lower blood sugar, they have not been shown to have significant benefits in patients with PAD and are not generally recommended as a first-line treatment.
*   **DPP-4 inhibitors (Dipep

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/449 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/701 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/6.11G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.95G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Model loaded successfully.


Evaluating JSL-MedLlama-3-8B:   0%|          | 0/35 [00:00<?, ?it/s]


--- First Output for Model: JSL-MedLlama-3-8B (Prompt ID: Med-2692) ---
Question:
For patients with PAD and type 2 diabetes, the

A: Sulfonylureas
B: DPP-4 inhibitors
C: SGLT-2 inhibitors
D: Meglitinides

Reference Answer:
c- SGLT-2 inhibitors

Generated Text:
C: SGLT-2 inhibitors
--------------------------------------------------

  [Log at sample 8/280] Model: JSL-MedLlama-3-8B
    > Current Slice Counts: MCQs=1, Open-Ended=0, Bait=0
    > Last 1 samples | Avg MCQ Acc: 1.00 | Avg BERT_F1: nan | Avg ROUGE_L: nan | Avg Refusal: nan

  [Log at sample 16/280] Model: JSL-MedLlama-3-8B
    > Current Slice Counts: MCQs=1, Open-Ended=0, Bait=0
    > Last 1 samples | Avg MCQ Acc: 1.00 | Avg BERT_F1: nan | Avg ROUGE_L: nan | Avg Refusal: nan

  [Log at sample 24/280] Model: JSL-MedLlama-3-8B
    > Current Slice Counts: MCQs=0, Open-Ended=1, Bait=0
    > Last 1 samples | Avg MCQ Acc: nan | Avg BERT_F1: 0.88 | Avg ROUGE_L: 0.00 | Avg Refusal: nan

  [Log at sample 32/280] Model: JSL-MedLlama-3-

### Medichat-Llama3-8B Results

In [None]:
MODELS_TO_EVALUATE = [
    # {"name": "II-Medical-8B", "model_id": "Intelligent-Internet/II-Medical-8B"},
    # {"name": "Bio-Medical-Llama-3-8B", "model_id": "ContactDoctor/Bio-Medical-Llama-3-8B-CoT-012025"},
    # {"name": "MedGemma-4B", "model_id": "google/medgemma-4b-it"},
    # {"name": "JSL-MedLlama-3-8B", "model_id": "johnsnowlabs/JSL-MedLlama-3-8B-v2.0"},
    # {"name": "Qwen2-7B", "model_id": "Qwen/Qwen2-7B-Instruct"},
    {"name": "Medichat-Llama3-8B", "model_id": "sethuiyer/Medichat-Llama3-8B"}
]

In [None]:
BATCH_SIZE = 8 # @param {type:"integer"}
USE_4BIT_QUANTIZATION = True # @param {type:"boolean"}
LOG_EVERY_N_SAMPLES = 8 # @param {type:"integer"}

evaluate_model(
    models_to_evaluate=MODELS_TO_EVALUATE,
    dev_set=DEV_SET,
    log_every_n_samples=LOG_EVERY_N_SAMPLES,
    batch_size=BATCH_SIZE,
    use_4bit_quantization=USE_4BIT_QUANTIZATION
)


Evaluating Model: Medichat-Llama3-8B (sethuiyer/Medichat-Llama3-8B)
Loading model and tokenizer/processor...
  > NOTE: Loading with 4-bit quantization.


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/449 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/708 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]

model-00003-of-00009.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

model-00006-of-00009.safetensors:   0%|          | 0.00/1.95G [00:00<?, ?B/s]

model-00008-of-00009.safetensors:   0%|          | 0.00/1.31G [00:00<?, ?B/s]

model-00005-of-00009.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

model-00002-of-00009.safetensors:   0%|          | 0.00/1.90G [00:00<?, ?B/s]

model-00007-of-00009.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

model-00004-of-00009.safetensors:   0%|          | 0.00/1.95G [00:00<?, ?B/s]

model-00001-of-00009.safetensors:   0%|          | 0.00/1.97G [00:00<?, ?B/s]

model-00009-of-00009.safetensors:   0%|          | 0.00/1.05G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/9 [00:00<?, ?it/s]

Model loaded successfully.


Evaluating Medichat-Llama3-8B:   0%|          | 0/35 [00:00<?, ?it/s]


--- First Output for Model: Medichat-Llama3-8B (Prompt ID: Med-2692) ---
Question:
For patients with PAD and type 2 diabetes, the

A: Sulfonylureas
B: DPP-4 inhibitors
C: SGLT-2 inhibitors
D: Meglitinides

Reference Answer:
c- SGLT-2 inhibitors

Generated Text:
I am Medilearn, an expert AI anatomy tutor for a VR application. I am sorry, but I can only answer questions related to human anatomy and medicine. The correct answer is C: SGLT-2 inhibitors.

SGLT-2 inhibitors are a class of medications that have been shown to improve glycemic control in patients with type 2 diabetes. In addition, SGLT-2 inhibitors have been found to have beneficial effects on cardiovascular outcomes, including a reduction in the incidence of hospitalization for heart failure and a reduction in the incidence of cardiovascular death. The cardiovascular benefits of SGLT-2 inhibitors are thought to be related to their ability to reduce blood pressure and to improve left ventricular function.

Always consult a qua

### Qwen2-7B Results

In [None]:
MODELS_TO_EVALUATE = [
    # {"name": "II-Medical-8B", "model_id": "Intelligent-Internet/II-Medical-8B"},
    # {"name": "Bio-Medical-Llama-3-8B", "model_id": "ContactDoctor/Bio-Medical-Llama-3-8B-CoT-012025"},
    # {"name": "MedGemma-4B", "model_id": "google/medgemma-4b-it"},
    # {"name": "JSL-MedLlama-3-8B", "model_id": "johnsnowlabs/JSL-MedLlama-3-8B-v2.0"},
    {"name": "Qwen2-7B", "model_id": "Qwen/Qwen2-7B-Instruct"},
    # {"name": "Medichat-Llama3-8B", "model_id": "sethuiyer/Medichat-Llama3-8B"}
]

In [None]:
BATCH_SIZE = 8 # @param {type:"integer"}
USE_4BIT_QUANTIZATION = True # @param {type:"boolean"}
LOG_EVERY_N_SAMPLES = 8 # @param {type:"integer"}

evaluate_model(
    models_to_evaluate=MODELS_TO_EVALUATE,
    dev_set=DEV_SET,
    log_every_n_samples=LOG_EVERY_N_SAMPLES,
    batch_size=BATCH_SIZE,
    use_4bit_quantization=USE_4BIT_QUANTIZATION
)


Evaluating Model: Qwen2-7B (Qwen/Qwen2-7B-Instruct)
Loading model and tokenizer/processor...
  > NOTE: Loading with 4-bit quantization.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/663 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.95G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.56G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/243 [00:00<?, ?B/s]

Model loaded successfully.


Evaluating Qwen2-7B:   0%|          | 0/35 [00:00<?, ?it/s]


--- First Output for Model: Qwen2-7B (Prompt ID: Med-2692) ---
Question:
For patients with PAD and type 2 diabetes, the

A: Sulfonylureas
B: DPP-4 inhibitors
C: SGLT-2 inhibitors
D: Meglitinides

Reference Answer:
c- SGLT-2 inhibitors

Generated Text:
C: SGLT-2 inhibitors

SGLT-2 inhibitors are recommended for patients with PAD ( Peripheral Artery Disease) and type 2 diabetes as they have been shown to improve cardiovascular outcomes and reduce the risk of hospitalization for heart failure. They work by increasing the excretion of glucose in the urine, which can help manage blood sugar levels.

Always consult a qualified healthcare professional for medical advice.
--------------------------------------------------

  [Log at sample 8/280] Model: Qwen2-7B
    > Current Slice Counts: MCQs=1, Open-Ended=0, Bait=0
    > Last 1 samples | Avg MCQ Acc: 1.00 | Avg BERT_F1: nan | Avg ROUGE_L: nan | Avg Refusal: nan

  [Log at sample 16/280] Model: Qwen2-7B
    > Current Slice Counts: MCQs=1, O

### Bio-Medical-Llama-3-8B Results

In [None]:
MODELS_TO_EVALUATE = [
    # {"name": "II-Medical-8B", "model_id": "Intelligent-Internet/II-Medical-8B"},
    {"name": "Bio-Medical-Llama-3-8B", "model_id": "ContactDoctor/Bio-Medical-Llama-3-8B-CoT-012025"},
    # {"name": "MedGemma-4B", "model_id": "google/medgemma-4b-it"},
    # {"name": "JSL-MedLlama-3-8B", "model_id": "johnsnowlabs/JSL-MedLlama-3-8B-v2.0"},
    # {"name": "Qwen2-7B", "model_id": "Qwen/Qwen2-7B-Instruct"},
    # {"name": "Medichat-Llama3-8B", "model_id": "sethuiyer/Medichat-Llama3-8B"}
]

BATCH_SIZE = 8 # @param {type:"integer"}
USE_4BIT_QUANTIZATION = True # @param {type:"boolean"}
LOG_EVERY_N_SAMPLES = 8 # @param {type:"integer"}

evaluate_model(
    models_to_evaluate=MODELS_TO_EVALUATE,
    dev_set=DEV_SET,
    log_every_n_samples=LOG_EVERY_N_SAMPLES,
    batch_size=BATCH_SIZE,
    use_4bit_quantization=USE_4BIT_QUANTIZATION
)


Evaluating Model: Bio-Medical-Llama-3-8B (ContactDoctor/Bio-Medical-Llama-3-8B-CoT-012025)
Loading model and tokenizer/processor...
  > NOTE: Loading with 4-bit quantization.


tokenizer_config.json:   0%|          | 0.00/3.06k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.08M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/946 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Model loaded successfully.


Evaluating Bio-Medical-Llama-3-8B:   0%|          | 0/35 [00:00<?, ?it/s]


--- First Output for Model: Bio-Medical-Llama-3-8B (Prompt ID: Med-2692) ---
Question:
For patients with PAD and type 2 diabetes, the

A: Sulfonylureas
B: DPP-4 inhibitors
C: SGLT-2 inhibitors
D: Meglitinides

Reference Answer:
c- SGLT-2 inhibitors

Generated Text:
<think>
Okay, so I need to figure out which medications are used for patients with PAD and type 2 diabetes. Let me start by recalling what each of these drugs does.

First, I remember that PAD stands for peripheral artery disease. It's related to the blood supply in the legs, especially in the lower extremities. So, the medications here must be dealing with something like blood flow or reducing complications related to that.

Type 2 diabetes makes me think about managing blood sugar. But wait, the question is about PAD, so maybe it's more about the cardiovascular aspects. I think drugs that affect the heart or blood flow might be involved.

Looking at the options: A is sulfonylureas. I know sulfonylureas are for managing bl

In [None]:
def print_prompt_by_id(prompt_id, dataset, bait_questions):
    """Looks up and prints the prompt text and its MCQ flag for a given ID."""
    try:
        if prompt_id.startswith("Med-"):
            index = int(prompt_id.split('-')[1])
            item = dataset['validation'][index]

            # Extract prompt and reference answer from the item
            messages = item.get('messages', [])
            prompt, ref_answer = "", ""
            for msg in messages:
                if msg.get('role') == 'user': prompt = msg.get('content')
                elif msg.get('role') == 'assistant': ref_answer = msg.get('content')

            if not prompt: prompt = item.get('prompt')
            if not ref_answer: ref_answer = item.get('completion')

            # Use the parser to see how it's flagged
            structured_info = _parse_and_structure_item(prompt, ref_answer)
            is_mcq_flag = structured_info.get("is_mcq", False)

            print(f"--- Prompt for {prompt_id} ---")
            print(f"IS_MCQ Flag: {is_mcq_flag}")
            print(f"Prompt Text:\n{prompt}\n")

        elif prompt_id.startswith("Bait-"):
            index = int(prompt_id.split('-')[1])
            prompt = bait_questions[index]
            print(f"--- Prompt for {prompt_id} ---")
            print("IS_MCQ Flag: False (Bait Question)")
            print(f"Prompt Text:\n{prompt}\n")
        else:
            print(f"Unknown ID format: {prompt_id}")

    except (IndexError, ValueError) as e:
        print(f"Could not find or parse ID: {prompt_id}. Error: {e}")

In [None]:
prompt_ids_to_check = ["Med-1823", "Med-2279", "Med-1698"]
for pid in prompt_ids_to_check:
    print_prompt_by_id(pid, ds, BAIT_QUESTIONS)

--- Prompt for Med-1823 ---
IS_MCQ Flag: False
Prompt Text:
Physiologic compensatory mechanisms that minimize V/Q mismatching include :/nA. High alveolar PO2 causes bronchoconstriction./nB. High alveolar PCO2 causes bronchoconstriction./nC. High alveolar PO2 causes bronchodilatation./nD. High alveolar PCO2 causes bronchodilatation.

--- Prompt for Med-2279 ---
IS_MCQ Flag: True
Prompt Text:
Living antigen is used in A. Leishmanin skin test, B. Sabin Feldman dye test, C. Card Agglutination trypansomiasis test, D. IgG avidity test

--- Prompt for Med-1698 ---
IS_MCQ Flag: True
Prompt Text:
According to Bernoulli's principle as applied to fluid dynamics in blood vessels: A. The sum of potential energy and kinetic energy remains constant along a streamline (assuming no energy loss). B. Kinetic energy typically decreases as blood flows through an atherosclerotic (narrowed) vessel segment. C. Potential energy and kinetic energy are inversely related; where velocity (kinetic energy) is high, 