In [1]:
from glob import glob

from datasets import load_dataset
from sal.utils.math import *
from sal.utils.grader import *

from sal.utils.qwen_math_parser import *
from sal.utils.data import get_dataset, save_dataset
from collections import defaultdict
import json
import numpy as np
import pickle
from sal.config import Config

import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = Config()
config.datset_name = "/ssdscratch/byuan48/particle_filtering/probabilistic-inference-scaling/datasets/math500.jsonl"
config.dataset_split = "train"
gt = get_dataset(config)


In [3]:
def is_correct(sample, key):
    ans = memoized_canonical_form(sample['answer'])
    pred = memoized_canonical_form(sample[f'pred_{key}']).strip("\\boxed{").strip("}")

    sample['is_correct_' + key] = math_equal(ans, pred)
    return sample


def parse_responses(sample):
    sample['parsed_responses'] = [strip_string(extract_answer(x, 'math')) for x in sample['completions']]
    return sample


def extract(string):
    return strip_string(extract_answer(string, 'math'))

In [4]:
files = glob('./output/p2/seed2/*.pkl')

data = {}
for file in files:
    with open(file, 'rb') as f:
        if file.split('/')[-1].startswith('batch'):
                continue
                
        # if not os.path.exists('/shiv/search-and-learn/llama1b_jan11_mathShephard_vs_deepSeek_PRM_comparison/seed96/temp_1/p128/' + file.split('/')[-1]):
        #         continue
        print("file: ", file)
        data[f"{file.split('/')[-1].replace('.pkl', '.json')}"] = pickle.load(f)

file:  ./output/p2/seed2/test_algebra_1072.pkl


2025-04-22 14:51:26,665	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


file:  ./output/p2/seed2/test_prealgebra_930.pkl
file:  ./output/p2/seed2/test_algebra_2036.pkl
file:  ./output/p2/seed2/test_intermediate_algebra_232.pkl
file:  ./output/p2/seed2/test_prealgebra_505.pkl
file:  ./output/p2/seed2/test_algebra_1098.pkl
file:  ./output/p2/seed2/test_geometry_434.pkl
file:  ./output/p2/seed2/test_number_theory_516.pkl
file:  ./output/p2/seed2/test_intermediate_algebra_207.pkl
file:  ./output/p2/seed2/test_algebra_1035.pkl
file:  ./output/p2/seed2/test_precalculus_779.pkl
file:  ./output/p2/seed2/test_intermediate_algebra_1408.pkl
file:  ./output/p2/seed2/test_prealgebra_1302.pkl
file:  ./output/p2/seed2/test_number_theory_847.pkl
file:  ./output/p2/seed2/test_geometry_178.pkl
file:  ./output/p2/seed2/test_algebra_722.pkl
file:  ./output/p2/seed2/test_counting_and_probability_134.pkl
file:  ./output/p2/seed2/test_algebra_2700.pkl
file:  ./output/p2/seed2/test_geometry_1140.pkl
file:  ./output/p2/seed2/test_counting_and_probability_159.pkl
file:  ./output/p2

In [5]:
unique_ids = list(data.keys())
gt = gt.map(lambda x: {'unique_id': x['unique_id'].replace('/', '_')})
gt = gt.filter(lambda x: x['unique_id'] in unique_ids)

Filter: 100%|██████████| 500/500 [00:00<00:00, 53278.59 examples/s]


In [6]:
len(set(unique_ids))

149

In [7]:
gt

Dataset({
    features: ['problem', 'solution', 'answer', 'subject', 'level', 'unique_id'],
    num_rows: 149
})

In [8]:
def get_pg_response(sample, pg_data):
    pg_sample = pg_data[sample['unique_id']]
    # Track accuracy at each state
    state_accuracies = []
    state_answers = []
    assert len(pg_sample) == 1 
    for state in pg_sample:
        rewards = [x.rewards[-1] for x in state]
        all_rewards = [x.rewards for x in state]
        # tokens_num = np.array([sum(x.tokens_num) for x in state]).sum()
        best_particle = state[np.argmax(rewards)]


        # logits = [inverse_sigmoid(r) for r in rewards]
        # logits = np.array(logits)
        # weights = softmax(logits/.5)
        # best_particle = np.random.choice(state, p=weights)
        trajectory = "\n\n".join(best_particle.trajectory)
        answer = extract(trajectory)
        is_correct = math_equal(memoized_canonical_form(sample['answer']), memoized_canonical_form(answer))
        state_accuracies.append(is_correct)
        state_answers.append(answer)
    
    trajectories = []
    for state in pg_sample:
        for x in state:
            trajectories.append("\n\n".join(x.trajectory))
    
    # Still store the final prediction and accuracy
    # last_state = pg_sample[-1]
    # rewards = [x.rewards[-1] for x in last_state]
    # best_particle = last_state[np.argmax(rewards)]
    # trajectory = "\n\n".join(best_particle.trajectory)
    # answer = extract(trajectory)
    # sample['pred_pg'] = answer
    sample['is_correct_pg_states'] = state_accuracies
    sample['preds'] = state_answers
    # sample['tokens_num'] = tokens_num
    sample['is_correct'] = any(state_accuracies)
    sample['trajectories'] = trajectories
    sample['rewards'] = rewards
    sample['all_rewards'] = all_rewards
    

    return sample


In [9]:
gt = gt.map(get_pg_response, fn_kwargs={'pg_data': data})

Map: 100%|██████████| 149/149 [00:02<00:00, 70.34 examples/s]


In [10]:
#num_states_to_test = 2  # Based on seeing 7 states in the data from context
for i in range(len(gt[0]['is_correct_pg_states'])):
    correct_at_state = [d['is_correct_pg_states'][i] for d in gt]
    # tokens_num = [d['tokens_num'] for d in gt]
    print(f"State {i}: {np.mean(correct_at_state):.3f} accuracy ({sum(correct_at_state)} correct)")
    # print(f"State {i}: {np.mean(tokens_num):.3f} tokens") # Represents the number of tokens in tokens_num



State 0: 0.376 accuracy (56 correct)


In [11]:
gt[10]

{'problem': 'What is the least positive integer multiple of 30 that can be written with only the digits 0 and 2?',
 'solution': "Let $M$ be the least positive multiple of 30 that can be written with only the digits 0 and 2. First, $M$ is a multiple of 10, so its units digit must be 0. $M$ is also a multiple of 3, which means the sum of its digits must be a multiple of 3. Therefore, we must take at least  three 2's. Since $M$ is minimal, we take exactly three 2's and do not have any additional 0's: $M=\\boxed{2220}$.",
 'answer': '2220',
 'subject': 'Number Theory',
 'level': 3,
 'unique_id': 'test_number_theory_1032.json',
 'is_correct_pg_states': [False],
 'preds': ['2020'],
 'is_correct': False,
 'trajectories': ["## Step 1: Determine the possible digits\nWe need to find the least positive integer multiple of 30 that can be written using only the digits 0 and 2.\n\n## Step 2: Analyze the divisibility rules for 2 and 3\nSince 2 and 30 are both even numbers, the number must be even. To

In [12]:
gt

Dataset({
    features: ['problem', 'solution', 'answer', 'subject', 'level', 'unique_id', 'is_correct_pg_states', 'preds', 'is_correct', 'trajectories', 'rewards', 'all_rewards'],
    num_rows: 149
})