In [1]:

import json
import os

from dotenv import load_dotenv, find_dotenv
from langchain_core.prompt_values import PromptValue
from langchain_openai import ChatOpenAI
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio

from agent.utils.loader import load_prompt, load_processed_data
from agent.utils.tools.google_search import GoogleSearchTool

_ = load_dotenv(find_dotenv())

In [9]:
MAX_ITERATION = 7

dataset_name = 'ambig_qa'
mode = "critic"
model = "gpt-4o-mini-2024-07-18"
num_samples = -1
if dataset_name in ['hotpot_qa', 'ambig_qa', 'trivia_qa']:
	google_search = GoogleSearchTool()
	temperature = 0
	n = 1
elif dataset_name in ['gsm8k', 'svamp', 'tabmwp']:
	temperature = 0.5
	n = 1
else:
	temperature = 0.9
	n = 25
	model = "gpt-4o-mini"
top_p = 0.95
seed = 42
batch_size = 100
processed_data_path = f"../../../output/inference/{model}/{dataset_name}/cot/num_samples_{num_samples}_top_p_{top_p}_temperature_0_seed_{seed}.jsonl"
save_results_path = f"../../../output/inference/{model}/{dataset_name}/{mode}/num_samples_{num_samples}_top_p_{top_p}_temperature_{temperature}_seed_{seed}.jsonl"
prompt = load_prompt(dataset_name=dataset_name, mode=mode)
dataset = load_processed_data(dataset_name=dataset_name, file_path=processed_data_path).map(
	lambda x: {"generation": x["generation"][3:]})
if num_samples > 0:
	dataset = dataset.select(range(num_samples))
llm = ChatOpenAI(model_name="gpt-4o-mini",
                 top_p=top_p,
                 n=n,
                 temperature=temperature,
                 openai_api_base="https://api.chsdw.top/v1",
                 seed=seed,
                 max_retries=3)

prompt.pretty_print()

print(dataset[2])


Question: When did men's figure skating become a summer Olympic sport?
Proposed Answer: Men's figure skating has never been a summer Olympic sport. It has been a part of the Winter Olympics since the first Winter Olympics in 1924. FINAL ANSWER: never

What's the problem with the above answer?

1. Plausibility:

The question asks for the time men's figure skating become a summer Olympic sport, and the answer "never" does not provide a time. So it's not plausible. The answer should be a time, like year or date.

2. Truthfulness:

Let's search the question in google:

> Search Query: When did men's figure skating become a summer Olympic sport?
> Evidence: [Figure skating at the Olympic Games - Wikipedia] Figure skating was first contested in the Olympic Games at the 1908 Summer Olympics . Since 1924, the sport has been a part of the Winter Olympic Games .

The evidence suggests Figure skating became an Olympic sport at the 1908 Summer Olympics, and has been a part of the Winter Olympic G

In [3]:
async def critic_iter(item: dict):
	# load prompt
	prompt_critic: PromptValue = prompt.invoke(input=item)

	# verify: plausible & truthful
	context = "What's the problem with the above answer?\n\n1. Plausibility:\n\n"
	prompt_critic.messages[-1].content += context

	exist_query = []
	exist_evidence = set()
	revised_cot = ""
	for idx in range(MAX_ITERATION):  # max interaction with tool
		# print("\n\n" + "=" * 30, "Round", idx, "=" * 30)
		# get LLM res
		max_retries = 3
		try:
			res = await llm.ainvoke(input=prompt_critic, stop=["> Evidence:", "---"])
			res = res.content
		except:
			res = ""

		# case1: search
		if "> Search Query:" in res:
			# print("CASE 1:")
			try:
				_, search_query = res.split("> Search Query:")[:2]
				search_query = search_query.split("\n")[0].strip()
			except:
				# print("Search Query Error:", res)
				exit()

			prompt_critic.messages[-1].content += res
			# print(res, end="")

			# if args.use_tool:
			# use Tool: search a new evidence
			search_res: list = await google_search.arun(search_query)
			search_res = search_res[0]

			try:
				context = f"""> Evidence: [{search_res}]\n\n"""
			except:
				context = f"""> Error when trying to search evidence.\n\n"""
			if idx == MAX_ITERATION - 2:
				context += f"Let's give the most possible answer.\n\nQuestion: {item['question']}\nHere's "
			# else:
			#     # w/o Tool: use LLMs generated evidence
			#     context = """> Evidence: """
			# print(context, end="")
			prompt_critic.messages[-1].content += context


		# case2: most possible answer
		elif "most possible answer:" in res:
			# print("CASE 2:")
			# print(res)
			_, revised_cot = res.split("most possible answer:")
			revised_cot = revised_cot.strip()
			prompt_critic.messages[-1].content += revised_cot
			# print(prompt_critic.to_string())
			break
		# case3: other output
		else:
			# print("CASE 3:")
			if not res:
				print("NOT A RESPONSE.")
				break
			context = res
			context += f"\nLet's give the most possible answer.\n\nQuestion: {item['question']}\nHere's "
			# print(context, end="")
			prompt_critic.messages[-1].content += context

	return revised_cot


def is_null_answer(text):
	if not text:
		return True
	text = text.strip().lower()
	if text in ["none", "", "no answer", "never", "null", "both", "neither"]:
		return True
	if text.startswith("none"):
		return True
	return False


async def critic(item: dict):
	# print(f"Question: {item['question']}")
	# print(f"Gold answer: {item['answer']}")
	# iterative correction
	previous_corrected = True
	for itr in range(1, 4):
		# initialization
		if itr == 1:
			# extract prediction
			init_cot = item['generation']
			# 返回答案部分
			init_pred = item['prediction']

			# cot and pred
			item['cot'] = [init_cot]
			item['pred'] = [init_pred]

		# choose the latest answer that is not "None" to critic
		base_idx = itr - 1
		while base_idx > 0 and is_null_answer(item['pred'][base_idx]):
			base_idx -= 1
		previous_cot = item['cot'][base_idx]
		previous_pred = item['pred'][base_idx]

		# one iteration
		revised_cot = await critic_iter(item)
		revised_pred = revised_cot.split("FINAL ANSWER:")[-1].strip()

		# is corrected
		corrected = True
		if revised_cot and (revised_cot == previous_cot):
			corrected = False

		item['cot'].append(revised_cot)
		item['pred'].append(revised_pred)

		# if no correction for twice, break
		if not corrected and not previous_corrected:
			print("Stop.")
			break
		previous_corrected = corrected

	return item

In [None]:
async def inference() -> None:
	results = []
	if os.path.exists(save_results_path):
		with open(save_results_path, 'r') as file:
			for line in file:
				results.append(json.loads(line))
	else:
		folder_path = os.path.dirname(save_results_path)
		os.makedirs(folder_path, exist_ok=True)

	for idx in tqdm(range(len(results), dataset.num_rows, batch_size)):
		batch = dataset.select(range(idx, min(idx + batch_size, dataset.num_rows)))
		results.extend(await tqdm_asyncio.gather(*(critic(item) for item in batch)))
		with open(save_results_path, 'w') as file:
			for result in results:
				file.write(json.dumps(result) + "\n")