In [13]:
from datasets import load_from_disk
from matplotlib import pyplot as plt
import re
import glob
import os
import pandas as pd
import numpy as np

In [14]:
def format_reward(completion):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>.*?</reasoning>\s*<answer>.*</answer>$"
    matches = re.match(pattern, completion, re.DOTALL)
    return matches is not None

In [15]:
def eval_model_answer(equation, nums, target):
    try:
        if equation is None:
            return False

        used_numbers = [int(n) for n in re.findall(r"\d+", equation)]

        # Check if all numbers are used exactly once
        if sorted(used_numbers) != sorted(nums):
            return False

        # Define a regex pattern that only allows numbers, operators, parentheses, and whitespace
        allowed_pattern = r"^[\d+\-*/().\s]+$"
        if not re.match(allowed_pattern, equation):
            return False

        # Evaluate the equation with restricted globals and locals
        result = eval(equation, {"__builtins__": None}, {})
        # Check if the equation is correct and matches the ground truth
        if abs(float(result) - float(target)) < 1e-5:
            return True
        else:
            return False
    except Exception:
        # If evaluation fails, reward is 0
        return False

In [16]:
EXPERIMENTS = glob.glob("eval_outputs/*")

In [17]:
results_records = []

In [18]:
for experiment in EXPERIMENTS:
    checkpoints = glob.glob(os.path.join(experiment, "checkpoint-*"))
    for checkpoint in checkpoints:
        print(f"Evaluating {checkpoint}")
        dataset = load_from_disk(checkpoint)
        format_matches = 0
        correct = 0
        correct_strict = 0
        total = len(dataset)
        for item in dataset:
            if eval_model_answer(item["model_answer"], item["nums"], item["target"]):
                correct += 1
            if format_reward(item["model_response"]):
                format_matches += 1
                if eval_model_answer(item["model_answer"], item["nums"], item["target"]):
                    correct_strict += 1
            
            # print("Model response: ", item['model_response'])
            # print("Format match: ", format_reward(item["model_response"]))
            # print("="*50)
            # break
        accuracy = correct / total
        strict_accuracy = correct_strict / total
        format_accuracy = format_matches / total
        print(f"Accuracy at {checkpoint}: {accuracy:.4f}")
        print(f"Format Accuracy at {checkpoint}: {format_accuracy:.4f}")
        results_records.append({
            "experiment_name": experiment.split("/")[-1],
            "checkpoint_number": int(checkpoint.split("-")[-1]),
            "accuracy": accuracy,
            "strict_accuracy": strict_accuracy,
            "format_accuracy": format_accuracy
        })

Evaluating eval_outputs/rl_equation_base_config/checkpoint-10
Accuracy at eval_outputs/rl_equation_base_config/checkpoint-10: 0.8000
Format Accuracy at eval_outputs/rl_equation_base_config/checkpoint-10: 0.8200
Evaluating eval_outputs/rl_equation_base_config/checkpoint-100
Accuracy at eval_outputs/rl_equation_base_config/checkpoint-100: 0.0400
Format Accuracy at eval_outputs/rl_equation_base_config/checkpoint-100: 1.0000
Evaluating eval_outputs/rl_equation_base_config/checkpoint-110
Accuracy at eval_outputs/rl_equation_base_config/checkpoint-110: 0.0800
Format Accuracy at eval_outputs/rl_equation_base_config/checkpoint-110: 1.0000
Evaluating eval_outputs/rl_equation_base_config/checkpoint-120
Accuracy at eval_outputs/rl_equation_base_config/checkpoint-120: 0.3100
Format Accuracy at eval_outputs/rl_equation_base_config/checkpoint-120: 1.0000
Evaluating eval_outputs/rl_equation_base_config/checkpoint-20
Accuracy at eval_outputs/rl_equation_base_config/checkpoint-20: 0.7800
Format Accurac

In [19]:
results_df = pd.DataFrame.from_records(results_records)

In [20]:
results_df.sort_values(by=["experiment_name", "checkpoint_number"], inplace=True)

In [21]:
results_df

Unnamed: 0,experiment_name,checkpoint_number,accuracy,strict_accuracy,format_accuracy
0,rl_equation_base_config,10,0.8,0.8,0.82
4,rl_equation_base_config,20,0.78,0.78,0.79
5,rl_equation_base_config,30,0.8,0.8,0.82
6,rl_equation_base_config,40,0.82,0.82,0.82
7,rl_equation_base_config,50,0.65,0.65,0.73
8,rl_equation_base_config,60,0.01,0.01,1.0
9,rl_equation_base_config,70,0.0,0.0,0.99
10,rl_equation_base_config,80,0.01,0.01,1.0
11,rl_equation_base_config,90,0.01,0.01,1.0
1,rl_equation_base_config,100,0.04,0.04,1.0


In [22]:
EXPERIMENTS_ORDER = [
    'rl_equation_base_config', 
    'rl_equation_base_config_kl_0.0001_beta', 
    'rl_equation_base_config_kl_0.01_beta', 
    'rl_equation_base_config_kl_0.1_beta', 
    'rl_equation_base_config_kl_high_lr', 
    'rl_equation_base_config_alllinear', 
    'rl_equation_base_config_alllinear_kl_largebeta', 
    'rl_equation_first', 
    'rl_equation_think', 
]

In [23]:
idx = results_df.groupby("experiment_name")["accuracy"].idxmax()
idx = idx.reindex(EXPERIMENTS_ORDER).dropna().astype(int)
best_results_df = results_df.loc[idx]
best_results_df

Unnamed: 0,experiment_name,checkpoint_number,accuracy,strict_accuracy,format_accuracy
6,rl_equation_base_config,40,0.82,0.82,0.82
23,rl_equation_base_config_kl_0.0001_beta,30,0.81,0.81,0.81
26,rl_equation_base_config_kl_0.01_beta,20,0.85,0.85,0.85
31,rl_equation_base_config_kl_0.1_beta,30,0.8,0.8,0.81
33,rl_equation_base_config_kl_high_lr,10,0.8,0.8,0.82
14,rl_equation_base_config_alllinear,30,0.85,0.85,0.85
17,rl_equation_base_config_alllinear_kl_largebeta,10,0.8,0.8,0.81
37,rl_equation_first,20,0.75,0.0,0.0
41,rl_equation_think,20,0.79,0.79,0.8


In [24]:
idx = results_df.groupby("experiment_name")["strict_accuracy"].idxmax()
idx = idx.reindex(EXPERIMENTS_ORDER).dropna().astype(int)
best_results_df = results_df.loc[idx]
best_results_df

Unnamed: 0,experiment_name,checkpoint_number,accuracy,strict_accuracy,format_accuracy
6,rl_equation_base_config,40,0.82,0.82,0.82
23,rl_equation_base_config_kl_0.0001_beta,30,0.81,0.81,0.81
26,rl_equation_base_config_kl_0.01_beta,20,0.85,0.85,0.85
31,rl_equation_base_config_kl_0.1_beta,30,0.8,0.8,0.81
33,rl_equation_base_config_kl_high_lr,10,0.8,0.8,0.82
14,rl_equation_base_config_alllinear,30,0.85,0.85,0.85
17,rl_equation_base_config_alllinear_kl_largebeta,10,0.8,0.8,0.81
37,rl_equation_first,20,0.75,0.0,0.0
41,rl_equation_think,20,0.79,0.79,0.8
