In [1]:
from glob import glob

from datasets import load_dataset
from sal.utils.math import *
from sal.utils.grader import *

from sal.utils.qwen_math_parser import *
from sal.utils.data import get_dataset, save_dataset
from collections import defaultdict
import json
import numpy as np
import pickle
from sal.config import Config

import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = Config()
config.datset_name = "/ssdscratch/byuan48/particle_filtering/probabilistic-inference-scaling/datasets/math500.jsonl"
config.dataset_split = "train"
gt = get_dataset(config)


In [3]:
def is_correct(sample, key):
    ans = memoized_canonical_form(sample['answer'])
    pred = memoized_canonical_form(sample[f'pred_{key}']).strip("\\boxed{").strip("}")

    sample['is_correct_' + key] = math_equal(ans, pred)
    return sample


def parse_responses(sample):
    sample['parsed_responses'] = [strip_string(extract_answer(x, 'math')) for x in sample['completions']]
    return sample


def extract(string):
    return strip_string(extract_answer(string, 'math'))

In [4]:
files = glob('/ssdscratch/byuan48/particle_filtering/probabilistic-inference-scaling/output/p4/*.pkl')

data = {}
for file in files:
    with open(file, 'rb') as f:
        if file.split('/')[-1].startswith('batch'):
                continue
                
        # if not os.path.exists('/shiv/search-and-learn/llama1b_jan11_mathShephard_vs_deepSeek_PRM_comparison/seed96/temp_1/p128/' + file.split('/')[-1]):
        #         continue
        print("file: ", file)
        data[f"{file.split('/')[-1].replace('.pkl', '.json')}"] = pickle.load(f)

file:  /ssdscratch/byuan48/particle_filtering/probabilistic-inference-scaling/output/p4/test_prealgebra_1622.pkl


2025-04-21 14:09:00,690	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


file:  /ssdscratch/byuan48/particle_filtering/probabilistic-inference-scaling/output/p4/test_intermediate_algebra_1410.pkl
file:  /ssdscratch/byuan48/particle_filtering/probabilistic-inference-scaling/output/p4/test_algebra_841.pkl
file:  /ssdscratch/byuan48/particle_filtering/probabilistic-inference-scaling/output/p4/test_precalculus_477.pkl
file:  /ssdscratch/byuan48/particle_filtering/probabilistic-inference-scaling/output/p4/test_number_theory_1055.pkl
file:  /ssdscratch/byuan48/particle_filtering/probabilistic-inference-scaling/output/p4/test_prealgebra_572.pkl
file:  /ssdscratch/byuan48/particle_filtering/probabilistic-inference-scaling/output/p4/test_prealgebra_1686.pkl
file:  /ssdscratch/byuan48/particle_filtering/probabilistic-inference-scaling/output/p4/test_geometry_283.pkl
file:  /ssdscratch/byuan48/particle_filtering/probabilistic-inference-scaling/output/p4/test_intermediate_algebra_964.pkl
file:  /ssdscratch/byuan48/particle_filtering/probabilistic-inference-scaling/outp

In [5]:
unique_ids = list(data.keys())
gt = gt.map(lambda x: {'unique_id': x['unique_id'].replace('/', '_')})
gt = gt.filter(lambda x: x['unique_id'] in unique_ids)

Filter: 100%|██████████| 500/500 [00:00<00:00, 56669.06 examples/s]


In [6]:
len(set(unique_ids))

200

In [7]:
gt

Dataset({
    features: ['problem', 'solution', 'answer', 'subject', 'level', 'unique_id'],
    num_rows: 200
})

In [26]:
def get_pg_response(sample, pg_data):
    pg_sample = pg_data[sample['unique_id']]
    # Track accuracy at each state
    state_accuracies = []
    state_answers = []
    assert len(pg_sample) == 1 
    for state in pg_sample:
        rewards = [x.rewards[-1] for x in state]
        all_rewards = [x.rewards for x in state]
        tokens_num = np.array([sum(x.tokens_num) for x in state]).sum()
        best_particle = state[np.argmax(rewards)]


        # logits = [inverse_sigmoid(r) for r in rewards]
        # logits = np.array(logits)
        # weights = softmax(logits/.5)
        # best_particle = np.random.choice(state, p=weights)
        trajectory = "\n\n".join(best_particle.trajectory)
        answer = extract(trajectory)
        is_correct = math_equal(memoized_canonical_form(sample['answer']), memoized_canonical_form(answer))
        state_accuracies.append(is_correct)
        state_answers.append(answer)
    
    trajectories = []
    for state in pg_sample:
        for x in state:
            trajectories.append("\n\n".join(x.trajectory))
    
    # Still store the final prediction and accuracy
    # last_state = pg_sample[-1]
    # rewards = [x.rewards[-1] for x in last_state]
    # best_particle = last_state[np.argmax(rewards)]
    # trajectory = "\n\n".join(best_particle.trajectory)
    # answer = extract(trajectory)
    # sample['pred_pg'] = answer
    sample['is_correct_pg_states'] = state_accuracies
    sample['preds'] = state_answers
    sample['tokens_num'] = tokens_num
    sample['is_correct'] = any(state_accuracies)
    sample['trajectories'] = trajectories
    sample['rewards'] = rewards
    sample['all_rewards'] = all_rewards
    

    return sample


In [27]:
gt = gt.map(get_pg_response, fn_kwargs={'pg_data': data})

Map: 100%|██████████| 200/200 [00:00<00:00, 480.56 examples/s]


In [10]:
#num_states_to_test = 2  # Based on seeing 7 states in the data from context
for i in range(len(gt[0]['is_correct_pg_states'])):
    correct_at_state = [d['is_correct_pg_states'][i] for d in gt]
    tokens_num = [d['tokens_num'] for d in gt]
    print(f"State {i}: {np.mean(correct_at_state):.3f} accuracy ({sum(correct_at_state)} correct)")
    print(f"State {i}: {np.mean(tokens_num):.3f} tokens") # Represents the number of tokens in tokens_num

tokens_num

State 0: 0.450 accuracy (90 correct)
State 0: 2484.395 tokens


[1618,
 7016,
 1999,
 703,
 1644,
 675,
 1552,
 2192,
 944,
 3108,
 1164,
 670,
 1233,
 1010,
 2152,
 3350,
 561,
 6145,
 2101,
 2989,
 484,
 5243,
 3518,
 2580,
 1984,
 7348,
 1753,
 1092,
 1337,
 1040,
 1000,
 569,
 1136,
 1890,
 2489,
 1924,
 2999,
 640,
 801,
 1104,
 1142,
 1664,
 842,
 5262,
 2116,
 768,
 2629,
 1280,
 918,
 960,
 1874,
 3223,
 1089,
 1070,
 323,
 1372,
 556,
 708,
 1144,
 1168,
 3827,
 2652,
 3212,
 2131,
 7520,
 292,
 1100,
 1069,
 7796,
 21593,
 1244,
 2360,
 920,
 1251,
 1784,
 1158,
 1344,
 776,
 9057,
 772,
 8664,
 1672,
 1924,
 1700,
 2690,
 1474,
 493,
 1226,
 2965,
 2037,
 448,
 1304,
 2503,
 1377,
 2547,
 3272,
 8588,
 731,
 680,
 3909,
 7768,
 9750,
 442,
 11870,
 4241,
 1062,
 837,
 1035,
 11571,
 2524,
 4192,
 1302,
 2030,
 646,
 2572,
 1124,
 804,
 753,
 1782,
 1679,
 3628,
 768,
 1168,
 5508,
 2704,
 921,
 6448,
 744,
 752,
 3520,
 2067,
 1652,
 934,
 953,
 2740,
 1336,
 3532,
 1512,
 1447,
 3820,
 6254,
 1719,
 1098,
 2034,
 1011,
 8036,
 840,
 262

In [30]:
(gt[10]['trajectories'])


['## Step 1: Understand the requirements\nWe need to find the smallest positive integer that is a multiple of 30 and can be formed using only the digits 0 and 2.\n\n## Step 2: Determine the factors of 30\nThe factors of 30 are 1, 2, 3, 5, 6, 10, 15, and 30.\n\n## Step 3: Combine the digits 0 and 2 to form multiples of 30\nStarting with the smallest possible number using the digits 0 and 2, we can try different combinations that also result in a multiple of 30.\n\n## Step 4:  Check the multiples of 30 using the digits 0 and 2\nThe multiples of 30 are 30, 60, 90, 120, 150, 180, 210, 240, 270, 300, and 330. Now we need to check each one.\n\n## Step 5: Test the multiples of 30\nAmong these multiples, 60 is the first number that can be formed using the digits 0 and 2, and it is indeed a multiple of 30.\n\nThe final answer is: $\\boxed{60}$',
 '## Step 1: Understand the requirements\nWe need to find the smallest positive integer that is a multiple of 30 and can be formed using only the digit