In [1]:
from typing import List, Dict
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.message import add_messages
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.tools.tavily_search import TavilySearchResults
from dotenv import load_dotenv
import re
import os

In [2]:
load_dotenv()

# Memory for state checkpointing
memory = MemorySaver()

# ReWOO State Definition
class ReWOOState(TypedDict):
    task: str
    plan_string: str
    steps: List  # List of (Plan description, step_name, tool, tool_input)
    results: Dict[str, str]
    result: str

# Groq LLaMA 3 Model Initialization
llm = ChatGroq(
    model="llama3-8b-8192",
    temperature=0,
    max_tokens=None,
    timeout=None,
    max_retries=2,
)

# Prompt Template for planning
planner_prompt_template = """
For the following task, make step-by-step plans to solve the problem. Indicate which external tool to use:
(1) Google(query): For online search.
(2) LLM(instruction): General knowledge reasoning.

Example:
Task: Example math problem
Plan: Analyze the problem using LLM. #E1 = LLM[Analyze math expression]
Plan: Search current economic situation. #E2 = Google[Nigeria economic news 2024]

Describe plans with rich details. Format:
Plan: (your plan description) #E(step_number) = (Tool)[(input)]

Task: {task}
"""

regex_pattern = r"Plan:\s*(.*?)\s*#E(\d+)\s*=\s*(\w+)\s*\[(.*?)\]"

planner_chain = ChatPromptTemplate.from_template(planner_prompt_template) | llm
search_tool = TavilySearchResults(max_results=2)

# Step 1: Planning Node
def planning_node(state: ReWOOState) -> Dict:
    task = state["task"]
    response = planner_chain.invoke({"task": task})
    plan_text = response.content

    matches = re.findall(regex_pattern, plan_text, re.DOTALL)
    steps = [(plan.strip(), f"#E{step}", tool.strip(), tool_input.strip())
             for plan, step, tool, tool_input in matches]

    return {"plan_string": plan_text, "steps": steps, "results": {}}

# Step 2: Tool Execution Node
def get_next_step_index(state: ReWOOState):
    if not state.get("results"):
        return 0
    return len(state["results"])

def tool_execution_node(state: ReWOOState) -> Dict:
    step_idx = get_next_step_index(state)
    if step_idx >= len(state["steps"]):
        return {}

    plan, step_name, tool, tool_input = state["steps"][step_idx]
    results = state.get("results", {})

    # Replace previous evidence variables in the input
    for k, v in results.items():
        tool_input = tool_input.replace(k, v)

    if tool == "Google":
        result = search_tool.invoke({"query": tool_input})
        output = str(result)
    elif tool == "LLM":
        output = llm.invoke(tool_input).content
    else:
        raise ValueError(f"Unsupported tool: {tool}")

    results[step_name] = output
    return {"results": results}

# Step 3: Solving Node
solve_template = """
Solve the following problem based on the Plan and collected Evidence:

{plan_string}

Evidence:
{evidence}

Task: {task}

Provide the final answer directly:
"""

def solve_node(state: ReWOOState) -> Dict:
    evidence = "\n".join([f"{k}: {v}" for k, v in (state["results"] or {}).items()])
    prompt = solve_template.format(
        plan_string=state.get("plan_string", ""),
        evidence=evidence,
        task=state["task"]
    )
    response = llm.invoke(prompt)
    return {"result": response.content}

# Step 4: Routing Logic
def route_next(state: ReWOOState) -> str:
    if len(state.get("results", {})) < len(state.get("steps", [])):
        return "tool"
    return "solve"



In [3]:
# Graph Construction
graph = StateGraph(ReWOOState)
graph.add_node("plan", planning_node)
graph.add_node("tool", tool_execution_node)
graph.add_node("solve", solve_node)

graph.add_edge("plan", "tool")
graph.add_conditional_edges("tool", route_next)
graph.add_edge("solve", END)
graph.add_edge(START, "plan")

app = graph.compile()


In [14]:
from IPython.display import Markdown, display

# Compile the graph
app = graph.compile()

# Extract mermaid code (only after compile)
mermaid_code = app.get_graph().draw_mermaid()

# Display the mermaid diagram (renders in Jupyter Notebooks)
display(Markdown(f"```mermaid\n{mermaid_code}\n```"))





```mermaid
---
config:
  flowchart:
    curve: linear
---
graph TD;
	__start__([<p>__start__</p>]):::first
	plan(plan)
	tool(tool)
	solve(solve)
	__end__([<p>__end__</p>]):::last
	__start__ --> plan;
	plan --> tool;
	solve --> __end__;
	tool -.-> plan;
	tool -.-> solve;
	tool -.-> __end__;
	classDef default fill:#f2f0ff,line-height:1.2
	classDef first fill-opacity:0
	classDef last fill:#bfb6fc

```

In [5]:

# Run Example
if __name__ == "__main__":
    task_input = {"task": "Write a movie script about Nigeria and its economic situation"}

    for s in app.stream(task_input):
        print(f"Current State: {s}\n{'-'*60}")

    print("\nFinal Result:")
    print(s.get('result'))


Current State: {'plan': {'plan_string': 'Here\'s a step-by-step plan to solve the problem:\n\n**Plan: Gather information about Nigeria\'s economic situation #E1 = Google[Nigeria economic news 2024]**\n\n* Start by searching online for recent news articles about Nigeria\'s economic situation using Google.\n* Use specific keywords such as "Nigeria economic news 2024" to get relevant results.\n* Read through the search results to get an overview of the current economic situation in Nigeria, including any major challenges or developments.\n* Take note of key points, such as GDP growth rates, inflation rates, and major industries driving the economy.\n\n**Plan: Analyze the economic situation using LLM #E2 = LLM[Analyze economic data]**\n\n* Use a Large Language Model (LLM) to analyze the economic data gathered from the search results.\n* Input the data into the LLM and ask it to identify trends, patterns, and correlations between different economic indicators.\n* Use the LLM\'s reasoning ca

In [13]:
from IPython.display import Markdown, display

mermaid_code = app.get_graph().draw_mermaid()
display(Markdown(f"```mermaid\n{mermaid_code}\n```"))


```mermaid
---
config:
  flowchart:
    curve: linear
---
graph TD;
	__start__([<p>__start__</p>]):::first
	plan(plan)
	tool(tool)
	solve(solve)
	__end__([<p>__end__</p>]):::last
	__start__ --> plan;
	plan --> tool;
	solve --> __end__;
	tool -.-> plan;
	tool -.-> solve;
	tool -.-> __end__;
	classDef default fill:#f2f0ff,line-height:1.2
	classDef first fill-opacity:0
	classDef last fill:#bfb6fc

```