In [1]:
import json
import pandas as pd    

dreamprm_meta = "DreamPRM/data/meta.json"
dreamprm_train = "DreamPRM/data/train.json"
MATH_path = "MATH"

import os
MATH = {'train':{}, 'test':{}}
for i in MATH:
    path = os.path.join(MATH_path, i)
    for j in os.listdir(path):
        MATH[i][j] = [os.path.join(path, j, k) for k in os.listdir(os.path.join(path, j)) if k.endswith('.json')]
        

prm800k_path = "prm800k/prm800k/data"
prm_split_path = {"test": "prm800k/prm800k/math_splits/test.jsonl", 
                 "train": "prm800k/prm800k/math_splits/train.jsonl"}
prm800k  = {}

for phase in range(1,3):
    prm800k[f"phase{phase}"] = {}
    for type in ['train', 'test']:
        prm800k[f"phase{phase}"][type] = f"{prm800k_path}/phase{phase}_{type}.jsonl"
        


In [2]:
from datasets import  Dataset, concatenate_datasets, DatasetDict
import pandas as pd
from tqdm import tqdm


splits = ["train", "test"]


def augment(row):
    new_completions = []
    new_labels = []
    for t, l in zip(row["trajectories"], row["trajectories_labels"]):
        new_completions.append(row["completions"] + [t])
        new_labels.append(row["labels"] + [l])
    row["new_completions"] = new_completions
    row["new_labels"] = new_labels
    return row



def unnest_dataset_lists(
    dataset: Dataset,
    list_columns: str | list[str],
) -> Dataset:
    """
    Unnest list columns in a datasets.Dataset where each row contains a list of lists.
    Each inner list becomes a new row, maintaining alignment across specified columns.
    
    Args:
        dataset: The input datasets.Dataset
        list_columns: Column name(s) containing lists to unnest
        
    Returns:
        A new datasets.Dataset with unnested rows
    """
    df = dataset.to_pandas()
    
    if isinstance(list_columns, str):
        list_columns = [list_columns]
    
    # Create rows for each completion/label pair
    rows = []
    
    for idx, row in df.iterrows():
        # Get number of completions for this row
        num_completions = len(row[list_columns[0]])
        
        # For each completion
        for completion_idx in range(num_completions):
            new_row = {}
            
            # Copy non-list columns
            for col in df.columns:
                if col not in list_columns:
                    new_row[col] = row[col]
            
            # Add index to keep the reference to the prompt
            new_row['index'] = idx
            
            # Add list columns
            for col in list_columns:
                new_row[col] = row[col][completion_idx]
            
            rows.append(new_row)
    
    # Create new dataframe
    result_df = pd.DataFrame(rows)
    
    # Convert back to datasets.Dataset
    return Dataset.from_pandas(result_df)

full_ds = DatasetDict()

for split in splits:
    print("Split:", split)
    datasets = []
    for i, ds in enumerate(prm800k):
        ds = prm800k[ds]
        ds = pd.read_json(path_or_buf=ds[split], lines=True)
        print("phase:", i + 1)
        new_ds = []
        for row in tqdm(ds.iterrows(), total=len(ds)):
            row = row[1]  # Get the row data
            solution = []
            trajectories = []
            trajectories_labels = []
            label = []
            new_row = {}
            new_row["prompt"] = row["question"]["problem"]

            for step in row["label"]["steps"]:
                completions = step["completions"]
                if not completions:
                    continue
                if step["chosen_completion"] is None:
                    # Haven't found any relevant point here, discard them for the moment
                    #continue
                    if len(completions) > 1:
                        trajectories = []
                        trajectories_labels = []
                        for c in completions:
                            trajectories.append(c["text"])
                            trajectories_labels.append(not (c["rating"] == -1))
                        # We can continue here assuming there shouldn't be any more trajectories after this point
                        continue

                else:
                    completion = completions[step["chosen_completion"]]
                    solution.append(completion["text"])
                    label.append(not (completion["rating"] == -1))
                    continue

            new_row["trajectories"] = trajectories
            new_row["trajectories_labels"] = trajectories_labels
            new_row["completions"] = solution
            new_row["labels"] = label
            new_ds.append(new_row)

        new_ds = Dataset.from_list(new_ds)

        # Augment the dataset by adding the trajectories as completions
        ds_updated = new_ds.map(augment, num_proc=8)

        # Similar to a pd.DataFrame.explode, get a row per each nested list in the dataset,
        # so that we have all the possible alternative completions in separate rows
        print("Start unnesting...")
        updated = unnest_dataset_lists(
            ds_updated.select_columns(
                ["prompt", "new_completions", "new_labels"]
            ),
            ["new_completions", "new_labels"]
        ).rename_columns(
            {"new_completions": "completions", "new_labels": "labels"}
        ).select_columns(["prompt", "completions", "labels", "index"])  # Sort the column names

        datasets.append(updated)
    
    full_ds[split] = concatenate_datasets(datasets)

Split: train
phase: 1


100%|██████████| 949/949 [00:00<00:00, 17068.88it/s]


Map (num_proc=8):   0%|          | 0/949 [00:00<?, ? examples/s]

Start unnesting...
phase: 2


100%|██████████| 97782/97782 [00:05<00:00, 18664.27it/s]


Map (num_proc=8):   0%|          | 0/97782 [00:00<?, ? examples/s]

Start unnesting...
Split: test
phase: 1


100%|██████████| 106/106 [00:01<00:00, 78.54it/s]


Map (num_proc=8):   0%|          | 0/106 [00:00<?, ? examples/s]

Start unnesting...
phase: 2


100%|██████████| 2762/2762 [00:00<00:00, 22473.77it/s]


Map (num_proc=8):   0%|          | 0/2762 [00:00<?, ? examples/s]

Start unnesting...


In [70]:
full_ds['test'].to_pandas()

Unnamed: 0,prompt,completions,labels,index
0,Three pencils and a jumbo eraser cost $\$1.24$...,[Let's call the price of a pencil p and the pr...,"[True, True, True, True, True, False]",0
1,Three pencils and a jumbo eraser cost $\$1.24$...,[Let's call the price of a pencil p and the pr...,"[True, True, True, True, True, False]",0
2,Three pencils and a jumbo eraser cost $\$1.24$...,[Let's call the price of a pencil p and the pr...,"[True, True, True, True, True, False]",0
3,Three pencils and a jumbo eraser cost $\$1.24$...,[Let's call the price of a pencil p and the pr...,"[True, True, True, True, True, False]",0
4,Three pencils and a jumbo eraser cost $\$1.24$...,[Let's call the price of a pencil p and the pr...,"[True, True, True, True, True, False]",0
...,...,...,...,...
10241,Simplify $\frac{(10r^3)(4r^6)}{8r^4}$.,[To simplify a fraction involving powers of th...,"[True, True, True, False]",2760
10242,Simplify $\frac{(10r^3)(4r^6)}{8r^4}$.,[To simplify a fraction involving powers of th...,"[True, True, True, True]",2760
10243,Simplify $\frac{(10r^3)(4r^6)}{8r^4}$.,[To simplify a fraction involving powers of th...,"[True, True, True, False]",2760
10244,Simplify $\frac{(10r^3)(4r^6)}{8r^4}$.,[To simplify a fraction involving powers of th...,"[True, True, True, True]",2760


In [72]:
exhaustive=True

from datasets import Dataset
dataset = DatasetDict()
for split_type in ['test', 'train']:
    df = full_ds[split_type].to_pandas()
    train_split = pd.read_json(path_or_buf=prm_split_path[split_type], lines=True)
    df = df.merge(train_split, left_on='prompt', right_on='problem', how='inner')
    df = df.drop(columns=['solution', 'problem'])
    
    if exhaustive:
        rows = []
        for idx, row in tqdm(df.iterrows(), total=len(df)):
            for i in range(len(row["completions"])):
                new_row = {}
                new_row["prompt"] = row["prompt"]
                new_row["answer"] = row["answer"]
                new_row["subject"] = row["subject"]
                new_row["level"] = row["level"]
                new_row["completions"] = row["completions"][:i+1]
                new_row["labels"] = row["labels"][:i+1]
                new_row["index"] = row["index"]
                rows.append(new_row)
    
        df = pd.DataFrame(rows)
    
    
    
    df['completions_str'] = df['completions'].apply(lambda x: str(x))
    df['labels_str'] = df['labels'].apply(lambda x: str(x))
    df['prompt_str'] = df['prompt'].astype(str)
    df['unique_id'] = df['completions_str'] + df['prompt_str'] + df['labels_str']
    print(f"Total unique completions in train set: {len(df['unique_id'].unique())} out of {len(df)} total completions.")

    ## only keep unique rows
    df = df.drop_duplicates(subset=['unique_id'])
    df = df.drop(columns=['completions_str', 'labels_str', 'prompt_str', 'unique_id'])

    print("Length of dataset after removing duplicates:", len(df))
    dataset[split_type] = Dataset.from_pandas(df)

100%|██████████| 10246/10246 [00:01<00:00, 9239.76it/s] 


Total unique completions in train set: 20849 out of 64171 total completions.
Length of dataset after removing duplicates: 20849


100%|██████████| 377938/377938 [00:41<00:00, 9031.40it/s] 


Total unique completions in train set: 651944 out of 2378449 total completions.
Length of dataset after removing duplicates: 651944


In [73]:
dataset['train'], dataset['test']

(Dataset({
     features: ['prompt', 'answer', 'subject', 'level', 'completions', 'labels', 'index', '__index_level_0__'],
     num_rows: 651944
 }),
 Dataset({
     features: ['prompt', 'answer', 'subject', 'level', 'completions', 'labels', 'index', '__index_level_0__'],
     num_rows: 20849
 }))

In [None]:
# dataset.save_to_disk('PRM800k_cleaned')
# dataset.save_to_disk('PRM800k_cleaned_exhaustive')

Saving the dataset (0/1 shards):   0%|          | 0/20849 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/651944 [00:00<?, ? examples/s]

In [None]:
### set API key, through environment variable or directly
# dataset.push_to_hub("FrozenWolf/prm800k-exhaustive", token="")

# dataset.push_to_hub("FrozenWolf/prm800k", token="")

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/21 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/2 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/326 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/326 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/FrozenWolf/prm800k-exhaustive/commit/9dcac23fd0f5624d2838dd9d950a37ec78c54733', commit_message='Upload dataset', commit_description='', oid='9dcac23fd0f5624d2838dd9d950a37ec78c54733', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/FrozenWolf/prm800k-exhaustive', endpoint='https://huggingface.co', repo_type='dataset', repo_id='FrozenWolf/prm800k-exhaustive'), pr_revision=None, pr_num=None)

In [248]:
dataset

DatasetDict({
    test: Dataset({
        features: ['prompt', 'completions', 'labels', 'index', 'answer', 'subject', 'level'],
        num_rows: 9956
    })
    train: Dataset({
        features: ['prompt', 'completions', 'labels', 'index', 'answer', 'subject', 'level'],
        num_rows: 327797
    })
})

In [57]:
sample = dataset['test'][0]

In [59]:
sample['completions']

["Let's call the price of a pencil p and the price of a jumbo eraser e. Then we can write two equations.",
 "To solve this system, let's subtract the first equation from the second equation. This will eliminate e.",
 '$5p+e-3p-e=1.82-1.24$.',
 'This simplifies to $2p=0.58$. So $p=0.29$.',
 'That means a pencil costs 29 cents.\n\n# Answer\n\n29',
 'The first equation is 3p + e = 124, and the second equation is 5p + e = 182.']

In [60]:
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForTokenClassification

model_name = "Qwen/Qwen2.5-Math-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [235]:
import DreamPRM.prm_data
from DreamPRM.prm_data import QwenMathMetaDataset, QwenMathDataset
## reload this dataset class
import importlib
importlib.reload(DreamPRM.prm_data)
meta_ds = QwenMathMetaDataset(dataset['test'], tokenizer)
math_ds = QwenMathDataset(dataset['test'], tokenizer, meta_ds)

In [236]:
batch = next(iter(meta_ds))

In [246]:
from transformers import DataCollatorForTokenClassification, DataCollatorWithPadding
from torch.utils.data import DataLoader
import torch
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, padding=True, return_tensors="pt")
pad_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")




def collate_merge_minibatch(batch):
    '''batch is a list of dictionaries'''

  
    names = ["input_ids", "attention_mask", "labels", "index"]
    merged = {}

    for name_idx, item in enumerate(zip(*batch)):
        j = []
        for i in item:
            j+=i
        
        if isinstance(j[0], torch.Tensor):
            max_len = max([i.shape[-1] for i in j ])
        for i in range(len(j)):    
            if isinstance(j[i], torch.Tensor):
                if names[name_idx] == "input_ids":
                    j[i] = torch.cat((j[i], torch.ones(1, max_len - j[i].shape[-1])*tokenizer.pad_token_id), dim=-1)
                elif names[name_idx] == "attention_mask":
                    j[i] = torch.cat((j[i], torch.zeros(1, max_len - j[i].shape[-1])), dim=-1)
                elif names[name_idx] == "labels":
                    j[i] = torch.cat((j[i], torch.ones(1, max_len - j[i].shape[-1])*-100), dim=-1)

        if isinstance(j[0], torch.Tensor):
            j = torch.cat(j, dim=0)
        
        merged[names[name_idx]] = j
        
    return merged
        
        

       
                
def collate_fn(batch):
    # batch = collate_merge_minibatch(batch)
    print([i['input_ids'].shape for i in batch])
    batch = data_collator(batch)
    return batch
    
new_tokens = ['<PRM_STEP_SCORE>']
num_added_tokens = tokenizer.add_tokens(new_tokens)
        
dataloader = DataLoader(
    math_ds,
    batch_size=8,
    collate_fn=collate_merge_minibatch,
    shuffle=False
)

sample = next(iter(dataloader))

In [247]:
sample['input_ids'].shape, sample['attention_mask'].shape, sample['labels'].shape, sample['index']

(torch.Size([8, 219]),
 torch.Size([8, 219]),
 torch.Size([8, 219]),
 [0, 0, 0, 0, 0, 0, 0, 0])

In [5]:
from datasets import load_from_disk
dataset = load_from_disk("PRM800k_cleaned")
dataset

DatasetDict({
    test: Dataset({
        features: ['prompt', 'completions', 'labels', 'index', 'answer', 'subject', 'level'],
        num_rows: 9956
    })
    train: Dataset({
        features: ['prompt', 'completions', 'labels', 'index', 'answer', 'subject', 'level'],
        num_rows: 327797
    })
})

In [12]:
sample = dataset['train'][100]
for completition, label in zip(sample['completions'], sample['labels']):
    print(f"Completion: {completition}, Label: {label}")

Completion: I think we should start by letting the number of people taking French be a variable, say $f$., Label: True
Completion: And then the number of people taking English is $2f$ because there are twice as many people in the English class as there are in the French class., Label: True
Completion: That's right. Now we can use the fact that there are 25 people taking either English or French., Label: True
Completion: So we need to subtract 2 from the sum. That gives $2f + f - 2 = 25$. Then we can combine like terms to get $3f - 2 = 25$., Label: True
Completion: So $3f = 27$. Then $f = 9$., Label: True
Completion: And that means $2f = 18$., Label: True
Completion: So the number of people taking English but not French is $18 - 2 = 16$., Label: True
Completion: Right. That's the final answer.

# Answer

16, Label: True
Completion: So $f + 2f = 25$ people taking either English or French., Label: False


In [13]:
dataset['train']

Dataset({
    features: ['prompt', 'completions', 'labels', 'index', 'answer', 'subject', 'level'],
    num_rows: 327797
})

In [26]:
df = dataset['train'].to_pandas()
df

Unnamed: 0,prompt,completions,labels,index,answer,subject,level
0,How many positive two-digit integers leave a r...,[So if a number leaves a remainder of 2 when d...,"[True, True, True, True, True, True, True, Tru...",1,12,Number Theory,2
1,How many positive two-digit integers leave a r...,[So if a number leaves a remainder of 2 when d...,"[True, True, True, True, True, True, True, Tru...",1,12,Number Theory,2
2,How many positive two-digit integers leave a r...,[So if a number leaves a remainder of 2 when d...,"[True, True, True, True, True, True, True, Tru...",1,12,Number Theory,2
3,How many positive two-digit integers leave a r...,[So if a number leaves a remainder of 2 when d...,"[True, True, True, True, True, True, True, Tru...",1,12,Number Theory,2
4,How many positive two-digit integers leave a r...,[So if a number leaves a remainder of 2 when d...,"[True, True, True, True, True, True, True, Tru...",1,12,Number Theory,2
...,...,...,...,...,...,...,...
327792,Find the greatest common divisor of 12 and 20.,"[To find the greatest common divisor, we need ...",[False],97732,4,Prealgebra,2
327793,Find the greatest common divisor of 12 and 20.,[Let's write out the prime factorization of 12.],[True],97732,4,Prealgebra,2
327794,Find the greatest common divisor of 12 and 20.,[Let's write the prime factorization for each ...,[True],97732,4,Prealgebra,2
327795,Find the greatest common divisor of 12 and 20.,[Let's prime factorize 12. 12 is divisible by ...,[True],97732,4,Prealgebra,2


In [11]:
neg = 0
pos = 0
neg_per_sent = 0
for label in dataset['train']['labels']:
    c = False
    for l in label:
        if l:
            pos += 1
        else:
            neg += 1
            if not c:
                neg_per_sent += 1
                c = True
            
print(f"Total positive labels in train set: {pos*100/(pos+neg)}, negative labels: {neg*100/(pos+neg)}")
print(f"Atleast 1 Neg labels per sentence: {neg_per_sent*100 / len(dataset['train'])}")

Total positive labels in train set: 89.96407002368821, negative labels: 10.035929976311786
Atleast 1 Neg labels per sentence: 60.82239922879099
