In [1]:
import json
import os
import re

from dotenv import load_dotenv, find_dotenv
from langchain_openai import ChatOpenAI

_ = load_dotenv(find_dotenv())
# Optional, add tracing in LangSmith
os.environ["LANGCHAIN_TRACING_V2"] = "false"
os.environ["LANGCHAIN_PROJECT"] = "self-correct"
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["HTTP_PROXY"] = "http://127.0.0.1:7890"
os.environ["HTTPS_PROXY"] = "http://127.0.0.1:7890"

In [2]:
meta_info = {
	"dataset_name": 'hotpot_qa',
	"mode": "self-improve",
	"base_mode": "cot",
	"model": "gpt-3.5-turbo",
	"num_samples": 1000,
	"top_p": 0.95,
	"temperature": 0,
	"seed": 42,
	"batch_size": 100
}
assert meta_info["mode"] == "self-improve"
assert meta_info["dataset_name"] in ["hotpot_qa", "trivia_qa", "ambig_qa"], "Invalid dataset name"

ROOT_DIR = "D:\Projects\self-improve"
processed_data_path = os.path.join(ROOT_DIR, "output", "inference", meta_info["model"], meta_info["dataset_name"], meta_info["mode"], f"with_question_before_fusion_{meta_info['base_mode']}_num_samples_{meta_info['num_samples']}_top_p_{meta_info['top_p']}_temperature_{meta_info['temperature']}_seed_{meta_info['seed']}.jsonl")
print("Loading processed data from:", processed_data_path)
dataset = []
try:
    with open(processed_data_path, 'r', encoding='utf-8') as file:
        # Attempt to load JSON data
        for line in file:
	        dataset.append(json.loads(line))
        print("Data loaded successfully from:", processed_data_path)
    print("Sample data:", dataset[0])

except json.JSONDecodeError as e:
    print(f"Error decoding JSON from '{processed_data_path}': {e}")
    raise
except Exception as e:
    print(f"An unexpected error occurred while loading the data: {e}")
    raise

print("Setting the save_results_path")
save_results_path = os.path.join(ROOT_DIR, "output", "ablation", meta_info["model"], meta_info["dataset_name"], meta_info["mode"], f"without_fusion_{meta_info['base_mode']}_num_samples_{meta_info['num_samples']}_top_p_{meta_info['top_p']}_temperature_{meta_info['temperature']}_seed_{meta_info['seed']}.jsonl")
print("Results will be saved to:", save_results_path)


model = ChatOpenAI(
	model_name=meta_info["model"],
	top_p=meta_info["top_p"],
	temperature=meta_info["temperature"],
	seed=meta_info["seed"],
	openai_api_base="https://api.chsdw.top/v1"
)



Loading processed data from: D:\Projects\self-improve\output\inference\gpt-3.5-turbo\hotpot_qa\self-improve\with_question_before_fusion_cot_num_samples_1000_top_p_0.95_temperature_0_seed_42.jsonl
Data loaded successfully from: D:\Projects\self-improve\output\inference\gpt-3.5-turbo\hotpot_qa\self-improve\with_question_before_fusion_cot_num_samples_1000_top_p_0.95_temperature_0_seed_42.jsonl
Sample data: {'context': '', 'question': 'Were Scott Derrickson and Ed Wood of the same nationality?', 'answer': ['yes'], 'guidance': 'Step 1: Determine the nationalities of Scott Derrickson and Ed Wood.\n- Research the nationality of Scott Derrickson.\n- Research the nationality of Ed Wood.\n\nError-prone points:\n- Inaccurate or conflicting information about their nationalities in sources.\n- Confusion with individuals who have similar names.\n- Lack of clear documentation of their nationalities.', 'messages': [{'content': 'You are a reactive agent. Given a question or problem, your job is to sele

In [None]:
# from langgraph.graph import add_messages
# from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
# from typing import Sequence
# from typing_extensions import TypedDict, Annotated
# from langgraph.managed.is_last_step import RemainingSteps
#
#
# class State(TypedDict):
# 	question: str
# 	guidance: str
# 	remaining_steps: RemainingSteps
# 	messages: Annotated[Sequence[BaseMessage], add_messages]
# 	initial_answer: str
# 	prediction: str
#
# from langchain_core.prompts import ChatPromptTemplate
#
# guidance_prompt = ChatPromptTemplate.from_messages(
# 	[
# 		(
# 			"system",
# 			"You are a question planner and error prone points identifier. Given a question or problem, your job is to come up with a step by step plan, and you should also identify the most error-prone points for each step, following them closely behind each step. The plan and the error prone points will then be used to guide the selection of subsequent tools and the corresponding tool inputs. The tool results should always be considered as reliable. Do not add any superfluous steps. Make sure that each step has all the information needed - do not skip steps. You should focus on the logic of how to solve the problem, rather than actually solving it."
# 		),
# 		(
# 			"user",
# 			"Question: {question}"
# 		)
# 	])
# guidance_generator = guidance_prompt | model
#
#
# async def guidance_node(state: State) -> State:
# 	assert state["question"] is not None, "Question is required"
# 	question:str = state["question"]
# 	guidance:AIMessage = await guidance_generator.ainvoke(input={"question": question})
# 	state["guidance"] = guidance.content
# 	return state
#
# from langchain_core.messages import AIMessage
# from langchain_community.utilities.wikidata import WikidataAPIWrapper
# from langgraph.prebuilt import ToolNode
# from agent.utils.tools import GoogleSearchTool, GoogleKnowledgeGraphTool, WikidataTool, WikipediaTool, python_interpreter
#
# google_search = GoogleSearchTool()
# google_knowledge_graph = GoogleKnowledgeGraphTool()
# wikidata = WikidataTool(api_wrapper=WikidataAPIWrapper())
# wikipedia = WikipediaTool()
# tools = [google_search, google_knowledge_graph, wikipedia, wikidata, python_interpreter]
#
# model_with_tools = model.bind_tools(tools)
#
# from typing import Literal
#
# critique_prompt = ChatPromptTemplate.from_messages([
# 	(
# 		"system",
# 		"You are a reactive agent. Given a question or problem, your job is to select the appropriate tools to answer the question or solve the problem. You should consider the guidance provided by the question planner and error prone points identifier, and the tool results are reliable. If you find the answer from the tool results, you should provide the answer."
# 	),
# 	(
# 		"user",
# 		"Question: {question}"
# 		"Guidance: {guidance}"
# 	)
# ])
#
# async def critique_node(state: State):
# 	assert state["question"] is not None, "Question is required"
# 	assert state["guidance"] is not None, "Guidance is required"
# 	question:str = state["question"]
# 	guidance:str = state["guidance"]
# 	messages:list[BaseMessage] = []
# 	if len(state["messages"]) == 0:
# 		messages = critique_prompt.invoke(input={"question": question, "guidance": guidance}).to_messages()
# 		critique:AIMessage = await model_with_tools.ainvoke(input=messages)
# 		messages.append(critique)
# 	else:
# 		critique:AIMessage = await model_with_tools.ainvoke(input=state["messages"])
# 		messages.append(critique)
# 	return {"messages": messages}
#
# # Define our tool node
# tool_node = ToolNode(tools)
# # Define our tool node
#
#
# fusion_prompt = ChatPromptTemplate.from_messages(
# 	[
# 		(
# 			"placeholder",
# 			"{messages}"
# 		),
# 		(
# 			"user",
# 			"Question: {question}"
# 			"Now based on the previous information, revise your answer. Use the XML tag <answer></answer> to indicate the final answer part. Do not provide multiple answers in the final answer to increase your chances of getting the answer right. You need to give the answer you think is the most appropriate."
# 			"Do not include any explanations, context, or additional information. Just focus on delivering the exact answer as concisely as possible!!! "
# 			"There is no need to answer the question in the form of a complete sentence, just provide the answer in the form of a noun, time, entity, single number, yes or no, etc."
# 		)
# 	])
# fusion_generator = fusion_prompt | model
#
# async def fusion_node(state: State) -> State:
# 	assert state["question"] is not None, "Question is required"
# 	assert state["guidance"] is not None, "Guidance is required"
# 	critique_messages:Sequence[BaseMessage] = state["messages"][1:]
# 	response:AIMessage = await fusion_generator.ainvoke(input={"messages": critique_messages, "question": state["question"]})
# 	answer_matches = re.findall(r"<answer>(.*?)</answer>", response.content, re.DOTALL)
# 	if answer_matches:
# 		state["prediction"] = answer_matches[0]
# 	else:
# 		state["prediction"] = "None"
#
# 	return state
#
# # Define the conditional edge that determines whether to continue or not
# def should_continue(state: State) -> Literal["fuse", "tools"]:
# 	messages = state["messages"]
# 	last_message = messages[-1]
#
# 	# If there is no function call, then we finish
# 	if last_message.tool_calls:
# 		return "tools"
# 	# Otherwise if there is, we continue
# 	else:
# 		return "fuse"
#
# def tools_router(state: State) -> Literal["fuse",  "critique"]:
# 	if state["remaining_steps"] <= 3:
# 		return "fuse"
# 	else:
# 		return "critique"
#
#
# from langgraph.graph import StateGraph
#
# workflow = StateGraph(State)
# workflow.add_node("guide", guidance_node)
# workflow.add_node("critique", critique_node)
# workflow.add_node("tools", tool_node)
# workflow.add_node("fuse", fusion_node)
#
# workflow.set_entry_point("guide")
# workflow.add_edge("guide", "critique")
# workflow.add_conditional_edges("tools", tools_router)
# workflow.add_conditional_edges("critique", should_continue)
# workflow.add_edge("fuse", "__end__")
#
# app = workflow.compile()
#
# from IPython.display import Image, display
#
# try:
#     display(Image(app.get_graph().draw_mermaid_png()))
# except Exception:
#     pass

In [None]:
# inputs = {**dataset[5], "messages": []}
# async for event in app.astream(inputs):
#     for k, v in event.items():
#         if k != "__end__":
#             print(v)

In [3]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import ToolMessage, BaseMessage, HumanMessage, AIMessage


# from langchain_core.prompts import ChatPromptTemplate
# from typing import Union
# from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
#
#
# def parse_pretty_repr(pretty_str: str) -> Union[HumanMessage, AIMessage, ToolMessage]:
# 	"""
#     解析 pretty_repr 字符串并返回对应的 LangChain 消息对象。
#     """
# 	# 去掉多余的分隔符和空行
# 	lines = [line.strip() for line in pretty_str.split("\n") if line.strip()]
#
# 	msg_type = lines[0].strip("=").strip()
#
# 	content = "\n".join(lines[1:])  # 消息的实际内容
#
# 	if "Human Message" in msg_type:
# 		return HumanMessage(content=content)
# 	elif "Ai Message" in msg_type:
# 		if "Tool Calls" in content:
# 			tool_calls = re.findall(r"(\w+)\s+\((.*)\).*Args:\s+(.+)", content, re.DOTALL)
# 			list_tool_calls = []
# 			for tool_call in tool_calls:
# 				name, call_id = tool_call[:2]
# 				args = re.findall(r"(\w+): (.+)", tool_call[2], re.DOTALL)
# 				args_dict = {arg[0]: arg[1] for arg in args}
# 				list_tool_calls.append({"name": name, "args": args_dict, "id": call_id})
#
# 			return AIMessage(content="", tool_calls=list_tool_calls)
# 		else:
# 			return AIMessage(content=content)
# 	elif "Tool Message" in msg_type:
# 		name = re.findall(r"Name: (.*)\n", content, re.DOTALL)[0]
# 		content = content.split(f"{name}\n")[-1]
# 		return ToolMessage(content=content, name=name, tool_call_id="123")
#
#
# async def construct_messages(item: dict) -> list:
# 	messages = []
# 	for message in item["messages"][1:]:
# 		messages.append(parse_pretty_repr(message))
# 	return messages
#
def construct_messages(messages_dict: list[dict]) -> list[BaseMessage]:
	messages = []
	for message_dict in messages_dict:
		if message_dict["type"] == "human":
			messages.append(HumanMessage(**message_dict))
		elif message_dict["type"] == "ai":
			messages.append(AIMessage(**message_dict))
		elif message_dict["type"] == "tool":
			messages.append(ToolMessage(**message_dict))
		# elif message_dict["type"] == "system":
		# 	messages.append(SystemMessage(**message_dict))

	return messages

prediction_prompt = ChatPromptTemplate.from_messages(
	[
		(
			"placeholder",
			"{messages}"
		),
		(
			"user",
			"Question: {question}"
			"Now based on the previous information, revise your answer. Use the XML tag <answer></answer> to indicate the final answer part. Do not provide multiple answers in the final answer to increase your chances of getting the answer right. You need to give the answer you think is the most appropriate."
			"Do not include any explanations, context, or additional information. Just focus on delivering the exact answer as concisely as possible!!! "
			"There is no need to answer the question in the form of a complete sentence, just provide the answer in the form of a noun, time, entity, single number, yes or no, etc."
		)
	])
prediction_generator = prediction_prompt | model


In [7]:
from langchain_core.messages import message_to_dict
from tqdm.asyncio import tqdm_asyncio
import os
import json
import logging
import nest_asyncio

# 配置logger
logging.basicConfig(
	level=logging.ERROR,  # 设置日志级别
	format='%(asctime)s - %(levelname)s - %(message)s',  # 设置日志格式
	handlers=[
		logging.FileHandler("inference.log"),  # 将日志输出到文件
		logging.StreamHandler()  # 也输出到控制台
	]
)

logger = logging.getLogger("InferenceLogger")

nest_asyncio.apply()
results = []
batch_size = 100

async def process(item):
	try:
		if not item.get("messages"):
			return {**item, "prediction": "None"}
		messages = construct_messages(item["messages"])
		response: AIMessage = await prediction_generator.ainvoke(input={"messages": messages, "question": item["question"]})
		answer_matches = re.findall(r"<answer>(.*?)</answer>", response.content, re.DOTALL)
		if answer_matches:
			prediction = answer_matches[0]
		else:
			prediction = "None"
		return {**item, "response": message_to_dict(response), "prediction": prediction}
	except Exception as e:
		logger.error(f"Error processing item: {item}. Error: {e}")
		raise

# async def process(item):
# 	try:
# 		item["messages"] = construct_messages(item["messages"])
# 		del item["prediction"]
# 		state = await app.ainvoke({**item}, config={"recursion_limit": 18})
# 		state["messages"] = messages_to_dict(state["messages"])
# 		logger.info(f"Processed item: {item}")
# 		return {**item, **state}
# 	except Exception as e:
# 		logger.error(f"Error processing item: {item}. Error: {e}")
# 		return {**item, "prediction": "None"}



async def self_improve_inference() -> None:
	error_indices = []  # 用于记录包含 "ERROR" 的条目索引

	# 读取已有结果或初始化文件
	if os.path.exists(save_results_path):
		logger.info(f"Loading existing results from {save_results_path}")
		with open(save_results_path, 'r') as file:
			for idx, line in enumerate(file):
				result = json.loads(line)
				results.append(result)
				# 检查是否存在 "prediction: ERROR"
				if "None" == result.get("prediction"):
					error_indices.append(idx)
	else:
		folder_path = os.path.dirname(save_results_path)
		os.makedirs(folder_path, exist_ok=True)
		logger.info(f"Created directory for results: {folder_path}")

	# 重新推理错误的数据
	if error_indices:
		logger.warning(f"Found {len(error_indices)} ERROR entries. Retrying inference...")
		error_data = [dataset[idx] for idx in error_indices]
		new_results = await tqdm_asyncio.gather(*(process(item) for item in error_data))
		# 更新原始结果
		for i, new_result in zip(error_indices, new_results):
			results[i] = new_result


	for idx in range(len(results), len(dataset), batch_size):
		batch = dataset[idx: min(idx + batch_size, len(dataset))]
		batch_results = await tqdm_asyncio.gather(*(process(item) for item in batch))
		results.extend(batch_results)

		logger.info(f"Processed batch starting at index {idx}")

		# 保存结果
		with open(save_results_path, 'w') as file:
			for result in results:
				file.write(json.dumps(result) + "\n")
		logger.info(f"Saved results to {save_results_path}")


In [8]:
await self_improve_inference()


  0%|          | 0/5 [00:00<?, ?it/s][A
 20%|██        | 1/5 [00:01<00:04,  1.16s/it][A
 80%|████████  | 4/5 [00:01<00:00,  2.89it/s][A
100%|██████████| 5/5 [00:01<00:00,  2.70it/s][A

  0%|          | 0/100 [00:00<?, ?it/s][A
  3%|▎         | 3/100 [00:02<01:34,  1.03it/s][A
 10%|█         | 10/100 [00:03<00:21,  4.23it/s][A
 14%|█▍        | 14/100 [00:03<00:13,  6.37it/s][A
 19%|█▉        | 19/100 [00:03<00:09,  9.00it/s][A
 26%|██▌       | 26/100 [00:03<00:05, 14.70it/s][A
 31%|███       | 31/100 [00:03<00:03, 17.90it/s][A
 44%|████▍     | 44/100 [00:03<00:01, 32.67it/s][A
 52%|█████▏    | 52/100 [00:03<00:01, 40.07it/s][A
 59%|█████▉    | 59/100 [00:03<00:00, 43.89it/s][A
 66%|██████▌   | 66/100 [00:04<00:01, 21.86it/s][A
 71%|███████   | 71/100 [00:04<00:01, 21.40it/s][A
 77%|███████▋  | 77/100 [00:05<00:00, 25.74it/s][A
 82%|████████▏ | 82/100 [00:06<00:01, 11.67it/s][A
 95%|█████████▌| 95/100 [00:06<00:00, 20.78it/s][A
 95%|█████████▌| 95/100 [00:17<00:00, 20