In [1]:
from tqdm import tqdm
import json
from langchain_core.messages import HumanMessage, AIMessage
from langchain_openai import ChatOpenAI


import getpass
import os


def _set_if_undefined(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"Please provide your {var}")

_set_if_undefined("LANGCHAIN_API_KEY")

# Optional, add tracing in LangSmith
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "critic no tools"


In [2]:
from email import message
from typing import Literal
from langgraph.graph import END, StateGraph, START
from langgraph.graph.message import add_messages
from typing import Annotated
from typing_extensions import TypedDict
from operator import add
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage, merge_message_runs
dataset_name = "hotpot_qa"

MAX_ITERATIONS = 4

class State(TypedDict):
    input: str
    messages: Annotated[list, add]
    predictions: Annotated[list[str], add]
    iteration: int

critic_prompt = HumanMessage(content=("Inspect the previous messages and identify any potential issues or errors. "
                                    "Check the Plausibility、 Truthfulness、 Correctness of your previous answer. "
                                    "Your response should be short、 direct and concise."))

if dataset_name == "hotpot_qa":
    revise_prompt = HumanMessage(content=("Based on the previous criqique, revise your answer. Remember your FINAL ANSWER should be clear and concise.(a single number or phrases, not a sentence!) "
                                      "Your response should follow the original format. For following example:\n\n"
                                      "Proposed Answer: Let's think step by step. ...\n"
                                      "So the FINAL ANSWER is: <FINAL ANSWER>\n"))
    llm = ChatOpenAI(temperature=0, base_url="https://api.chsdw.top/v1", model="gpt-4o-mini")
elif dataset_name == "gsm8k":
    revise_prompt = HumanMessage(content=("Based on the previous criqique, revise your answer.  "
                                      "Your response should follow the original format. Here is a better solution:\n\n"))
    llm = ChatOpenAI(temperature=0.5, base_url="https://api.chsdw.top/v1", model="gpt-4o-mini")
elif dataset_name == "toxicity":
    revise_prompt = HumanMessage(content=("Based on the previous criqique, revise your answer. Remember your FINAL ANSWER should be clear and concise.(a single number or phrases, not a sentence!) "
                                      "Your response should follow the original format. For following example:\n\n"
                                      "Proposed Answer: Let's think step by step. ...\n"
                                      "So the FINAL ANSWER is: <FINAL ANSWER>\n"))
    
def criticize(state):
    try:
        messages = [critic_prompt]
        critique = llm.invoke(state["messages"] + messages)
        messages.append(AIMessage(**critique.dict(exclude={"type", "name"})))
    except:
        return {
        "iteration":state["iteration"] + 1 
        } 
    return {
    "messages":messages,
    "iteration":state["iteration"] + 1 
    }

def react(state):
    try:
        question_message = HumanMessage(content=state["input"])
        messages = merge_message_runs(revise_prompt, question_message)
        result = llm.invoke(state["messages"] + [messages])
        result_message = AIMessage(**result.dict(exclude={"type", "name"}))
    except:
        return{
            "messages": [SystemMessage(content="Sorry, I have trouble understanding your answer. Please try again.")],
            "predictions": [],
        }
    return {
        "messages": [revise_prompt, HumanMessage(content=state["input"]), result_message],
        "predictions": [result_message.content.split("FINAL ANSWER :")[-1].strip()],
    }

# Either agent can decide to end
from typing import Literal

def should_end(state) -> Literal["critic", "__end__"]:
    if state["iteration"] == 4 or len(state["predictions"]) > 1 and state["predictions"][-1] == state["predictions"][-2]:
        return "__end__"
    else:
        return "critic"
    
def should_criticize(state) -> Literal["react"]:
    if state["messages"][-1].tool_calls:
        return "call_tools"
    else:
        return "react"
    


# from IPython.display import Image, display

# try:
#     display(Image(graph.get_graph(xray=True).draw_mermaid_png()))
# except Exception:
#     # This requires some extra dependencies and is optional
#     pass


# Load graph

In [4]:
from self_improve import *

dataset_name = "hotpot_qa"

builder = StateGraph(State)

builder.add_node("critic", criticize)
builder.add_node("react", react)

builder.add_edge(START, "critic")
builder.add_edge("critic", "react")

builder.add_conditional_edges("react", should_end)

graph = builder.compile()

In [20]:
# 并行处理
import asyncio
from langchain_core.messages import BaseMessage

async def reflect(item, graph, dataset_name:str="hotpot_qa") -> str:
    if dataset_name == "hotpot_qa":
        input={"messages": [HumanMessage(content=item["question"]), HumanMessage(content=item["prediction"])], "iteration": 0}
    elif dataset_name == "toxicity":
        messages = HumanMessage(content=item["prompt"]["text"])
    try:
        result = await graph.ainvoke(input=input)
        return {**item, "predictions": result["predictions"]}
    except Exception:
        return {**item, "predictions": ["None"]}

# Hotpot QA

In [5]:
dataset_name = "hotpot_qa"
mode = "critic_no_tool"
num_test_sample = 200

# Load Dataset

In [6]:
from datasets import load_dataset
dataset = load_dataset("json", data_files=f"../output/{dataset_name}/200_cot.jsonl", split="train")
if num_test_sample > 0:
    dataset = dataset.select(range(num_test_sample))
print(dataset)

Dataset({
    features: ['idx', 'question', 'answer', 'prediction'],
    num_rows: 200
})


In [None]:
from tqdm.asyncio import tqdm_asyncio
results = await tqdm_asyncio.gather(*(reflect(item, graph, dataset_name) for item in dataset))

In [21]:
save_folder = f"/Users/ariete/Projects/self-improve/output/{dataset_name}"
os.makedirs(save_folder, exist_ok=True)
with open ("/Users/ariete/Projects/self-improve/output/{}/{}_{}.jsonl".format(dataset_name, num_test_sample, mode), "w") as f:
    for idx, result in enumerate(results):
        f.write(json.dumps({"idx": idx, "question": dataset[idx]["question"], "answer":dataset[idx]["answer"], "predictions": results[idx]["predictions"]}) + "\n")