In [None]:
import getpass
import os


def _set_if_undefined(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"Please provide your {var}")

_set_if_undefined("LANGCHAIN_API_KEY")

# Optional, add tracing in LangSmith
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "critic"

# Tools

In [None]:
from utils.tools import construct_tools, get_tools_descriptions
from langgraph.prebuilt import ToolNode

tools = construct_tools()
tools_descriptions = get_tools_descriptions(tools)
tool_node = ToolNode(tools)
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(temperature=0, model="gpt-4o-mini", base_url="https://api.chsdw.top/v1", top_p=1, max_retries=3)
llm_with_tools = llm.bind_tools(tools)

# Define State

In [None]:
from typing import Literal
from langgraph.graph import END, StateGraph, START
from langgraph.graph.message import add_messages
from typing import Annotated
from typing_extensions import TypedDict
from operator import add

MAX_ITERATIONS = 4

class State(TypedDict):
    input: str
    react_messages: Annotated[list[list], add]
    predictions: Annotated[list[str], add]
    critiques: Annotated[list[str], add]
    iteration: int

# Define Critic

In [42]:
import stat
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage, AIMessage

critic_prompt = HumanMessage(content=("What's the problem of the previous answer? "
                                      "Reflect on the process of problem-solving. "
                                      "Identify any potential issues or errors. "
                                      "Then list them in a single response. "))

async def criticize(state):
    try:
        messages = [critic_prompt]
        critique = await llm.ainvoke(state["react_messages"][-1] + messages)
        messages.append(AIMessage(**critique.dict(exclude={"type", "name"})))
    except:
        return {
        "iteration":state["iteration"] + 1 ,
        "critiques": ["I'm sorry, I couldn't generate a critique. Please try again."]
        } 
    return {
    "iteration": state["iteration"] + 1 ,
    "critiques": [critique.content]
    } 

# Define Reacter

In [47]:
from langchain_core.prompts import ChatPromptTemplate
from langgraph.prebuilt import create_react_agent
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage

# revise_prompt = ("Based on the previous criqique, you should use tools to confirm each point listed one by one. "
#                  "If you find any errors in the previous points, correct them and change the information in the subsequent points accordingly! "
#                  "Remember that there may be multiple tools that can be used at one time. "
#                  "So if there are more than one tool that can be used, your single response should contain all of them. "
#                  "Then revise your answer. Remember your FINAL ANSWER should be clear and concise.(a single number or phrases, not a sentence!) "
#                  "Your final response should follow the original format. For following example:\n\n"
#                  "{question}\n"
#                  "Let's think step by step. (the reasoning of your thought)\n"
#                  "So the FINAL ANSWER is: <FINAL ANSWER>\n")

reflect_prompt = ("Based on the criqique, use tools to check the Truthfulness and Plausibility of your answer. "
                  "If you find any errors in the previous process, correct them and change the information in the subsequent step accordingly! ")

revise_prompt = ("Based on the previous messages, revise your answer.\n"
                 "Remember your FINAL ANSWER should be clear and concise.(a single number or phrases, not a sentence!) "
                 "Your response should follow the original format. For following example:\n\n"
                 "Question: <ORIGINAL QUESTION>\n"
                 "Let's think step by step. (the reasoning of your thought)\n"
                 "FINAL ANSWER: <FINAL ANSWER>\n\nBegin !\n\n\n"
                 "Question: {question}\n")

reacter = create_react_agent(model=llm, tools=tools, state_modifier=revise_prompt)

async def react(state):
    messages = [HumanMessage(content=f"{state["input"]}\n{state["predictions"][-1]}\n\n{state["critiques"][-1]}")]
    try:
        reflect_result = await reacter.ainvoke(input={"messages": messages}, config={"recursion_limit": 15})
        revised_answer = await llm.ainvoke(reflect_result["messages"]+[HumanMessage(content=revise_prompt.format(question=state["input"]))])
    except Exception as e:
        return{
            "react_messages": [state["react_messages"][-1]], 
            "predictions": ["None"]
        }
    return {
        "react_messages": [[HumanMessage(content=f"{state["input"]}\n{revised_answer.content}")]],
        "predictions": [revised_answer.content],
    }

# Define Router

In [48]:
# Either agent can decide to end
from typing import Literal

def should_end(state) -> Literal["critic", "__end__"]:
    if state["iteration"] == MAX_ITERATIONS or len(state["predictions"]) > 2 and state["predictions"][-1].split("FINAL ANSWER:")[-1].strip() == state["predictions"][-2].split("FINAL ANSWER:")[-1].strip():
        return "__end__"
    else:
        return "critic"


# Construct Graph

In [49]:
builder = StateGraph(State)

builder.add_node("critic", criticize)
builder.add_node("react", react)
# builder.add_node("tools", tool_node)

builder.add_edge(START, "critic")
builder.add_edge("critic", "react")


builder.add_conditional_edges("react", should_end)

graph = builder.compile()

In [55]:
from datasets import load_dataset
dataset_name = "gsm8k"
num_test_sample = 200
mode = "critic"
batch_size = 50

dataset = load_dataset("json", data_files=f"/Users/ariete/Projects/self-improve/output/gsm8k/200_pot.jsonl", split="train")
if num_test_sample > 0:
    dataset = dataset.select(range(num_test_sample))

print(dataset)


Dataset({
    features: ['idx', 'question', 'answer', 'python code', 'prediction'],
    num_rows: 200
})


In [51]:
from time import sleep
from tqdm.asyncio import tqdm_asyncio
import asyncio
# define process function
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage, BaseMessage
import random

In [57]:
import json
import asyncio
from tqdm.asyncio import tqdm_asyncio
batch_size = 50
results = []
semaphore = asyncio.Semaphore(20)
async def process(item, dataset_name:str="hotpot_qa", timeout: int = 180):
    if dataset_name == "hotpot_qa":
        messages = [HumanMessage(content=f"{item['question']}\n{item['prediction']}")]
        input = {"react_messages": [messages], "input": item["question"], "predictions": [item["prediction"]], "iteration": 0}
    elif dataset_name == "gsm8k":
        messages = [HumanMessage(content=f"Write a python code which could be used to solve the following problem, variable <answer> should contain the final answer. Use \"print(answer)\" to get the final answer.\n{item['question']}\n{item['prediction']}")]
        input = {"react_messages": [messages], "input": item["question"], "predictions": [item["prediction"]], "iteration": 0}
    try:
        async with semaphore:
            return await asyncio.wait_for(graph.ainvoke(input=input), timeout=timeout)
    except asyncio.TimeoutError:
        print(f"Timeout processing item {item}")
        return f"Timeout on {item}"
    except Exception as e:
        print(f"Error processing item {item}: {e}")
        return f"Error on {item}: {str(e)}"

In [58]:
temp = await process(dataset[0], dataset_name=dataset_name)



In [54]:
for i in range(0, 200, batch_size):
    batch = dataset.select(range(i, i + batch_size))
    batch_results = await asyncio.gather(*(process(item, dataset_name) for item in batch))
    results.extend(batch_results)

with open("/Users/ariete/Projects/self-improve/output/hotpot_qa/200_critic.jsonl", 'w') as f:
    for idx, item in enumerate(results):
        if isinstance(item, str):
            temp = {"idx": idx, "question":  dataset[idx]["question"], "predictions": [dataset[idx]["prediction"]], "answer": dataset[idx]["answer"]}
            f.write(json.dumps(temp) + "\n")
        else:
            temp = {"idx": idx, "question": item["input"], "predictions": item["predictions"], "answer": dataset[idx]["answer"]}
            f.write(json.dumps(temp) + "\n")



Cache hit for text: Brown County, Kansas




Cache hit for text: Ralph Hefferline




Cache hit for text: Jerry Goldsmith




Cache hit for text: Carrefour




Cache hit for text: Handi-Snacks




Cache hit for text: 514th Flight Test Squadron




Cache hit for text: Ángel Cabrera




Cache hit for text: Marco Da Silva




Cache hit for text: Gurney Norman


