# Finetuning

I'm going to be fintuning Qwen2-7B using Unsloth GRPO

Open in [Colab](https://colab.research.google.com/drive/1x9rrEn2_c-I-V4ThzdYqL2dwBSJJ8USA?usp=sharing)

### Installation

In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    !pip install --no-deps unsloth vllm==0.8.5.post1

In [None]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    !pip install --no-deps unsloth vllm==0.8.5.post1
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer

    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy
    f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
    with open("vllm_requirements.txt", "wb") as file:
        file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
    !pip install -r vllm_requirements.txt

In [None]:
#@title Libraries Extra Install { display-mode: "form" }
!pip install bert_score

Collecting bert_score
  Downloading bert_score-0.3.13-py3-none-any.whl.metadata (15 kB)
Downloading bert_score-0.3.13-py3-none-any.whl (61 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bert_score
Successfully installed bert_score-0.3.13


# Setup Qwen

In [None]:
from unsloth import FastLanguageModel, is_bfloat16_supported
import torch
import os
import gc
import transformers
import pandas as pd
from datasets import load_dataset, Dataset, DatasetDict
from huggingface_hub import hf_hub_download
import json
import random
import time
import re
import numpy as np
import warnings
import bert_score

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 07-05 09:22:29 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 07-05 09:22:29 [__init__.py:239] Automatically detected platform cuda.


In [None]:
max_seq_length = 1500 # Can increase for longer reasoning traces
lora_rank = 64 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "Qwen/Qwen2-7B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.8, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank,
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

==((====))==  Unsloth 2025.6.12: Fast Qwen2 patching. Transformers: 4.53.0. vLLM: 0.8.5.post1.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 1. Max memory: 39.557 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.0. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/qwen2-7b-instruct-bnb-4bit with actual GPU utilization = 79.08%
Unsloth: Your GPU has CUDA compute capability 8.0 with VRAM = 39.56 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 1500. Num Sequences = 320.
Unsloth: vLLM's KV Cache can use up to 25.15 GB. Also swap space = 6 GB.
INFO 07-05 04:46:57 [config.py:717] This model supports multiple tasks: {'classify', 'generate', 'score', 'reward', 'embed'}. Defaulting to 'generate'.
INFO 07-05 04:46:57 [config.py

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 07-05 04:47:03 [punica_selector.py:18] Using PunicaWrapperGPU.
INFO 07-05 04:47:03 [gpu_model_runner.py:1347] Model loading took 5.5142 GiB and 4.292228 seconds
INFO 07-05 04:47:20 [backends.py:420] Using cache directory: /root/.cache/vllm/torch_compile_cache/e581ad8dfd/rank_0_0 for vLLM's torch.compile
INFO 07-05 04:47:20 [backends.py:430] Dynamo bytecode transform time: 16.26 s
INFO 07-05 04:47:30 [backends.py:118] Directly load the compiled graph(s) for shape None from the cache, took 7.210 s
INFO 07-05 04:47:33 [monitor.py:33] torch.compile takes 16.26 s in total
INFO 07-05 04:47:34 [kv_cache_utils.py:634] GPU KV cache size: 439,360 tokens
INFO 07-05 04:47:34 [kv_cache_utils.py:637] Maximum concurrency for 1,500 tokens per request: 292.91x
INFO 07-05 04:48:43 [gpu_model_runner.py:1686] Graph capturing finished in 69 secs, took 1.33 GiB
INFO 07-05 04:48:43 [core.py:159] init engine (profile, create kv cache, warmup model) took 99.44 seconds
Unsloth: Just some info: will skip pa

Unsloth 2025.6.12 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.


# Data Prep

In [None]:
DATASET_REPO_ID = "Anatomy-Tutor/Anatomy-and-Medical-Dataset"
DATASET_FILENAME = "processed_medical_and_anatomy.json"

In [None]:
# Load the dataset
print("Loading dataset from Hugging Face Hub...")
try:
    hf_data_path = hf_hub_download(
        repo_id=DATASET_REPO_ID,
        filename=DATASET_FILENAME,
        repo_type="dataset"
    )
    with open(hf_data_path, "r", encoding="utf-8") as f:
        splits = json.load(f)
    ds = DatasetDict({
        "train": Dataset.from_list(splits["train"]),
        "validation": Dataset.from_list(splits["validation"]),
        "test": Dataset.from_list(splits["test"]),
    })
except Exception as e:
    print(f"Failed to load or process dataset. Error: {e}")

Loading dataset from Hugging Face Hub...


In [None]:
specific_refusal_phrase = "I am sorry, but I can only answer questions related to human anatomy and medicine."

SYSTEM_PROMPT = """
You are "Medilearn," an expert AI anatomy tutor for a VR application. Your goal is to provide clear, accurate, and educational explanations.

**Rules:**
1.  **Stay On Topic:** Politely refuse any question not related to human anatomy or medicine. When you refuse a question, you MUST begin your response with the exact phrase: "{specific_refusal_phrase}"
2.  **Be Direct and Unambiguous:** Provide answers that are clear and to the point. Avoid hedging or overly conversational filler.
3.  **MCQ Answering:** For multiple-choice questions, start your response by stating the correct letter or number, followed by a colon and then your brief explanation. For example: "A: This is the explanation."
4.  **End Your Turn:** After providing the complete answer and the mandatory safety warning, you MUST output the special token `<|end_of_turn|>`.
5.  **Safety First:** Your final sentence before the end-of-turn token must be: "Always consult a qualified healthcare professional for medical advice."

**Formatting Requirement:**
Respond in the following format:
<reasoning>
... your explanation goes here ...
</reasoning>
<answer>
... final answer here ...
</answer>
"""


XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

In [None]:
# Common refusal templates
refusal_responses = [
    f"{specific_refusal_phrase} My focus is strictly on human health and body systems.",
    f"{specific_refusal_phrase} Therefore, I cannot provide information on this topic.",
    f"{specific_refusal_phrase} My knowledge base is specialized in the medical field.",
    f"{specific_refusal_phrase} I'm designed to assist with anatomical and medical inquiries only.",
    f"{specific_refusal_phrase} This question falls outside my area of expertise.",
    f"{specific_refusal_phrase} I recommend seeking information from a source specialized in that subject.",
    f"{specific_refusal_phrase} My purpose is to educate on human anatomy and medicine.",
    f"{specific_refusal_phrase} I'm unable to discuss non-health related matters.",
    f"{specific_refusal_phrase} Please ensure your questions pertain to the human body or medical science.",
    f"{specific_refusal_phrase} I can only engage with topics within the scope of human anatomy and medicine."
]

# Sample non-medical questions
refusal_questions = [
    "What's the capital of France?",
    "Tell me a joke.",
    "Who won the World Cup in 2018?",
    "How do I fix my car's engine?",
    "What's the weather like today?",
    "How do I start a business?",
    "Who painted the Mona Lisa?",
    "What year did the Titanic sink?",
    "Which Roman emperor made Christianity the state religion?",
    "What is the highest mountain in Africa?",
    "What is the currency of Japan?",
    "When was the Declaration of Independence signed?",
    "Who was the first person to walk on the moon?",
    "What is the longest river in the world?",
    "Which country is famous for the Eiffel Tower?",
    "What is the smallest continent by land area?",
    "How does Wi-Fi work?",
    "What is an algorithm?",
    "Explain the concept of cloud computing.",
    "What does CPU stand for?",
    "How do I clear my computer's cache?",
    "What's the difference between RAM and ROM?",
    "How can I improve my phone's battery life?",
    "What is cybersecurity?",
    "What programming language is used for web development?",
    "How do search engines rank websites?",
    "Who wrote 'Romeo and Juliet'?",
    "What is a sonnet?",
    "Name a famous opera composer.",
    "What is the primary art form of ballet?",
    "Who composed the 'Moonlight Sonata'?",
    "What is the meaning of 'carpe diem'?",
    "Name a dystopian novel.",
    "What is the difference between prose and poetry?",
    "Who is considered the father of English literature?",
    "What is impressionism in art?",
    "How do I bake a chocolate cake?",
    "What are the rules of chess?",
    "How to change a car tire?",
    "What's the best way to grow tomatoes?",
    "How do I knit a scarf?",
    "What are some good tips for budgeting money?",
    "How to train a puppy?",
    "What are common misconceptions about sleep?",
    "How do you compost kitchen waste?",
    "What are the basic steps to learning a new language?",
    "What is the meaning of life?",
    "Define 'justice.'",
    "What is free will?",
    "Explain the concept of infinity.",
    "What is the purpose of art?",
    "How do we know what is real?",
    "What is happiness?",
    "Discuss the ethics of artificial intelligence.",
    "What is the nature of time?",
    "What is the role of government in society?"
]

In [None]:
def make_refusal_example(question):
    return {
        "messages": [
            {"role": "user", "content": question},
            {"role": "assistant", "content": random.choice(refusal_responses)},
        ]
    }

# Generate examples
new_refusal_examples = [make_refusal_example(q) for q in refusal_questions]

In [None]:
# Add to the existing train dataset
original_train_data = ds["train"].to_list()
combined_data = original_train_data + new_refusal_examples

# Convert back to HF Dataset
ds["train"] = Dataset.from_list(combined_data)


In [None]:
def prepare_dataset(initial_dataset: Dataset) -> list[dict]:
    """
    Reformats an already loaded dataset with a 'messages' structure for GRPO training,
    returning a list of dictionaries. Each dictionary will contain a 'prompt'
    (in chatML format) and a 'reference_answer'.

    Args:
        initial_dataset (Dataset): The already loaded dataset (e.g., a specific split like train or validation).

    Returns:
        list[dict]: A list of dictionaries, where each dictionary represents a reformatted data sample.
    """
    formatted_data = []
    # Iterate through the dataset to create a list of dictionaries
    for item in initial_dataset:
        user_prompt_content = None
        assistant_answer_content = None

        # Extract user prompt and assistant answer from the 'messages' list
        # This assumes 'messages' is a list of dictionaries within each item.
        for message in item.get("messages", []):
            if message.get("role") == "user":
                user_prompt_content = message.get("content")
            elif message.get("role") == "assistant":
                assistant_answer_content = message.get("content")

        # Ensure both user prompt and assistant answer are found
        if user_prompt_content is not None and assistant_answer_content is not None:
            # Create the prompt in chatML format, including the general SYSTEM_PROMPT
            prompt_messages = [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": user_prompt_content},
            ]

            # Create a dictionary for the current item
            formatted_item = {
                "prompt": prompt_messages,
                "answer": assistant_answer_content,
            }
            formatted_data.append(formatted_item)
        else:
            print(f"Skipping item due to missing user prompt or assistant answer in messages: {item}")

    return formatted_data

In [None]:
my_prepared_dataset = prepare_dataset(initial_dataset=ds["train"])
print(my_prepared_dataset[0])



# Reward Functions


In [None]:
reasoning_start = "<reasoning>"
reasoning_end   = "</reasoning>"
solution_start  = "<answer>"
solution_end    = "</answer>"

In [None]:
match_format = re.compile(
    rf"^[\s]{{0,}}"\
    rf"{reasoning_start}.+?{reasoning_end}.*?"\
    rf"{solution_start}(.+?){solution_end}"\
    rf"[\s]{{0,}}$",
    flags = re.MULTILINE | re.DOTALL
)

In [None]:
def extract_xml_answer(text: str) -> str:
    match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
    if match:
        return match.group(1).strip()
    return ""

def extract_xml_section(text: str, section: str) -> str:
    m = re.search(f"<{section}>(.*?)</{section}>", text, re.DOTALL)
    return m.group(1).strip() if m else ""


def answer_semantic_similarity(completions, prompts, answer, **kwargs):
    """
    Rewards based on the semantic similarity (BERTScore F1) between
    the model's generated answer and the reference answer.
    Top priority: scale BERTScore F1 (0–1) up to [0, +10].
    """
    import inspect

    # Try to extract current step
    step = kwargs.get("step", None)
    if step is None:
        for frame_info in inspect.stack():
            frame = frame_info.frame
            if "self" in frame.f_locals and hasattr(frame.f_locals["self"], "state"):
                step = getattr(frame.f_locals["self"].state, "global_step", None)
                break

    # 1) Extract model‐generated answers
    responses = [c[0]["content"] for c in completions]
    preds = [extract_xml_answer(r) for r in responses]

    # 2) Debug print ONLY if step was found and is a multiple of 10
    if step is not None and step % 100 == 0:
        print("\n" + "=" * 40)
        print(f"[Step {step}] Semantic Similarity Debug")
        for i, (prompt, r, p) in enumerate(zip(prompts, responses, preds)):
            print(f"Prompt:\n{prompt[-1]['content']}")
            print(f"Model Output:\n{r}")
            print(f"Extracted Answer:\n{p}")
            print("-" * 30)
        print("=" * 40 + "\n")

    # 3) Align with references
    refs = answer[:len(preds)]

    # 4) Compute BERTScore
    P, R, F1 = bert_score.score(
        preds,
        refs,
        lang="en",
        rescale_with_baseline=True
    )

    # 5) Scale each F1 value by 10
    return [f1 * 10.0 for f1 in F1.tolist()]



def match_format_exactly(completions, **kwargs):
    """
    Rewards for exact adherence to the specified XML output format.
    High priority: exact format match → +8, otherwise 0.
    """
    scores = []
    for comp in completions:
        scores.append(8.0 if match_format.search(comp[0]["content"]) else 0.0)
    return scores


def match_format_approximately(completions, **kwargs):
    """
    Rewards for approximate adherence to the XML output format by counting tags.
    Medium priority: each correct tag +1, each missing/extra –1.
    Total range: 4 tags → [–4, +4]
    """
    scores = []
    for comp in completions:
        r = comp[0]["content"]
        score = 0
        score += 1.0 if r.count(reasoning_start) == 1 else -1.0
        score += 1.0 if r.count(reasoning_end) == 1 else -1.0
        score += 1.0 if r.count(solution_start) == 1 else -1.0
        score += 1.0 if r.count(solution_end) == 1 else -1.0
        scores.append(score)
    return scores


def on_topic_refusal(completions, **kwargs):
    """
    Rewards for staying on topic and penalizes for out-of-bounds refusals
    (i.e., if the model incorrectly refuses a relevant question).
    Conversely, for *out-of-scope* questions, it effectively rewards a proper refusal (by not penalizing).
    Medium-low priority: good answer +2, penalty (refusal) 0.
    """
    bad_kw = [
        "i don't know", "dont know", "not sure", "no idea", "beyond my knowledge",
        "outside my knowledge", "off-topic", "irrelevant", "i have no information",
        "i'm sorry", "i am sorry", "apologies", "i apologize", "my apologies",
        "cannot answer", "can’t answer", "unable to answer", "cannot provide",
        "outside my scope", "knowledge cutoff", "cutoff date",
    ]
    scores = []
    for comp in completions:
        ans = extract_xml_section(comp[0]["content"], "answer").lower()
        scores.append(0.0 if any(k in ans for k in bad_kw) else 2.0)
    return scores


def response_pacing_and_length(completions, **kwargs):
    """
    Rewards for the length and pacing of the reasoning and answer sections.
    Low priority: raw pacing score [0,1], scale to [0, 1.5].
    """
    scores = []
    for comp in completions:
        txt = comp[0]["content"]
        r = extract_xml_section(txt, "reasoning").split()
        a = extract_xml_section(txt, "answer").split()
        mid_r, mid_a = (20 + 300) / 2, (5 + 200) / 2
        sc_r = max(0, 1 - abs(len(r) - mid_r) / (300 - 20))
        sc_a = max(0, 1 - abs(len(a) - mid_a) / (200 - 5))
        # average [0,1], then *1.5
        scores.append(((sc_r + sc_a) / 2) * 1.5)
    return scores


def disclaimer_presence(completions, **kwargs):
    """
    Rewards for including the mandatory safety disclaimer as the final sentence
    within the <answer> tag. Penalizes if missing.
    Medium priority: +3 for exact final placement, +2 for presence, -1 for missing.
    """
    scores = []
    mandatory_disclaimer = "Always consult a qualified healthcare professional for medical advice."

    for comp in completions:
        assistant_content = comp[0]["content"]
        extracted_answer = extract_xml_answer(assistant_content)
        cleaned_answer = extracted_answer.replace("<|end_of_turn|>", "").strip()

        sentences = [s.strip() for s in re.split(r'(?<=[.!?])\s+', cleaned_answer) if s.strip()]

        if sentences and sentences[-1] == mandatory_disclaimer:
            scores.append(3.0)
        elif mandatory_disclaimer in cleaned_answer:
            scores.append(2.0)
        else:
            scores.append(-1)
    return scores

In [None]:
reward_funcs = [
    answer_semantic_similarity, # [0,10]
    match_format_exactly, # [0,8]
    match_format_approximately, # [–4,+4]
    on_topic_refusal, # [0,2]
    response_pacing_and_length, # [0,1.5]
    disclaimer_presence # [-1, 3]
]


# Training

In [None]:
max_prompt_length = 500

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    use_vllm = True, # use vLLM for fast inference!
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    logging_steps = 5,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 4, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 500,
    save_steps = 10,
    max_grad_norm = 0.1,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",

    loss_type = "dr_grpo",
    epsilon = 0.2,
    epsilon_high = 0.28,
    delta = 1.5,
    mask_truncated_completions = True,

)

Unsloth: The Dr GRPO paper recommends setting `scale_rewards` to False! Will override. Set it to `None` to force False.
Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 4


In [None]:
# Initialize trainer
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = reward_funcs,
    args = training_args,
    train_dataset = my_prepared_dataset,
)

# Start training
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 21,724 | Num Epochs = 1 | Total steps = 500
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 1 x 1) = 4
 "-____-"     Trainable parameters = 161,480,704 of 7,000,000,000 (2.31% trained)



[Step 0] Semantic Similarity Debug
Prompt:
A substance that is freely filtered, completely reabsorbed and not secreted has a renal plasma clearance:/na) Equal to GFR./nb) Equal to effective renal plasma flow./nc) Higher than GFR./nd) Equal to zero.
Model Output:
<reasoning>When a substance is freely filtered, completely reabsorbed, and not secreted by the kidneys, it indicates that the substance does not undergo any net loss once filtered from the glomerular filtrate. The substance exits the nephrons unchanged, suggesting no net alteration in plasma substance concentration before it's filtered compared to reabsorbed. This scenario directly correlates with a renal plasma clearance of zero, as there's no net loss of the substance to gauge the clearance rate.</reasoning>
<answer>nd) Equal to zero.</answer>
<end_of_turn> Always consult a qualified healthcare professional for medical advice.
Extracted Answer:
nd) Equal to zero.
------------------------------
Prompt:
A substance that is fre

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,kl,rewards / answer_semantic_similarity / mean,rewards / answer_semantic_similarity / std,rewards / match_format_exactly / mean,rewards / match_format_exactly / std,rewards / match_format_approximately / mean,rewards / match_format_approximately / std,rewards / on_topic_refusal / mean,rewards / on_topic_refusal / std,rewards / response_pacing_and_length / mean,rewards / response_pacing_and_length / std,rewards / disclaimer_presence / mean,rewards / disclaimer_presence / std
5,-0.8297,-27.136355,29.232609,148.5,84.0,219.8,0.0,148.5,84.0,219.8,0.000619,-30.950413,23.54901,2.4,3.447521,-0.4,3.08956,2.0,0.0,0.814059,0.122864,-1.0,0.0
10,-0.208,-42.028995,12.4705,75.95,29.0,127.2,0.0,75.95,29.0,127.2,0.001458,-41.431081,11.177933,0.0,0.0,-2.3,1.858506,2.0,0.0,0.702088,0.041827,-1.0,0.0
15,-0.5731,-32.782944,27.58119,179.25,125.2,244.4,0.0,179.25,125.2,244.4,0.000962,-34.346278,21.688522,1.6,2.52376,-2.0,3.058762,2.0,0.0,0.763334,0.143301,-0.8,0.4
20,-0.3793,-41.772778,18.809571,95.4,34.0,167.6,0.0,95.4,34.0,167.6,0.001491,-41.378151,15.745195,0.4,0.8,-2.5,2.424621,2.0,0.0,0.705374,0.044788,-1.0,0.0
25,0.3213,-35.128761,19.197273,170.7,108.0,254.4,0.0,170.7,108.0,254.4,0.001181,-36.891934,15.216666,1.2,0.8,-1.6,2.850008,2.0,0.0,0.763173,0.066522,-0.6,0.46188
30,-0.3077,-31.664046,28.692642,96.35,55.8,152.2,0.0,96.35,55.8,152.2,0.001112,-33.374114,23.088895,1.6,2.52376,-1.8,3.148382,2.0,0.0,0.710069,0.045857,-0.8,0.4
35,-0.5595,-26.055703,27.024432,111.55,55.0,185.6,0.0,111.55,55.0,185.6,0.002084,-28.294968,22.521996,1.6,1.6,-1.1,3.190563,2.0,0.0,0.739265,0.07093,-1.0,0.0
40,-0.1257,-28.904487,25.887299,158.0,99.0,251.8,0.0,158.0,99.0,251.8,0.002159,-31.466403,20.787637,1.6,2.52376,-1.0,2.86051,2.0,0.0,0.761916,0.120953,-0.8,0.4
45,-0.2755,-30.653377,20.535104,146.0,84.4,205.8,0.0,146.0,84.4,205.8,0.027291,-34.0776,15.990381,2.0,2.647521,-0.4,2.421503,2.0,0.0,0.824224,0.089053,-1.0,0.0
50,-0.4234,-18.695454,32.491791,142.4,71.4,237.8,0.0,142.4,71.4,237.8,0.00486,-22.86705,26.883811,2.4,3.2,0.0,3.400571,2.0,0.0,0.771597,0.114168,-1.0,0.0


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Unsloth: Will smartly offload gradients to save VRAM!


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You sho


[Step 100] Semantic Similarity Debug
Prompt:
Sites & types of teratomas (mature teratoma & immature teratoma)
Model Output:
<reasoning>
Teratomas, also known as成熟的畸胎瘤, are complex tumors that can contain a mixture of tissue from two or more germ layers that developed during embryonic life. They typically consist of living tissues that can be derived from all three germ layers: ectoderm, mesoderm, and endoderm.

There are two main categories of teratomas:

1. Mature Teratomas:
   Mature teratomas consist mostly of developed tissue that resembles typical adult structures. These tumors usually contain mature cell organizations like skin, hair, teeth, bone, cartilage, and often neural tissue. Due to their development into recognizable adult structures, the prognosis for mature teratomas is typically positive, and they can often be successfully removed surgically.

2. Immature Teratomas:
   Immature teratomas, on the other hand, contain less differentiated tissues and may include embryonic

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You sho


[Step 200] Semantic Similarity Debug
Prompt:
Examples of: Gene amplification
Model Output:
<reasoning>Gene amplification refers to the process by which a particular gene or a set of genes is duplicated within the genome. This process increases the number of gene copies within a cell. Genetic amplification typically occurs in two main contexts:
1. During DNA replication, errors can lead to the duplication of specific gene segments.
2. In cancer cells, gene amplification can occur as a mechanism to enhance the expression of genes that promote cell growth, survival, or metabolism.

Gene amplification is often studied in the field of oncology, particularly in understanding cancer mechanisms and developing targeted therapies.</reasoning>
<answer>Gene amplification</answer>
</answer>
Extracted Answer:
Gene amplification
------------------------------
Prompt:
Examples of: Gene amplification
Model Output:
<reasoning>
Gene amplification is a type of genetic event where multiple copies of a spe

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You sho


[Step 300] Semantic Similarity Debug
Prompt:
Centro acinar cells , all true except: a. Beginning of pancreatic duct system. b. Secrete bicarbonate ion ( with intercalated ducts) c. Its secretion is controlled by secretin hormone d. Beginning of biliary duct system
Model Output:
<reasoning>
Centro acinar cells are specialized cells found in the pancreas, specifically in the acinar units. They are involved in the secretion of digestive enzymes. Let's evaluate each statement:

a. Conditional acinar cells indeed constitute a part of the pancreatic secretion system, but the entire duct system is initiated by ductal cells, not centro acinar cells. This statement is incorrect.

b. Centro acinar cells do secrete bicarbonate ions, alongside ductal cells, which forms part of their function to neutralize acidic food in the stomach before it enters the intestine. This statement is correct.

c. The control of secretion by the centro acinar cells is indeed regulated by the hormone secretin, along w

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You sho


[Step 400] Semantic Similarity Debug
Prompt:
Nerve that runs between trachea & oesophagus in superior mediastinum (A) Rt phrenic (B) Lt vagus (C) Rt recurrent laryngeal (D) Lt recurrent laryngeal
Model Output:
<reasoning>
In the superior mediastinum, between the trachea and oesophagus, the Lt vagus nerve (C) is the correct answer. The Lt vagus nerve enters the neck above the manubrium sterni and passes behind the trachea and oesophagus before dividing into two parts: the external and internal laryngeal nerves. The external laryngeal nerve travels down from the ventral surface of the neck, while the internal laryngeal nerve runs anterolaterally to the trachea and oesophagus.</reasoning>
<answer>
(C) Lt recurrent laryngeal</answer>

Always consult a qualified healthcare professional for medical advice.
Extracted Answer:
(C) Lt recurrent laryngeal
------------------------------
Prompt:
Nerve that runs between trachea & oesophagus in superior mediastinum (A) Rt phrenic (B) Lt vagus (C) Rt

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You sho

TrainOutput(global_step=500, training_loss=-0.03930371666606516, metrics={'train_runtime': 6505.6981, 'train_samples_per_second': 0.307, 'train_steps_per_second': 0.077, 'total_flos': 0.0, 'train_loss': -0.03930371666606516})

<a name="Inference"></a>
# Inference
Now let's try the model we just trained! First, let's first try the model without any GRPO trained:

In [None]:
prompt = """Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>"""

prompt2 = """
You are a MediLearn, knowledgeable medical tutor AI.

Respond in the following format:
<reasoning>
...your full explanation goes here...
</reasoning>
<answer>
...your final answer goes here...
</answer>

Follow these rules strictly:

1. **Always Include Reasoning:** For every question (whether multiple-choice or not), first think through the problem carefully. Place this full explanation inside the `<reasoning>` tag.
2. **Final Answer in Tag:** After reasoning, give your final answer clearly inside the `<answer>` tag. Be direct and specific. If the question is multiple choice, begin the answer with the correct letter followed by a colon.
For example:
<reasoning> Option A refers to... Option B is correct because... etc. </reasoning> <answer> B: This is the correct answer because it inhibits the correct enzyme involved. </answer> ```
3. Stay On Topic: If a question is not related to human anatomy or medicine, refuse it politely. You must begin such responses with the exact phrase: "{specific_refusal_phrase}" — and still include the <reasoning> and <answer> tags.
4. Safety First: End every response with this exact sentence: "Always consult a qualified healthcare professional for medical advice."
5. After the safety disclaimer, finish your reply with this special token: "<|end_of_turn|>"
6. Never omit any part of this structure. Even refusals must include all required tags and formatting.
"""

text = tokenizer.apply_chat_template([
    {"role" : "system", "content" : prompt2},
    {"role" : "user", "content" : "Which of the following hormones does NOT increase cardiac output? A. epinephrine B. thyroid hormones C. glucagon D. acetyl choline"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 5000,
)
output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = None,
)[0].outputs[0].text

output

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

'<reasoning> Cardiac output is the product of heart rate and stroke volume. Hormones affect cardiac output by increasing heart rate, stroke volume, or both. \n\nA. Epinephrine increases both heart rate and contractility, thus increasing cardiac output.\nB. Thyroid hormones also increase heart rate and contractility, thus increasing cardiac output.\nC. Glucagon primarily affects the liver, stimulating glycogenolysis, but has direct effects on heart muscle cells, thus it increases cardiac output indirectly.\nD. Acetylcholine, on the other hand, acts as a parasympathetic neurotransmitter, which generally decreases heart rate and, when considering its systemic effects, might lead to a slight decrease in cardiac output.\n\nTherefore, acetylcholine does NOT increase cardiac output directly. </reasoning>\n<answer> D: Acetyl choline decreases cardiac output because it acts as a parasympathetic neurotransmitter, leading to decreased heart rate and, consequently, reduced cardiac output. </answer

Our reasoning model is much better. It's not always correct, since we only trained it for an hour or so it'll be better if we extend the sequence length and train for longer!

# Saving

In [None]:
# Merge to 16bit
if True: model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)
if True: model.push_to_hub_merged("noureldinayman/MediLearn-Qwen2-7B-GRPO-500Steps-vLLM-MergedSave", tokenizer, save_method = "merged_16bit", token = "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")

# # Merge to 4bit
# if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_4bit",)
# if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_4bit", token = "")

# # Just LoRA adapters
# if False:
#     model.save_pretrained("model")
#     tokenizer.save_pretrained("model")
# if False:
#     model.push_to_hub("hf/model", token = "")
#     tokenizer.push_to_hub("hf/model", token = "")


Found HuggingFace hub cache directory: /root/.cache/huggingface/hub
Checking cache directory for required files...
Cache check failed: model-00001-of-00004.safetensors not found in local cache.
Not all required files found in cache. Will proceed with downloading.
Downloading safetensors index for unsloth/qwen2-7b-instruct...


Unsloth: Merging weights into 16bit:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.95G [00:00<?, ?B/s]

Unsloth: Merging weights into 16bit:  25%|██▌       | 1/4 [00:32<01:38, 32.92s/it]

model-00002-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

Unsloth: Merging weights into 16bit:  50%|█████     | 2/4 [01:12<01:13, 36.58s/it]

model-00003-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

Unsloth: Merging weights into 16bit:  75%|███████▌  | 3/4 [01:34<00:29, 29.93s/it]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.56G [00:00<?, ?B/s]

Unsloth: Merging weights into 16bit: 100%|██████████| 4/4 [01:54<00:00, 28.56s/it]


  0%|          | 0/1 [00:00<?, ?it/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

Found HuggingFace hub cache directory: /root/.cache/huggingface/hub
Checking cache directory for required files...
Cache check failed: model-00001-of-00004.safetensors not found in local cache.
Not all required files found in cache. Will proceed with downloading.
Downloading safetensors index for unsloth/qwen2-7b-instruct...


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Unsloth: Merging weights into 16bit:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.95G [00:00<?, ?B/s]

  0%|          | 0/1 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.95G [00:00<?, ?B/s]

Unsloth: Merging weights into 16bit:  25%|██▌       | 1/4 [01:13<03:41, 73.72s/it]

model-00002-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

  0%|          | 0/1 [00:00<?, ?it/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

Unsloth: Merging weights into 16bit:  50%|█████     | 2/4 [02:27<02:27, 73.53s/it]

model-00003-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

  0%|          | 0/1 [00:00<?, ?it/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

Unsloth: Merging weights into 16bit:  75%|███████▌  | 3/4 [03:37<01:12, 72.07s/it]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.56G [00:00<?, ?B/s]

  0%|          | 0/1 [00:00<?, ?it/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.56G [00:00<?, ?B/s]

Unsloth: Merging weights into 16bit: 100%|██████████| 4/4 [04:42<00:00, 70.60s/it]


### GGUF / llama.cpp Conversion
To save to `GGUF` / `llama.cpp`, we support it natively now! We clone `llama.cpp` and we default save it to `q8_0`. We allow all methods like `q4_k_m`. Use `save_pretrained_gguf` for local saving and `push_to_hub_gguf` for uploading to HF.

Some supported quant methods (full list on our [Wiki page](https://github.com/unslothai/unsloth/wiki#gguf-quantization-options)):
* `q8_0` - Fast conversion. High resource use, but generally acceptable.
* `q4_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K.
* `q5_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K.

[**NEW**] To finetune and auto export to Ollama, try our [Ollama notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)

In [None]:
# Save to 8bit Q8_0
if False: model.save_pretrained_gguf("model", tokenizer,)
# Remember to go to https://huggingface.co/settings/tokens for a token!
# And change hf to your username!
if False: model.push_to_hub_gguf("hf/model", tokenizer, token = "")

# Save to 16bit GGUF
# if True: model.save_pretrained_gguf("model", tokenizer, quantization_method = "f16")
if True: model.push_to_hub_gguf("noureldinayman/MediLearn-Qwen2-7B-GRPO-500Steps-GGUF", tokenizer, quantization_method = "f16", token = "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")

# Save to q4_k_m GGUF
if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "q4_k_m")
if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "q4_k_m", token = "")

# Save to multiple GGUF options - much faster if you want multiple!
if False:
    model.push_to_hub_gguf(
        "hf/model", # Change hf to your username!
        tokenizer,
        quantization_method = ["q4_k_m", "q8_0", "q5_k_m",],
        token = "",
    )

Unsloth: Merging 4bit and LoRA weights to 16bit...
Unsloth: Will use up to 48.66 out of 83.48 RAM for saving.
Unsloth: Saving model... This might take 5 minutes ...


100%|██████████| 28/28 [00:24<00:00,  1.15it/s]


Unsloth: Saving tokenizer... Done.
Done.
==((====))==  Unsloth: Conversion from QLoRA to GGUF information
   \\   /|    [0] Installing llama.cpp might take 3 minutes.
O^O/ \_/ \    [1] Converting HF to GGUF 16bits might take 3 minutes.
\        /    [2] Converting GGUF 16bits to ['f16'] might take 10 minutes each.
 "-____-"     In total, you will have to wait at least 16 minutes.

Unsloth: Installing llama.cpp. This might take 3 minutes...
Unsloth: [1] Converting model at noureldinayman/MediLearn-Qwen2-7B-GRPO-500Steps-GGUF into f16 GGUF format.
The output location will be /content/noureldinayman/MediLearn-Qwen2-7B-GRPO-500Steps-GGUF/unsloth.F16.gguf
This might take 3 minutes...
INFO:hf-to-gguf:Loading model: MediLearn-Qwen2-7B-GRPO-500Steps-GGUF
INFO:hf-to-gguf:Model architecture: Qwen2ForCausalLM
INFO:gguf.gguf_writer:gguf: This GGUF file is for Little Endian only
INFO:hf-to-gguf:Exporting model...
INFO:hf-to-gguf:gguf: loading model weight map from 'model.safetensors.index.json'
INF

  0%|          | 0/1 [00:00<?, ?it/s]

unsloth.F16.gguf:   0%|          | 0.00/15.2G [00:00<?, ?B/s]

Saved GGUF to https://huggingface.co/noureldinayman/MediLearn-Qwen2-7B-GRPO-500Steps-GGUF


In [None]:
import gc
import torch

# Delete all model references
del model
del tokenizer
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()


In [None]:
# Step 1: Save locally
model.save_pretrained("MediLearn-Qwen2-7B-transformers")
tokenizer.save_pretrained("MediLearn-Qwen2-7B-transformers")

# Step 2: Push to Hugging Face Hub
from huggingface_hub import HfApi
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load the saved model and tokenizer again (optional, but safe to ensure everything is clean)
model = AutoModelForCausalLM.from_pretrained("MediLearn-Qwen2-7B-transformers")
tokenizer = AutoTokenizer.from_pretrained("MediLearn-Qwen2-7B-transformers")

# Push to Hub
model.push_to_hub("noureldinayman/MediLearn-Qwen2-7B-GRPO-500Steps-TransformersSave", token="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")
tokenizer.push_to_hub("noureldinayman/MediLearn-Qwen2-7B-GRPO-500Steps-TransformersSave", token="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")


NameError: name 'model' is not defined

# Evaluation On 0.1% (270) Validation Set

In [None]:
!pip install evaluate
!pip install bert_score
!pip install rouge_score

Collecting evaluate
  Downloading evaluate-0.4.4-py3-none-any.whl.metadata (9.5 kB)
Downloading evaluate-0.4.4-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.4
Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=decd9a191a5138e5c50ece59f91bdacbcc018a0b5cff7c5a2aa4b54bb7568e2f
  Stored in directory: /root/.cache/pip/wheels/1e/19/43/8a442dc83660ca25e163e1bd1f89919284ab0d0c1475475148
Successfully built rouge_score
Installing collected packages: rouge_score
Successfully installed rouge_score-0.1.2


In [None]:
!pip install -U bitsandbytes



In [None]:
!nvidia-smi --query-gpu=name --format=csv,noheader

NVIDIA A100-SXM4-40GB


In [None]:
!huggingface-cli login --token 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

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
The token `AnatomyLLM` has been saved to /root/.cache/huggingface/stored_tokens
Your token has been saved to /root/.cache/huggingface/token
Login successful.
The current active token is: `AnatomyLLM`


In [None]:
import os
import gc
import torch
import transformers
import pandas as pd
from datasets import load_dataset, Dataset, DatasetDict
from huggingface_hub import hf_hub_download
import json
import random
import time
import re
from evaluate import load
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoProcessor,
    BitsAndBytesConfig,
    pipeline,
    AutoModelForImageTextToText, # For MedGemma
)
from tqdm.auto import tqdm
import numpy as np
import warnings

# Suppress warnings
warnings.filterwarnings("ignore")
transformers.utils.logging.set_verbosity_error()

print(f"Transformers version installed: {transformers.__version__}")

# Set seed for reproducibility
random.seed(42)

os.environ["TOKENIZERS_PARALLELISM"] = "false"

DATASET_REPO_ID = "Anatomy-Tutor/Anatomy-and-Medical-Dataset"
DATASET_FILENAME = "processed_medical_and_anatomy.json"

# Load the dataset
print("Loading dataset from Hugging Face Hub...")
try:
    hf_data_path = hf_hub_download(
        repo_id=DATASET_REPO_ID,
        filename=DATASET_FILENAME,
        repo_type="dataset"
    )
    with open(hf_data_path, "r", encoding="utf-8") as f:
        splits = json.load(f)
    ds = DatasetDict({
        "train": Dataset.from_list(splits["train"]),
        "validation": Dataset.from_list(splits["validation"]),
        "test": Dataset.from_list(splits["test"]),
    })
except Exception as e:
    print(f"Failed to load or process dataset. Error: {e}")
    ds = None # Changed DEV_SET to ds to match variable in try block

# A specific refusal phrase for the model to use.
specific_refusal_phrase = "I am sorry, but I can only answer questions related to human anatomy and medicine."
safety_warning = "Always consult a qualified healthcare professional for medical advice."
end_of_turn_token = "<|end_of_turn|>" # Define this token explicitly

MEDICAL_CHATBOT_SYSTEM_PROMPT = f"""You are "Medilearn," an expert AI anatomy tutor for a VR application. Your goal is to provide clear, accurate, and educational explanations.
Respond in the following format:
<reasoning>
...your full explanation goes here...
</reasoning>
<answer>
...your final answer goes here...
</answer>

*Rules:**
1. **Always Include Reasoning:** For every question (whether multiple-choice or not), first think through the problem carefully. Place this full explanation inside the `<reasoning>` tag.
2. **Stay On Topic:** Politely refuse any question not related to human anatomy or medicine. When you refuse a question, you MUST begin your response with the exact phrase: "{specific_refusal_phrase}"
3. **Final Answer in Tag:** After reasoning, give your final answer clearly inside the `<answer>` tag. Be direct and specific. If the question is multiple choice, begin the answer with the correct letter followed by a colon.
   For example: <reasoning> Option A refers to... Option B is correct because... etc. </reasoning> <answer> B: This is the correct answer because it inhibits the correct enzyme involved. </answer>
4. **End Your Turn:** After providing the complete answer and the mandatory safety warning, you MUST output the special token `{end_of_turn_token}`.
5. **Safety First:** Your final sentence before the end-of-turn token must be: "{safety_warning}"
"""

# Additional instruction for reasoning models
REASONING_FORMAT_INSTRUCTION = "Please provide your concise reasoning within <think> tags and your final answer within <Answer> tags."

REASONING_MODELS = [
    "Intelligent-Internet/II-Medical-8B",
    "ContactDoctor/Bio-Medical-Llama-3-8B-CoT-012025"
]

# Non-medical "bait" questions to test refusal
BAIT_QUESTIONS = [
    "What is the capital of France?",
    "Can you tell me a joke??",
    "Who is the current president of the United States?",
    "What's 5 multiplied by 7?",
    "Name a popular song from 2023.",
    "How do you make scrambled eggs?",
    "What is the tallest mountain in the world?",
    "Give me a fun fact about space.",
    "What's the meaning of life?",
    "Can you recommend a good book to read?",
]

def _parse_and_structure_item(prompt_text, reference_answer):
    """
    Helper function to auto-detect and structure MCQs from raw text.
    This version includes the more robust logic from the analyzer script.
    """
    option_marker_pattern = re.compile(r'\b[A-Da-d1-4][\.\)-:]')
    splitter_pattern = re.compile(r'\s+(?=[A-Da-d1-4][\.\)-:]|\([A-Da-d1-4]\))')
    parts = splitter_pattern.split(prompt_text)

    if len(parts) >= 3:
        base_prompt = parts[0]
        option_parts = parts[1:]
        options_dict = {}
        for part in option_parts:
            match = re.match(r'\(?([A-Da-d1-4])\)?[.\s:-]\s*(.*)', part.strip())
            if match:
                key, text = match.groups()
                options_dict[key.upper()] = text.strip()

        if len(options_dict) >= 2:
            structured_item = {
                "is_mcq": True,
                "prompt": base_prompt.strip(),
                "options": options_dict
            }
            key_pattern = re.compile(
                r"^\s*\(?([A-D1-4])\)?[.\s:-]|(?:the correct answer is|the answer is|completion:)\s*\(?([A-D1-4])\)?",
                re.IGNORECASE
            )
            match = key_pattern.search(reference_answer)
            if match:
                found_key = (match.group(1) or match.group(2))
                if found_key:
                    structured_item["correct_answer_key"] = found_key.upper()
                    return structured_item

    options_found = option_marker_pattern.findall(prompt_text)
    is_short_answer = len(reference_answer.strip()) < 100 and not reference_answer.strip().endswith('.')

    if len(set(options_found)) >= 2 or is_short_answer:
        pass # Let it fall through to the non-MCQ return

    return {"is_mcq": False, "prompt": prompt_text}

def prepare_evaluation_set(full_dataset, max_samples: int):
    """
    Prepares the evaluation set from the full dataset.
    """
    print("Preparing evaluation set from full dataset...")
    dev_set = []
    if full_dataset and 'validation' in full_dataset:
        validation_split = full_dataset["validation"]
        num_medical_samples = max_samples - len(BAIT_QUESTIONS)
        if num_medical_samples < 0: num_medical_samples = 0

        num_medical_samples = min(num_medical_samples, len(validation_split))

        indices = list(range(len(validation_split)))
        random.shuffle(indices)
        sampled_indices = indices[:num_medical_samples]
        print(f"Sampling {len(sampled_indices)} medical questions for processing...")

        for i in sampled_indices:
            item = validation_split[i]
            user_prompt, reference_answer = None, None

            messages = item.get('messages', item.get('conversations', []))

            for message in messages:
                if message.get('role') == 'user': user_prompt = message.get('content')
                elif message.get('role') == 'assistant': reference_answer = message.get('content')

            if not user_prompt: user_prompt = item.get('prompt')
            if not reference_answer: reference_answer = item.get('completion')

            if user_prompt and reference_answer:
                structured_info = _parse_and_structure_item(user_prompt, reference_answer)
                eval_item = {
                    "id": f"Med-{i}", "prompt": structured_info["prompt"],
                    "reference_answer": reference_answer, "is_bait": False,
                    "is_mcq": structured_info["is_mcq"], "expected_tags": item.get("expected_tags", []),
                }
                if structured_info["is_mcq"]:
                    eval_item["options"] = structured_info.get("options", {})
                    eval_item["correct_answer_key"] = structured_info.get("correct_answer_key", "")
                dev_set.append(eval_item)
    else:
        print("Medical dataset not available or invalid. Proceeding with bait questions only.")

    for i, question in enumerate(BAIT_QUESTIONS):
        dev_set.append({"id": f"Bait-{i}", "prompt": question, "reference_answer": "", "is_bait": True, "is_mcq": False, "expected_tags": []})

    random.shuffle(dev_set)
    print(f"Prepared a final set of {len(dev_set)} mixed samples for evaluation.")
    return dev_set

def extract_answer_from_model_output(generated_text: str, is_mcq: bool) -> str:
    """
    Extracts the answer part from the model's generated text based on the defined format.
    Prioritizes text within <answer> tags.
    """
    # Define common ending patterns for generated text
    ending_patterns = [
        end_of_turn_token,
        safety_warning,
        specific_refusal_phrase,
        r'System:\s*.*', # Catches accidental regeneration of system prompt
        r'User:\s*.*', # Catches accidental regeneration of user prompt
    ]

    # Remove safety warning and end-of-turn token first
    cleaned_temp = generated_text.replace(safety_warning, "").replace(end_of_turn_token, "").strip()

    # 1. Try to find content within <answer>...</answer> tags
    answer_match = re.search(r'<answer>(.*?)</answer>', cleaned_temp, re.DOTALL | re.IGNORECASE)
    if answer_match:
        extracted_answer = answer_match.group(1).strip()
        return extracted_answer

    # 2. If no full <answer> tag, try to find an opening <answer> tag and take everything after it
    open_answer_tag_match = re.search(r'<answer>(.*)', cleaned_temp, re.DOTALL | re.IGNORECASE)
    if open_answer_tag_match:
        extracted_answer = open_answer_tag_match.group(1).strip()
        # Clean any remaining ending patterns from this fallback
        for pattern in ending_patterns:
            extracted_answer = re.sub(pattern, '', extracted_answer, flags=re.DOTALL | re.IGNORECASE).strip()
        return extracted_answer

    # 3. If no <answer> tag at all, try to find text after </reasoning>
    reasoning_end_match = re.search(r'</reasoning>(.*)', cleaned_temp, re.DOTALL | re.IGNORECASE)
    if reasoning_end_match:
        # Take everything after </reasoning>
        post_reasoning_text = reasoning_end_match.group(1).strip()
        # Clean any remaining ending patterns from this fallback
        for pattern in ending_patterns:
            post_reasoning_text = re.sub(pattern, '', post_reasoning_text, flags=re.DOTALL | re.IGNORECASE).strip()
        return post_reasoning_text

    # 4. As a last resort, if no structure is found, take the whole generated text
    # and try to remove common unwanted phrases/patterns.
    final_cleaned_text = generated_text.strip()
    for pattern in ending_patterns:
        final_cleaned_text = re.sub(pattern, '', final_cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip()

    # Also remove system prompt content if it got accidentally regenerated at the start
    final_cleaned_text = re.sub(r'^System:\s*.*', '', final_cleaned_text, flags=re.DOTALL | re.IGNORECASE).strip()

    return final_cleaned_text


MAX_SAMPLES_TO_EVALUATE = 280 # A limit to keep evaluation time reasonable. 270 from the eval and 10 from bait

DEV_SET = [] # Initialize DEV_SET to an empty list
if ds: # Check if dataset was loaded successfully
    DEV_SET = prepare_evaluation_set(full_dataset=ds, max_samples=MAX_SAMPLES_TO_EVALUATE)

    # --- ADDED: Analyze and Print Final Counts ---
    if DEV_SET:
        mcq_count = sum(1 for item in DEV_SET if item.get('is_mcq') and not item.get('is_bait'))
        open_ended_count = sum(1 for item in DEV_SET if not item.get('is_mcq') and not item.get('is_bait'))

        print("\n--- Dataset Content Analysis ---")
        print(f"Total Multiple-Choice Questions (MCQs) detected: {mcq_count}")
        print(f"Total Open-Ended Questions detected: {open_ended_count}")
        print(f"Total Bait Questions added: {len(BAIT_QUESTIONS)}")
        print("--------------------------------\n")

        # --- ADDED: Save the generated DEV_SET to a file ---
        output_filename = "evaluation_set_280_samples.json"
        print(f"Saving the prepared evaluation set to '{output_filename}'...")
        try:
            with open(output_filename, "w", encoding="utf-8") as f:
                json.dump(DEV_SET, f, indent=4)
            print(f"File '{output_filename}' saved successfully.")
        except Exception as e:
            print(f"Error saving file: {e}")
else:
    print("Dataset not loaded, DEV_SET is empty. Evaluation will only run on bait questions if any are defined.")


def evaluate_model(models_to_evaluate, dev_set, log_every_n_samples, batch_size, use_4bit_quantization):
    """
    Main function to load models, run evaluation, and report metrics.
    """
    if not dev_set:
        print("Evaluation cannot proceed without a development set.")
        return
    if not torch.cuda.is_available():
        print("ERROR: This script requires a CUDA-enabled GPU.")
        return

    bertscore = load("bertscore")
    rouge = load("rouge")
    all_results_data = []

    refusal_keywords = [
        "sorry", "cannot", "unable", "not medical", "not anatomy", "only answer",
        "not equipped", "outside my scope", "my purpose is", "my knowledge is limited to",
        specific_refusal_phrase.lower() # Use the exact phrase for checking
    ]

    current_system_prompt = MEDICAL_CHATBOT_SYSTEM_PROMPT

    # Define a custom max_seq_length for Unsloth models
    UNSLOTH_MAX_SEQ_LENGTH = 1500

    for model_info in models_to_evaluate:
        model_name, model_id = model_info["name"], model_info["model_id"]
        print(f"\n{'='*20}\nEvaluating Model: {model_name} ({model_id})\n{'='*20}")

        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()
        model, processor_or_tokenizer = None, None
        current_model_results = []

        try:
            print("Loading model and tokenizer/processor...")

            # Determine if it's your specific Unsloth model
            is_unsloth_finetuned_qwen = "medilearn-qwen2-7b-grpo-500steps" in model_name.lower()
            is_gemma_model = "gemma" in model_id.lower()
            is_jsl_model = "johnsnowlabs" in model_id.lower()

            # Define common quantization config if 4-bit quantization is used
            q_config = None
            if use_4bit_quantization:
                print("    > NOTE: Loading with 4-bit quantization.")
                q_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_quant_type="nf4",
                    bnb_4bit_compute_dtype=torch.bfloat16,
                    bnb_4bit_use_double_quant=True,
                )
            else:
                print("    > NOTE: Loading in native precision.")

            if is_unsloth_finetuned_qwen:
                print(f"    > Loading Unsloth fine-tuned model: {model_id}")
                # Import FastLanguageModel here to avoid error if not used
                try:
                    from unsloth import FastLanguageModel
                except ImportError:
                    raise ImportError("Please install unsloth to use FastLanguageModel: pip install unsloth")

                model, processor_or_tokenizer = FastLanguageModel.from_pretrained(
                    model_name = model_id, # This is the Hugging Face repo ID
                    max_seq_length = UNSLOTH_MAX_SEQ_LENGTH,
                    load_in_4bit = use_4bit_quantization,
                    fast_inference = True,
                )
                actual_tokenizer = processor_or_tokenizer # For Unsloth, processor_or_tokenizer IS the tokenizer
            elif is_gemma_model:
                processor_or_tokenizer = AutoProcessor.from_pretrained(model_id)
                model = AutoModelForImageTextToText.from_pretrained(
                    model_id,
                    device_map="auto",
                    trust_remote_code=True,
                    quantization_config=q_config,
                    torch_dtype=torch.bfloat16 if not use_4bit_quantization else None
                )
                actual_tokenizer = processor_or_tokenizer.tokenizer if hasattr(processor_or_tokenizer, 'tokenizer') else processor_or_tokenizer
            else: # Generic AutoModel loading for all other models
                processor_or_tokenizer = AutoTokenizer.from_pretrained(model_id)
                model = AutoModelForCausalLM.from_pretrained(
                    model_id,
                    device_map="auto",
                    trust_remote_code=True,
                    quantization_config=q_config,
                    torch_dtype=torch.bfloat16 if not use_4bit_quantization else None
                )
                actual_tokenizer = processor_or_tokenizer.tokenizer if hasattr(processor_or_tokenizer, 'tokenizer') else processor_or_tokenizer

            if actual_tokenizer.pad_token is None:
                actual_tokenizer.pad_token = actual_tokenizer.eos_token

            if not hasattr(processor_or_tokenizer, 'pad_token_id') or processor_or_tokenizer.pad_token_id is None:
                processor_or_tokenizer.pad_token_id = actual_tokenizer.pad_token_id

            print("Model loaded successfully.")
            peak_vram_gb = torch.cuda.max_memory_allocated() / (1024**3)

            for i in tqdm(range(0, len(dev_set), batch_size), desc=f"Evaluating {model_name}"):
                batch = dev_set[i:i + batch_size]

                batch_prompts = []
                for item in batch:
                    user_prompt = item["prompt"]
                    if item.get("is_mcq", False):
                        options_str = "\n".join([f"{key}: {value}" for key, value in item["options"].items()])
                        user_prompt = f"{user_prompt}\n\n{options_str}"
                    batch_prompts.append(user_prompt)

                batch_inputs_text = []
                for prompt_text in batch_prompts:
                    if is_jsl_model:
                        formatted_prompt = f"###Question: {prompt_text} ###Answer:"
                        batch_inputs_text.append(formatted_prompt)
                    else:
                        full_system_prompt_content = current_system_prompt
                        if model_name in REASONING_MODELS:
                            full_system_prompt_content += "\n" + REASONING_FORMAT_INSTRUCTION

                        if is_gemma_model:
                            messages = [{"role": "user", "content": [{"type": "text", "text": prompt_text}]}]
                        else:
                            messages = [
                                {"role": "system", "content": full_system_prompt_content},
                                {"role": "user", "content": prompt_text},
                            ]
                        formatted_prompt = processor_or_tokenizer.apply_chat_template(
                            messages, tokenize=False, add_generation_prompt=True
                        )
                        batch_inputs_text.append(formatted_prompt)

                inputs = processor_or_tokenizer(
                    batch_inputs_text, return_tensors="pt", padding=True, truncation=True
                ).to(model.device)

                start_time = time.perf_counter()

                output_ids = model.generate(
                    **inputs,
                    max_new_tokens=1000,
                    num_return_sequences=1,
                    do_sample=False,
                    pad_token_id=processor_or_tokenizer.pad_token_id,
                )

                end_time = time.perf_counter()
                total_time = end_time - start_time

                input_lengths = inputs.input_ids.shape[1]
                newly_generated_ids = output_ids[:, input_lengths:]
                raw_generated_texts = processor_or_tokenizer.batch_decode(newly_generated_ids, skip_special_tokens=True)

                for j, item in enumerate(batch):
                    raw_generated_text = raw_generated_texts[j].strip()

                    # === EXTRACT ANSWER USING THE NEW FUNCTION ===
                    cleaned_text = extract_answer_from_model_output(raw_generated_text, item.get("is_mcq", False))
                    # ============================================

                    num_new_tokens = newly_generated_ids[j].shape[0]
                    tokens_per_second = num_new_tokens / total_time if total_time > 0 else 0

                    if i == 0 and j == 0:
                        print(f"\n--- First Output for Model: {model_name} (Prompt ID: {item['id']}) ---")
                        print(f"Question:\n{batch_prompts[j]}\n")
                        print(f"Reference Answer:\n{item['reference_answer']}\n")
                        print("Raw Generated Text (for debugging format issues):")
                        print(raw_generated_text)
                        print("\nCleaned Extracted Answer:")
                        print(cleaned_text)
                        print("-" * 50)

                    metrics = {"Model": model_name, "Prompt ID": item["id"], "Is MCQ": item.get("is_mcq", False), "Is Bait": item["is_bait"], "Tokens/Sec": tokens_per_second, "Peak VRAM (GB)": peak_vram_gb, "BERTScore_F1": 0, "ROUGE_L": 0, "Refusal_Acc": 0, "MCQ_Acc": 0}

                    # Refusal check - cleaned for safety warning and end token
                    refusal_check_text_raw = raw_generated_text.lower()
                    for pattern in [safety_warning.lower(), end_of_turn_token.lower()]:
                        refusal_check_text_raw = refusal_check_text_raw.replace(pattern, "")

                    # Now check if the specific refusal phrase is present in the raw output
                    if item["is_bait"]:
                        if specific_refusal_phrase.lower() in refusal_check_text_raw:
                            metrics["Refusal_Acc"] = 1
                        else:
                            print(f"\n[Refusal Check Fail] Prompt ID: {item['id']}")
                            print(f"Raw Generated Text: {raw_generated_text[:200]}...") # Print beginning of output
                            print(f"Expected Refusal: '{specific_refusal_phrase}'")
                    elif item.get("is_mcq", False):
                        correct_key = item["correct_answer_key"]

                        # Look for answers in the extracted cleaned_text
                        # This pattern is more flexible for "A: text", "A. text", "(A) text", "A text"
                        found_keys_in_cleaned = re.findall(r"(?:^|\b)(?:option|answer|choice)?\s*([A-D1-4])\s*[:\)\.\-\s]", cleaned_text, re.IGNORECASE)
                        found_keys_in_cleaned = [k.upper() for k in found_keys_in_cleaned]

                        # Also check the raw generated text for robustness if needed, but primarily rely on cleaned
                        found_keys_in_raw = re.findall(r"(?:^|\b)(?:option|answer|choice)?\s*([A-D1-4])\s*[:\)\.\-\s]", raw_generated_text, re.IGNORECASE)
                        found_keys_in_raw = [k.upper() for k in found_keys_in_raw]

                        # Combine found keys and prioritize those from cleaned_text
                        all_found_keys = list(set(found_keys_in_cleaned + found_keys_in_raw))

                        # Refine chosen_answers based on explicit negation (if model explicitly says "X is not correct")
                        negated_keys = re.findall(r"([A-D1-4])\s*(?:is not correct|is incorrect|is wrong)", raw_generated_text, re.IGNORECASE)
                        negated_keys = [k.upper() for k in negated_keys]

                        chosen_answers = [k for k in all_found_keys if k not in negated_keys]

                        # If multiple keys are found and one matches the correct key, count as correct if it's the *only* one.
                        # If a single, unambiguous correct key is identified, and it matches the reference, it's correct.
                        # This logic needs to be careful: if it says "A is wrong, B is correct", and B is indeed correct.
                        # The simple `chosen_answers` list works if the model explicitly selects one.

                        # Simpler MCQ matching: If the correct key is found among the chosen_answers and it's unambiguous
                        if correct_key and correct_key in chosen_answers and len(chosen_answers) == 1:
                            metrics["MCQ_Acc"] = 1
                        # Edge case: If the correct key is mentioned *first* and no other key is chosen,
                        # or if it's the only one clearly stated as "correct".
                        elif correct_key and re.search(r'(?:^|\b)' + re.escape(correct_key) + r'\s*[:\)\.\-]', cleaned_text, re.IGNORECASE):
                            # This catches "C: SGLT-2 inhibitors" where 'C' is the correct key
                            metrics["MCQ_Acc"] = 1
                        else:
                            metrics["MCQ_Acc"] = 0
                            # Optional: Log specific MCQ failures for manual review
                            # print(f"\n[MCQ Fail] Prompt ID: {item['id']}")
                            # print(f"Reference Answer: {item['reference_answer']}")
                            # print(f"Raw Generated Text: {raw_generated_text}")
                            # print(f"Extracted Cleaned Text: {cleaned_text}")
                            # print(f"Correct Key: {correct_key}, Found Keys: {all_found_keys}, Chosen: {chosen_answers}")

                    else: # Open-ended questions
                        if cleaned_text and item["reference_answer"]: # Ensure both are non-empty for BERTScore/ROUGE
                            try:
                                bert_results = bertscore.compute(predictions=[cleaned_text], references=[item["reference_answer"]], lang="en")
                                rouge_results = rouge.compute(predictions=[cleaned_text], references=[item["reference_answer"]])
                                metrics["BERTScore_F1"] = bert_results['f1'][0]
                                metrics["ROUGE_L"] = rouge_results['rougeL']

                                # Log low BERTScore for open-ended questions for debugging
                                if metrics["BERTScore_F1"] < 0.5: # Arbitrary threshold for logging
                                     print(f"\n[Low BERTScore] Prompt ID: {item['id']}")
                                     print(f"Reference Answer: {item['reference_answer']}")
                                     print(f"Cleaned Extracted Answer: {cleaned_text}")
                                     print(f"BERTScore_F1: {metrics['BERTScore_F1']:.2f}")

                            except Exception as bert_rouge_e:
                                print(f"Error computing BERTScore/ROUGE for Prompt ID {item['id']}: {bert_rouge_e}")
                                print(f"Cleaned text: {cleaned_text}")
                                print(f"Reference answer: {item['reference_answer']}")
                        else:
                            print(f"Skipping BERTScore/ROUGE for Prompt ID {item['id']} due to empty prediction or reference.")


                    current_model_results.append(metrics)
                    all_results_data.append(metrics)

                samples_processed = i + batch_size
                if log_every_n_samples > 0 and samples_processed % log_every_n_samples == 0 and samples_processed > 0 and samples_processed < len(dev_set):
                    print(f"\n    [Log at sample {samples_processed}/{len(dev_set)}] Model: {model_name}")

                    # Ensure enough results are available for the log slice
                    num_results_to_log = min(log_every_n_samples, len(current_model_results))
                    log_slice = current_model_results[-num_results_to_log:]

                    if log_slice:
                        df_log = pd.DataFrame(log_slice)

                        slice_mcq_count = df_log['Is MCQ'].sum()
                        slice_bait_count = df_log['Is Bait'].sum()
                        slice_open_ended_count = len(df_log) - slice_mcq_count - slice_bait_count

                        print(f"      > Current Slice Counts: MCQs={slice_mcq_count}, Open-Ended={slice_open_ended_count}, Bait={slice_bait_count}")

                        # Calculate means only if there are relevant samples
                        avg_mcq = df_log[df_log['Is MCQ']]['MCQ_Acc'].mean() if slice_mcq_count > 0 else np.nan
                        avg_bert = df_log[~df_log['Is MCQ'] & ~df_log['Is Bait']]['BERTScore_F1'].mean() if slice_open_ended_count > 0 else np.nan
                        avg_rouge = df_log[~df_log['Is MCQ'] & ~df_log['Is Bait']]['ROUGE_L'].mean() if slice_open_ended_count > 0 else np.nan
                        avg_refusal = df_log[df_log['Is Bait']]['Refusal_Acc'].mean() if slice_bait_count > 0 else np.nan

                        print(f"      > Last {len(log_slice)} samples | Avg MCQ Acc: {avg_mcq:.2f} | Avg BERT_F1: {avg_bert:.2f} | Avg ROUGE_L: {avg_rouge:.2f} | Avg Refusal: {avg_refusal:.2f}")


            print("\n")
        except Exception as e:
            print(f"\nERROR: Failed to evaluate model {model_name}. Error: {e}")
            import traceback
            traceback.print_exc()
        finally:
            print(f"Clearing memory after evaluating {model_name}...")
            if 'model' in locals() and model is not None: del model
            if 'processor_or_tokenizer' in locals() and processor_or_tokenizer is not None: del processor_or_tokenizer
            gc.collect()
            torch.cuda.empty_cache()

        if current_model_results:
            df_current = pd.DataFrame(current_model_results)
            # Use 'first' for Peak_VRAM_GB as it's constant for the model run
            summary = df_current.groupby("Model").agg(
                Avg_Tokens_Sec=("Tokens/Sec", "mean"),
                Peak_VRAM_GB=("Peak VRAM (GB)", "first"),
                Avg_MCQ_Acc=("MCQ_Acc", lambda x: x[df_current.loc[x.index, 'Is MCQ']].mean() if df_current.loc[x.index, 'Is MCQ'].any() else np.nan),
                Avg_OpenEnded_BERT_F1=("BERTScore_F1", lambda x: x[~df_current.loc[x.index, 'Is Bait'] & ~df_current.loc[x.index, 'Is MCQ']].mean() if (~df_current.loc[x.index, 'Is Bait'] & ~df_current.loc[x.index, 'Is MCQ']).any() else np.nan),
                Avg_OpenEnded_ROUGE_L=("ROUGE_L", lambda x: x[~df_current.loc[x.index, 'Is Bait'] & ~df_current.loc[x.index, 'Is MCQ']].mean() if (~df_current.loc[x.index, 'Is Bait'] & ~df_current.loc[x.index, 'Is MCQ']).any() else np.nan),
                Avg_Refusal_Acc=("Refusal_Acc", lambda x: x[df_current.loc[x.index, 'Is Bait']].mean() if df_current.loc[x.index, 'Is Bait'].any() else np.nan)
            ).reset_index().fillna(0) # Fill NaN from empty groups with 0 for display
            print(f"\n--- METRIC SUMMARY for {model_name} ---")
            print(summary.round(3).to_string(index=False))
            print("-" * 50)

    if not all_results_data:
        print("\nNo overall results to display.")
        return

    pd.set_option('display.max_colwidth', 80)
    pd.set_option('display.width', 120)
    df_detailed = pd.DataFrame(all_results_data)

    df_summary_overall = df_detailed.groupby("Model").agg(
        Avg_Tokens_Sec=("Tokens/Sec", "mean"),
        Peak_VRAM_GB=("Peak VRAM (GB)", "first"),
        Avg_MCQ_Acc=("MCQ_Acc", lambda x: x[df_detailed.loc[x.index, 'Is MCQ']].mean() if df_detailed.loc[x.index, 'Is MCQ'].any() else np.nan),
        Avg_OpenEnded_BERT_F1=("BERTScore_F1", lambda x: x[~df_detailed.loc[x.index, 'Is Bait'] & ~df_detailed.loc[x.index, 'Is MCQ']].mean() if (~df_detailed.loc[x.index, 'Is Bait'] & ~df_detailed.loc[x.index, 'Is MCQ']).any() else np.nan),
        Avg_OpenEnded_ROUGE_L=("ROUGE_L", lambda x: x[~df_detailed.loc[x.index, 'Is Bait'] & ~df_detailed.loc[x.index, 'Is MCQ']].mean() if (~df_detailed.loc[x.index, 'Is Bait'] & ~df_detailed.loc[x.index, 'Is MCQ']).any() else np.nan),
        Avg_Refusal_Acc=("Refusal_Acc", lambda x: x[df_detailed.loc[x.index, 'Is Bait']].mean() if df_detailed.loc[x.index, 'Is Bait'].any() else np.nan)
    ).reset_index().fillna(0)

    print("\n\n--- FINAL METRIC SUMMARY (All Models Combined) ---")
    print(df_summary_overall.round(3).to_string(index=False))

    df_detailed.to_csv("detailed_results.csv", index=False)
    df_summary_overall.to_csv("summary_results.csv", index=False)
    print("\nResults saved to detailed_results.csv and summary_results.csv")
    print("\nEvaluation complete.")

MODELS_TO_EVALUATE = [
    {"name": "MediLearn-Qwen2-7B-GRPO-500Steps", "model_id": "noureldinayman/MediLearn-Qwen2-7B-GRPO-500Steps-vLLM-MergedSave"},
    # You can add more models here as needed for comparison
]

BATCH_SIZE = 8
USE_4BIT_QUANTIZATION = True
LOG_EVERY_N_SAMPLES = 8 # Changed to 8 to log every batch

evaluate_model(
    models_to_evaluate=MODELS_TO_EVALUATE,
    dev_set=DEV_SET,
    log_every_n_samples=LOG_EVERY_N_SAMPLES,
    batch_size=BATCH_SIZE,
    use_4bit_quantization=USE_4BIT_QUANTIZATION
)

Transformers version installed: 4.53.0
Loading dataset from Hugging Face Hub...
Preparing evaluation set from full dataset...
Sampling 270 medical questions for processing...
Prepared a final set of 280 mixed samples for evaluation.

--- Dataset Content Analysis ---
Total Multiple-Choice Questions (MCQs) detected: 160
Total Open-Ended Questions detected: 110
Total Bait Questions added: 10
--------------------------------

Saving the prepared evaluation set to 'evaluation_set_280_samples.json'...
File 'evaluation_set_280_samples.json' saved successfully.

Evaluating Model: MediLearn-Qwen2-7B-GRPO-500Steps (noureldinayman/MediLearn-Qwen2-7B-GRPO-500Steps-vLLM-MergedSave)
Loading model and tokenizer/processor...
    > NOTE: Loading with 4-bit quantization.
    > Loading Unsloth fine-tuned model: noureldinayman/MediLearn-Qwen2-7B-GRPO-500Steps-vLLM-MergedSave
==((====))==  Unsloth 2025.6.12: Fast Qwen2 patching. Transformers: 4.53.0. vLLM: 0.8.5.post1.
   \\   /|    NVIDIA A100-SXM4-40GB. 

Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]


INFO 07-05 09:23:00 [punica_selector.py:18] Using PunicaWrapperGPU.
INFO 07-05 09:23:00 [gpu_model_runner.py:1347] Model loading took 5.5142 GiB and 5.195378 seconds
INFO 07-05 09:23:17 [backends.py:420] Using cache directory: /root/.cache/vllm/torch_compile_cache/70f353c6e1/rank_0_0 for vLLM's torch.compile
INFO 07-05 09:23:17 [backends.py:430] Dynamo bytecode transform time: 16.41 s
INFO 07-05 09:23:27 [backends.py:118] Directly load the compiled graph(s) for shape None from the cache, took 7.372 s
INFO 07-05 09:23:31 [monitor.py:33] torch.compile takes 16.41 s in total
INFO 07-05 09:23:32 [kv_cache_utils.py:634] GPU KV cache size: 232,432 tokens
INFO 07-05 09:23:32 [kv_cache_utils.py:637] Maximum concurrency for 1,500 tokens per request: 154.95x
INFO 07-05 09:24:43 [gpu_model_runner.py:1686] Graph capturing finished in 72 secs, took 1.33 GiB
INFO 07-05 09:24:43 [core.py:159] init engine (profile, create kv cache, warmup model) took 103.11 seconds
Unsloth: Just some info: will skip p

Evaluating MediLearn-Qwen2-7B-GRPO-500Steps:   0%|          | 0/35 [00:00<?, ?it/s]


--- First Output for Model: MediLearn-Qwen2-7B-GRPO-500Steps (Prompt ID: Med-2692) ---
Question:
For patients with PAD and type 2 diabetes, the

A: Sulfonylureas
B: DPP-4 inhibitors
C: SGLT-2 inhibitors
D: Meglitinides

Reference Answer:
c- SGLT-2 inhibitors

Raw Generated Text (for debugging format issues):
<reasoning> Option A (Sulfonylureas) is not typically recommended for patients with PAD (Peripheral Artery Disease) and type 2 diabetes because they can cause hypoglycemia, which can worsen PAD symptoms. Option B (DPP-4 inhibitors) are not specifically designed to address PAD but can help manage type 2 diabetes. Option C (SGLT-2 inhibitors) are effective in managing type 2 diabetes and have shown benefits in reducing cardiovascular risk, which can be beneficial for patients with PAD. Option D (Meglitinides) are insulin secretagogues that can cause hypoglycemia, which is not ideal for patients with PAD. </reasoning> <answer> C: SGLT-2 inhibitors: These are effective in managing typ

# Results
- Avg_Tokens_Sec: 19.048
- Peak_VRAM_GB: 19.056          
- Avg_MCQ_Acc: 0.5                  
- Avg_OpenEnded_BERT_F1: 0.735                  
- Avg_OpenEnded_ROUGE_L: 0.094              
- Avg_Refusal_Acc: 0.0