In [1]:
import re

from dotenv import load_dotenv, find_dotenv
from langchain_core.outputs import LLMResult
from langchain_openai import ChatOpenAI

from agent.utils.loader import load_prompt, load_processed_data
from agent.utils.tools import GoogleSearchTool

_ = load_dotenv(find_dotenv())

In [2]:
dataset_name = 'trivia_qa'
mode = "react"
model = "gpt-4o-mini-2024-07-18"
num_samples = 1000
top_p = 0.95
temperature = 0
seed = 42
batch_size = 100
n = 1

processed_data_path = f"../../data/processed_data/{dataset_name}.jsonl"
save_results_path = f"../../output/inference/{model}/{dataset_name}/{mode}/num_samples_{num_samples}_top_p_{top_p}_temperature_{temperature}_seed_{seed}.jsonl"
prompt = load_prompt(dataset_name=dataset_name, mode=mode)
dataset = load_processed_data(dataset_name=dataset_name, file_path=processed_data_path)
if num_samples > 0:
	dataset = dataset.select(range(num_samples))
llm = ChatOpenAI(model=model, top_p=top_p, n=n, temperature=temperature, base_url="https://api.chsdw.top/v1", seed=seed)
google_search = GoogleSearchTool()
prompt.pretty_print()
print(dataset[2])


Your response should follow the previous format and style.


Question: Which innovation for the car was developed by Prince Henry of Prussia in 1911?
Thought 1: Let's search the question in google
Action 1: Search[Which innovation for the car was developed by Prince Henry of Prussia in 1911? site: wikipedia.org]
Observation 1: [Prince Henry of Prussia (1862–1929) - Wikipedia] Henry was interested in motor cars as well and supposedly invented a windshield wiper and, according to other sources, the car horn.
Thought 2: The evidence suggests that Prince Henry of Prussia invented a windshield wiper and the car horn.
Action 2: Search[When did Prince Henry of Prussia invented a windshield wiper and the car horn?]
Observation 2: [110 years ago: windscreen wiper patent for Prince Henry of Prussia] Quite apart from a member of the German aristocracy: it was Prince Henry of Prussia (1862-1929) who was granted the first German patent about the windscreen wiper on 24. March 1908.
Thought 3: Accor

In [3]:
print(await google_search.arun(
	"science fantasy young adult series with companion books about enslaved worlds alien species"))

[{'title': 'Sci-fi books with non-human protagonist : r/suggestmeabook', 'link': 'https://www.reddit.com/r/suggestmeabook/comments/pk9nvx/scifi_books_with_nonhuman_protagonist/', 'snippet': 'Sep 8, 2021 ... Preferably no YA. I\'ve already read the Ancillary trilogy, Murderbot diaries, and most of Adrian Tschaikovsky\'s stuff ("Dogs of War" was\xa0...'}, {'title': 'Animorphs - Wikipedia', 'link': 'https://en.wikipedia.org/wiki/Animorphs', 'snippet': 'Animorphs is a science fantasy series of youth books written by Katherine Applegate and her husband Michael Grant, writing together under the name K. A.\xa0...'}, {'title': 'Fantasy Books with Animal-People : r/Fantasy', 'link': 'https://www.reddit.com/r/Fantasy/comments/14ny56l/fantasy_books_with_animalpeople/', 'snippet': "Jul 1, 2023 ... Adrian Tchaikovsky's Shadows of the Apt series has various humanoid species with insect (and spider) characteristics."}]


In [4]:
MAX_ITERATION = 5


async def iteration(item):
	completion = f"{item["question"]}\n"
	for i in range(MAX_ITERATION):
		max_retries = 0
		while max_retries < 3:
			try:
				response: LLMResult = await llm.agenerate(
					messages=[prompt.invoke(input={"question": completion})],
					stop=["Observation", "---"]
				)
				response_content: str = response.generations[0][0].message.content
				break
			except:
				max_retries += 1
				if max_retries == 3:
					return {**item, "generation": completion, "prediction": "ERROR"}
				continue
		# 如果生成的内容包含"Finish"，则停止迭代
		if "Finish" in response_content:
			completion += response_content
			matches = re.findall(r"Finish\[(.*)]", response_content, re.DOTALL)
			if matches:
				prediction = matches[0]
			else:
				prediction = "None"
			return {**item, "generation": completion, "prediction": prediction}

		# 如果生成的内容包含"Search"，则进行搜索
		elif "Search" in response_content:
			completion += response_content
			matches = re.findall(r"Search\[(.*)]", response_content, re.DOTALL)
			if matches:
				tool_input = matches[0]
				tool_result = await google_search.arun(tool_input)
				title = tool_result[0]['title']
				evidence = f"{tool_result[0].get('snippet', '')}"
				completion += f"Observation {i + 1}: [{title}] {evidence}\n"
			else:
				completion += f"Observation {i + 1}: [None] None\n"
			if i == 3:
				completion += f"Thought {i + 1}: Now I know the answer, and I will provide the answer in the following Action.\n"
			continue
		else:
			completion += response_content
			completion += f"Thought {i + 1}: Now I know the answer, and I will provide the answer in the following Action.\n"
	return {**item, "generation": completion, "prediction": "None"}



In [5]:
import json
import os
from tqdm.asyncio import tqdm, tqdm_asyncio
import nest_asyncio

nest_asyncio.apply()


async def react_inference() -> None:
	results = []
	if os.path.exists(save_results_path):
		with open(save_results_path, 'r') as file:
			for line in file:
				results.append(json.loads(line))
	else:
		folder_path = os.path.dirname(save_results_path)
		os.makedirs(folder_path, exist_ok=True)

	for idx in tqdm(range(len(results), dataset.num_rows, batch_size)):
		batch = dataset.select(range(idx, min(idx + batch_size, dataset.num_rows)))
		results.extend(await tqdm_asyncio.gather(*(iteration(item) for item in batch)))
		with open(save_results_path, 'w') as file:
			for result in results:
				file.write(json.dumps(result) + "\n")

In [6]:
await react_inference()

  0%|          | 0/9 [00:00<?, ?it/s]
  0%|          | 0/100 [00:00<?, ?it/s][A
  1%|          | 1/100 [00:07<11:51,  7.18s/it][A
  2%|▏         | 2/100 [00:07<05:13,  3.19s/it][A
  3%|▎         | 3/100 [00:07<03:02,  1.88s/it][A
  4%|▍         | 4/100 [00:08<02:05,  1.31s/it][A
  5%|▌         | 5/100 [00:08<01:28,  1.07it/s][A
  6%|▌         | 6/100 [00:08<01:01,  1.54it/s][A
 10%|█         | 10/100 [00:08<00:21,  4.18it/s][A
 12%|█▏        | 12/100 [00:09<00:23,  3.81it/s][A
 14%|█▍        | 14/100 [00:09<00:18,  4.55it/s][A
 16%|█▌        | 16/100 [00:09<00:15,  5.35it/s][A
 17%|█▋        | 17/100 [00:10<00:14,  5.56it/s][A
 18%|█▊        | 18/100 [00:10<00:17,  4.66it/s][A
 19%|█▉        | 19/100 [00:10<00:15,  5.17it/s][A
 20%|██        | 20/100 [00:11<00:24,  3.23it/s][A
 21%|██        | 21/100 [00:11<00:24,  3.19it/s][A
 23%|██▎       | 23/100 [00:12<00:22,  3.39it/s][A
 24%|██▍       | 24/100 [00:12<00:26,  2.86it/s][A
 26%|██▌       | 26/100 [00:12<00:20,  3.