In [1]:
import os
import re

from dotenv import load_dotenv, find_dotenv
from langchain_openai import ChatOpenAI

from agent.utils.loader import load_processed_data
from agent.utils.tools import python_interpreter

_ = load_dotenv(find_dotenv())

In [2]:
dataset_name = 'tabmwp'
mode = "self-improve"
base_mode = "pot"
model = "gpt-4o-mini-2024-07-18"
num_samples = -1
top_p = 0.95
temperature = 0
seed = 42
batch_size = 100
if base_mode == "pot":
	processed_data_path = f"../../output/inference/gpt-4o-mini/{dataset_name}/pot/num_samples_-1_top_p_0.95_temperature_0_seed_42.jsonl"
else:
	processed_data_path = f"../../data/processed_data/{dataset_name}.jsonl"
dataset = load_processed_data(dataset_name=dataset_name, file_path=processed_data_path)
if dataset_name == "tabmwp":
	dataset = dataset.map(lambda x: {"question": f"{x["context"]}\n{x["question"]}", "initial_answer": f"```python\n{x["code"]}```"})
	dataset = dataset.remove_columns(["prediction", "code"])
elif dataset_name in ["gsmhard", "gsm8k"]:
	dataset = dataset.rename_column("code", "initial_answer")
	dataset = dataset.remove_columns("prediction")
	dataset = dataset.map(lambda x: {"question": f"For the following math question, just focus on the calculation process without considering any realistic factors.\n{x["question"]}", "initial_answer": f"```python\n{x["initial_answer"]}```"})
elif dataset_name == "math":
	from agent.utils.math_util import last_boxed_only_string
	def remove_boxed(s):
		left = "\\boxed{"
		try:
			assert s[:len(left)] == left
			assert s[-1] == "}"
			return s[len(left):-1]
		except:
			return None
	dataset = dataset.map(lambda x: {"answer": remove_boxed(last_boxed_only_string(x["solution"]))})
if num_samples > 0:
	dataset = dataset.select(range(num_samples))
print(dataset[10])
# calculator = calculator
python_interpreter = python_interpreter

{'context': 'Read the following table regarding "None" and then answer a question.\n\nfine gravel | $2 per lb\npebbles | $3 per lb\nblack sand | $3 per lb\nrocks | $3 per lb\ncoarse gravel | $3 per lb\nwhite sand | $5 per lb', 'question': 'Read the following table regarding "None" and then answer a question.\n\nfine gravel | $2 per lb\npebbles | $3 per lb\nblack sand | $3 per lb\nrocks | $3 per lb\ncoarse gravel | $3 per lb\nwhite sand | $5 per lb\nBrenda purchased 1.1 pounds of coarse gravel. What was the total cost?', 'answer': '3.30', 'ques_type': 'free_text', 'choices': None, 'initial_answer': '```python\n# Python code, return answer \nprice_per_pound_coarse_gravel = 3 \npounds_purchased = 1.1 \n# Calculate total cost\ntotal_cost = price_per_pound_coarse_gravel * pounds_purchased\nanswer = total_cost\nprint(answer)```'}


In [3]:

from typing import  List
from typing_extensions import TypedDict
from langgraph.graph import   END
from langchain_core.prompts import ChatPromptTemplate

class Step(TypedDict):
	step: str
	error_prone_points: str
	evidence: str
	result: str

class State(TypedDict):
	question: str
	initial_answer: str
	code: str
	result: str
	step_list: List[Step]
	prediction: str
	answer: str


In [4]:
from langchain import hub

plan_prompt:ChatPromptTemplate = hub.pull("arietem/math_plan")

plan = plan_prompt | ChatOpenAI(
    model=model, temperature=0, top_p=0.95, n=1, base_url="https://api.chsdw.top/v1"
)


async def plan_step(state: State):
	assert state["question"] is not None
	assert state["initial_answer"] is not None
	try:
		plan_response = await plan.ainvoke({"question": state["question"], "initial_answer": state["initial_answer"]})
	except Exception as e:
		print("plan_step error", e)
		return {"prediction": "None"}

	step_list = re.findall(r'<step>(.*?)</step>', plan_response.content, re.DOTALL)
	step_list = [{"step": step, "error_prone_points": [], "evidence": None, "result": None} for step in step_list]
	return {
		"step_list": step_list,
	}

In [5]:
error_prone_identification_prompt:ChatPromptTemplate = hub.pull("arietem/math_error_prone_identification")
error_prone_identification = error_prone_identification_prompt | ChatOpenAI(
    model="gpt-4o-mini-2024-07-18", temperature=0, top_p=0.95, n=1, base_url="https://api.chsdw.top/v1"
)

async def error_prone_identification_step(state: State):
	assert state["step_list"] is not None
	step_list = "\n".join([f"step: {step["step"]}" for step in state["step_list"]])
	try:
		error_prone_identification_response = await error_prone_identification.ainvoke({"question": state["question"], "initial_answer": state["initial_answer"], "step_list": step_list})
	except Exception as e:
		print("error_prone_identification error", e)
		return {"cot_message": []}
	error_prone_points_list = error_prone_identification_response.content.split("step:")[1:]
	error_prone_points_list	= [re.findall(r'point: (.*?)\n', error_prone_point, re.DOTALL) for error_prone_point in error_prone_points_list]
	step_list = [
		{**step, "error_prone_points": error_prone_points}  # 新字典，包含更新
		for step, error_prone_points in zip(state["step_list"], error_prone_points_list)
	]
	return {
		"step_list": step_list
	}

In [6]:
if dataset_name in ["gsmhard", "gsm8k"]:
	code_generation_prompt = hub.pull("arietem/pot_generation")
elif dataset_name in ["tabmwp"]:
	code_generation_prompt = hub.pull("arietem/tabmwp_pot_generation")
else:
	code_generation_prompt = hub.pull("arietem/math_code_generation")

code_generation = code_generation_prompt | ChatOpenAI(
    model="gpt-4o-mini-2024-07-18", temperature=0, top_p=0.95, n=1, base_url="https://api.chsdw.top/v1"
)

async def code_generation_step(state: State):
	assert state["step_list"] is not None
	guidance = [
    f"step: {step['step']}\n" + "\n".join([f"point: {point}" for point in step["error_prone_points"]])
    for step in state["step_list"]
	]
	try:
		response = await code_generation.ainvoke({"question": state["question"], "guidance": "\n".join(guidance) + "\n Your response should follow previous pattern, which is a code block."})
	except Exception as e:
		print("code generation step error", e)
		return {"code": "None", "result": "None"}
	result = await python_interpreter.arun(response.content)
	return {"code": response.content, "result": result}

In [7]:
final_answer_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            """Using the provided evidence, answer the question by extracting only the specific information required. \
Your response should contains two part, the first part is the fusion of the Revising Process and the second part is the final answer. \
In the final answer, do not include any explanations, context, or additional information. Just focus on delivering the exact answer as \
concisely as possible!!! There is no need to answer the question in the form of a complete sentence, just provide the answer in the form \
of a noun, time, entity, single number, yes or no, etc.

Each part of your response should be enclosed by a XML tag, following the format:
<fusion>The fusion of the Revising Process</fusion>
<final_answer>The final answer</final_answer>
""",
        ),
        ("placeholder", "{messages}"),
    ]
)

get_final_answer = final_answer_prompt | ChatOpenAI(
	model=model, temperature=0, top_p=0.95, n=1, base_url="https://api.chsdw.top/v1"
)

async def final_answer_step(state: State):
	question = f"Question: {state["question"]}"
	guidance = "\n".join([
    f"step: {step['step']}\n" + "\n".join([f"point: {point}" for point in step["error_prone_points"]])
    for step in state["step_list"]
	])
	code = f"```python\n{state["code"]}```"
	result = state["result"]
	fusion = "None"
	prediction = "None"
	try:
		final_answer = await get_final_answer.ainvoke({"messages": [f"{question}\n\n{guidance}\n\n{code}\n\n{result}"]})
		fusion = re.findall(r"<fusion>(.*?)</fusion>", final_answer.content, re.DOTALL)[-1]
		prediction = re.findall(r"<final_answer>(.*?)</final_answer>", final_answer.content, re.DOTALL)[-1]
	except Exception as e:
		print("final_answer_step", e)
		return {"fusion":fusion, "prediction": prediction}

	return {"fusion": fusion, "prediction": prediction}

In [8]:
from langgraph.graph import StateGraph, START

workflow = StateGraph(State)

# Add the plan node
workflow.add_node("plan", plan_step)

# Add the error identification node
workflow.add_node("error_identification", error_prone_identification_step)

workflow.add_node("code_generation", code_generation_step)

workflow.add_node("get_final_answer", final_answer_step)

workflow.add_edge(START, "plan")

# From plan we go to error identification
workflow.add_edge("plan", "error_identification")

workflow.add_edge("error_identification", "code_generation")

workflow.add_edge("code_generation", "get_final_answer")

workflow.add_edge("get_final_answer", END)

# Finally, we compile it!
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
app = workflow.compile()

In [9]:
async for event in app.astream({**dataset[900]},):
    for k, v in event.items():
        if k != "__end__":
            print(v)
print(dataset[900])

{'step_list': [{'step': ' Identify the waiting times for April and May from the table. ', 'error_prone_points': [], 'evidence': None, 'result': None}, {'step': ' Calculate the difference in waiting times between April and May to find the rate of change. ', 'error_prone_points': [], 'evidence': None, 'result': None}, {'step': ' Execute the Python code to compute the rate of change. ', 'error_prone_points': [], 'evidence': None, 'result': None}]}
{'step_list': [{'step': ' Identify the waiting times for April and May from the table. ', 'error_prone_points': ['There is potential for error if the administrator misreads the table or records the wrong values for April and May. It is important to double-check the values: April is 18 minutes and May is 17 minutes.'], 'evidence': None, 'result': None}, {'step': ' Calculate the difference in waiting times between April and May to find the rate of change. ', 'error_prone_points': ['The description incorrectly states that the rate of change is simp

In [11]:
from tqdm.asyncio import tqdm, tqdm_asyncio
import nest_asyncio
import json

nest_asyncio.apply()
results = []
async def process(item):
	try:
		state = await app.ainvoke({**item})
		return {**item, **state}
	except:
		return {**item, "prediction": "None"}
async def self_improve_inference() -> None:
    error_indices = []  # 用于记录包含 "ERROR" 的条目索引

    # 读取已有结果或初始化文件
    if os.path.exists(save_results_path):
        with open(save_results_path, 'r') as file:
            for idx, line in enumerate(file):
                result = json.loads(line)
                results.append(result)
                # 检查是否存在 "prediction: ERROR"
                if not result.get("code"):
                    error_indices.append(idx)
    else:
        folder_path = os.path.dirname(save_results_path)
        os.makedirs(folder_path, exist_ok=True)
    # print(error_indices)
    # raise Exception("stop")

    # 重新推理错误的数据
    if error_indices:
        print(f"Found {len(error_indices)} ERROR entries. Retrying inference...")
        error_data = [dataset[idx] for idx in error_indices]
        new_results = await tqdm_asyncio.gather(*(process(item) for item in error_data))
        # 更新原始结果
        for i, new_result in zip(error_indices, new_results):
            results[i] = new_result

    for idx in tqdm(range(len(results), dataset.num_rows, batch_size)):
	    batch = dataset.select(range(idx, min(idx+batch_size, dataset.num_rows)))
	    results.extend(await tqdm_asyncio.gather(*(process(item) for item in batch)))
	    with open(save_results_path, 'w') as file:
		    for result in results:
			    file.write(json.dumps(result) + "\n")


In [12]:
await self_improve_inference()

Found 39 ERROR entries. Retrying inference...


100%|██████████| 39/39 [04:23<00:00,  6.76s/it]
0it [00:00, ?it/s]


In [13]:
with open(save_results_path, 'w') as file:
    for result in results:
	    file.write(json.dumps(result) + "\n")