In [None]:
from datasets import load_dataset

df=load_dataset("openai/gsm8k", "main", split="test").to_pandas().head(1250)
df.shape

In [2]:
df['ground_truth'] = df['answer'].str.extract(r'####\s*(.*)', expand=False)

In [None]:
import re
from typing import Tuple, Union, Optional

from pydantic import BaseModel

from sketch_of_thought import SoT
from doraemon import Doraemon

class Entity(BaseModel):
    question: str
    answer: str
    ground_truth: str

class ReasoningPath(BaseModel):
    question: str
    reason: str
    answer: Optional[str]=None
    ground_truth: str
    temperature: float
    tokens: int
    score: float


logger=Doraemon.get_logger(name=__name__, logfile="gsm8k_sot_dataset_builder.log")

entities=[Entity.model_validate(entity) for entity in df.to_dict(orient='records')]


def get_answer(raw_answer: str)-> str:
    answer=re.search(r"\\boxed\{(.*?)\}", raw_answer)
    if answer:
        return answer.group(1)
    return None


def process_entity(args)-> Optional[ReasoningPath]:
    
    et,paradigm,temperature=args

    try:
        prompt=SoT.get_initialized_prompt(
            paradigm=paradigm, 
            question=f"Question:{et.question}\n"
        )
        
        r_s, tokens=Doraemon.inference(
            logger=logger, 
            messages=prompt, 
            temperature=temperature
        )
        
        result=ReasoningPath(
            question=str(et.question), 
            reason=str(r_s), 
            answer=get_answer(r_s), 
            ground_truth=str(et.ground_truth), 
            temperature=float(temperature), 
            tokens=int(tokens), 
            score=0.0
        )
        return result
    except Exception as e:
        logger.error(f"Error processing quetion {et.question} at temperature {temperature} with exception {e}")
        return None

paradigm = SoT.classify_question(entities[0].question)
logger.info(paradigm)

In [None]:
temperatures = [i * 0.25 for i in range(9)]  # [0.0, 0.25, 0.5, ... ,2.0]

tasks=[]
for et in entities:
    for tp in temperatures:
        tasks.append((et, paradigm, tp))
tasks[0]

In [None]:
from tqdm import tqdm
import concurrent.futures

with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
    results=list(tqdm(executor.map(process_entity, tasks), total=len(tasks)))

In [6]:
import pandas as pd
import pickle


pd.DataFrame([rp.model_dump() for rp in results if rp is not None]).to_pickle('sots.pkl')