In [None]:
import pandas as pd

splits = {'train': 'data/train-00000-of-00001.parquet', 'test': 'data/test-00000-of-00001.parquet'}
df = pd.read_parquet("hf://datasets/nlile/hendrycks-MATH-benchmark/" + splits["train"])
df.shape

In [None]:
algebra_df = df[df['subject'] == 'Algebra']
algebra_df.shape

In [3]:
algebra_df=algebra_df.head(1200)

# Model the Dataset

In [None]:
from typing import Optional, List, Tuple, Union, Dict
from pydantic import BaseModel, Field

class Entity(BaseModel):
    problem: str
    solution: str
    answer: str
    subject: str
    level: int
    unique_id: str

class ReasoningPath(BaseModel):
    question: str
    reason: str
    answer: Optional[str]=None
    ground_truth: str
    temperature: float
    tokens: int
    score: float

entities = [Entity.model_validate(row) for row in algebra_df.to_dict(orient='records')]
print(len(entities))
print(entities[0])

# Generate SoTs

In [None]:
import re

from sketch_of_thought import SoT
from doraemon import Doraemon
from relaxed_fda import RelaxedFDA

logger=Doraemon.get_logger(name=__name__, logfile="math_dataset_builder.log")

paradigm = SoT.classify_question(entities[0].problem)
logger.info(paradigm)
assert "chunked_symbolism"==str(paradigm)

In [6]:
temperatures = [i * 0.25 for i in range(9)]  # [0.0, 0.25, 0.5, ... ,2.0]

tasks=[]
for et in entities:
    for tp in temperatures:
        tasks.append((et,paradigm,tp))

In [None]:
from tqdm import tqdm
import concurrent.futures

def process_entity(args)-> Optional[Dict[str, str]]:
    et,paradigm,temperature=args
    try:
        prompt=SoT.get_initialized_prompt(paradigm=paradigm, question=f"Question:{et.problem}")
        r_s, tokens=Doraemon.inference(logger=logger, messages=prompt, temperature=temperature)
        result=ReasoningPath(
            question=et.problem, 
            reason=str(r_s), 
            answer=RelaxedFDA.get_answer(r_s),
            ground_truth=str(et.answer),
            temperature=float(temperature), 
            tokens=int(tokens), 
            score=0.0)
        return result.model_dump()
    except Exception as e:
        logger.error(f"Error processing quetion {et.problem} at temperature {temperature} with exception {e}")
        return None

with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
    results=list(tqdm(executor.map(process_entity, tasks), total=len(tasks)))

In [8]:
results_filtered:List[Dict[str,str]] = [r for r in results if r is not None]

final_df=pd.DataFrame(results_filtered)

final_df.to_pickle('sots_df.pkl')

In [None]:
import matplotlib.pyplot as plt

# Tokens distribution
plt.figure()
final_df['tokens'].hist()
plt.title("Tokens per Example")
plt.xlabel("tokens")
plt.ylabel("Number of SoTs")
plt.show()