# LLMCompiler
LLMCompiler is an agent architecture to speed up the execution of agentic tasks by eagerly-executed tasks within a DAG. It also saves costs on redundant token usage by reducing the number of calls to the LLM.
<br>
It has 3 main components:
1. Planner: stream a DAG of tasks
2. Task Fetching Unit: Schedules and executes the tasks as soon as they are executable
3. Joiner: Responds to the user or triggers a second plan

In [26]:
import re
import os
import json
import time
import base64
import asyncio
import platform
import requests
import operator
import itertools
import playwright
import numpy as np
import pandas as pd
import datetime as dt

from enum import Enum
from typing import Any
from typing import List
from typing import Dict
from typing import Tuple
from typing import Union
from typing import Literal
from typing import Iterable
from typing import Optional
from typing import Sequence
from typing import Annotated
from typing import TypedDict
from operator import itemgetter

from IPython import display
from IPython.display import HTML
from IPython.display import Image

from langgraph.graph import END
from langgraph.graph import StateGraph
from langgraph.graph import MessageGraph

from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings

from langchain_core.tools import BaseTool
from langchain_core.messages import BaseMessage
from langchain_core.messages.ai import AIMessage
from langchain_core.messages.chat import ChatMessage
from langchain_core.messages.tool import ToolMessage
from langchain_core.messages.human import HumanMessage
from langchain_core.messages.system import SystemMessage
from langchain_core.messages.function import FunctionMessage
from langchain_core.prompts.image import ImagePromptTemplate

from langchain_core.pydantic_v1 import Field
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables import RunnableBranch
from langchain_core.runnables import RunnableParallel
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import chain as chain_decorator

from langchain_core.runnables.graph import CurveStyle
from langchain_core.runnables.graph import NodeColors
from langchain_core.runnables.graph import MermaidDrawMethod

from langchain_core.language_models import BaseChatModel

from langchain import hub
from langchain.schema import Document
from langchain.prompts import PromptTemplate
from langchain.prompts import ChatPromptTemplate
from langchain.prompts import MessagesPlaceholder
from langchain.prompts import HumanMessagePromptTemplate
from langchain.prompts import SystemMessagePromptTemplate
from langchain.agents import create_openai_functions_agent
from langchain.text_splitter import RecursiveCharacterTextSplitter

from langchain_community.vectorstores import Chroma
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.tools.tavily_search import TavilySearchResults

from math_tools import get_math_tool
from output_parser import LLMCompilerPlanParser, Task

from concurrent.futures import wait
from concurrent.futures import ThreadPoolExecutor

from dotenv import load_dotenv
load_dotenv()
os.environ['LANGCHAIN_TRACING_V2'] = 'true'
os.environ['LANGCHAIN_PROJECT'] = 'LLMCompiler'

## 1. Tools
We'll first define the tools for the agent to use in our demo. We'll give it the class search engine + calculator combo

In [2]:
calculate = get_math_tool(ChatOpenAI(model='gpt-4o'))
search = TavilySearchResults(max_results=1, description='tavily_search_results_json(query="the search query") - a search engine.')
tools = [search, calculate]

  warn_deprecated(


In [3]:
calculate.invoke({
    'problem': 'What is the temp of sf + 57',
    'context': ['The temperature of sf is 32 degrees']
})

'89'

## 2. Planner
Largely adopted from the original source code, the planner accepts the input question and generates a task list to execute.<br>
If it is provided with a previous plan, it is instructed to re-plan, which is useful if, upon completion of the first batch of tasks, the agent must take more actions.<br><br>
The code below constructs the prompt template for the planner and composes it with LLM and output parser, defined in output_parser.py. The output parser processes a task list in the following form.<br>
```plaintext
1. tool_1(arg1='arg1', arg2=3.5, ...)
Thought: I then want to find out Y by using tool_2
2. tool_2(arg1='', arg2='${1}')
3. join()<END_OF_PLAN>
```

The Thought lines are optional. The ${#} placeholders are variables. These are used to route tool (task) outputs to other tools.

In [6]:
prompt = hub.pull('wfh/llm-compiler')
print(prompt.pretty_print())


Given a user query, create a plan to solve it with the utmost parallelizability. Each plan should comprise an action from the following [33;1m[1;3m{num_tools}[0m types:
[33;1m[1;3m{tool_descriptions}[0m
[33;1m[1;3m{num_tools}[0m. join(): Collects and combines results from prior actions.

 - An LLM agent is called upon invoking join() to either finalize the user query or wait until the plans are executed.
 - join should always be the last action in the plan, and will be called in two scenarios:
   (a) if the answer can be determined by gathering the outputs from tasks to generate the final response.
   (b) if the answer cannot be determined in the planning phase before you execute the plans. Guidelines:
 - Each action described above contains input/output types and description.
    - You must strictly adhere to the input and output types for each action.
    - The action descriptions contain the guidelines. You MUST strictly follow those guidelines when you use the actions.
 -

In [7]:
def create_planner(llm: BaseChatModel, tools: Sequence[BaseTool], base_prompt: ChatPromptTemplate):
    tool_descriptions = '\n'.join(f"{i + 1}. {tool.description}\n" for i, tool in enumerate(tools))
    planner_prompt = base_prompt.partial(replan='', num_tools=len(tools) + 1, tool_descriptions=tool_descriptions)
    replanner_prompt = base_prompt.partial(
        replan=' - You are given "Previous Plan" which is the plan that the previous agent created along with the execution results '
        '(given as Observation) of each plan and a general thought (given as Thought) about the executed results.'
        'You MUST use these information to create the next plan under "Current Plan".\n'
        ' - When starting the Current Plan, you should start with "Thought" that outlines the strategy for the next plan.\n'
        ' - In the Current Plan, you should NEVER repeat the actions that are already executed in the Previous Plan.\n'
        ' - You must continue the task index from the end of the previous one. Do not repeat task indices.',
        num_tools=len(tools) + 1,
        tool_descriptions=tool_descriptions
    )

    def should_replan(state: list):
        # context is passed as a system message
        return isinstance(state[-1], SystemMessage)
    
    def wrap_messages(state: list):
        return {'messages': state}
    
    def wrap_and_get_last_index(state: list):
        next_task = 0
        for message in state[::-1]:
            if isinstance(message, FunctionMessage):
                next_task = message.additional_kwargs['idx'] + 1
                break
        state[-1].content = state[-1].content + f' - Begin counting at : {next_task}'
        return {'messages': state}
    
    return (
        RunnableBranch((should_replan, wrap_and_get_last_index | replanner_prompt), wrap_messages | planner_prompt)
        | llm
        | LLMCompilerPlanParser(tools=tools)
    )

In [8]:
llm = ChatOpenAI(model='gpt-4o')
planner = create_planner(llm, tools, prompt)

In [9]:
example_question = 'What is the temperature in SF raised to the 3rd power?'

for task in planner.stream([HumanMessage(content=example_question)]):
    print(task['tool'], task['args'])
    print('----')

description='tavily_search_results_json(query="the search query") - a search engine.' max_results=1 {'query': 'current temperature in San Francisco'}
----
name='math' description='math(problem: str, context: Optional[list[str]]) -> float:\n - Solves the provided math problem\n - `problem` can either be a simple math problem (eg "1 + 3") or a word problem (eg "how many apples are there if there are 3 apples and 2 apples").\n - You cannot calculate multiple expressions in one call. For instance, `math("1 + 3, 2 + 4")` does not work. If you need to calculate multiple expressions, you need to call them separately like `math("1 + 3")` and then `math("2 + 4")`\n - Minimize the number of `math` actions as much as possible. For instance, instead of calling 2. math("what is the 10% of $1") and then call 3. math("$1 + $2"), you MUST call 2. math("what is the 110% of $1") instead, which will reduce the number of math actions.\n - You can optionally provide a list of strings as `context` to help t

## 3. Task Fetching Unit
This component schedules the tasks. It receives a stream of tools of the following format:
```typescript
{
    tool: BaseTool,
    dependencies: number[]
}
```

The basic idea is to begin executing tools as soon as their dependencies are met. This is done through multi-threading. We will combine the task fetching unit and executor.

### LLMCompiler
is a framework that enables an efficient and effective orchestration of parallel function calling with LLMs, including both open source and close-source models, by automatically identifying which tasks can be performed in parallel and which ones are interdependent.

In [13]:
def _get_observations(messages: List[BaseMessage]) -> Dict[int, Any]:
    # get all previous tool responses
    results = {}
    for message in messages[::-1]:
        if isinstance(message, FunctionMessage):
            results[int(message.additional_kwargs['idx'])] = message.content
    return results

```python
class Task(TypedDict):
    idx: int
    tool: BaseTool
    args: list
    dependencies: Dict[str, list]
    thought: Optional[str]
```

In [18]:
class SchedulerInput(TypedDict):
    messages: List[BaseMessage]
    tasks: Iterable[Task]

def _execute_task(task, observations, config):
    tool_to_use = task['tool']
    if isinstance(tool_to_use, str):
        return tool_to_use
    args = task['args']
    try:
        if isinstance(args, str):
            resolved_args = _resolve_arg(args, observations)
        elif isinstance(args, dict):
            resolved_args = {key: _resolve_arg(val, observations) for key, val in args.items()}
        else:
            # this will likely fail
            resolved_args = args

    except Exception as e:
        return (f'ERROR(Failed to call {tool_to_use.name} with args {args}. Args could not be resolved. Error: {repr(e)}')
    
    try:
        return tool_to_use.invoke(resolved_args, config)
    except Exception as e:
        return (f'ERROR(Failed to call {tool_to_use.name} with args {args}. Args resolved to {resolved_args}. Error: {repr(e)})')
    
def _resolve_arg(arg: Union[str, Any], observations: Dict[int, Any]):
    # $1 or ${1} -> 1
    ID_PATTERN = r'\$\{?(\d+)\}?'

    def replace_match(match):
        # if the string is ${123}, match.group(0) is ${123}, and match.group(1) is 123
        # return the match group, in this case the index, from the string. This is the index number we get back.
        idx = int(match.group(1))
        return str(observations.get(idx, match.group(0)))
    
    if isinstance(arg, str):
        return re.sub(ID_PATTERN, replace_match, arg)
    elif isinstance(arg, list):
        return [_resolve_arg(a, observations) for a in arg]
    else:
        return str(arg)
    
@chain_decorator
def schedule_task(task_inputs, config):
    task: Task = task_inputs['task']
    observations: Dict[int, Any] = task_inputs['observations']
    try:
        observation = _execute_task(task, observations, config)
    except Exception:
        import traceback

        observation = traceback .format_exception()
    observations[task['idx']] = observation

def schedule_pending_task(task: Task, observations: Dict[int, Any], retry_after: float=0.2):
    while True:
        deps = task['dependencies']
        if deps and (any([dep not in observations for dep in deps])):
            # dependencies not yet satisfied
            time.sleep(retry_after)
            continue
        schedule_task.invoke({'task': task, 'observations': observations})
        break

@chain_decorator
def schedule_tasks(scheduler_input: SchedulerInput) -> List[FunctionMessage]:
    '''Group the tasks into a DAG schedule'''
    # for streaming we are making a few simplifying assumptions
    # 1. The LLM does not create cyclic dependencies
    # 2. The LLM will not generate tasks with future deps
    # If this ceases to be a good assumption, you can either adjust to do a proper topological sort (not-stream)
    # or use a more complicated data-structure
    tasks = scheduler_input['tasks']
    args_for_tasks = {}
    messages = scheduler_input['messages']

    # if we are re-planning, we may have calls that depend on previous plans. Start with those
    observations = _get_observations(messages)
    task_names = {}
    originals = set(observations)
    # ^^ We assume each task inserts a different key above to avoid race conditions
    futures = []
    retry_after = 0.25 # retry after every quarter second
    with ThreadPoolExecutor() as executor:
        for task in tasks:
            deps = task['dependencies']
            task_names[task['idx']] = (task['tool'] if isinstance(task['tool'], str) else task['tool'].name)
            args_for_tasks[task['idx']] = task['args']
            if (deps and (any([dep not in observations for dep in deps]))):
                futures.append(executor.submit(schedule_pending_task, task, observations, retry_after))
            else:
                # no deps, or all deps satisfied
                # can schedule now
                schedule_task.invoke(dict(task=task, observations=observations))
                # futures.append(executor.submit(schedule_task.invoke dict(task=task, observations=observations)))

        # all tasks have been submitted or enqueued
        # wait for them to complete
        wait(futures)

    # convert observations to new tools messages to add to the state
    new_observations = {
        k: (task_names[k], args_for_tasks[k], observations[k]) for k in sorted(observations.keys() - originals)
    }

    tool_messages = [
        FunctionMessage(name=name, content=str(obs), additional_kwargs={'idx': k, 'args': task_args}) for k, (name, task_args, obs) in new_observations.items()
    ]

    return tool_messages

In [19]:
@chain_decorator
def plan_and_schedule(messages: List[BaseMessage], config):
    tasks = planner.stream(messages, config)
    # begin executing the planner immediately
    try:
        tasks = itertools.chain([next(tasks)], tasks)
    except StopIteration:
        # handle the case where tasks is empty
        tasks = iter([])
    scheduled_tasks = schedule_tasks.invoke({'messages': messages, 'tasks': tasks}, config)
    return scheduled_tasks

#### Example plan
We still haven't introduced any cycles in our computation graph, so this is all easily expressed in LCEL.

In [20]:
tool_messages = plan_and_schedule.invoke([HumanMessage(content=example_question)])

In [21]:
tool_messages

[FunctionMessage(content="[{'url': 'https://forecast.weather.gov/MapClick.php?lat=37.7749&lon=-122.4194', 'content': 'NOAA National Weather Service. Current conditions at SAN FRANCISCO DOWNTOWN (SFOC1) Lat: 37.77056°NLon: 122.42694°WElev: 150.0ft.'}]", additional_kwargs={'idx': 1, 'args': {'query': 'current temperature in San Francisco'}}, name='tavily_search_results_json'),
 FunctionMessage(content='ValueError(\'Failed to evaluate "${current_temperature}**3". Raised error: SyntaxError(\\\'invalid syntax\\\', (\\\'<expr>\\\', 1, 1, \\\'${current_temperature}**3\\\', 1, 2)). Please try again with a valid numerical expression\')', additional_kwargs={'idx': 2, 'args': {'problem': 'what is the temperature in San Francisco raised to the 3rd power?', 'context': ['$1']}}, name='math'),
 FunctionMessage(content='join', additional_kwargs={'idx': 3, 'args': ()}, name='join')]

## 4. Joiner
So now we have the planning and initial execution done. We need a component to process these outputs and either:
1. Respond with the correct answer
2. Loop with a new plan

<br>
The paper refers to this as the "joiner". It's another LLM call. We are using function calling to improve parsing reliability.

In [23]:
class FinalResponse(BaseModel):
    response: str


class Replan(BaseModel):
    feedback: str = Field(description='Analysis of the previous attempts and recommendations on what needs to be fixed.')


class JoinOutputs(BaseModel):
    '''Decide whether to replan or whether you can return the final response'''
    thought: str = Field(description='The chain of thought reasoning for the selected action')
    action: Union[FinalResponse, Replan]

joiner_prompt = hub.pull('wfh/llm-compiler-joiner').partial(examples='') # you can optionally add examples
llm = ChatOpenAI(model='gpt-4o')
runnable = joiner_prompt | llm.with_structured_output(JoinOutputs)

In [24]:
def _parse_joiner_output(decision: JoinOutputs) -> List[BaseMessage]:
    response = [AIMessage(content=f'Thought: {decision.thought}')]
    if isinstance(decision.action, Replan):
        return response + [SystemMessage(content=f'Context from last attempt: {decision.action.feedback}')]
    else:
        return response + [AIMessage(content=decision.action.response)]

def select_recent_messages(messages: list) -> dict:
    selected = []
    for msg in messages[::-1]:
        selected.append(msg)
        if isinstance(msg, HumanMessage):
            break
    return {'messages': selected[::-1]}

joiner = select_recent_messages | runnable | _parse_joiner_output

In [25]:
input_messages = [HumanMessage(content=example_question)] + tool_messages
joiner.invoke(input_messages)

[AIMessage(content='Thought: I attempted to find the current temperature in San Francisco and then calculate its cube, but I encountered an error. I need to successfully retrieve the current temperature first.'),
 SystemMessage(content='Context from last attempt: I need to retrieve the current temperature in San Francisco from a reliable source.')]

## 5. Compose using LangGraph
We'll define the agent as a stateful graph, with the main nodes being:
1. Plan and execute (the DAG from the first step above)
2. Join: determine if we should finish or replan
3. Recontextualize: update the graph state based on the output from the joiner

In [29]:
graph_builder = MessageGraph()

# 1. Define vertices
# we defined plan_and_schedule above already
# assign each node ot a state variable to update
graph_builder.add_node('plan_and_schedule', plan_and_schedule)
graph_builder.add_node('join', joiner)

# define edges
graph_builder.add_edge('plan_and_schedule', 'join')

# this condition determines looping logic
def should_continue(state: List[BaseMessage]):
    if isinstance(state[-1], AIMessage):
        return END
    return 'plan_and_schedule'

graph_builder.add_conditional_edges('join', should_continue)
graph_builder.set_entry_point('plan_and_schedule')
chain = graph_builder.compile()

In [30]:
for step in chain.stream([HumanMessage(content="What's the GDP of New York?")]):
    print(step)
    print('----')

{'plan_and_schedule': [FunctionMessage(content='[{\'url\': \'https://en.wikipedia.org/wiki/Economy_of_New_York_(state)\', \'content\': "36,000 farms occupy 7.6\\xa0million acres or about 25 percent of the state\'s land area, to produce a variety of food products.[22] Here are some of the items in which New York ranks high nationally:\\nNew York is an agricultural leader and is one of the top five states for agricultural products, including dairy, cattle, apples, cabbages, potatoes, beets, viticulture, onions, maple syrup and many others.[23] The state is the second largest producer of cabbage in the U.S.[22] In April 2021, GlobalFoundries, a company specializing in the semiconductor industry, moved its headquarters from Silicon Valley, California to its most advanced semiconductor-chip manufacturing facility in Saratoga County near a section of the Adirondack Northway, in Malta, New York.[9]\\nNew York City[edit]\\nNew York City, characterized as the world\'s premier financial center,[

In [33]:
step['join'][-1].content

'The GDP of the State of New York in 2022 was $2.053 trillion.'

#### Multi-hop question

In [34]:
steps = chain.stream([HumanMessage(content="What's the oldest parrot alive, and how much longer is that than the average?")], {'recursion_limit': 100})
for step in steps:
    print(step)
    print('----')

{'plan_and_schedule': [FunctionMessage(content='[{\'url\': \'https://www.guinnessworldrecords.com/world-records/442525-oldest-parrot-ever\', \'content\': "Oldest parrot ever. The oldest parrot ever is Cookie, a Major Mitchell\'s cockatoo (Cacatua leadbeateri) who was at least 82 years and 88 days old when he passed away on 27 August 2016. Cookie\'s exact age was unknown when he arrived at Brookfield Zoo in May 1934. His arrival was documented in a ledger dated May 1934, when he was estimated to ..."}]', additional_kwargs={'idx': 1, 'args': {'query': 'oldest parrot alive'}}, name='tavily_search_results_json', id='22f91ac3-29c5-461d-af47-c69dcf8df7b4'), FunctionMessage(content='[{\'url\': \'https://www.thesprucepets.com/how-long-do-parrots-and-other-pet-birds-live-1238433\', \'content\': "It\'s possible that a pet bird can outlive its owners\\nThe Spruce / Adrienne Legault\\nParrots and other birds can live up to 10 to 50 years or more depending on the type and the conditions they live i

In [35]:
step['join'][-1].content

"The oldest known parrot was Cookie, a Major Mitchell's cockatoo, who lived to be at least 82 years and 88 days old. The average lifespan of parrots ranges from 10 to 50 years, so Cookie lived significantly longer than the average parrot."

#### Multi-step math

In [36]:
for step in chain.stream([HumanMessage(content="What's ((3*(4+5)/0.5)+3245) + 8? What's 32/4.23? What's the sum of those two values?")]):
    print(step)

{'plan_and_schedule': [FunctionMessage(content='None', additional_kwargs={'idx': 1, 'args': {'problem': '3*(4+5)/0.5+3245+8'}}, name='math', id='59533f8e-111e-4b2f-b907-f19f41baa394'), FunctionMessage(content='None', additional_kwargs={'idx': 2, 'args': {'problem': '32/4.23'}}, name='math', id='9d123b3f-f9f3-426f-8509-d84d16d91783'), FunctionMessage(content='join', additional_kwargs={'idx': 3, 'args': ()}, name='join', id='df950090-a788-45b5-8790-153222e57c35')]}
{'join': [AIMessage(content='Thought: I need to perform the calculations to find the values and their sum as requested by the user.', id='aff1ffe9-7e39-4e67-ac5a-ebe27b89d4cc'), SystemMessage(content='Context from last attempt: I need to perform the calculations to find the values and their sum as requested by the user.', id='29309ac3-bcfe-40f5-b9d0-b11b438b359d')]}
{'plan_and_schedule': [FunctionMessage(content='None', additional_kwargs={'idx': 4, 'args': {'problem': '3*(4+5)/0.5 + 3245 + 8'}}, name='math', id='793bcc12-e2d6-

In [38]:
print(step['join'][-1].content)

I encountered difficulties in directly evaluating the expressions. Here are the steps you can follow to solve them manually:

1. Calculate the first expression: ((3*(4+5)/0.5)+3245) + 8
   - First, solve inside the parentheses: (4+5) = 9
   - Then multiply by 3: 3 * 9 = 27
   - Divide by 0.5: 27 / 0.5 = 54
   - Add 3245: 54 + 3245 = 3299
   - Finally, add 8: 3299 + 8 = 3307

2. Calculate the second expression: 32 / 4.23
   - Perform the division: 32 / 4.23 ≈ 7.57

3. Sum the two results: 3307 + 7.57 = 3314.57

So, the final sum is approximately 3314.57.


## Conclusion
Known limitations to the implementation above:
1. The planner output parsing format is fragile if your function requires more than 1 or 2 arguments. We could make it more robust by using streaming tool calling.
