In [None]:
import os
import json

from langchain_core.outputs import LLMResult
from agent.utils.tools.google_search import GoogleSearchTool
from langchain_core.prompt_values import PromptValue

from agent.utils.loader import load_prompt, load_processed_data
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv, find_dotenv
from agent.utils.tools.interpreter_api import safe_execute

_ = load_dotenv(find_dotenv())

In [None]:
dataset_name = 'gsm8k'
mode = "critic"
basemode = "pot"
model = "gpt-4o-mini-2024-07-18"
num_samples = -1
top_p = 0.95
temperature = 0
n=1
seed = 42
batch_size = 100
if "gpt-4o-mini" in model:
    processed_data_path = f"../../../output/inference/gpt-4o-mini/{dataset_name}/{basemode}/num_samples_{num_samples}_top_p_{top_p}_temperature_0_seed_{seed}.jsonl"
else:
	processed_data_path = f"../../../output/inference/{model}/{dataset_name}/{basemode}/num_samples_{num_samples}_top_p-{top_p}_temperature_0_seed_{seed}.jsonl"
save_results_path = f"../../../output/inference/{model}/{dataset_name}/{mode}/num_samples_{num_samples}_top_p_{top_p}_temperature_{temperature}_seed_{seed}.jsonl"
prompt = load_prompt(dataset_name=dataset_name, mode=mode)
dataset = load_processed_data(dataset_name=dataset_name, file_path=processed_data_path)
if num_samples > 0:
	dataset = dataset.select(range(num_samples))
llm = ChatOpenAI(model=model, top_p=top_p, n=n, temperature=temperature, base_url="https://api.chsdw.top/v1", seed=seed)

prompt.pretty_print()
print(dataset[2])

In [None]:
from tqdm.asyncio import tqdm_asyncio
from tqdm import tqdm
from utils import remove_comment, floatify_ans, finqa_equal
MAX_ITERATION = 7
async def critic_iter(sample: dict) -> dict:
	import re
	def extract_code(text: str):
		pattern = r'```python(.*?)```'
		match = re.search(pattern, text, re.DOTALL)
		if match:
			return match.group(1).strip()
		else:
			return text
	sample['answer'] = sample['answer'].replace('\n', '')
	for itr in range(1, 4):
		if itr == 1:
			prediction, report = safe_execute(sample["code"])
			# print("Is initial program correct:", sample['answer'] == prediction)
			sample['prediction'] = [prediction]
			sample['report'] = [report]
			sample['code'] = [sample['code']]
		# print("\n" + "-" * 20, "iteration", itr, "-" * 20)
		
		# criticize latest answer that is not "None"
		base_idx = itr - 1
		while base_idx > 0 and sample['prediction'][base_idx] is None:
			base_idx -= 1

		previous_code = remove_comment(sample['code'][base_idx])

		# construct prompt
		context = f"Question: {sample['question']}\n"
		context += f"```python\n{previous_code}\n```\n"
		context += f"Execution: {sample['report'][base_idx]}\n"
		context += f"Output: answer = {floatify_ans(sample['prediction'][base_idx])}\n"
		context += "\nWhat's the problem with the above code?\n\n"
		prompt_critic = prompt.invoke(input={"context": context})

		# verify previous code
		try:
			result = await llm.ainvoke(prompt_critic, stop=["Here's", "---"])
		except:
			result = None
			
		context += result.content if result else ""
		# print(context)
		
		# if context not end with a "\n", add "\n"
		if context and context[-1] != "\n":
			context += "\n"

		# print("="*10, "生成新代码", "="*10)
		# generate new code
		context += "Here's a better solution:\n"
		prompt_critic = prompt.invoke(input={"context": context})
		# print(context)

		try:
			result = await llm.ainvoke(prompt_critic, stop=["```\n", "---"])
		except:
			result = None

		# print(result.content)
		# excute new code
		code = extract_code(result.content) if result else ""
		prediction, report = safe_execute(code)
		prediction = floatify_ans(prediction)
		corrected = True
		# print("Execution:", report)
		# print("Output: answer =", prediction)

		if code.strip() == sample['code'][base_idx].strip(): # no correction
			corrected = False
			code = sample['code'][base_idx]
			report = sample['report'][base_idx]
			prediction = sample['prediction'][base_idx]

		# append new result
		sample['code'].append(code)
		sample['report'].append(report)
		sample['prediction'].append(prediction)
		is_correct = finqa_equal(prediction, sample['answer'])

		# print("Gold Answer:", sample['answer'])
		# print("Corrected:", "Yes" if corrected else "No")
		# print("Is correct:", is_correct)
	
	return sample

async def critic() -> None:
	results = []
	if os.path.exists(save_results_path):
	    with open(save_results_path, 'r') as file:
	        for line in file:
	            results.append(json.loads(line))
	else:
		folder_path = os.path.dirname(save_results_path)
		os.makedirs(folder_path, exist_ok=True)
	
	for idx in tqdm(range(len(results), dataset.num_rows, batch_size)):
		batch = dataset.select(range(idx, min(idx+batch_size, dataset.num_rows)))
		results.extend(await tqdm_asyncio.gather(*(critic_iter(item) for item in batch)))
		with open(save_results_path, 'w') as file:
			for result in results:
				file.write(json.dumps(result) + "\n")


In [None]:
await critic()