In [1]:
import re

from dotenv import load_dotenv, find_dotenv
from langchain_core.outputs import LLMResult
from langchain_openai import ChatOpenAI

from agent.utils.loader import load_prompt, load_processed_data
from agent.utils.tools import GoogleSearchTool

_ = load_dotenv(find_dotenv())

In [11]:
dataset_name = 'hotpot_qa'
mode = "react"
model = "gpt-4o-mini-2024-07-18"
num_samples = 1000
top_p = 0.95
temperature = 0
seed = 42
batch_size = 100
n = 1

processed_data_path = f"../../data/processed_data/{dataset_name}.jsonl"
save_results_path = f"../../output/inference/{model}/{dataset_name}/{mode}/num_samples_{num_samples}_top_p_{top_p}_temperature_{temperature}_seed_{seed}.jsonl"
prompt = load_prompt(dataset_name=dataset_name, mode=mode)
dataset = load_processed_data(dataset_name=dataset_name, file_path=processed_data_path)
if num_samples > 0:
	dataset = dataset.select(range(num_samples))
llm = ChatOpenAI(model_name="gpt-4o-mini-2024-07-18", top_p=top_p, n=n, temperature=temperature, openai_api_base="https://api.chsdw.top/v1", seed=seed, max_retries=3)
google_search = GoogleSearchTool()
prompt.pretty_print()
print(dataset[2])

Failed to retrieve data from Google Search: Rate limit exceeded or persistent error after 3 retries.

Your response should follow the previous format and style.


Question: Serianna is a band of what genre that combines elements of heavy metal and hardcore punk?
Thought 1: Let's search the question in google
Action 1: Search[Serianna is a band of what genre that combines elements of heavy metal and hardcore punk? site: wikipedia.org]
Observation 1: [Metalcore - Wikipedia] Metalcore is a fusion music genre that combines elements of extreme metal and hardcore punk.
Thought 2: The evidence suggests that metalcore is a genre that combines elements of extreme metal and hardcore punk.
Action 2: Search[Serianna is a band of metalcore genre. site: wikipedia.org
Observation 2: [Serianna - Wikipedia] Serianna was a metalcore band from Madison, Wisconsin. The band formed in 2006...
Thought 3: The evidence suggests Serianna is a metalcore band.
Action 3: Finish[Metalcore]
---
Question: Which band 

In [3]:
print(await google_search.arun(
	"science fantasy young adult series with companion books about enslaved worlds alien species"))

[{'title': 'Sci-fi books with non-human protagonist : r/suggestmeabook', 'link': 'https://www.reddit.com/r/suggestmeabook/comments/pk9nvx/scifi_books_with_nonhuman_protagonist/', 'snippet': 'Sep 8, 2021 ... Preferably no YA. I\'ve already read the Ancillary trilogy, Murderbot diaries, and most of Adrian Tschaikovsky\'s stuff ("Dogs of War" was\xa0...'}, {'title': 'Book about lizard people and cat people - Science Fiction & Fantasy ...', 'link': 'https://scifi.stackexchange.com/questions/255516/book-about-lizard-people-and-cat-people', 'snippet': 'Oct 31, 2021 ... Show activity on this post. When I was young I read a book about a humanoid lizard race owning a humanoid feline race as slaves, and the cat\xa0...'}, {'title': 'Fantasy Books with Animal-People : r/Fantasy', 'link': 'https://www.reddit.com/r/Fantasy/comments/14ny56l/fantasy_books_with_animalpeople/', 'snippet': "Jul 1, 2023 ... Adrian Tchaikovsky's Shadows of the Apt series has various humanoid species with insect (and spider

In [4]:
MAX_ITERATION = 5


async def iteration(item):
	completion = f"{item['question']}\n"
	for i in range(MAX_ITERATION):
		response: LLMResult = await llm.agenerate(
			messages=[prompt.invoke(input={"question": completion})],
			stop=["Observation", "---"]
		)
		response_content: str = response.generations[0][0].message.content
				# return {**item, "generation": completion, "prediction": "ERROR"}
		# 如果生成的内容包含"Finish"，则停止迭代
		if "Finish" in response_content:
			completion += response_content
			matches = re.findall(r"Finish\[(.*)]", response_content, re.DOTALL)
			if matches:
				prediction = matches[0]
			else:
				prediction = "None"
			return {**item, "generation": completion, "prediction": prediction}

		# 如果生成的内容包含"Search"，则进行搜索
		elif "Search" in response_content:
			completion += response_content
			matches = re.findall(r"Search\[(.*)]", response_content, re.DOTALL)
			if matches:
				tool_input = matches[0]
				tool_result = await google_search.arun(tool_input)
				title = tool_result[0]['title']
				evidence = f"{tool_result[0].get('snippet', '')}"
				completion += f"Observation {i + 1}: [{title}] {evidence}\n"
			else:
				completion += f"Observation {i + 1}: [None] None\n"
			if i >= 3:
				completion += f"Thought {i + 1}: Now I know the answer, and I will provide the answer in the following Action.\n"
			continue
		else:
			completion += response_content
			completion += f"Thought {i + 1}: Now I know the answer, and I will provide the answer in the following Action.\n"
	return {**item, "generation": completion, "prediction": "None"}



In [9]:
import json
import os
from tqdm.asyncio import tqdm, tqdm_asyncio
import nest_asyncio

nest_asyncio.apply()


async def react_inference() -> None:
	results = []
	if os.path.exists(save_results_path):
		with open(save_results_path, 'r') as file:
			for line in file:
				results.append(json.loads(line))
	else:
		folder_path = os.path.dirname(save_results_path)
		os.makedirs(folder_path, exist_ok=True)

	error_index_list = []

	for idx, example in enumerate(results):
		if example['prediction'] == "None":
			error_index_list.append(idx)

	if error_index_list:
		error_data = [results[i] for i in error_index_list]
		for idx in range(0, len(error_data), batch_size):
			batch = error_data[idx: min(idx + batch_size, len(error_data))]
			new_results = await tqdm_asyncio.gather(*(iteration(item) for item in batch))
			# 更新原始结果
			for i, new_result in zip(error_index_list[idx: idx + len(batch)], new_results):
				results[i] = new_result
		with open(save_results_path, 'w') as file:
			for result in results:
				file.write(json.dumps(result) + "\n")

	for idx in tqdm(range(len(results), dataset.num_rows, batch_size)):
		batch = dataset.select(range(idx, min(idx + batch_size, dataset.num_rows)))
		results.extend(await tqdm_asyncio.gather(*(iteration(item) for item in batch)))
		with open(save_results_path, 'w') as file:
			for result in results:
				file.write(json.dumps(result) + "\n")

In [12]:
await react_inference()



Failed to retrieve data from Google Search: Rate limit exceeded or persistent error after 3 retries.




CancelledError: 