# Implementing LLMCompiler using LangGraph
By Kim, et. al [🔗](https://arxiv.org/abs/2312.04511)

LLMCompiler is an agent architecture intented on speeding up the latency of agentic tasks via fast, parallel tool execution. It has 3 main components:

1. Planner: generate a DAG of tasks.
2. Task Fetching Unit: schedules and executes the tasks
3. Joiner: Responds to the user or triggers a second plan


This notebook walks through each component and shows how to wire them together using LangGraph.

# Part 1: Planner


Largely adapted from [the original source code](https://github.com/SqueezeAILab/LLMCompiler/blob/main/src/llm_compiler/output_parser.py), the planner  accepts the input question and generates a task list to execute.

If it is provided with a previous plan, it is instructed to re-plan, which is useful if, upon completion of the first batch of tasks, the agent must take more actions.

The code below composes constructs the prompt template for the planner and composes it with LLM and output parser, defined in [output_parser.py](./output_parser.py). The output parser processes a task list in the following form:

```plaintext
1. tool_1(arg1="arg1", arg2=3.5, ...)
Thought: I then want to find out Y by using tool_2
2. tool_2(arg1="", arg2="${1}")'
3. join()<END_OF_PLAN>"
```

The "Thought" lines are optional. The `${#}` placeholders are variables. These are used to route tool (task) outputs to other tools.

In [1]:
from typing import Optional, Sequence

from langchain.chat_models.base import BaseChatModel
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableBranch
from langchain_core.tools import BaseTool
from llm_compiler.output_parser import LLMCompilerPlanParser

END_OF_PLAN = "<END_OF_PLAN>"


# The required extra "tool"
JOIN_DESCRIPTION = (
    "join():\n"
    " - Collects and combines results from prior actions.\n"
    " - A LLM agent is called upon invoking join to either finalize the user query or wait until the plans are executed.\n"
    " - join should always be the last action in the plan, and will be called in two scenarios:\n"
    "   (a) if the answer can be determined by gathering the outputs from tasks to generate the final response.\n"
    "   (b) if the answer cannot be determined in the planning phase before you execute the plans. "
)

planner_prompt_tmpl_str = (
    "Given a user query, create a plan to solve it with the utmost parallelizability. "
    "Each plan should comprise an action from the following {num_tools} types:\n"
    "{tool_descriptions}"
    f"\n{{num_toolsp1}}. {JOIN_DESCRIPTION}"
    "Guidelines:\n"
    " - Each action described above contains input/output types and description.\n"
    "    - You must strictly adhere to the input and output types for each action.\n"
    "    - The action descriptions contain the guidelines. You MUST strictly follow those guidelines when you use the actions.\n"
    " - Each action in the plan should strictly be one of the above types. Follow the Python conventions for each action.\n"
    " - Each action MUST have a unique ID, which is strictly increasing.\n"
    " - Inputs for actions can either be constants or outputs from preceding actions. "
    "In the latter case, use the format $id to denote the ID of the previous action whose output will be the input.\n"
    f" - Always call join as the last action in the plan. Say '{END_OF_PLAN}' after you call join\n"
    " - Ensure the plan maximizes parallelizability.\n"
    " - Only use the provided action types. If a query cannot be addressed using these, invoke the join action for the next steps.\n"
    " - Never introduce new actions other than the ones provided.\n\n"
    "{replan}"
    "{examples}"
)


def _generate_planner_prompt(
    tools: Sequence[BaseTool],
    example_prompt=str,
):
    tool_descriptions = "\n".join(
        f"{i+1}. {tool.name}: {tool.description}" for i, tool in enumerate(tools)
    )
    planner_prompt_template = ChatPromptTemplate.from_messages(
        [("system", planner_prompt_tmpl_str), ("user", "Question: {input}{context}")]
    ).partial(
        tool_descriptions=tool_descriptions,
        examples="Here are some examples:\n\n" + example_prompt
        if example_prompt
        else "",
        num_tools=len(tools),
        num_toolsp1=len(tools) + 1,
    )

    return planner_prompt_template


def create_planner(
    llm: BaseChatModel,
    example_prompt: str,
    tools: Sequence[BaseTool],
    stop: Optional[list[str]] = None,
):
    og_planner_prompt = _generate_planner_prompt(tools, example_prompt).partial(
        replan="",
        context="",
    )
    replanner_prompt = _generate_planner_prompt(tools, example_prompt).partial(
        replan=' - You are given "Previous Plan" which is the plan that the previous agent created along with the execution results '
        "(given as Observation) of each plan and a general thought (given as Thought) about the executed results."
        'You MUST use these information to create the next plan under "Current Plan".\n'
        ' - When starting the Current Plan, you should start with "Thought" that outlines the strategy for the next plan.\n'
        " - In the Current Plan, you should NEVER repeat the actions that are already executed in the Previous Plan.\n"
        " - You must continue the task index from the end of the previous one. Do not repeat task indices."
    )
    bound_llm = llm.bind(stop=stop)
    return (
        RunnableBranch(
            ((lambda x: x.get("context") is not None), replanner_prompt),
            og_planner_prompt,
        )
        | bound_llm
        | LLMCompilerPlanParser(tools=tools)
    )

#### Example usage

Here's an example usage of the planner module.

In [2]:
from typing import Optional

from langchain.tools import tool
from langchain_openai import ChatOpenAI


@tool
def get_user_id(first_name: str, last_name: str) -> Optional[int]:
    """Query the user IDs of everyone with the provided name."""
    student_ids = {
        ("Eric", "Zhang"): 1432,
        ("Sam", "Van Damm"): 8523,
        ("Will", "Van Damm"): 2341,
    }
    return student_ids.get((first_name, last_name))


@tool
def get_scores(class_name: str, user_id: int) -> Optional[str]:
    """Query the class registry for grades of the provided user ID."""
    return {
        ("Geology", 1432): "A+",
        ("Geology", 8523): "A",
        ("Geology", 2341): "B",
    }.get((class_name, user_id))


examples = (
    "Question: What's the user ID for Johnny Drop Tables?\n"
    '1. get_user_id(first_name="Johnny", "ast_name="Drop Tables")\n'
    f"2. join(){END_OF_PLAN}\n"
    "###\n"
    "\n"
    "Question: What was Eric Zhang's score in Calc?\n"
    '1. get_user_id("Eric")\n'
    '2. get_scores("calc", "$1")\n'
    f"3. join(){END_OF_PLAN}\n"
    "###\n"
    "\n"
)

planner = create_planner(
    ChatOpenAI(model="gpt-3.5-turbo"),
    example_prompt=examples,
    tools=[get_user_id, get_scores],
)

In [3]:
tasks = planner.invoke(
    {"input": "What are the Calc BC grades for Sam and Will Van Damm?"}
)
tasks

{1: {'idx': 1,
  'tool': StructuredTool(name='get_user_id', description='get_user_id(first_name: str, last_name: str) -> Optional[int] - Query the user IDs of everyone with the provided name.', args_schema=<class 'pydantic.main.get_user_idSchemaSchema'>, func=<function get_user_id at 0x104e33c40>),
  'args': {'first_name': 'Sam', 'last_name': 'Van Damm'},
  'dependencies': [],
  'thought': None},
 2: {'idx': 2,
  'tool': StructuredTool(name='get_user_id', description='get_user_id(first_name: str, last_name: str) -> Optional[int] - Query the user IDs of everyone with the provided name.', args_schema=<class 'pydantic.main.get_user_idSchemaSchema'>, func=<function get_user_id at 0x104e33c40>),
  'args': {'first_name': 'Will', 'last_name': 'Van Damm'},
  'dependencies': [],
  'thought': None},
 3: {'idx': 3,
  'tool': StructuredTool(name='get_scores', description='get_scores(class_name: str, user_id: int) -> Optional[str] - Query the class registry for grades of the provided user ID.', arg

## 2. Task Fetching Unit

This component schedules the tasks. In the paper, it's kept separate from the "executor", but here we create a single DAG defined in LangChain expression language.

The basic idea is that, given a list of dicts of the form:

```typescript
{
    tool: BaseTool,
    dependencies: number[],
}
```

1. Create a topological sort of the tasks
2. Execute them on the previous step's output, ensuring to perform variable substitution where appropriate

If we make the assumption that the tasks generated by the LLM are already sorted, we could adapt this to execute in a purely streaming fashion.

In [4]:
import functools
from typing import Any, Union

from langchain_core.runnables import (
    RunnableLambda,
    RunnableParallel,
    RunnablePassthrough,
)


def _sort_tasks(data):
    if not data:
        return []
    sorted_tasks = []
    # Remove tasks already completed
    min_idx = min([int(k) for k in data])
    data = {
        int(k): {
            **v,
            "dependencies": [dep for dep in v["dependencies"] if dep >= min_idx],
        }
        for k, v in data.items()
    }
    while data:
        no_deps = {k: v for k, v in data.items() if not v["dependencies"]}
        if not no_deps:
            raise ValueError("We seem to have run into a circular dependency.")

        sorted_tasks.append(no_deps)
        data = {
            k: {
                **v,
                "dependencies": [d for d in v["dependencies"] if d not in no_deps],
            }
            for k, v in data.items()
            if k not in no_deps
        }
    return sorted_tasks


def _resolve_arg(x: dict, arg: Union[str, Any]):
    if isinstance(arg, str) and arg.startswith("$"):
        try:
            return x[f"task_{arg[1:]}"]
        except:
            if arg.endswith(".output"):
                return x[f"task_{arg[1:-7]}"]
            raise
    else:
        return arg


def _execute_task(x, task):
    tool_to_use = task["tool"]
    args = task["args"]
    if isinstance(args, str):
        resolved_args = _resolve_arg(x, args)
    elif isinstance(args, dict):
        resolved_args = {key: _resolve_arg(x, val) for key, val in args.items()}
    else:
        # This will likely fail
        resolved_args = args
    try:
        return tool_to_use.invoke(resolved_args)
    except Exception as e:
        return (
            f"ERROR(Failed to call tool {tool_to_use} with args {tool_to_use}."
            + f" Args resolved to {resolved_args}. Error: {repr(e)})"
        )


def construct_dag(tasks):
    sorted_tasks = _sort_tasks(tasks)
    chain = None
    for idx, task_group in enumerate(sorted_tasks):
        if len(task_group) == 1 and next(iter(task_group.values()))["tool"] == "join":
            step = lambda x: {"join": x}
        else:
            # Cascade all results forward
            constructor = (
                RunnableParallel if chain is None else RunnablePassthrough.assign
            )
            task_dict = {}
            for idx, task in task_group.items():
                task_dict[f"task_{idx}"] = RunnableLambda(
                    functools.partial(_execute_task, task=task)
                ).with_config(run_name=f"task_{idx}")

            step = constructor(**task_dict).with_config(run_name=f"TaskGroup{idx}")
        if chain is None:
            chain = step
        else:
            chain |= step

    if chain is not None:
        return chain | RunnablePassthrough.assign(tasks=lambda _: tasks)
    return chain

In [5]:
graph = construct_dag(tasks)
graph.get_graph().print_ascii()

                      +------------------------------+                       
                      | Parallel<task_1,task_2>Input |                       
                      +------------------------------+                       
                             ***            ***                              
                           **                  **                            
                         **                      **                          
               +-------------+               +-------------+                 
               | Lambda(...) |               | Lambda(...) |                 
               +-------------+               +-------------+                 
                             ***            ***                              
                                **        **                                 
                                  **    **                                   
                     +-------------------------------+          

#### Example Plan

We still haven't introduced any cycles in our computation graph, so this is all easily expressed in LCEL.

In [6]:
chain = planner | construct_dag

In [7]:
example_question = "Did Sam Van Damm score higher than Eric Zhang in Geology?"
task_results = chain.invoke({"input": example_question})
# task_results["join"]

In [8]:
task_results

{'join': {'task_1': 8523,
  'task_2': 1432,
  'task_3': "ERROR(Failed to call tool name='get_scores' description='get_scores(class_name: str, user_id: int) -> Optional[str] - Query the class registry for grades of the provided user ID.' args_schema=<class 'pydantic.main.get_scoresSchemaSchema'> func=<function get_scores at 0x104e337e0> with args name='get_scores' description='get_scores(class_name: str, user_id: int) -> Optional[str] - Query the class registry for grades of the provided user ID.' args_schema=<class 'pydantic.main.get_scoresSchemaSchema'> func=<function get_scores at 0x104e337e0>. Args resolved to {}. Error: ValidationError(model='get_scoresSchemaSchema', errors=[{'loc': ('class_name',), 'msg': 'field required', 'type': 'value_error.missing'}, {'loc': ('user_id',), 'msg': 'field required', 'type': 'value_error.missing'}]))",
  'task_4': "ERROR(Failed to call tool name='get_scores' description='get_scores(class_name: str, user_id: int) -> Optional[str] - Query the class 

## 3. "Joiner" 

So now we have the planning and initial execution done. We need a component to process these outputs and either:

1. Respond with the correct answer.
2. Loop with a new plan.

The paper refers to this as the "joiner". It's another LLM call, defined below:

In [9]:
from langchain_core.output_parsers import StrOutputParser
from typing_extensions import TypedDict


def format_task(task, idx):
    tool = task["tool"]
    tool_name = tool if isinstance(tool, str) else tool.name  # Handle join()
    args = ", ".join([f"{k}={v}" for k, v in task["args"].items()])
    return f"{idx}. {tool_name}({args})"


def format_tasks(executor_output: dict):
    tasks = executor_output["tasks"]
    prior_observations = executor_output.get("observations")
    joined_output = executor_output["join"]
    execution_results = []
    for idx, task in tasks.items():
        observation_idx = f"task_{idx}"
        if observation_idx in joined_output:
            observation = joined_output[observation_idx]
            execution_results.append(f"{format_task(task, idx)}\n\t=> {observation}")
    joined_results = "\n".join(execution_results)
    result = f"Executed plan results:\n{joined_results}"
    if prior_observations:
        result += f"\nPrevious Results:\n{prior_observations}"
    return result


def _parse_joiner_output(raw_answer: str) -> str:
    thought, answer, is_replan = "", "", False  # default values
    raw_answers = raw_answer.split("\n")
    for ans in raw_answers:
        if ans.startswith("Action:"):
            answer = ans[ans.find("(") + 1 : ans.find(")")]
            is_replan = JOINER_REPLAN in ans
        elif ans.startswith("Thought:"):
            thought = ans.split("Thought:")[1].strip()
    if is_replan:
        return {"thought": thought, "context": answer}
    else:
        return {"thought": thought, "answer": answer}

In [10]:
from langchain_core.prompts import ChatPromptTemplate


def create_joiner(prompt, llm):
    return (
        (
            lambda x: {
                **x["plan"],
                "input": x["input"],
                "context": x.get("context"),
                "observations": x.get("observations"),
            }
        )
        | RunnablePassthrough.assign(scratchpad=format_tasks)
        | ChatPromptTemplate.from_messages([("system", prompt), ("user", "{input}")])
        | llm
        | StrOutputParser()
        | _parse_joiner_output
    )

In [11]:
JOINER_FINISH = "Finish"
JOINER_REPLAN = "Replan"

system_prompt = (
    "Solve a question answering task. Here are some guidelines:\n"
    " - In the Assistant Scratchpad, you will be given results of a plan you have executed to answer the user's question.\n"
    " - Thought needs to reason about the question based on the Observations in 1-2 sentences.\n"
    " - Ignore irrelevant action results.\n"
    " - If the required information is present, give a concise but complete and helpful answer to the user's question.\n"
    " - If you are unable to give a satisfactory finishing answer, replan to get the required information."
    " Respond in the following format:\n\n"
    "Thought: <reason about the task results and whether you have sufficient information to answer the question>\n"
    "Action: <action to take>\n"
    "Available actions:\n"
    f" (1) {JOINER_FINISH}(the final answer to return to the user): returns the answer and finishes the task.\n"
    f" (2) {JOINER_REPLAN}(the reasoning and other information that will help you plan again. Can be a line of any length): instructs why we must replan\n\n"
    " Examples:\n"
    "Question: How many users are currently using the new product?\n"
    "...task returns the number 32,000\n"
    "Thought: I find no issue with the original plan, and the results satisfy everything in the user question.\n"
    f"Action: {JOINER_FINISH}(32,000 users currently use the new product)\n###\n"
    "Question: How much cooler is it in NY than SF?\n"
    "...task results show SF is 57 degrees fahrenheit today, and they show in NY it has a high of 32 degrees fahrenheit \n"
    "Thought: I can answer by synthesizing the results.\n"
    f"Action: {JOINER_FINISH}(NY is 25 degrees cooler than SF today, as it has a high of 32 degrees Fahrenheit today, whereas in SF, it is 57 degrees Fahrenheit.)\n###\n"
    "Question: Are the gophers beating the rabbits??\n"
    "...task returns the a score of 7 for rabbits but no other value...\n"
    "Thought: I need the gophers' score to make a final decision.\n"
    f"Action: {JOINER_REPLAN}(The rabbits have a score of 7, but I need the gophers' score.)"
    "\n\nAssistant Scratchpad:\n{scratchpad}"
)


joiner = create_joiner(system_prompt, ChatOpenAI(model="gpt-4"))
joiner.invoke({"plan": task_results, "input": example_question})

{'thought': 'The initial plan failed to get the scores for Sam Van Damm and Eric Zhang in Geology. In order to provide a final answer, I need these scores.',
 'context': "We need to execute get_scores with class_name set to 'Geology' and user_id set to the respective IDs for Sam Van Damm (8523"}

## Compose using LangGraph

Now we have all the required pieces! Let's construct an LLMCompiler agent. We'll give it a search engine (Tavily) and a simple "calculate" function.

In [12]:
import getpass
import os

os.environ["TAVILY_API_KEY"] = (
    os.environ.get("TAVILY_API_KEY")
    if "TAVILY_API_KEY" in os.environ
    else getpass.getpass("Tavily API Key:")
)

In [13]:
from operator import add, mul, sub, truediv
from typing import Literal

from langchain_community.agent_toolkits import GmailToolkit
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.tools import tool


@tool
def calculate(
    arg1: float,
    arg2: float,
    op: Union[Literal["+"], Literal["-"], Literal["*"], Literal["/"]],
):
    """Calculate a mathematical operation on two arguments."""
    resolved_op = {"+": add, "-": sub, "*": mul, "/": truediv}
    return resolved_op[op](arg1, arg2)


tools = [TavilySearchResults(max_results=1), calculate]

In [14]:
calculate.invoke(dict(arg1=1, arg2=3, op="+"))

4.0

#### Defining the stateful graph

We'll define the agent as a stateful graph, with the main nodes being:

1. Plan and execute (the DAG from the first step above)
2. Join: determine if we should finish or replan
3. Recontextualize: update the graph state based on the output from the joiner


In [16]:
from typing import Dict

MAX_ITERATIONS = 5


class GraphState(TypedDict):
    input: str
    plan: Dict
    agent_output: Dict
    observations: Dict
    num_iterations: int  # Maximum
    context: str  # Exra commentary for the joiner
    stop_reason: str


def recontextualize(state):
    # Insert a context string for the re-planner.
    # This could alternatively call an LLM to provide additional logic
    context = state["agent_output"]["context"]
    num_iterations = int(state.get("num_iterations") or 1) + 1
    formatted_tasks = format_tasks(state["plan"])
    context_str = f"\n\nPrevious Plan:\n{formatted_tasks}\n" f"{context}"
    observations = state["observations"] or {}
    for task, observation in state["plan"]["join"].items():
        observations[task] = observation
    return {
        "context": context_str,
        "num_iterations": num_iterations,
        "observations": observations,
    }


def add_stop_reason(state: GraphState):
    # Helpful for letting the user know why the agent responded the way it did
    num_iterations = int(state.get("num_iterations") or 0)
    if num_iterations >= MAX_ITERATIONS:
        return {"stop_reason": "end_max_iter"}
    if state["agent_output"].get("answer"):
        return {"stop_reason": "answer"}
    return {"stop_reason": None}

In [17]:
import json

from langchain_core.tools import BaseTool
from langgraph.graph import END, StateGraph

workflow = StateGraph(GraphState)

# 1.  Define vertices

planner = create_planner(
    llm=ChatOpenAI(model="gpt-4-1106-preview"),
    # Add more examples to improve reliability
    example_prompt=(
        "Question: What's the capital of Myanmar?\n"
        '1. tavily_search_results_json(query="Capital of Myanmar)\n'
        f"2. join(){END_OF_PLAN}\n"
        "###\n"
        "\n"
    ),
    tools=tools,
)

plan_and_execute = planner | construct_dag
joiner = create_joiner(system_prompt, ChatOpenAI(model="gpt-4-1106-preview"))


# Assign each node to a state variable to update
workflow.add_node("plan_and_execute", RunnablePassthrough.assign(plan=plan_and_execute))
workflow.add_node("join", RunnablePassthrough.assign(agent_output=joiner))
workflow.add_node("recontextualize", recontextualize)
workflow.add_node("provide_stop_reason", add_stop_reason)


## Define edges

workflow.add_edge("plan_and_execute", "join")
workflow.add_edge("recontextualize", "plan_and_execute")
workflow.add_edge("join", "provide_stop_reason")

### This condition determines looping logic


def should_continue(state):
    if state["stop_reason"] is None:
        return "continue"
    return "end"


workflow.add_conditional_edges(
    start_key="provide_stop_reason",
    # Next, we pass in the function that will determine which node is called next.
    condition=should_continue,
    conditional_edge_mapping={
        # If it generates context, we must replan
        "continue": "recontextualize",
        # Otherwise we finish.
        "end": END,
    },
)
workflow.set_entry_point("plan_and_execute")
chain = workflow.compile()

## Simple question

Let's ask a simple question of the agent.

In [18]:
result = chain.invoke({"input": "What's the GDP of New York?"})
print(result["agent_output"]["answer"])

In 2022, the real GDP of New York was about 1.56 trillion U.S. dollars.


## Multi-hop question

In [21]:
result = chain.invoke(
    {
        "input": "What's the oldest parrot alive, and how much longer is that than the average?"
    },
    {
        "recursion_limit": 100,
    },
)

In [22]:
print(result["agent_output"]["answer"])

Cookie, a cockatoo, was the oldest parrot alive, having reached the age of 83, which is 23 years longer than the maximum average lifespan of a cockatoo in captivity, which is 60 years.


## Streaming

In [25]:
last_step = None
for step in chain.stream({"input": "What's ((3*(4+5)/0.5)+3245) + 8?"}):
    print("Step: ", str(step)[:10] + "...")
    last_step = step
print("***")
print(last_step["__end__"]["agent_output"]["answer"])

Step:  {'plan_and...
Step:  {'join': {...
Step:  {'provide_...
Step:  {'recontex...
Step:  {'plan_and...
Step:  {'join': {...
Step:  {'provide_...
Step:  {'__end__'...
3307.0
