In [4]:

import re

def calculate_reward(text):
    # Pattern to match <api> [Calculator(...)] </api>
    api_pattern = r'<api>\s*\[Calculator\((.*?)\)\]\s*</api>'
    
    # Check if the pattern exists
    match = re.search(api_pattern, text)
    
    if match:
        # If the complete API pattern is found, return a positive reward
        return {
            'reward': 1.0,
            'detected_expression': match.group(1),  # This will capture the expression inside Calculator()
            'message': 'Valid API call pattern detected'
        }
    else:
        # If the pattern is not found or incomplete, return no reward
        return {
            'reward': 0.0,
            'detected_expression': None,
            'message': 'Invalid or missing API call pattern'
        }

# Test cases
test_cases = [
    '<api> [Calculator(18 + 12 * 3)] </api>',  # Valid case
    'Calculator(18 + 12 * 3)',                  # Invalid case (missing api tags)
    '<api> Calculator(18 + 12 * 3) </api>',     # Invalid case (missing square brackets)
    'Random text',
    """
    <think>
I have to solve this expression "18 + 12 x 3". let me use the calculator to solve this expression

<api> [Calculator(18 + 2 * 3)] </api> -> 54

I got 54 as an answer from the Calculator. So the answer is 54

<answer>
54
<answer>
    """
                                  # Invalid case (no pattern)
]

# Test the function
for test in test_cases:
    result = calculate_reward(test)
    print(f"Input: {test}")
    print(f"Result: {result}\n")


Input: <api> [Calculator(18 + 12 * 3)] </api>
Result: {'reward': 1.0, 'detected_expression': '18 + 12 * 3', 'message': 'Valid API call pattern detected'}

Input: Calculator(18 + 12 * 3)
Result: {'reward': 0.0, 'detected_expression': None, 'message': 'Invalid or missing API call pattern'}

Input: <api> Calculator(18 + 12 * 3) </api>
Result: {'reward': 0.0, 'detected_expression': None, 'message': 'Invalid or missing API call pattern'}

Input: Random text
Result: {'reward': 0.0, 'detected_expression': None, 'message': 'Invalid or missing API call pattern'}

Input: 
    <think>
I have to solve this expression "18 + 12 x 3". let me use the calculator to solve this expression

<api> [Calculator(18 + 2 * 3)] </api> -> 54

I got 54 as an answer from the Calculator. So the answer is 54

<answer>
54
<answer>
    
Result: {'reward': 1.0, 'detected_expression': '18 + 2 * 3', 'message': 'Valid API call pattern detected'}



In [19]:
import pytest
import re

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r'<think>[\s\S]*?</think>\s*<answer>\d+</answer>'
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]


In [22]:
 valid_completion = [[{
        "content": """<think>
I have to solve this expression "18 + 12 x 3". let me use the calculator to solve this expression

<api>[Calculator(18 + 12 * 3)]</api> -> 54

I got 54 as an answer from the Calculator. So the answer is 54

</think>
<api>[Calculator(18 + 12 * 3)]</api> -> 54
<answer>54</answer>

        """
    }]]
x = strict_format_reward_func(valid_completion)
print(x)

[0.0]


In [23]:
import re
import numexpr as ne

def validate_math_expression(text):
    """
    Validates mathematical expressions found within API tags using numexpr.
    
    Args:
        text (str): The text containing API tags with math expressions
        
    Returns:
        float: 1.0 if all expressions are valid, 0.0 if any expression is invalid
    """
    # Find all expressions within API tags
    api_pattern = r'\<api\>\s*\[Calculator\((.*?)\)\].*?\<\/api\>'
    expressions = re.findall(api_pattern, text, re.DOTALL)
    
    if not expressions:
        return 0.0  # No expressions found
    
    try:
        # Validate each expression
        for expr in expressions:
            # Clean the expression
            cleaned_expr = expr.strip()
            
            # Replace common mathematical words with symbols
            replacements = {
                'x': '*',
                '×': '*',
                '÷': '/',
            }
            for old, new in replacements.items():
                cleaned_expr = cleaned_expr.replace(old, new)
            
            # Try to evaluate the expression using numexpr
            # This will raise an exception if the expression is invalid
            ne.evaluate(cleaned_expr)
        
        return 1.0  # All expressions are valid
        
    except Exception as e:
        return 0.0  # Invalid expression found

# Example usage
def test_validate_math_expression():
    # Valid expression
    text1 = """
    <think>
    Let me calculate this.
    <api>[Calculator(18 + 12 * 3)]</api> -> 54
    </think>
    """
    
    # Invalid expression
    text2 = """
    <think>
    Let me calculate this.
    <api>[Calculator(18 ++ 12 * 3)]</api> -> 54
    </think>
    """
    
    # Multiple expressions
    text3 = """
    <think>
    Let me break this down.
    <api>[Calculator(10 + 5)]</api> -> 15
    Then,
    <api>[Calculator(15 * 2)]</api> -> 30
    </think>
    """
    
    print(f"Text 1 reward: {validate_math_expression(text1)}")  # Should print 1.0
    print(f"Text 2 reward: {validate_math_expression(text2)}")  # Should print 0.0
    print(f"Text 3 reward: {validate_math_expression(text3)}")  # Should print 1.0

if __name__ == "__main__":
    test_validate_math_expression()


Text 1 reward: 1.0
Text 2 reward: 1.0
Text 3 reward: 1.0


In [32]:
def validate_math_expression_reward(completions, **kwargs) -> list[float]:
    """Reward function that validates if the math expression in API tags is accepted by numexpr."""
    from utils import extract_last_calculator_expression, calculate_safe
    
    responses = [completion[0]["content"] for completion in completions]
    rewards = []
    
    for response in responses:
        # Extract the expression from Calculator API tags
        expression = extract_last_calculator_expression(response)
        print(expression)
        
        if expression is None:
            rewards.append(0.0)
            continue
            
        # Check if the expression is valid using numexpr
        try:
            calculate_safe(expression)
            rewards.append(1.0)  # Expression is valid
        except Exception as e:
            rewards.append(0.0)  # Expression is invalid
            
    return rewards

In [37]:
def test_validate_math_expression_reward():
    # Valid expression
    text1 = """
    <think>
    Let me calculate this.
    <api>[Calculator(18 + 12 * 3)]</api> -> 54
    </think>
    """
    
    # Invalid expression
    text2 = """
 
    """
    
    # Multiple expressions
    text3 = """
    <think>
    Let me break this down.
    <api>[Calculator(10 + 5)]</api> -> 15
    Then,
    <api>[Calculator(15 * 2)]</api> -> 30
    </think>
    """
    
    print(f"Text 1 reward: {validate_math_expression_reward(text1)}")  # Should print 1.0
    print(f"Text 2 reward: {validate_math_expression_reward(text2)}")  # Should print 0.0
    print(f"Text 3 reward: {validate_math_expression_reward(text3)}")  # Should print 1.0


In [38]:
test_validate_math_expression()

Text 1 reward: 1.0
Text 2 reward: 1.0
Text 3 reward: 1.0


In [41]:
 valid_completion = [[{
        "content": """<think>
I have to solve this expression "18 + 12 x 3". let me use the calculator to solve this expression

<api>[Calculator(18 + 12 * 3)]</api> -> 54

I got 54 as an answer from the Calculator. So the answer is 54

</think>
<api>[Calculator(3**3)]</api> -> 54
<answer>54</answer>

        """
    }]]
x = validate_math_expression_reward(valid_completion)
x

3**3
27
27
27


[1.0]

In [54]:
from utils import calculate_safe
calculate_safe("((3*3)**(1/2))")

3.0


3.0

In [57]:
def extract_answer(output_text):
    """Extract the numeric answer from the model output."""
    # Look for an answer tag if it exists
    answer_match = re.search(r'<answer>(.*?)</answer>', output_text, re.DOTALL)
    if answer_match:
        # Extract numeric value from the answer
        answer_text = answer_match.group(1).strip()
        numeric_match = re.search(r'\d+\.?\d*', answer_text)
        if numeric_match:
            return numeric_match.group(0)

In [63]:
def numerical_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    """Reward function that compares numerical answers, allowing for different formats.
    
    Args:
        prompts: List of conversation histories
        completions: List of model completions
        answer: List of correct answers
        **kwargs: Additional arguments
        
    Returns:
        List of rewards (1.0 for correct numerical answer, 0.0 otherwise)
    """
    #from evaluate_gsm8k import extract_answer
    
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_answer(r) for r in responses]
    correct_answers = [extract_answer(str(a)) for a in answer]

    print(extracted_responses, correct_answers)
    
    # Compare numerical values with conversion to float for proper comparison
    rewards = []
    for resp, ans in zip(extracted_responses, correct_answers):
        if resp is not None and ans is not None:
            try:
                resp_num = float(resp)
                ans_num = float(ans)
                print(resp_num, ans_num)
                rewards.append(1.0 if resp_num == ans_num else 0.0)
            except ValueError:
                rewards.append(0.0)
        else:
            rewards.append(0.0)
    
    return rewards


In [84]:
def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

In [85]:
def tool_usage_reward(completions, prompts, answer, **kwargs):
        """Reward function that evaluates tool usage patterns.
        
        Rewards proper tool usage and penalizes excessive tool calls compared to golden response.
        
        Args:
            completions: List of model completions
            solution: List of ground truth solutions containing tool usage
        
        Returns:
            List of rewards, one per completion
        """
        responses = [completion[0]['content'] for completion in completions]
        #q = prompts[0][-1]['content']
        extracted_responses = [extract_xml_answer(r) for r in responses]
        print('-'*20, f"Question:", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
        return [2.0 if r.count("[Calculator(") == a.count("[Calculator(") else 1.0 if (r.count("[Calculator(") > a.count("[Calculator(") or r.count("[Calculator(") < a.count("[Calculator(")) and r.count("[Calculator(") != 0 else 0.0 for r, a in zip(extracted_responses, answer)]



In [86]:
"""Example showing how to use the numerical reward function."""

#from rewards import numerical_reward_func

# Example question
question = "On Monday, Mack writes in his journal for 60 minutes at a rate of 1 page every 30 minutes. On Tuesday, Mack writes in his journal for 45 minutes at a rate of 1 page every 15 minutes. On Wednesday, Mack writes 5 pages in his journal. How many pages total does Mack write in his journal from Monday to Wednesday? "

# Example completion (model's response)
completion = [[{"content": """ <think>
Let's calculate the pages Mack wrote each day:

**Monday:**
- He writes for 60 minutes at a rate of 1 page every 30 minutes.
- The number of pages he writes on Monday = 60 minutes / 30 minutes per page = 2 pages.

**Tuesday:**
- He writes for 45 minutes at a rate of 1 page every 15 minutes.
- The number of pages he writes on Tuesday = 45 minutes / 15 minutes per page = 3 pages.

**Wednesday:**
- He writes 5 pages in his journal.

Now, we sum up the pages written over the three days:
- Total pages = Pages on Monday + Pages on Tuesday + Pages on Wednesday = 2 + 3 + 5 = 10 pages

<api> [Calculator(2 + 3 + 5)] </api> -> 10

The total number of pages Mack writes from Monday to Wednesday is 10 pages.
</think>

<answer>
10
</answer> 
    """ }]]

# Example correct answer
answer = """<think>
On Monday, Mack writes 60 / 30 = <api>[Calculator(60/30)]</api> 2 pages
On Tuesday, Mack writes 45 / 15 = <api>[Calculator(45/15)]</api> 3 pages
In total, from Monday to Wednesday, Mack writes <api>[Calculator(2 + 3 + 5)]</api> 10 pages
</think> """

# Calculate reward
rewards = tool_usage_reward(completion, [question], [answer])

print(f"Reward: {rewards[0]}")  # Should print 1.0 if model's answer matches the correct answer

-------------------- Question: 
Answer:
<think>
On Monday, Mack writes 60 / 30 = <api>[Calculator(60/30)]</api> 2 pages
On Tuesday, Mack writes 45 / 15 = <api>[Calculator(45/15)]</api> 3 pages
In total, from Monday to Wednesday, Mack writes <api>[Calculator(2 + 3 + 5)]</api> 10 pages
</think>  
Response:
 <think>
Let's calculate the pages Mack wrote each day:

**Monday:**
- He writes for 60 minutes at a rate of 1 page every 30 minutes.
- The number of pages he writes on Monday = 60 minutes / 30 minutes per page = 2 pages.

**Tuesday:**
- He writes for 45 minutes at a rate of 1 page every 15 minutes.
- The number of pages he writes on Tuesday = 45 minutes / 15 minutes per page = 3 pages.

**Wednesday:**
- He writes 5 pages in his journal.

Now, we sum up the pages written over the three days:
- Total pages = Pages on Monday + Pages on Tuesday + Pages on Wednesday = 2 + 3 + 5 = 10 pages

<api> [Calculator(2 + 3 + 5)] </api> -> 10

The total number of pages Mack writes from Monday to Wedn

In [1]:
def extract_answer(text):
    # Pattern to match the final answer (usually after "The answer is" or similar phrases)
    patterns = [
        r"The answer is\s*(-?\d+(?:\.\d+)?)",
        r"The final answer is\s*(-?\d+(?:\.\d+)?)",
        r"The result is\s*(-?\d+(?:\.\d+)?)",
        r"equals\s*(-?\d+(?:\.\d+)?)",
        r"=\s*(-?\d+(?:\.\d+)?)",
        r"(-?\d+(?:\.\d+)?)\s*\$",  # If the answer is just at the end
        r"####\s*(-?\d+(?:\.\d+)?)"  # New pattern to match "#### number" format
    ]
    
    for pattern in patterns:
        matches = re.search(pattern, text)
        if matches:
            return float(matches.group(1))
    
    # If no patterns match, try to find any number in the last line
    lines = text.strip().split("\n")
    if lines:
        last_line = lines[-1]
        numbers = re.findall(r"(-?\d+(?:\.\d+)?)", last_line)
        if numbers:
            return float(numbers[-1])
    
    return None


In [6]:
text = """
"Natalia sold 48/2 = <<48/2=24>>24 clips in May.
Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.
#### 72"

"""

In [7]:
import re
extract_answer(text)

24.0

In [1]:
def extract_final_answer(text):
    # Method 1: Look for answer after '####' marker
    if '####' in text:
        parts = text.split('####')
        if len(parts) > 1:
            return parts[1].strip()
    
    # Method 2: Look for boxed answer format
    if r'\boxed{' in text:
        import re
        match = re.search(r'\\boxed\{(.*?)\}', text)
        if match:
            return match.group(1).strip()
    
    # Method 3: Look for "final answer is:" pattern
    if "final answer is:" in text.lower():
        import re
        match = re.search(r'final answer is:?\s*\\$?\\?boxed\{?(.*?)\}?\\$?', text.lower())
        if match:
            return match.group(1).strip()
    
    return None

# Example usage
answer1 = "Natalia sold 48/2 = <<48/2=24>>24 clips in May.Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.#### 72"
answer2 = "## Step 1: Calculate the number of clips sold in April.Natalia sold 48 clips in April.## Step 2: Calculate the number of clips sold in May.She sold half as many clips in May as she did in April, so she sold 48 / 2 = 24 clips in May.## Step 3: Add the number of clips sold in April and May to find the total number of clips sold.Total clips sold = clips sold in April + clips sold in May = 48 + 24 = 72.The final answer is: \$\\boxed{72}\$"

print("First answer:", extract_final_answer(answer1))  # Should output: 72
print("Second answer:", extract_final_answer(answer2))  # Should output: 72

First answer: 72
Second answer: 72


In [1]:
import re
def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    #pattern = r'<think>.*?</think>\s*<answer>\d+(\.\d+)?</answer>'
    pattern= r'<think>[\s\S]*?</think>\s*<answer>[0-9.]+</answer>'
    responses = [completion["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [2.5 if match else 0.0 for match in matches]



In [2]:
prompt = [{"content": """<think> First, I need to calculate the price after the 15% discount. Discount amount = \$80 × 0.15 <api>[Calculator(80 * 0.15)]</api> -> 12 Discounted price = \$80 - \$12 <api>[Calculator(80 - 12)]</api> -> 68
Now I need to add the 8% sales tax to this discounted price. Tax amount = $68 × 0.08 <api>[Calculator(68 * 0.08)]</api> -> 5.44 Final price = $68 + $5.44 <api>[Calculator(68 + 5.44)]</api> -> 73.44
So the final price is $73.44 </think>
<answer>73.44</answer>"""},
{"content": """<think> First, I need to calculate the price after the 15% discount. Discount amount = \$80 × 0.15 <api>[Calculator(80 * 0.15)]</api> -> 12 Discounted price = \$80 - \$12 <api>[Calculator(80 - 12)]</api> -> 68
Now I need to add the 8% sales tax to this discounted price. Tax amount = $68 × 0.08 <api>[Calculator(68 * 0.08)]</api> -> 5.44 Final price = $68 + $5.44 <api>[Calculator(68 + 5.44)]</api> -> 73.44
So the final price is $73.44 </think>
<answer>73.44</answer>"""}]

In [3]:
prompt[1]["content"]

'<think> First, I need to calculate the price after the 15% discount. Discount amount = \\$80 × 0.15 <api>[Calculator(80 * 0.15)]</api> -> 12 Discounted price = \\$80 - \\$12 <api>[Calculator(80 - 12)]</api> -> 68\nNow I need to add the 8% sales tax to this discounted price. Tax amount = $68 × 0.08 <api>[Calculator(68 * 0.08)]</api> -> 5.44 Final price = $68 + $5.44 <api>[Calculator(68 + 5.44)]</api> -> 73.44\nSo the final price is $73.44 </think>\n<answer>73.44</answer>'

In [4]:
strict_format_reward_func(prompt)

[2.5, 2.5]