In [None]:
import re
from typing import Optional
from datasets import load_dataset, Dataset

# Load dataset: openai/gsm8k 
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    if "<answer>" in text and "</answer>" in text:
        return text.split("<answer>")[-1].split("</answer>")[0].strip()
    return ""

def extract_hash_answer(text: str) -> Optional[str]:
    if "####" in text:
        return text.split("####")[1].strip()
    return None

def get_gsm8k_questions(split="train") -> Dataset:
    try:
        data = load_dataset('openai/gsm8k', 'main')[split]
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return None

    data = data.map(lambda x: {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer']) or ""
    })
    return data

dataset = get_gsm8k_questions()

# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    if not completions:
        return [0.0]

    responses = [c[0]['content'] for c in completions if c]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    
    print('-'*20, f"Question:\n{prompts[0][-1]['content']}",
          f"\nAnswer:\n{answer[0] if answer else 'N/A'}",
          f"\nResponse:\n{responses[0] if responses else 'N/A'}",
          f"\nExtracted:\n{extracted_responses[0] if extracted_responses else 'N/A'}")

    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    if not completions:
        return [0.0]

    responses = [c[0]['content'] for c in completions if c]
    extracted_responses = [extract_xml_answer(r) for r in responses]

    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    
    if not completions:
        return [0.0]

    responses = [c[0]["content"] for c in completions if c]
    matches = [bool(re.match(pattern, r)) for r in responses if isinstance(r, str)]
    
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    
    if not completions:
        return [0.0]

    responses = [c[0]["content"] for c in completions if c]
    matches = [bool(re.match(pattern, r)) for r in responses if isinstance(r, str)]
    
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text: str) -> float:
    if not isinstance(text, str):
        return 0.0
    
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1]) * 0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    if not completions:
        return [0.0]

    contents = [c[0]["content"] for c in completions if c]
    return [count_xml(c) for c in contents]
