In [23]:
from glob import glob

from datasets import load_dataset
from sal.utils.math import *
from sal.utils.grader import *

from sal.utils.qwen_math_parser import *
from sal.utils.data import get_dataset, save_dataset
from collections import defaultdict
import json
import numpy as np
import pickle
from sal.config import Config

import os

In [24]:
config = Config()
config.datset_name = "/ssdscratch/byuan48/particle_filtering/probabilistic-inference-scaling/datasets/math500.jsonl"
config.dataset_split = "train"
gt = get_dataset(config)


In [25]:
def is_correct(sample, key):
    ans = memoized_canonical_form(sample['answer'])
    pred = memoized_canonical_form(sample[f'pred_{key}']).strip("\\boxed{").strip("}")

    sample['is_correct_' + key] = math_equal(ans, pred)
    return sample


def parse_responses(sample):
    sample['parsed_responses'] = [strip_string(extract_answer(x, 'math')) for x in sample['completions']]
    return sample


def extract(string):
    return strip_string(extract_answer(string, 'math'))

In [26]:
files = glob('/ssdscratch/byuan48/particle_filtering/probabilistic-inference-scaling/output/p2/*.pkl')

data = {}
for file in files:
    with open(file, 'rb') as f:
        if file.split('/')[-1].startswith('batch'):
                continue
                
        # if not os.path.exists('/shiv/search-and-learn/llama1b_jan11_mathShephard_vs_deepSeek_PRM_comparison/seed96/temp_1/p128/' + file.split('/')[-1]):
        #         continue
        print("file: ", file)
        data[f"{file.split('/')[-1].replace('.pkl', '.json')}"] = pickle.load(f)

file:  /ssdscratch/byuan48/particle_filtering/probabilistic-inference-scaling/output/p2/test_intermediate_algebra_1994.pkl
file:  /ssdscratch/byuan48/particle_filtering/probabilistic-inference-scaling/output/p2/test_algebra_1349.pkl
file:  /ssdscratch/byuan48/particle_filtering/probabilistic-inference-scaling/output/p2/test_precalculus_807.pkl
file:  /ssdscratch/byuan48/particle_filtering/probabilistic-inference-scaling/output/p2/test_number_theory_572.pkl
file:  /ssdscratch/byuan48/particle_filtering/probabilistic-inference-scaling/output/p2/test_algebra_2584.pkl


In [27]:
unique_ids = list(data.keys())
gt = gt.map(lambda x: {'unique_id': x['unique_id'].replace('/', '_')})
gt = gt.filter(lambda x: x['unique_id'] in unique_ids)

Filter: 100%|██████████| 500/500 [00:00<00:00, 59226.52 examples/s]


In [28]:
len(set(unique_ids))

5

In [29]:
gt

Dataset({
    features: ['problem', 'solution', 'answer', 'subject', 'level', 'unique_id'],
    num_rows: 5
})

In [34]:
def get_pg_response(sample, pg_data):
    pg_sample = pg_data[sample['unique_id']]
    # Track accuracy at each state
    state_accuracies = []
    state_answers = []
    assert len(pg_sample) == 1 
    for state in pg_sample:
        rewards = [x.rewards[-1] for x in state]
        tokens_num = np.array([sum(x.tokens_num) for x in state]).sum()
        best_particle = state[np.argmax(rewards)]


        # logits = [inverse_sigmoid(r) for r in rewards]
        # logits = np.array(logits)
        # weights = softmax(logits/.5)
        # best_particle = np.random.choice(state, p=weights)
        trajectory = "\n\n".join(best_particle.trajectory)
        answer = extract(trajectory)
        is_correct = math_equal(memoized_canonical_form(sample['answer']), memoized_canonical_form(answer))
        state_accuracies.append(is_correct)
        state_answers.append(answer)
    
    # Still store the final prediction and accuracy
    # last_state = pg_sample[-1]
    # rewards = [x.rewards[-1] for x in last_state]
    # best_particle = last_state[np.argmax(rewards)]
    # trajectory = "\n\n".join(best_particle.trajectory)
    # answer = extract(trajectory)
    # sample['pred_pg'] = answer
    sample['is_correct_pg_states'] = state_accuracies
    sample['preds'] = state_answers
    sample['tokens_num'] = tokens_num
    sample['is_correct'] = any(state_accuracies)

    return sample


In [35]:
gt = gt.map(get_pg_response, fn_kwargs={'pg_data': data})

Map: 100%|██████████| 5/5 [00:00<00:00, 240.94 examples/s]


In [40]:
#num_states_to_test = 2  # Based on seeing 7 states in the data from context
for i in range(len(gt[0]['is_correct_pg_states'])):
    correct_at_state = [d['is_correct_pg_states'][i] for d in gt]
    tokens_num = [d['tokens_num'] for d in gt]
    print(f"State {i}: {np.mean(correct_at_state):.3f} accuracy ({sum(correct_at_state)} correct)")
    print(f"State {i}: {np.mean(tokens_num):.3f} tokens") # Represents the number of tokens in tokens_num

tokens_num

State 0: 0.600 accuracy (3 correct)
State 0: 968.400 tokens


[684, 2506, 737, 303, 612]