In [None]:
!pip install datasets

In [None]:
!pip install unsloth vllm
!pip install --upgrade pillow

In [None]:
import pandas as pd
import numpy as np
import os

from datasets import load_dataset

In [None]:
import wandb

wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mashaduzzaman2505[0m ([33mashaduzzaman_sarker[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
from huggingface_hub import notebook_login

notebook_login()

## Load Datasets (all splits)

In [None]:
import re
from datasets import load_dataset
from datasets import concatenate_datasets

In [None]:
# Load MedQA-USMLE (if available on Hugging Face)
medqa_usmle = load_dataset("Neelectric/MedQA-USMLE")
print("MedQA-USMLE loaded successfully.")
medqa_usmle

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


MedQA-USMLE loaded successfully.


DatasetDict({
    train: Dataset({
        features: ['question', 'answer', 'options', 'meta_info', 'answer_idx'],
        num_rows: 10178
    })
    validation: Dataset({
        features: ['question', 'answer', 'options', 'meta_info', 'answer_idx'],
        num_rows: 1272
    })
    test: Dataset({
        features: ['question', 'answer', 'options', 'meta_info', 'answer_idx'],
        num_rows: 1273
    })
})

In [None]:
# Load MedMCQA
medmcqa = load_dataset("openlifescienceai/medmcqa")
print("MedMCQA loaded successfully.")
medmcqa

MedMCQA loaded successfully.


DatasetDict({
    train: Dataset({
        features: ['id', 'question', 'opa', 'opb', 'opc', 'opd', 'cop', 'choice_type', 'exp', 'subject_name', 'topic_name'],
        num_rows: 182822
    })
    test: Dataset({
        features: ['id', 'question', 'opa', 'opb', 'opc', 'opd', 'cop', 'choice_type', 'exp', 'subject_name', 'topic_name'],
        num_rows: 6150
    })
    validation: Dataset({
        features: ['id', 'question', 'opa', 'opb', 'opc', 'opd', 'cop', 'choice_type', 'exp', 'subject_name', 'topic_name'],
        num_rows: 4183
    })
})

In [None]:
# Load PubMedQA for additional medical domain data
pubmed_qa = load_dataset("bigbio/pubmed_qa")
print("PubMedQA loaded successfully.")
pubmed_qa

PubMedQA loaded successfully.


DatasetDict({
    train: Dataset({
        features: ['QUESTION', 'CONTEXTS', 'LABELS', 'MESHES', 'YEAR', 'reasoning_required_pred', 'reasoning_free_pred', 'final_decision', 'LONG_ANSWER'],
        num_rows: 200000
    })
    validation: Dataset({
        features: ['QUESTION', 'CONTEXTS', 'LABELS', 'MESHES', 'YEAR', 'reasoning_required_pred', 'reasoning_free_pred', 'final_decision', 'LONG_ANSWER'],
        num_rows: 11269
    })
})

In [None]:
# Load MMLU for general-domain verification questions
mmlu = load_dataset("TIGER-Lab/MMLU-Pro")
print("MMLU loaded successfully.")
mmlu

MMLU loaded successfully.


DatasetDict({
    test: Dataset({
        features: ['question_id', 'question', 'options', 'answer', 'answer_index', 'cot_content', 'category', 'src'],
        num_rows: 12032
    })
    validation: Dataset({
        features: ['question_id', 'question', 'options', 'answer', 'answer_index', 'cot_content', 'category', 'src'],
        num_rows: 70
    })
})

## Create Unified verifiable dataset

In [None]:
# Merge all splits of a DatasetDict

def merge_all_splits(dataset_dict):
    """
    Merge all splits from a DatasetDict into a single Dataset.
    """
    datasets_list = [dataset_dict[k] for k in dataset_dict.keys()]  # Convert to list
    merged_dataset = concatenate_datasets(datasets_list)  # Use concatenate_datasets
    return merged_dataset


# Merge splits for each dataset
merged_medqa = merge_all_splits(medqa_usmle)
merged_medmcqa = merge_all_splits(medmcqa)
merged_pubmed = merge_all_splits(pubmed_qa)
merged_mmlu = merge_all_splits(mmlu)

In [None]:
merged_medqa[13]

{'question': 'A 23-year-old man comes to the physician for evaluation of decreased hearing, dizziness, and ringing in his right ear for the past 6 months. Physical examination shows multiple soft, yellow plaques and papules on his arms, chest, and back. There is sensorineural hearing loss and weakness of facial muscles bilaterally. His gait is unsteady. An MRI of the brain shows a 3-cm mass near the right internal auditory meatus and a 2-cm mass at the left cerebellopontine angle. The abnormal cells in these masses are most likely derived from which of the following embryological structures?',
 'answer': 'Neural crest',
 'options': {'A': 'Neural tube',
  'B': 'Surface ectoderm',
  'C': 'Neural crest',
  'D': 'Notochord',
  'E': 'Mesoderm'},
 'meta_info': 'step1',
 'answer_idx': 'C'}

## Reformatting functions for each dataset

In [None]:
def reformat_medqa(item):
    """
    For MedQA-USMLE items with fields: 'question', 'answer', 'options', 'answer_idx'
    """
    x = item["question"]
    y_star = None
    if "answer" in item and item["answer"]:
        y_star = item["answer"].strip()
    elif "answer_idx" in item and "options" in item:
        idx = item["answer_idx"]
        if idx is not None and idx < len(item["options"]):
            y_star = chr(ord('A') + idx)
    return {"x": x, "y_star": y_star}

def reformat_medmcqa(item):
    """
    Reformat MedMCQA items into verifiable medical problems.
    - `question`: The question text.
    - `opa`, `opb`, `opc`, `opd`: Answer choices.
    - `cop`: Index of the correct answer (integer).
    """
    x = item["question"]
    options = [item["opa"], item["opb"], item["opc"], item["opd"]]

    # Convert integer answer index to a letter (A, B, C, D)
    if "cop" in item and isinstance(item["cop"], int) and 0 <= item["cop"] < len(options):
        y_star = chr(ord('A') + item["cop"])  # Convert 0->'A', 1->'B', etc.
    else:
        y_star = None

    return {"x": x, "y_star": y_star, "options": options}


def reformat_pubmed(item):
    """
    For PubMedQA items with fields: 'QUESTION', 'final_decision'
    """
    x = item["QUESTION"]
    y_star = item["final_decision"].strip() if "final_decision" in item and item["final_decision"] else None
    return {"x": x, "y_star": y_star}

def reformat_mmlu(item):
    """
    For MMLU items with fields: 'question', 'answer', 'options', 'answer_index'
    """
    x = item["question"]
    y_star = None
    if "answer" in item and item["answer"]:
        y_star = item["answer"].strip()
    elif "answer_index" in item and "options" in item:
        idx = item["answer_index"]
        if idx is not None and idx < len(item["options"]):
            y_star = chr(ord('A') + idx)
    return {"x": x, "y_star": y_star}

## Filtering Function

In [None]:
def filter_challenging(item, min_word_count=30):
    """
    Keep only items with a question length above a minimum threshold.
    """
    question = item.get("x", "")
    return len(question.split()) >= min_word_count

## Create Unified Verifiable Dataset

In [None]:
def create_verifiable_dataset():
    verifiable = []

    # Process MedQA-USMLE
    for item in merged_medqa:
        reformatted = reformat_medqa(item)
        if reformatted["y_star"]:
            verifiable.append(reformatted)

    # Process MedMCQA
    for item in merged_medmcqa:
        reformatted = reformat_medmcqa(item)
        if reformatted["y_star"]:
            verifiable.append(reformatted)

    # Process PubMedQA
    for item in merged_pubmed:
        reformatted = reformat_pubmed(item)
        if reformatted["y_star"]:
            verifiable.append(reformatted)

    # Process MMLU
    for item in merged_mmlu:
        reformatted = reformat_mmlu(item)
        if reformatted["y_star"]:
            verifiable.append(reformatted)

    # Optionally, filter out examples with very short questions
    verifiable = [ex for ex in verifiable if filter_challenging(ex)]

    return verifiable

verifiable_dataset = create_verifiable_dataset()
print("Unified verifiable dataset size:", len(verifiable_dataset))

Unified verifiable dataset size: 33929


In [None]:
verifiable_dataset[12]

{'x': 'A 9-month-old female is brought to the emergency department after experiencing a seizure. She was born at home and was normal at birth according to her parents. Since then, they have noticed that she does not appear to be achieving developmental milestones as quickly as her siblings, and often appears lethargic. Physical exam reveals microcephaly, very light pigmentation (as compared to her family), and a "musty" body odor. The varied manifestations of this disease can most likely be attributed to which of the following genetic principles?',
 'y_star': 'Pleiotropy'}

## Develop a Medical Verifier

In [None]:
import re

class MedicalVerifier:
    def __init__(self):
        # Regex patterns to extract content within <think> and <answer> tags.
        self.think_pattern = r"<think>(.*?)</think>"
        self.answer_pattern = r"<answer>(.*?)</answer>"

    def verify_format(self, model_output):
        """
        Verify that the model output follows the required format:
         - Exactly one <think>...</think> block.
         - Exactly one <answer>...</answer> block.
         - No extra non-whitespace text exists outside these tags.

        Returns:
            (bool, str): A tuple containing whether the format is correct and an error message (if any).
        """
        # Find all occurrences of the required tags.
        think_matches = re.findall(self.think_pattern, model_output, re.DOTALL)
        answer_matches = re.findall(self.answer_pattern, model_output, re.DOTALL)

        if len(think_matches) != 1 or len(answer_matches) != 1:
            return False, "Incorrect number of <think> or <answer> tags."

        # Remove the tags from the output.
        cleaned_output = re.sub(self.think_pattern, "", model_output, flags=re.DOTALL)
        cleaned_output = re.sub(self.answer_pattern, "", cleaned_output, flags=re.DOTALL)
        # Remove all whitespace characters.
        cleaned_output = cleaned_output.strip()

        # Check if any non-whitespace text remains.
        if cleaned_output and not cleaned_output.isspace():
            return False, "Extra text found outside of the required tags."

        return True, ""

    def extract_answer(self, model_output):
        """
        Extract the final answer from the <answer>...</answer> tag.

        Returns:
            str or None: The extracted answer with whitespace stripped, or None if not found.
        """
        match = re.search(self.answer_pattern, model_output, re.DOTALL)
        if match:
            return match.group(1).strip()
        return None

    def verify(self, model_output, ground_truth):
        """
        Verify the model's output by:
          1. Checking that the format is correct (exactly one <think> block and one <answer> block with no extra text).
          2. Extracting and comparing the final answer against the ground-truth.

        Args:
            model_output (str): The complete output from the model.
            ground_truth (str): The expected correct answer (e.g., "C" or "Cerebral edema").

        Returns:
            bool: True if both format and answer match the ground truth, otherwise False.
        """
        # Step 1: Verify format.
        format_ok, error_message = self.verify_format(model_output)
        if not format_ok:
            print("Format error:", error_message)
            return False

        # Step 2: Extract and verify answer.
        extracted_answer = self.extract_answer(model_output)
        if extracted_answer is None:
            print("No answer found in the output.")
            return False

        if extracted_answer == ground_truth:
            return True
        else:
            print(f"Answer mismatch: extracted '{extracted_answer}', expected '{ground_truth}'")
            return False

# Example usage:
if __name__ == "__main__":
    verifier = MedicalVerifier()

    # Correctly formatted output with extra whitespace around tags.
    model_output = (
        "   <think>A 23-year-old man comes to the physician for evaluation of decreased hearing, dizziness, and ringing in his right ear for the past 6 months. Physical examination shows multiple soft, yellow plaques and papules on his arms, chest, and back. There is sensorineural hearing loss and weakness of facial muscles bilaterally. His gait is unsteady. An MRI of the brain shows a 3-cm mass near the right internal auditory meatus and a 2-cm mass at the left cerebellopontine angle. The abnormal cells in these masses are most likely derived from which of the following embryological structures?'</think>   "
        "   <answer>C</answer>   "
    )
    ground_truth = "C"
    result = verifier.verify(model_output, ground_truth)
    print("Verification result (should be True):", result)

    # Example with extra non-whitespace text outside of the tags.
    model_output_extra = (
        "Intro text that should not be here. "
        "<think>Detailed reasoning steps.</think><answer>C</answer>"
    )
    result_extra = verifier.verify(model_output_extra, ground_truth)
    print("Verification result with extra text (should be False):", result_extra)


Verification result (should be True): True
Format error: Extra text found outside of the required tags.
Verification result with extra text (should be False): False


## GRPO with Unsloth

In [None]:
from unsloth import FastLanguageModel
import torch

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


In [None]:
# Define hyperparameters
max_seq_length = 1024  # Increase if you need longer reasoning traces
lora_rank = 32         # Larger rank may improve performance at the cost of speed

# Step 1: Load the model and tokenizer using Unsloth's FastLanguageModel.
# This loads the model with settings for fast inference (vLLM) and 4-bit quantization.
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="meta-llama/meta-Llama-3.1-8B-Instruct",
    max_seq_length=max_seq_length,
    load_in_4bit=True,            # Set to False if using LoRA with 16-bit precision instead
    fast_inference=True,          # Enable vLLM fast inference for speed
    max_lora_rank=lora_rank,      # Ensures that our LoRA modules are configured properly
    gpu_memory_utilization=0.6,   # Adjust if you're running out of GPU memory
)

# Step 2: Configure the model for fine-tuning using PEFT (LoRA).
# We attach LoRA modules to key projection layers to enable efficient fine-tuning.
model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank,                  # LoRA rank, suggested values: 8, 16, 32, etc.
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],                            # You may remove some modules (like QKVO) if you experience memory issues
    lora_alpha=lora_rank,
    use_gradient_checkpointing="unsloth",  # Enables gradient checkpointing for long-context fine-tuning
    random_state=3407,                     # For reproducibility
)

print("Model and tokenizer are successfully loaded and configured for GRPO fine-tuning!")


INFO 03-17 15:37:52 __init__.py:207] Automatically detected platform cuda.
==((====))==  Unsloth 2025.3.14: Fast Llama patching. Transformers: 4.48.3. vLLM: 0.7.3.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.1+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.1.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.28.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/meta-llama-3.1-8b-instruct-unsloth-bnb-4bit with actual GPU utilization = 59.43%
Unsloth: Your GPU has CUDA compute capability 7.5 with VRAM = 14.74 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 1024. Num Sequences = 160.
Unsloth: vLLM's KV Cache can use up to 2.59 GB. Also swap space = 2 GB.
INFO 03-17 15:38:06 config.py:549] This model supports multiple tasks: {'embed', 'reward', 'generate', 'score

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 03-17 15:38:57 model_runner.py:1115] Loading model weights took 5.5976 GB
INFO 03-17 15:38:57 punica_selector.py:18] Using PunicaWrapperGPU.
INFO 03-17 15:39:03 worker.py:267] Memory profiling takes 5.51 seconds
INFO 03-17 15:39:03 worker.py:267] the current vLLM instance can use total_gpu_memory (14.74GiB) x gpu_memory_utilization (0.59) = 8.76GiB
INFO 03-17 15:39:03 worker.py:267] model weights take 5.60GiB; non_torch_memory takes 0.03GiB; PyTorch activation peak memory takes 0.74GiB; the rest of the memory reserved for KV Cache is 2.39GiB.
INFO 03-17 15:39:03 executor_base.py:111] # cuda blocks: 1224, # CPU blocks: 1024
INFO 03-17 15:39:03 executor_base.py:116] Maximum concurrency for 1024 tokens per request: 19.12x
INFO 03-17 15:39:05 model_runner.py:1434] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occ

Capturing CUDA graph shapes: 100%|██████████| 23/23 [00:39<00:00,  1.70s/it]

INFO 03-17 15:39:44 model_runner.py:1562] Graph capturing finished in 39 secs, took 0.59 GiB
INFO 03-17 15:39:44 llm_engine.py:436] init engine (profile, create kv cache, warmup model) took 47.64 seconds



Unsloth 2025.3.14 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


Model and tokenizer are successfully loaded and configured for GRPO fine-tuning!


## Reformats a verifiable dataset

In [None]:
# Define the system prompt that instructs the model to use a specific format.
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

# Define the XML Chain-of-Thought (CoT) format template.
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""


In [None]:
def format_verifiable_item(item):
    """
    Reformats a verifiable dataset item into a training example with a prompt and target output.

    Each item is expected to have:
      - 'x': The problem statement.
      - 'y_star': The ground-truth answer.

    The prompt includes a system instruction (SYSTEM_PROMPT) followed by the problem text.
    The target output follows the XML_COT_FORMAT, with a placeholder for reasoning and the ground-truth answer.
    """
    # Create the prompt by concatenating the system instruction and the problem statement.
    prompt = SYSTEM_PROMPT.strip() + "\n\n" + item['x']

    # The target output encourages the model to explain its reasoning.
    # Since our dataset might not have explicit reasoning annotations, we insert a placeholder.
    target = XML_COT_FORMAT.format(reasoning="(explain your reasoning here)", answer=item['y_star'])

    return {"prompt": prompt, "target": target}

# Example: Reformatting a verifiable dataset (list of dicts with 'x' and 'y_star' keys).
# Assume verifiable_dataset is already defined and populated.
formatted_dataset = [format_verifiable_item(item) for item in verifiable_dataset]

# Print one example to inspect the formatting.
example = formatted_dataset[0]
print("Prompt:\n", example["prompt"])
print("Target Output:\n", example["target"])

Prompt:
 Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>

A 23-year-old pregnant woman at 22 weeks gestation presents with burning upon urination. She states it started 1 day ago and has been worsening despite drinking more water and taking cranberry extract. She otherwise feels well and is followed by a doctor for her pregnancy. Her temperature is 97.7°F (36.5°C), blood pressure is 122/77 mmHg, pulse is 80/min, respirations are 19/min, and oxygen saturation is 98% on room air. Physical exam is notable for an absence of costovertebral angle tenderness and a gravid uterus. Which of the following is the best treatment for this patient?
Target Output:
 <reasoning>
(explain your reasoning here)
</reasoning>
<answer>
Nitrofurantoin
</answer>



## Reward Functions

In [None]:
import re
from typing import List, Union

# Helper: Extract answer from the <answer> tag.
def extract_answer_from_output(model_output: str) -> Union[str, None]:
    match = re.search(r"<answer>(.*?)</answer>", model_output, re.DOTALL)
    return match.group(1).strip() if match else None

# 1. Reward function that checks if the answer is correct.
def reward_correct_answer(prompts: List[str], completions: List[str], ground_truths: List[str], **kwargs) -> List[float]:
    """
    For each example, returns 1.0 if the extracted answer exactly matches the ground truth; 0.0 otherwise.
    """
    rewards = []
    for comp, gt in zip(completions, ground_truths):
        answer = extract_answer_from_output(comp)
        rewards.append(1.0 if answer == gt else 0.0)
    return rewards

# 2. Reward function that checks if the answer is an integer.
def reward_integer_answer(prompts: List[str], completions: List[str], **kwargs) -> List[float]:
    """
    For each example, returns 1.0 if the extracted answer can be parsed as an integer; 0.0 otherwise.
    """
    rewards = []
    for comp in completions:
        answer = extract_answer_from_output(comp)
        if answer is None:
            rewards.append(0.0)
        else:
            try:
                int(answer)
                rewards.append(1.0)
            except ValueError:
                rewards.append(0.0)
    return rewards

# 3. Reward function that checks if the completion follows the strict format.
def reward_strict_format(prompts: List[str], completions: List[str], **kwargs) -> List[float]:
    """
    For each example, returns 1.0 if the output contains exactly one <think>...</think> block,
    exactly one <answer>...</answer> block, and no extra text; 0.0 otherwise.
    """
    rewards = []
    for comp in completions:
        think_matches = re.findall(r"<think>(.*?)</think>", comp, re.DOTALL)
        answer_matches = re.findall(r"<answer>(.*?)</answer>", comp, re.DOTALL)
        cleaned_output = re.sub(r"<think>.*?</think>", "", comp, flags=re.DOTALL)
        cleaned_output = re.sub(r"<answer>.*?</answer>", "", cleaned_output, flags=re.DOTALL).strip()
        rewards.append(1.0 if (len(think_matches) == 1 and len(answer_matches) == 1 and cleaned_output == "") else 0.0)
    return rewards

# 4. Reward function that checks if the completion follows a more relaxed format.
def reward_relaxed_format(prompts: List[str], completions: List[str], **kwargs) -> List[float]:
    """
    For each example, returns 1.0 if the output contains exactly one <think> block and one <answer> block,
    ignoring extra whitespace or minor text differences; 0.0 otherwise.
    """
    rewards = []
    for comp in completions:
        think_matches = re.findall(r"<think>(.*?)</think>", comp, re.DOTALL)
        answer_matches = re.findall(r"<answer>(.*?)</answer>", comp, re.DOTALL)
        rewards.append(1.0 if (len(think_matches) == 1 and len(answer_matches) == 1) else 0.0)
    return rewards

# 5. Reward function that counts XML tags and penalizes extra content.
def reward_xml_tag_penalty(prompts: List[str], completions: List[str], **kwargs) -> List[float]:
    """
    For each example, starts with a base reward of 1.0 if the required tags (<think> and <answer>)
    are present exactly once. Then penalizes extra XML tags and any extra non-tag content.
    """
    rewards = []
    for comp in completions:
        required_tags = ["<think>", "</think>", "<answer>", "</answer>"]
        counts = {tag: comp.count(tag) for tag in required_tags}
        base_reward = 1.0 if all(counts[tag] == 1 for tag in required_tags) else 0.0
        if base_reward == 0.0:
            rewards.append(0.0)
            continue

        all_tags = re.findall(r"<.*?>", comp)
        extra_tags = len(all_tags) - 4  # Expect exactly 4 required tags.
        penalty_tags = 0.1 * max(extra_tags, 0)

        cleaned_output = re.sub(r"<think>.*?</think>", "", comp, flags=re.DOTALL)
        cleaned_output = re.sub(r"<answer>.*?</answer>", "", cleaned_output, flags=re.DOTALL).strip()
        penalty_text = 0.1 if cleaned_output != "" else 0.0

        total_penalty = penalty_tags + penalty_text
        final_reward = max(base_reward - total_penalty, 0.0)
        rewards.append(final_reward)
    return rewards

##  Training with GRPO

In [None]:
# GRPO Trainer Setup
from trl import GRPOConfig, GRPOTrainer

# Define maximum prompt length and derive completion length.
max_prompt_length = 256
max_seq_length = 1024  # from our model setup earlier
max_completion_length = max_seq_length - max_prompt_length

# Set up the GRPO configuration.
training_args = GRPOConfig(
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="paged_adamw_8bit",  # Use 8-bit optimizer to reduce memory usage.
    logging_steps=1,
    per_device_train_batch_size=6,  # Must be a multiple of num_generations.
    gradient_accumulation_steps=1,
    num_generations=6,             # Number of candidates generated per prompt.
    max_prompt_length=max_prompt_length,
    max_completion_length=max_completion_length,
    max_steps=10,   # For testing purposes; adjust as needed.
    save_steps=10,  # For testing purposes; adjust as needed.
    max_grad_norm=0.1,
    report_to="wandb",             # Change to "none" if not using Weights & Biases.
    output_dir="outputs",
)

# Define the list of reward functions.
reward_funcs = [
    reward_xml_tag_penalty,
    reward_relaxed_format,
    reward_strict_format,
    reward_integer_answer,
    # Wrap the correctness reward to pass ground truths.
    lambda prompts, completions, target, **kwargs: reward_correct_answer(prompts, completions, ground_truths=target, **kwargs),
]

In [None]:
# GRPO Trainer Initialization
trainer = GRPOTrainer(
    model=model,                    # Our fine-tuned model.
    processing_class=tokenizer,     # The associated tokenizer.
    reward_funcs=reward_funcs,
    args=training_args,
    train_dataset=formatted_dataset,  # Our dataset in role-based prompt format.
)

# Start Training
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 33,929 | Num Epochs = 1 | Total steps = 10
O^O/ \_/ \    Batch size per device = 6 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (6 x 1 x 1) = 6
 "-____-"     Trainable parameters = 83,886,080/8,000,000,000 (1.05% trained)


Step,Training Loss,reward,reward_std,completion_length,kl,rewards / reward_xml_tag_penalty,rewards / reward_relaxed_format,rewards / reward_strict_format,rewards / reward_integer_answer,rewards /
1,0.0,0.0,0.0,729.5,0.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,0.0,765.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.166667,0.408248,678.833374,0.0,0.0,0.0,0.0,0.166667,0.0
4,0.0,0.0,0.0,686.666687,0.000481,0.0,0.0,0.0,0.0,0.0
5,0.0,0.0,0.0,709.833374,7e-06,0.0,0.0,0.0,0.0,0.0
6,0.0,0.0,0.0,768.0,1e-05,0.0,0.0,0.0,0.0,0.0
7,0.0,0.0,0.0,523.833374,1.1e-05,0.0,0.0,0.0,0.0,0.0
8,0.0,0.0,0.0,768.0,6e-06,0.0,0.0,0.0,0.0,0.0
9,0.0,0.0,0.0,516.833374,1.3e-05,0.0,0.0,0.0,0.0,0.0
10,0.0,0.0,0.0,625.833374,2e-05,0.0,0.0,0.0,0.0,0.0


Unsloth: Will smartly offload gradients to save VRAM!


TrainOutput(global_step=10, training_loss=2.19791091851107e-06, metrics={'train_runtime': 1844.5457, 'train_samples_per_second': 0.033, 'train_steps_per_second': 0.005, 'total_flos': 0.0, 'train_loss': 2.19791091851107e-06})

In [None]:
model.save_lora("grpo_saved_lora")

## Testing the Model

In [None]:
formatted_dataset[:2]

[{'prompt': 'Respond in the following format:\n<reasoning>\n...\n</reasoning>\n<answer>\n...\n</answer>\n\nA 23-year-old pregnant woman at 22 weeks gestation presents with burning upon urination. She states it started 1 day ago and has been worsening despite drinking more water and taking cranberry extract. She otherwise feels well and is followed by a doctor for her pregnancy. Her temperature is 97.7°F (36.5°C), blood pressure is 122/77 mmHg, pulse is 80/min, respirations are 19/min, and oxygen saturation is 98% on room air. Physical exam is notable for an absence of costovertebral angle tenderness and a gravid uterus. Which of the following is the best treatment for this patient?',
  'target': '<reasoning>\n(explain your reasoning here)\n</reasoning>\n<answer>\nNitrofurantoin\n</answer>\n'},
 {'prompt': 'Respond in the following format:\n<reasoning>\n...\n</reasoning>\n<answer>\n...\n</answer>\n\nA 3-month-old baby died suddenly at night while asleep. His mother noticed that he had d

In [None]:
import random
from vllm import SamplingParams

# Select a random example from your verifiable dataset.
sample = random.choice(formatted_dataset)

# Prepare the test input.
if isinstance(sample["prompt"], list):
    test_text = tokenizer.apply_chat_template(
        sample["prompt"],
        tokenize=False,
        add_generation_prompt=True,
    )
else:
    test_text = sample["prompt"]

# Set up sampling parameters for generation.
sampling_params = SamplingParams(
    temperature=0.8,
    top_p=0.95,
    max_tokens=1024,
)

# Generate the model output using the saved LoRA weights.
output = (
    model.fast_generate(
        test_text,
        sampling_params=sampling_params,
        lora_request=model.load_lora("grpo_saved_lora"),
    )[0]
    .outputs[0]
    .text
)

# Print the test prompt, the ground-truth target, and the model output.
print("Test Prompt:")
print(test_text)
print("\nGround Truth Target:")
print(sample["target"])
print("\nModel Output:")
print(output)


Processed prompts: 100%|██████████| 1/1 [00:52<00:00, 52.28s/it, est. speed input: 1.51 toks/s, output: 18.09 toks/s]

Test Prompt:
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>

A person 'X' hits another person 'Y' with a wooden stick on provocation. This leads to the formation of a bruise 3 cm × 3 cm on the forearm. No other injuries are noted. Which of the following is true, regarding his punishment -

Ground Truth Target:
<reasoning>
(explain your reasoning here)
</reasoning>
<answer>
A
</answer>


Model Output:
 

A) He should be given bail.
B) Cognizable and non-bailable offence
C) Cognizable and bailable offence
D) The offence should be tried as Magistrate's Court.

<reasoning>
The criminal law relating to the punishment of the person 'X' depends on whether the offence falls under the categories of 'Cognizable or non-cognizable', 'Cognizable and bailable or non-bailable' and the relevant court to try the case. The given situation is hitting another person with a wooden stick on provocation. In criminal law, such a situation can be classified under Section 3




## Save the model

In [None]:
# # Save to 16-bit precision
# model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit")

## Benchmark Evaluation