In [5]:
# manual_grpo_pubmedqa.py

# ==============================================================================
# SECTION 1: IMPORTS AND SETUP
# ==============================================================================

# Basic Python libraries
import random
import copy
import re
import os
import numpy as np
# import wandb # Optional, for logging
import json
import csv
# PyTorch and related libraries
import torch
import torch.nn as nn
torch.cuda.empty_cache()
# Hugging Face libraries
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

def set_random_seed(seed: int = 42):
    """Sets the random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Set the seed for consistent results
set_random_seed(42)

# Set environment variables for Weights & Biases (wandb) logging
# Replace with your own key or comment out if not using wandb
# os.environ["WANDB_API_KEY"] = "YOUR_WANDB_API_KEY"
# os.environ["WANDB_PROJECT"] = "GRPO-Qwen-PubMedQA-Manual"

# ==============================================================================
# SECTION 2: PROMPT AND DATA PREPARATION (ADAPTED FOR PUBMEDQA)
# ==============================================================================
# ---- SUPERVISOR SIDE ----
SUPERVISOR_SYSTEM_PROMPT = """You are a supervisor routing requests to specialized agents.

Reply STRICTLY in the following format:
Agent: <name>
<one short sentence explaining why>

The agent name must be exactly one of:
question_understanding, context_analysis, reasoning, answering

IMPORTANT: Keep your response SHORT. Only output the agent name and one brief explanation sentence. Do not add any additional reasoning, answers, or explanations beyond this format."""

def build_supervisor_prompt(example: dict) -> str:
    """Build the supervisor routing prompt from a dataset row."""
    return (
        SUPERVISOR_SYSTEM_PROMPT
        + "\n\n"
        + f"Given the problem:\n{example['problem']}\n\n"
        + f"Context:\n{example['context']}\n\n"
        + "Please choose ONE next agent to call "
          "from: question_understanding, context_analysis, reasoning, answering.\n\n"
        + "Reply STRICTLY in the form:\nAgent: <name>\nThen explain why.\n\n"
        + "REMEMBER: Only output the agent name and one brief explanation sentence. Do not add reasoning sections, answers, or any other content."
    )


VALID_AGENTS = {
    "question_understanding",
    "context_analysis",
    "reasoning",
    "answering",
}

AGENT_NAME_PATTERN = re.compile(
    r"^\s*Agent:\s*(question_understanding|context_analysis|reasoning|answering)\b",
    flags=re.IGNORECASE | re.MULTILINE,
)

def parse_supervisor_choice(text: str):
    """Return normalized agent name (or None) and the (start,end) char span of the 'Agent: <name>' line."""
    m = AGENT_NAME_PATTERN.search(text)
    if not m:
        return "Answering", None
    raw = m.group(1).lower().strip()
    if raw not in VALID_AGENTS:
        return "Answering", None
    return raw  # (name, (start,end) of the Agent: line)


VALID_AGENTS = {
    "question_understanding",
    "context_analysis",
    "reasoning",
    "answering",
}

# Strict: must appear as a line that begins with "Agent:" then a valid name
_AGENT_LINE_RE = re.compile(
    r"(?im)^\s*Agent:\s*(question_understanding|context_analysis|reasoning|answering)\b"
)

def parse_supervisor_choice(supervisor_msg: str, fallback_names=None):
    """
    Returns chosen agent name (normalized lowercase) or None.
    1) Strictly parse the 'Agent: <name>' line.
    2) Fallback: keyword presence from allowed list in message body.
    """
    m = _AGENT_LINE_RE.search(supervisor_msg)
    if m:
        return m.group(1).lower()

    # Fallback: scan allowed names if the strict line is missing
    names = fallback_names or VALID_AGENTS
    msg_low = supervisor_msg.lower()
    for name in names:
        # match whole token (avoid partial overlaps)
        if re.search(rf"\b{name}\b", msg_low):
            return name
    return None

# Define the structured prompt format for our task
SYSTEM_PROMPT = """You are an expert biomedical researcher. Your task is to answer a question based on a provided context.
First, write out a step-by-step reasoning process within <reasoning> tags.
Then, provide the final answer (either "yes" or "no") within <answer> tags.
"""

def build_prompt(messages):
    """Builds a single prompt string from a list of messages."""
    return "\n".join([msg["content"].strip() for msg in messages])

def prepare_pubmedqa_dataset(csv_file_path="pubmedqa.csv"):
    """Loads and prepares the PubMedQA dataset from a CSV file."""
    formatted_data = []
    with open(csv_file_path, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        for row in reader:
            problem = row["question"].strip()
            context_info = row["context"].strip()
            answer = row["final_decision"].strip().lower()

            user_content = f"Context:\n{context_info}\n\nQuestion:\n{problem}"
            prompt_str = build_prompt([
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": user_content}
            ])

            formatted_data.append({
                "problem": problem,
                "context": context_info,
                "prompt": prompt_str,
                "answer": answer
            })
    return formatted_data

# ==============================================================================
# SECTION 2.5: Ollama calling the subagent specified by the supervisor
# ==============================================================================
import requests
import json as _json

OLLAMA_MODEL = "qwen2.5:0.5b-instruct"   # or your tag e.g. "qwen2.5-0.5b-instruct"

AGENT_PROMPTS = {
    "question_understanding": """You are the Question-Understanding agent.
Given the following problem and context, clarify the exact question to be answered and restate it precisely.
Based on your analysis, provide your best answer to the question.

Format your response as:
Answer: yes/no/maybe

IMPORTANT: Always end with "Answer: yes/no/maybe" based on your understanding of the question.""",

    "context_analysis": """You are the Context-Analysis agent.
Analyze the provided context: extract key facts, contradictions, and relevance to the question.
Based on your analysis of the context, provide your best answer to the question.

Format your response as:
<analysis>Analysis of the context and key facts</analysis>
Answer: yes/no/maybe

IMPORTANT: Always end with "Answer: yes/no/maybe" based on your analysis of the context.""",

    "reasoning": """You are the Reasoning agent.
Do step-by-step reasoning combining the question and the context to reach a conclusion.
Based on your reasoning, provide your best answer to the question.

Format your response as:
<reasoning>Step-by-step reasoning combining question and context</reasoning>
Answer: yes/no/maybe

IMPORTANT: Always end with "Answer: yes/no/maybe" based on your reasoning.""",

    "answering": """You are the Answering agent.
Given the problem and the context, provide your final answer.
This is the final decision-making agent.

Format your response as:
Answer: yes/no/maybe

IMPORTANT: Always end with "Answer: yes/no/maybe" as your final answer to the question."""
}

def build_subagent_prompt(agent_name: str, example: dict) -> str:
    """Build the prompt for a specific subagent from a dataset row."""
    sys = AGENT_PROMPTS[agent_name]
    return (
        f"{sys}\n\n"
        f"Problem:\n{example['problem']}\n\n"
        f"Context:\n{example['context']}\n\n"
        "Follow the specified XML output strictly."
    )



# def call_ollama_chat(model: str, messages: list, temperature: float = 1.0,
#                      top_p: float = 0.95, max_tokens: int = 256,
#                      url: str = "http://localhost:11434/api/chat") -> str:
#     payload = {
#         "model": model,
#         "messages": messages,
#         "stream": False,
#         "options": {
#             "temperature": temperature,
#             "top_p": top_p,
#             "num_predict": max_tokens,
#         }
#     }
#     r = requests.post(url, json=payload, timeout=120)
#     # If you're on an older Ollama that lacks /api/chat, this may 404:
#     if r.status_code == 404:
#         raise RuntimeError("This Ollama version does not support /api/chat. Use /api/generate or upgrade Ollama.")
#     r.raise_for_status()
#     data = r.json()
#     return data.get("message", {}).get("content", "")  # chat returns 'message'
def call_ollama(model, messages, url="http://localhost:11434/api/chat",
                     temperature=0.8, top_p=0.95, max_tokens=256):
    # Validate messages
    assert isinstance(messages, list) and messages, "messages must be a non-empty list"
    for m in messages:
        assert isinstance(m, dict), "each message must be a dict"
        assert m.get("role") in {"system", "user", "assistant"}, f"bad role: {m}"
        assert isinstance(m.get("content"), str) and m["content"], f"empty content: {m}"

    # Minimal payload first (no options). If this works, add options next.
    payload = {
        "model": model,
        "messages": messages,
        "stream": False
    }

    r = requests.post(url, json=payload, timeout=120)
    if r.status_code >= 400:
        raise RuntimeError(f"Ollama chat error {r.status_code}: {r.text}\nPayload: {payload}")
    data = r.json()
    return data.get("message", {}).get("content", "")
# def call_ollama(model: str, prompt: str, temperature: float = 0.8, top_p: float = 0.95, max_tokens: int = 256):
#     """Calls Ollama local server. Assumes `ollama serve` is running."""
#     url = "http://localhost:11434/api/chat"
#     payload = {
#         "model": model,
#         "prompt": prompt,
#         "stream": False,
#         "options": {
#             "temperature": temperature,
#             "top_p": top_p,
#             "num_predict": max_tokens,
#         }
#     }
#     r = requests.post(url, data=_json.dumps(payload), timeout=120)
#     r.raise_for_status()
#     data = r.json()
#     # Ollama returns {"response": "...", ...}
#     return data.get("response", "")

def run_subagent(agent_name: str, example: dict):
    """Execute the chosen agent with Ollama and extract final yes/no if present."""
    print(f"\n[DEBUG] Running sub-agent: {agent_name}")
    print(f"[DEBUG] Example keys: {list(example.keys())}")
    
    prompt = build_subagent_prompt(agent_name, example)
    print(f"[DEBUG] Generated prompt length: {len(prompt)}")
    
    # Convert string prompt to message list format for call_ollama
    messages = [{"role": "user", "content": prompt}]
    print(f"[DEBUG] Calling Ollama with model: {OLLAMA_MODEL}")
    
    try:
        text = call_ollama(OLLAMA_MODEL, messages)
        print(f"[DEBUG] Ollama response length: {len(text)}")
        print(f"[DEBUG] Ollama response: {text[:200]}...")  # First 200 chars
        
        # Reuse your stricter extraction:
        ans = extract_answer_from_model_output(text)  # returns "yes"/"no"/None
        print(f"[DEBUG] Extracted answer: {ans}")
        return text, ans
    except Exception as e:
        print(f"[ERROR] Failed to call Ollama: {e}")
        return "", None



# ==============================================================================
# SECTION 3: REWARD FUNCTIONS (ADAPTED FOR PUBMEDQA)
# ==============================================================================

def extract_answer_from_model_output(text):
    """Extracts the value from the 'Answer: yes/no/maybe' format in the text."""
    import re
    
    # First try to find "Answer: yes/no/maybe" pattern
    answer_pattern = r'Answer:\s*(yes|no|maybe)\b'
    matches = re.findall(answer_pattern, text, re.IGNORECASE)
    if matches:
        return matches[-1].lower()
    
    # Fallback: try to find answer tags (for backward compatibility)
    answer_tag_pattern = r'<answer>\s*(yes|no|maybe)\s*</answer>'
    matches = re.findall(answer_tag_pattern, text, re.IGNORECASE)
    if matches:
        return matches[-1].lower()
    
    # Final fallback: look for yes/no/maybe in the text
    text_lower = text.lower()
    if "maybe" in text_lower:
        return "maybe"
    elif "yes" in text_lower and "no" not in text_lower:
        return "yes"
    elif "no" in text_lower and "yes" not in text_lower:
        return "no"
    elif "yes" in text_lower and "no" in text_lower:
        # If both are present, look for the last occurrence
        yes_pos = text_lower.rfind("yes")
        no_pos = text_lower.rfind("no")
        return "yes" if yes_pos > no_pos else "no"
    
    return None

def pubmedqa_correctness_reward(completions, answer, **kwargs):
    """Assigns a reward based on the correctness of the 'yes'/'no' answer."""
    responses = [comp[0]['content'] for comp in completions]
    extracted_answers = [extract_answer_from_model_output(r) for r in responses]
    rewards = []
    for extracted, expected in zip(extracted_answers, answer):
        if extracted and extracted == expected:
            rewards.append(2.0)  # High reward for an exact match
        else:
            rewards.append(0.0)  # No reward for wrong or missing answer
    return rewards

def format_reward(completions, **kwargs):
    """Assigns a reward for adhering to the desired XML format."""
    responses = [comp[0]['content'] for comp in completions]
    rewards = []
    for response in responses:
        score = 0.0
        if "<reasoning>" in response: score += 0.2
        if "</reasoning>" in response: score += 0.2
        if "<answer>" in response: score += 0.2
        if "</answer>" in response: score += 0.2
        rewards.append(score) # Max format score = 0.8
    return rewards

def combined_reward(prompts, completions, answer, **kwargs):
    """Combines correctness and format rewards."""
    correctness_scores = pubmedqa_correctness_reward(completions=completions, answer=answer)
    format_scores = format_reward(completions=completions)

    combined_rewards = [c_score + f_score for c_score, f_score in zip(correctness_scores, format_scores)]
    return combined_rewards
# ==============================================================================
# SECTION 3.5: CORE GRPO/PPO LOGIC Supervisor
# ==============================================================================

# ==============================================================================
# SECTION 4: CORE GRPO/PPO LOGIC (IMITATED FROM EXAMPLE)
# ==============================================================================

def selective_log_softmax(logits, input_ids):
    """Computes log probabilities for specific tokens."""
    log_probs = nn.functional.log_softmax(logits, dim=-1)
    return log_probs.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

def compute_log_probs(model, input_ids, attention_mask, logits_to_keep):
    """Computes the log probabilities for a batch of tokens."""
    logits = model(input_ids=input_ids, attention_mask=attention_mask).logits[:, :-1, :]
    input_ids = input_ids[:, -logits_to_keep:]
    logits = logits[:, -logits_to_keep:, :]
    return selective_log_softmax(logits, input_ids)

def create_completion_mask(completion_ids, eos_token_id):
    """Creates a mask for completion tokens, stopping after the first EOS token."""
    is_eos = completion_ids == eos_token_id
    # Find the index of the first EOS token for each sequence
    eos_indices = torch.argmax(is_eos.int(), dim=1)
    # If no EOS is found, argmax returns 0. We need to handle this.
    # We set the index to max_length if no EOS is found.
    eos_indices[~is_eos.any(dim=1)] = completion_ids.size(1)

    # Create a range tensor to compare with indices
    seq_indices = torch.arange(completion_ids.size(1), device=completion_ids.device).expand_as(completion_ids)
    
    # The mask is 1 for all tokens up to and including the first EOS
    mask = (seq_indices <= eos_indices.unsqueeze(1)).int()
    return mask

def generate_completions(model, tokenizer, prompts, num_generations=4, max_completion_length=128):
    """Generates multiple completions for each prompt."""
    device = model.device if hasattr(model, 'device') else next(model.parameters()).device
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
    # set once after loading the tokenizer
    tokenizer.padding_side = "left"
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    prompt_ids = inputs["input_ids"].to(device)
    prompt_mask = inputs["attention_mask"].to(device)
    
    prompt_length = prompt_ids.size(1)
    
    # Repeat prompts to generate multiple completions in one batch
    repeated_prompt_ids = prompt_ids.repeat_interleave(num_generations, dim=0)
    repeated_prompt_mask = prompt_mask.repeat_interleave(num_generations, dim=0)
    
    outputs = model.generate(
        repeated_prompt_ids,
        attention_mask=repeated_prompt_mask,
        max_new_tokens=max_completion_length,
        do_sample=True,
        temperature=1.0,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
    
    completion_ids = outputs[:, prompt_length:]
    completion_mask = create_completion_mask(completion_ids, tokenizer.eos_token_id)
    
    return prompt_ids, prompt_mask, completion_ids, completion_mask

def generate_rollout_data(model, ref_model, tokenizer, batch_samples, num_generations, max_completion_length):
    """Generates data for GRPO rollouts including completions and log probabilities."""
    prompts = [sample["prompt"] for sample in batch_samples]
    answers = [sample["answer"] for sample in batch_samples]
    
    with torch.no_grad():
        prompt_ids, prompt_mask, completion_ids, completion_mask = generate_completions(
            model, tokenizer, prompts, num_generations, max_completion_length
        )
        
        # We need the original prompts repeated for log prob calculation
        repeated_prompt_ids = prompt_ids.repeat_interleave(num_generations, dim=0)
        repeated_prompt_mask = prompt_mask.repeat_interleave(num_generations, dim=0)
        
        completion_attn = (completion_ids != tokenizer.pad_token_id).long()
        input_ids = torch.cat([repeated_prompt_ids, completion_ids], dim=1)
        attention_mask = torch.cat([repeated_prompt_mask, completion_attn], dim=1)
        logits_to_keep = completion_ids.size(1)
        # compute_log_probs needs a model on a single device, so we use .module
        # if it is wrapped in DataParallel
        policy_model = model.module if isinstance(model, nn.DataParallel) else model
        reference_model = ref_model.module if isinstance(ref_model, nn.DataParallel) else ref_model

        old_log_probs = compute_log_probs(policy_model, input_ids, attention_mask, logits_to_keep)
        ref_log_probs = compute_log_probs(reference_model, input_ids, attention_mask, logits_to_keep)
    texts = tokenizer.batch_decode(completion_ids.detach().cpu(), skip_special_tokens=True)

    # chosen agents + subagent outputs (to feed your reward fn)
    chosen_agents = []
    sub_texts = []

    # simple action mask: first few non-pad tokens (focus gradients on routing decision)
    action_masks = []
    for i, text in enumerate(texts):
        chosen = parse_supervisor_choice(text)
        chosen_agents.append(chosen if chosen in VALID_AGENTS else None)

        # Map to base sample (which problem/context)
        base_idx = i // num_generations
        sample_i = batch_samples[base_idx]   # get the full sample dict

        if chosen in VALID_AGENTS:
            sub_out_text, _ = run_subagent(chosen, sample_i)  # pass the full sample dict
            sub_texts.append(sub_out_text)
        else:
            sub_texts.append("")

        # Build a small front-span mask over completion tokens
        comp_ids_row = completion_ids[i]
        valid_len = int((comp_ids_row != tokenizer.pad_token_id).sum().item())
        L = min(8, valid_len)   # first 8 tokens
        m = torch.zeros_like(comp_ids_row, dtype=torch.long)
        if L > 0:
            m[:L] = 1
        action_masks.append(m)

    action_mask = torch.stack(action_masks, dim=0).to(input_ids.device)

    # Your reward function expects this shape:
    formatted_subagent_completions = [[{'content': t}] for t in sub_texts]
    repeated_prompts = [p for p in prompts for _ in range(num_generations)]
    repeated_answers = [a for a in answers for _ in range(num_generations)]

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "completion_mask": completion_mask,             # kept (unused for loss now)
        "action_mask": action_mask,                     # <-- NEW: use this in loss
        "old_log_probs": old_log_probs,
        "ref_log_probs": ref_log_probs,
        "formatted_completions": formatted_subagent_completions,  # <-- feed to your reward
        "repeated_prompts": repeated_prompts,
        "repeated_answers": repeated_answers,
        "logits_to_keep": logits_to_keep,
        "batch_size": len(prompts),
        "num_generations": num_generations
    }

def grpo_loss(model, ref_model, rollout_data, reward_function, beta=0.01, epsilon=0.2):
    """Computes the GRPO loss for updating the policy model."""
    device = next(model.parameters()).device
    
    # Unpack rollout data
    input_ids = rollout_data["input_ids"]
    attention_mask = rollout_data["attention_mask"]
    action_mask = rollout_data["action_mask"]
    completion_mask = rollout_data["completion_mask"]
    logits_to_keep = rollout_data["logits_to_keep"]
    old_log_probs = rollout_data["old_log_probs"]
    ref_log_probs = rollout_data["ref_log_probs"]
    
    # Compute current log probs
    policy_model = model.module if isinstance(model, nn.DataParallel) else model
    token_log_probs = compute_log_probs(policy_model, input_ids, attention_mask, logits_to_keep)
    
    # Calculate ratio and rewards
    ratio = torch.exp(token_log_probs - old_log_probs)
    rewards = torch.tensor(
        reward_function(
            prompts=rollout_data["repeated_prompts"], 
            completions=rollout_data["formatted_completions"], # subagent outputs
            answer=rollout_data["repeated_answers"]
        ),
        dtype=torch.float32,
        device=device
    )
    
    # Standardize rewards at the group level (GRPO)
    batch_size = rollout_data["batch_size"]
    num_generations = rollout_data["num_generations"]
    rewards_grouped = rewards.view(batch_size, num_generations)
    
    mean_rewards = rewards_grouped.mean(dim=1, keepdim=True)
    std_rewards = rewards_grouped.std(dim=1, keepdim=True)
    advantages = (rewards_grouped - mean_rewards) / (std_rewards + 1e-8)
    advantages = advantages.view(-1).unsqueeze(1) # Flatten back for token-wise multiplication
    
    # PPO Clipped Surrogate Objective
    surr1 = ratio * advantages
    surr2 = torch.clamp(ratio, 1 - epsilon, 1 + epsilon) * advantages
    surrogate_loss = torch.min(surr1, surr2)
    
    # KL Penalty
    kl_div = torch.exp(ref_log_probs - token_log_probs) - (ref_log_probs - token_log_probs) - 1
    
    # Combine and mask the loss
    per_token_loss = surrogate_loss - beta * kl_div
    # We only care about the loss for the completion tokens
    masked_loss = per_token_loss * action_mask.to(per_token_loss.dtype)
    loss = -(masked_loss.sum() / action_mask.sum().clamp_min(1))
    
    avg_reward = rewards.mean().item()
    return loss, avg_reward

# ==============================================================================
# SECTION 5: TRAINING LOOP (IMITATED AND ADAPTED FOR SINGLE/MULTI GPU)
# ==============================================================================
import torch
from torch.cuda.amp import autocast, GradScaler
import copy
import random

def train_with_grpo(
    model,
    tokenizer,
    train_data,
    num_iterations=1,
    num_steps=100,
    batch_size=2,                 # reduce batch_size for GPU memory
    num_generations=2,            # reduce generations
    max_completion_length=128,    # reduce completion length
    beta=0.1,
    learning_rate=5e-6,
    mu=3,
    epsilon=0.2,
    reward_function=pubmedqa_correctness_reward,
    device=None,
    use_lora=True                # optional flag to enable LoRA
):
    """Memory-safe GRPO training loop with mixed precision and optional LoRA."""

    # 1️⃣ Device setup
    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        torch.cuda.empty_cache()
    print(f"Using device: {device}")

    # 2️⃣ Optional LoRA
    if use_lora:
        from peft import LoraConfig, get_peft_model
        lora_config = LoraConfig(
            r=8,
            lora_alpha=16,
            target_modules=["q_proj","k_proj","v_proj"],
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM"
        )
        model = get_peft_model(model, lora_config)
        print("LoRA applied to q/k/v projections.")

    model.to(device)

    # 3️⃣ Mixed precision scaler
    scaler = GradScaler()

    # Outer loop for updating the reference model
    for iteration in range(num_iterations):
        print(f"\n--- Starting GRPO Iteration {iteration + 1}/{num_iterations} ---")

        # Reference model
        ref_model = copy.deepcopy(model)
        ref_model.eval()
        for param in ref_model.parameters():
            param.requires_grad = False
        ref_model.to(device)
        print("Reference model created.")

        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
        model.train()

        # Inner loop for batch updates
        for step in range(num_steps):
            n = min(batch_size, len(train_data))
            batch_samples = random.sample(train_data, n)

            # 1️⃣ Generate rollouts
            rollout_data = generate_rollout_data(
                model,
                ref_model,
                tokenizer,
                batch_samples,
                num_generations,
                max_completion_length
            )

            # 2️⃣ PPO-style updates with mixed precision
            for _ in range(mu):
                optimizer.zero_grad()
                
                # Mixed precision context
                with autocast():  
                    loss, avg_reward = grpo_loss(
                        model,
                        ref_model,
                        rollout_data,
                        reward_function,
                        beta=beta,
                        epsilon=epsilon
                    )

                # Scaled backward pass
                scaler.scale(loss).backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()

                # Clear memory after each inner step
                del loss
                torch.cuda.empty_cache()

            # Clear rollout_data after all inner iterations are done
            del rollout_data
            torch.cuda.empty_cache()

            print(f"Iter {iteration+1}, Step {step+1}/{num_steps}, Avg Reward: {avg_reward:.2f}")

    return model


# ==============================================================================
# SECTION 6: EVALUATION (ADAPTED FOR PUBMEDQA)
# ==============================================================================
def evaluate_supervisor(model, tokenizer, eval_examples, device=None, max_supervisor_new_tokens=64):
    """Evaluates accuracy by: supervisor routes -> sub-agent answers -> compare to gold yes/no."""
    if device is None:
        device = next(model.parameters()).device

    model.eval()
    correct, total = 0, len(eval_examples)
    print("\n" + "="*50)
    print(f"STARTING SUPERVISOR EVALUATION ON {total} EXAMPLES")
    print("="*50)

    for ex in eval_examples:
        # 1) Build supervisor prompt from the original user prompt
        sup_prompt = build_supervisor_prompt(ex)
        expected = ex["answer"]

        # 2) Supervisor generates a routing decision
        inputs = tokenizer.encode(sup_prompt, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model.generate(
                inputs,
                max_new_tokens=max_supervisor_new_tokens,
                temperature=0.1,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )
        sup_response = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # 3) Parse agent choice (default to 'answering' if bad format)
        agent = parse_supervisor_choice(sup_response)
                # Extract only the newly generated tokens (not the input prompt)
        input_length = inputs.shape[1]
        new_tokens = outputs[0][input_length:]
        sup_response = tokenizer.decode(new_tokens, skip_special_tokens=True)
        print(f"\n[DEBUG] Parsed agent choice: {agent}")
        print(f"[DEBUG] Valid agents: {VALID_AGENTS}")
        print(f"[DEBUG] Agent in valid agents: {agent in VALID_AGENTS}")

        # 4) Run sub-agent via Ollama on the ORIGINAL user prompt
        print(f"\n[DEBUG] About to call sub-agent: {agent}")
        if agent in VALID_AGENTS:
            print(f"[DEBUG] Agent is valid, calling run_subagent...")
            sub_text, pred = run_subagent(agent, ex)
            print(f"[DEBUG] Sub-agent call completed. Pred: {pred}")
        else:
            print(f"[WARNING] Invalid agent '{agent}', skipping sub-agent call")
            sub_text, pred = "", None

        # 5) Score using your unchanged extractor
        # pred = extract_answer_from_model_output(sub_text)
        is_correct = (pred == expected)
        correct += int(is_correct)

        # Optional logging
        print("\n--- Example ---")
        print(f"Expected: {expected} | Pred: {pred} | Agent: {agent} | Correct: {'✓' if is_correct else '✗'}")
        print(f"[Supervisor Message]\n{sup_response}")
        print(f"[Sub-agent Response]\n{sub_text}")
        print("-"*50)

    acc = 100.0 * correct / max(total, 1)
    print(f"\nEvaluation Complete. Accuracy: {acc:.2f}% ({correct}/{total})")
    print("="*50)
    model.train()
    return acc


# ==============================================================================
# SECTION 7: MAIN EXECUTION BLOCK
# ==============================================================================
from torch.cuda.amp import autocast, GradScaler

def main():
    """Main function to orchestrate the entire process."""
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using primary device: {device}")

    # --- Model and Tokenizer Loading ---
    model_name = "Qwen/Qwen2.5-0.5B-Instruct"
    print(f"Loading model: {model_name}...")
    # Load in default FP32 precision, as mixed precision will be handled by the training loop if enabled.
    # model = AutoModelForCausalLM.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)

    tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    print("Model and tokenizer loaded.")
    
    # Move model to device FIRST before evaluation
    model.to(device)
    
    # --- Data Preparation ---
    all_data = prepare_pubmedqa_dataset()
    random.shuffle(all_data)
    eval_data_size = 400
    eval_data = all_data[:eval_data_size]
    train_data = all_data[eval_data_size:]
    print(f"Data prepared. Training examples: {len(train_data)}, Evaluation examples: {len(eval_data)}")

    # --- Pre-Training Evaluation ---
    print("\nEvaluating model before fine-tuning...")
    evaluate_supervisor(model, tokenizer, eval_data, device)

    # --- Training Configuration ---
    # This config is designed for a single GPU with ~16-24GB VRAM. Adjust if needed.
    training_config = {
        'num_iterations': 1,        # Number of times to update the reference model
        'num_steps': 100,           # Batches per iteration. Increase for more training.
        'batch_size': 2,            # Prompts per batch. Decrease if OOM.
        'num_generations': 4,       # Completions per prompt. Decrease if OOM.
        'max_completion_length': 300, # Decrease if OOM.
        'beta': 0.01,               # KL penalty strength
        'learning_rate': 5e-6,      # Optimizer learning rate
        'mu': 2,                    # Number of optimization steps per batch
        'epsilon': 0.2              # PPO clipping value
    }
    
    # Initialize wandb if API key is set
    # if os.environ.get("WANDB_API_KEY"):
    #     wandb.init(project=os.environ["WANDB_PROJECT"], config=training_config, reinit=True)
    #     print("Weights & Biases initialized.")
    # --- Start Training ---
    print("\nStarting GRPO fine-tuning...")
    trained_model = train_with_grpo(
        model=model,
        tokenizer=tokenizer,
        train_data=train_data,
        device=device,
        batch_size=1,            # safer for GPU memory
        num_generations=2,       # safer for GPU memory
        max_completion_length=128,
        use_lora=True            # optional
    )

    # if os.environ.get("WANDB_API_KEY"):
    #     wandb.finish()

    # --- Post-Training Evaluation ---
    print("\nEvaluating model after GRPO fine-tuning...")
    evaluate_supervisor(trained_model, tokenizer, eval_data, device)

    # --- Save Final Model ---
    output_dir = "grpo_pubmedqa_finetuned_model"
    print(f"\nSaving fine-tuned model to {output_dir}...")
    trained_model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    print("Model saved successfully.")

if __name__ == "__main__":
    main()

Using primary device: cuda:0
Loading model: Qwen/Qwen2.5-0.5B-Instruct...
Model and tokenizer loaded.
Data prepared. Training examples: 100, Evaluation examples: 400

Evaluating model before fine-tuning...

STARTING SUPERVISOR EVALUATION ON 400 EXAMPLES

[DEBUG] Parsed agent choice: context_analysis
[DEBUG] Valid agents: {'reasoning', 'context_analysis', 'question_understanding', 'answering'}
[DEBUG] Agent in valid agents: True

[DEBUG] About to call sub-agent: context_analysis
[DEBUG] Agent is valid, calling run_subagent...

[DEBUG] Running sub-agent: context_analysis
[DEBUG] Example keys: ['problem', 'context', 'prompt', 'answer']
[DEBUG] Generated prompt length: 1636
[DEBUG] Calling Ollama with model: qwen2.5:0.5b-instruct
[DEBUG] Ollama response length: 1158
[DEBUG] Ollama response: Analysis of the context:

The provided context discusses surgical excision of ovarian endometriomas and the potential for these patients to have a negative diagnosis due to low levels of CA125, which ..

  scaler = GradScaler()



[DEBUG] Running sub-agent: reasoning
[DEBUG] Example keys: ['problem', 'context', 'prompt', 'answer']
[DEBUG] Generated prompt length: 1942
[DEBUG] Calling Ollama with model: qwen2.5:0.5b-instruct
[DEBUG] Ollama response length: 1331
[DEBUG] Ollama response: To answer whether displaced midshaft clavicular fractures should be treated surgically or not, I will follow these steps based on the provided context:

1. **Identify Relevant Information**: The study...
[DEBUG] Extracted answer: yes

[DEBUG] Running sub-agent: reasoning
[DEBUG] Example keys: ['problem', 'context', 'prompt', 'answer']
[DEBUG] Generated prompt length: 1942
[DEBUG] Calling Ollama with model: qwen2.5:0.5b-instruct
[DEBUG] Ollama response length: 1956
[DEBUG] Ollama response: Answer: Maybe

Reasoning:
To reach a conclusion about whether displaced midshaft clavicular fractures should be treated surgically, we need to examine the findings from the study described. The study ...
[DEBUG] Extracted answer: maybe


  with autocast():


Iter 1, Step 1/100, Avg Reward: 1.00

[DEBUG] Running sub-agent: reasoning
[DEBUG] Example keys: ['problem', 'context', 'prompt', 'answer']
[DEBUG] Generated prompt length: 1685
[DEBUG] Calling Ollama with model: qwen2.5:0.5b-instruct
[DEBUG] Ollama response length: 1479
[DEBUG] Ollama response: Let's break down the given information into logical steps:

1. **Context Analysis:**
   - The study aims to understand the influence of body mass index (BMI) on morbidity and long-term survival in gas...
[DEBUG] Extracted answer: no
Iter 1, Step 2/100, Avg Reward: 0.00

[DEBUG] Running sub-agent: reasoning
[DEBUG] Example keys: ['problem', 'context', 'prompt', 'answer']
[DEBUG] Generated prompt length: 1265
[DEBUG] Calling Ollama with model: qwen2.5:0.5b-instruct
[DEBUG] Ollama response length: 2406
[DEBUG] Ollama response: To determine if diffusion-weighted imaging (DWI) is a significant indicator of the development of vascularization in hypovascular hepatocellular lesions, we need to analyze 

In [8]:
import requests, json

# 1) Is the server up and which models are installed?
print("TAGS:", requests.get("http://localhost:11434/api/tags").text)
print("VERSION:", requests.get("http://localhost:11434/api/version").text)

TAGS: {"models":[{"name":"qwen2.5:0.5b-instruct","model":"qwen2.5:0.5b-instruct","modified_at":"2025-09-18T17:14:08.6959557-07:00","size":397821319,"digest":"a8b0c51577010a279d933d14c2a8ab4b268079d44c5c8830c0a93900f1827c67","details":{"parent_model":"","format":"gguf","family":"qwen2","families":["qwen2"],"parameter_size":"494.03M","quantization_level":"Q4_K_M"}},{"name":"qwen2.5:7b","model":"qwen2.5:7b","modified_at":"2025-09-03T10:49:46.343869-07:00","size":4683087332,"digest":"845dbda0ea48ed749caafd9e6037047aa19acfcfd82e704d7ca97d631a0b697e","details":{"parent_model":"","format":"gguf","family":"qwen2","families":["qwen2"],"parameter_size":"7.6B","quantization_level":"Q4_K_M"}},{"name":"qwen2.5:latest","model":"qwen2.5:latest","modified_at":"2025-09-03T09:52:53.9825475-07:00","size":4683087332,"digest":"845dbda0ea48ed749caafd9e6037047aa19acfcfd82e704d7ca97d631a0b697e","details":{"parent_model":"","format":"gguf","family":"qwen2","families":["qwen2"],"parameter_size":"7.6B","quantiza