In [1]:
import json
import os

from dotenv import load_dotenv, find_dotenv
from langchain_core.prompt_values import PromptValue
from langchain_openai import ChatOpenAI
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio

from agent.utils.loader import load_prompt, load_processed_data
from agent.utils.tools.google_search import GoogleSearchTool

_ = load_dotenv(find_dotenv())

os.environ["LANGCHAIN_TRACING_V2"] = "false"
os.environ["LANGCHAIN_PROJECT"] = "self-correct"
os.environ["LANGCHAIN_ENDPOINT"]="https://api.smith.langchain.com"

In [2]:
MAX_ITERATION = 7

dataset_name = 'hotpot_qa'
mode = "critic"
model = "gpt-4o-mini-2024-07-18"
num_samples = 1000
if dataset_name in ['hotpot_qa', 'ambig_qa', 'trivia_qa']:
	google_search = GoogleSearchTool()
	temperature = 0
	n = 1
elif dataset_name in ['gsm8k', 'svamp', 'tabmwp']:
	temperature = 0.5
	n = 1
else:
	temperature = 0.9
	n = 25
	model = "gpt-4o-mini"
top_p = 0.95
seed = 42
batch_size = 50
processed_data_path = f"../../../output/inference/{model}/{dataset_name}/cot/num_samples_{num_samples}_top_p_{top_p}_temperature_0_seed_{seed}.jsonl"
save_results_path = f"../../../output/inference/{model}/{dataset_name}/{mode}/num_samples_{num_samples}_top_p_{top_p}_temperature_{temperature}_seed_{seed}.jsonl"
prompt = load_prompt(dataset_name=dataset_name, mode=mode)
dataset = load_processed_data(dataset_name=dataset_name, file_path=processed_data_path).map(
	lambda x: {"generation": x["generation"][3:]})
if num_samples > 0:
	dataset = dataset.select(range(num_samples))
llm = ChatOpenAI(model_name="gpt-4o-mini-2024-07-18",
                 top_p=top_p,
                 n=n,
                 temperature=temperature,
                 openai_api_base="https://api.chsdw.top/v1",
                 seed=seed,
                 max_retries=3)

prompt.pretty_print()

print(dataset[2])


Your response should follow the previous format and style.


Question: Serianna is a band of what genre that combines elements of heavy metal and hardcore punk?
Proposed Answer: Let's think step by step. Serianna is a band of metalcore genre. Metalcore is a subgenre of heavy metal and hardcore punk. So Serianna is a band of heavy metal and hardcore punk. FINAL ANSWER: heavy metal and hardcore punk.

What's the problem with the above answer?

1. Plausibility:

The question asks for the genre that combines elements of heavy metal and hardcore punk, and the answer is "heavy metal and hardcore punk", simply repeat the question. So it's not plausible.

2. Truthfulness:

Let's search the question in google:
> Search Query: Serianna is a band of what genre that combines elements of heavy metal and hardcore punk? site: wikipedia.org
> Evidence: [Metalcore - Wikipedia] Metalcore is a fusion music genre that combines elements of extreme metal and hardcore punk.

The evidence suggests that metal

In [3]:
import logging

# 配置 logging
logging.basicConfig(
	level=logging.WARNING,
	format="%(asctime)s - %(levelname)s - %(message)s",  # 日志格式
    datefmt="%Y-%m-%d %H:%M:%S",  # 时间格式
)

async def critic_iter(item: dict, previous_cot: str):
	# load prompt
	prompt_critic: PromptValue = prompt.invoke(input={**item, "previous_cot": previous_cot})
	logging.debug(prompt_critic.to_string())

	# verify: plausible & truthful
	context = "What's the problem with the above answer?\n\n1. Plausibility:\n\n"
	prompt_critic.messages[-1].content += context
	logging.debug(context)


	exist_query = []
	exist_evidence = set()
	revised_cot = ""
	for idx in range(MAX_ITERATION):  # max interaction with tool
		# print("\n\n" + "=" * 30, "Round", idx, "=" * 30)
		# get LLM res
		try:
			res = await llm.ainvoke(input=prompt_critic, stop=["> Evidence:", "---"])
			res = res.content
		except Exception as e:
			logging.error("Error when invoking LLM.")
			raise e

		# case1: search
		if "> Search Query:" in res:
			logging.debug("CASE 1:")
			try:
				_, search_query = res.split("> Search Query:")[:2]
				search_query = search_query.split("\n")[0].strip()
			except:
				# print("Search Query Error:", res)
				exit()

			prompt_critic.messages[-1].content += res
			# print(res, end="")

			# if args.use_tool:
			# use Tool: search a new evidence
			# exist_query.append(search_query)
			# for k in range(exist_query.count(search_query), 8):
			search_res: list = await google_search.arun(search_query)
			search_res = search_res[0]
			# 	if search_res['link'] not in exist_evidence:
			# 		exist_evidence.add(search_res['page'])
			# 		break

			try:
				context = f"""> Evidence: [{search_res}]\n\n"""
			except:
				context = f"""> Error when trying to search evidence.\n\n"""
			if idx == MAX_ITERATION - 2:
				context += f"Let's give the most possible answer.\n\nQuestion: {item['question']}\nHere's "
			# else:
			#     # w/o Tool: use LLMs generated evidence
			#     context = """> Evidence: """
			# print(context, end="")
			prompt_critic.messages[-1].content += context


		# case2: most possible answer
		elif "most possible answer:" in res:
			# print("CASE 2:")
			# print(res)
			_, revised_cot = res.split("most possible answer:")
			revised_cot = revised_cot.strip()
			prompt_critic.messages[-1].content += revised_cot
			# print(prompt_critic.to_string())
			break
		# case3: other output
		else:
			# print("CASE 3:")
			if not res:
				print("NOT A RESPONSE.")
				break
			context = res
			context += f"\nLet's give the most possible answer.\n\nQuestion: {item['question']}\nHere's "
			# print(context, end="")
			prompt_critic.messages[-1].content += context

	return revised_cot


def is_null_answer(text):
	if not text:
		return True
	text = text.strip().lower()
	if text in ["none", "", "no answer", "never", "null", "both", "neither"]:
		return True
	if text.startswith("none"):
		return True
	return False


async def critic(item: dict):
	# print(f"Question: {item['question']}")
	# print(f"Gold answer: {item['answer']}")
	# iterative correction
	previous_corrected = True
	for itr in range(1, 4):
		# initialization
		if itr == 1:
			# extract prediction
			init_cot = item['generation']
			# 返回答案部分
			init_pred = item['prediction']

			# cot and pred
			item['cot'] = [init_cot]
			item['pred'] = [init_pred]

		# choose the latest answer that is not "None" to critic
		base_idx = itr - 1
		while base_idx > 0 and is_null_answer(item['pred'][base_idx]):
			base_idx -= 1
		previous_cot = item['cot'][base_idx]
		previous_pred = item['pred'][base_idx]

		# one iteration
		revised_cot = await critic_iter(item, previous_cot)
		revised_pred = revised_cot.split("FINAL ANSWER:")[-1].strip()

		# is corrected
		corrected = True
		if revised_pred and (revised_pred == previous_pred):
			corrected = False

		item['cot'].append(revised_cot)
		item['pred'].append(revised_pred)

		# if no correction for twice, break
		if not corrected and not previous_corrected:
			print("Stop.")
			break
		previous_corrected = corrected

	return item

In [6]:
results = []
async def inference() -> None:
	if os.path.exists(save_results_path):
		with open(save_results_path, 'r') as file:
			for line in file:
				results.append(json.loads(line))
	else:
		folder_path = os.path.dirname(save_results_path)
		os.makedirs(folder_path, exist_ok=True)

	for idx in tqdm(range(len(results), dataset.num_rows, batch_size)):
		batch = dataset.select(range(idx, min(idx + batch_size, dataset.num_rows)))
		results.extend(await tqdm_asyncio.gather(*(critic(item) for item in batch)))
		with open(save_results_path, 'qa') as file:
			for result in results:
				file.write(json.dumps(result) + "\n")

In [7]:
await inference()

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s][A[A

  2%|▏         | 1/50 [00:12<09:57, 12.19s/it][A[A

Stop.




  4%|▍         | 2/50 [00:12<04:07,  5.16s/it][A[A

  6%|▌         | 3/50 [00:12<02:13,  2.85s/it][A[A

Stop.
Stop.




  8%|▊         | 4/50 [00:14<01:46,  2.31s/it][A[A

Stop.




 10%|█         | 5/50 [00:18<02:19,  3.10s/it][A[A

Stop.




 12%|█▏        | 6/50 [00:19<01:47,  2.45s/it][A[A

Stop.




 14%|█▍        | 7/50 [00:21<01:38,  2.28s/it][A[A

Stop.




 16%|█▌        | 8/50 [00:23<01:32,  2.19s/it][A[A

Stop.




 18%|█▊        | 9/50 [00:24<01:10,  1.72s/it][A[A

 20%|██        | 10/50 [00:25<00:56,  1.41s/it][A[A

Stop.




 22%|██▏       | 11/50 [00:26<00:54,  1.40s/it][A[A

Stop.




 24%|██▍       | 12/50 [00:37<02:40,  4.23s/it][A[A

Stop.




 26%|██▌       | 13/50 [00:38<02:05,  3.39s/it][A[A

Stop.




 28%|██▊       | 14/50 [00:44<02:32,  4.25s/it][A[A

Stop.




 30%|███       | 15/50 [00:46<02:01,  3.48s/it][A[A

Stop.




 32%|███▏      | 16/50 [00:53<02:35,  4.59s/it][A[A

 34%|███▍      | 17/50 [00:54<01:56,  3.52s/it][A[A

Stop.




 36%|███▌      | 18/50 [00:55<01:25,  2.67s/it][A[A

 38%|███▊      | 19/50 [00:56<01:04,  2.07s/it][A[A

 40%|████      | 20/50 [00:59<01:15,  2.52s/it][A[A

Stop.




 42%|████▏     | 21/50 [01:04<01:35,  3.28s/it][A[A

 44%|████▍     | 22/50 [01:05<01:08,  2.44s/it][A[A

Stop.




 46%|████▌     | 23/50 [01:09<01:24,  3.12s/it][A[A

Stop.




 48%|████▊     | 24/50 [01:22<02:37,  6.05s/it][A[A

 50%|█████     | 25/50 [01:29<02:38,  6.34s/it][A[A

Stop.




 52%|█████▏    | 26/50 [01:30<01:52,  4.68s/it][A[A

Stop.




 54%|█████▍    | 27/50 [01:32<01:28,  3.83s/it][A[A

Stop.




 56%|█████▌    | 28/50 [01:36<01:26,  3.94s/it][A[A

Stop.




 58%|█████▊    | 29/50 [01:37<01:03,  3.01s/it][A[A

 60%|██████    | 30/50 [01:40<00:58,  2.92s/it][A[A

Stop.




 62%|██████▏   | 31/50 [01:44<01:04,  3.38s/it][A[A

Stop.




 64%|██████▍   | 32/50 [01:45<00:45,  2.51s/it][A[A

Stop.




 66%|██████▌   | 33/50 [01:46<00:34,  2.04s/it][A[A

Stop.




 68%|██████▊   | 34/50 [01:48<00:34,  2.14s/it][A[A

 70%|███████   | 35/50 [01:49<00:28,  1.93s/it][A[A

Stop.




 72%|███████▏  | 36/50 [01:51<00:26,  1.88s/it][A[A

 74%|███████▍  | 37/50 [01:53<00:22,  1.74s/it][A[A

 76%|███████▌  | 38/50 [01:53<00:16,  1.38s/it][A[A

 78%|███████▊  | 39/50 [01:58<00:28,  2.57s/it][A[A

Stop.




 80%|████████  | 40/50 [02:00<00:21,  2.18s/it][A[A

Stop.




 82%|████████▏ | 41/50 [02:00<00:14,  1.61s/it][A[A

Stop.




 84%|████████▍ | 42/50 [02:01<00:11,  1.46s/it][A[A

 86%|████████▌ | 43/50 [02:07<00:20,  2.88s/it][A[A

Stop.




 88%|████████▊ | 44/50 [02:08<00:12,  2.14s/it][A[A

Stop.




 90%|█████████ | 45/50 [02:11<00:12,  2.52s/it][A[A

 94%|█████████▍| 47/50 [02:13<00:05,  1.90s/it][A[A

 96%|█████████▌| 48/50 [02:15<00:03,  1.94s/it][A[A

 98%|█████████▊| 49/50 [02:19<00:02,  2.34s/it][A[A

100%|██████████| 50/50 [02:27<00:00,  2.96s/it][A[A
 33%|███▎      | 1/3 [02:27<04:55, 147.81s/it]

  0%|          | 0/50 [00:00<?, ?it/s][A[A

  2%|▏         | 1/50 [00:13<11:03, 13.55s/it][A[A

Stop.




  4%|▍         | 2/50 [00:21<08:16, 10.35s/it][A[A

Stop.




  6%|▌         | 3/50 [00:46<13:27, 17.19s/it][A[A

Stop.
NOT A RESPONSE.




  8%|▊         | 4/50 [01:18<17:37, 22.99s/it][A[A

Stop.




 10%|█         | 5/50 [01:20<11:24, 15.21s/it][A[A

Stop.




 12%|█▏        | 6/50 [01:23<08:05, 11.03s/it][A[A

Stop.




 14%|█▍        | 7/50 [01:23<05:28,  7.65s/it][A[A

Stop.




 16%|█▌        | 8/50 [01:44<08:19, 11.90s/it][A[A

Stop.




 18%|█▊        | 9/50 [01:47<06:13,  9.10s/it][A[A

 20%|██        | 10/50 [02:19<10:41, 16.03s/it][A[A

Stop.




 22%|██▏       | 11/50 [02:19<07:20, 11.29s/it][A[A

Stop.




 24%|██▍       | 12/50 [02:21<05:16,  8.32s/it][A[A

Stop.






 28%|██▊       | 14/50 [02:40<05:31,  9.21s/it][A[A

 30%|███       | 15/50 [02:41<03:53,  6.68s/it][A[A

Stop.




 32%|███▏      | 16/50 [02:48<03:54,  6.91s/it][A[A

Stop.




 34%|███▍      | 17/50 [02:52<03:19,  6.06s/it][A[A

Stop.




 36%|███▌      | 18/50 [02:54<02:31,  4.73s/it][A[A

 38%|███▊      | 19/50 [02:55<01:54,  3.69s/it][A[A

 40%|████      | 20/50 [03:01<02:10,  4.34s/it][A[A

 42%|████▏     | 21/50 [03:05<02:02,  4.22s/it][A[A

Stop.




 44%|████▍     | 22/50 [03:07<01:36,  3.45s/it][A[A

 46%|████▌     | 23/50 [03:11<01:37,  3.61s/it][A[A

Stop.




 48%|████▊     | 24/50 [03:14<01:32,  3.57s/it][A[A

Stop.




 50%|█████     | 25/50 [03:15<01:09,  2.80s/it][A[A

Stop.




 52%|█████▏    | 26/50 [03:17<01:03,  2.65s/it][A[A

 54%|█████▍    | 27/50 [03:19<00:57,  2.48s/it][A[A

Stop.




 56%|█████▌    | 28/50 [03:29<01:39,  4.54s/it][A[A

 58%|█████▊    | 29/50 [03:33<01:30,  4.29s/it][A[A

 60%|██████    | 30/50 [03:40<01:45,  5.28s/it][A[A

 62%|██████▏   | 31/50 [03:42<01:21,  4.28s/it][A[A

 64%|██████▍   | 32/50 [03:42<00:54,  3.04s/it][A[A

 66%|██████▌   | 33/50 [03:43<00:42,  2.50s/it][A[A

 68%|██████▊   | 34/50 [03:45<00:35,  2.22s/it][A[A



 72%|███████▏  | 36/50 [03:47<00:21,  1.54s/it][A[A

 74%|███████▍  | 37/50 [03:49<00:23,  1.77s/it][A[A

 76%|███████▌  | 38/50 [03:50<00:16,  1.40s/it][A[A

 78%|███████▊  | 39/50 [03:54<00:24,  2.20s/it][A[A

 80%|████████  | 40/50 [03:55<00:18,  1.83s/it][A[A

 84%|████████▍ | 42/50 [03:57<00:11,  1.38s/it][A[A

Stop.




 86%|████████▌ | 43/50 [04:02<00:15,  2.25s/it][A[A

 88%|████████▊ | 44/50 [04:02<00:10,  1.82s/it][A[A

 92%|█████████▏| 46/50 [04:03<00:04,  1.24s/it][A[A

 94%|█████████▍| 47/50 [04:06<00:04,  1.59s/it][A[A

 96%|█████████▌| 48/50 [04:07<00:02,  1.39s/it][A[A

100%|██████████| 50/50 [04:07<00:00,  4.95s/it][A[A
 67%|██████▋   | 2/3 [06:35<03:26, 206.47s/it]

  0%|          | 0/50 [00:00<?, ?it/s][A[A

  2%|▏         | 1/50 [00:35<29:12, 35.77s/it][A[A

Stop.




  4%|▍         | 2/50 [00:50<18:44, 23.42s/it][A[A

  6%|▌         | 3/50 [00:55<11:43, 14.96s/it][A[A

Stop.




  8%|▊         | 4/50 [01:31<17:43, 23.12s/it][A[A

 10%|█         | 5/50 [01:39<13:15, 17.69s/it][A[A

Stop.




 12%|█▏        | 6/50 [01:44<09:56, 13.56s/it][A[A

Stop.




 14%|█▍        | 7/50 [01:45<06:49,  9.52s/it][A[A

Stop.




 16%|█▌        | 8/50 [01:56<06:57,  9.93s/it][A[A

 18%|█▊        | 9/50 [01:59<05:19,  7.78s/it][A[A

Stop.




 20%|██        | 10/50 [02:10<05:43,  8.58s/it][A[A

Stop.




 22%|██▏       | 11/50 [02:13<04:29,  6.91s/it][A[A

Stop.




 24%|██▍       | 12/50 [02:13<03:07,  4.94s/it][A[A

 26%|██▌       | 13/50 [02:19<03:11,  5.18s/it][A[A

Stop.




 28%|██▊       | 14/50 [02:26<03:25,  5.70s/it][A[A

NOT A RESPONSE.




 30%|███       | 15/50 [02:29<02:53,  4.96s/it][A[A

Stop.




 32%|███▏      | 16/50 [02:31<02:20,  4.13s/it][A[A

 34%|███▍      | 17/50 [02:38<02:46,  5.05s/it][A[A

 36%|███▌      | 18/50 [02:41<02:22,  4.46s/it][A[A

Stop.




 38%|███▊      | 19/50 [02:45<02:09,  4.19s/it][A[A

 40%|████      | 20/50 [02:47<01:49,  3.65s/it][A[A

Stop.




 42%|████▏     | 21/50 [03:05<03:50,  7.97s/it][A[A

Stop.




 44%|████▍     | 22/50 [03:07<02:53,  6.18s/it][A[A

 46%|████▌     | 23/50 [03:09<02:09,  4.80s/it][A[A

Stop.




 48%|████▊     | 24/50 [03:09<01:29,  3.46s/it][A[A

 50%|█████     | 25/50 [03:11<01:11,  2.87s/it][A[A

 52%|█████▏    | 26/50 [03:13<01:01,  2.55s/it][A[A

Stop.




 54%|█████▍    | 27/50 [03:15<00:56,  2.45s/it][A[A

 56%|█████▌    | 28/50 [03:18<00:55,  2.52s/it][A[A

 58%|█████▊    | 29/50 [03:21<01:00,  2.87s/it][A[A

 60%|██████    | 30/50 [03:29<01:27,  4.38s/it][A[A

Stop.




 62%|██████▏   | 31/50 [03:30<01:00,  3.19s/it][A[A

 64%|██████▍   | 32/50 [03:30<00:43,  2.39s/it][A[A

NOT A RESPONSE.




 66%|██████▌   | 33/50 [03:34<00:45,  2.70s/it][A[A

 68%|██████▊   | 34/50 [03:44<01:18,  4.88s/it][A[A

 70%|███████   | 35/50 [03:46<01:01,  4.09s/it][A[A



 74%|███████▍  | 37/50 [03:52<00:51,  3.92s/it][A[A

 76%|███████▌  | 38/50 [03:56<00:47,  3.98s/it][A[A

 78%|███████▊  | 39/50 [03:58<00:36,  3.30s/it][A[A

Stop.




 80%|████████  | 40/50 [04:01<00:30,  3.03s/it][A[A

 82%|████████▏ | 41/50 [04:02<00:22,  2.51s/it][A[A

 84%|████████▍ | 42/50 [04:03<00:16,  2.00s/it][A[A

 86%|████████▌ | 43/50 [04:05<00:15,  2.16s/it][A[A

 88%|████████▊ | 44/50 [04:07<00:12,  2.06s/it][A[A

 90%|█████████ | 45/50 [04:09<00:09,  1.96s/it][A[A

 92%|█████████▏| 46/50 [04:09<00:06,  1.52s/it][A[A

 94%|█████████▍| 47/50 [04:10<00:03,  1.21s/it][A[A

 96%|█████████▌| 48/50 [04:11<00:02,  1.35s/it][A[A

Stop.




 98%|█████████▊| 49/50 [04:23<00:04,  4.48s/it][A[A

100%|██████████| 50/50 [04:24<00:00,  5.28s/it][A[A
100%|██████████| 3/3 [10:59<00:00, 219.87s/it]
