In [1]:
# install
!pip install -qU langgraph langsmith langchain_anthropic

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.4/50.4 kB[0m [31m378.8 kB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.7/87.7 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m149.1/149.1 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m891.5/891.5 kB[0m [31m26.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.6/75.6 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.9/77.9 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m393.9/393.9 kB[0m [31m21.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m141.9/141.9 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
# import
from typing import Optional, List, Annotated, Tuple
from typing_extensions import TypedDict
from langchain_anthropic import ChatAnthropic
from langchain_core.prompts import PromptTemplate
from dataclasses import dataclass
import json

In [9]:
@dataclass
class EvalResult:
    passed: bool
    details: str

@dataclass
class HLAState():
    lla_code: str
    lla_graph: Annotated[dict, "json describing the graph nodes and edges"]
    task: str
    eval_results: List[EvalResult]

In [4]:
import os
from google.colab import userdata
os.environ["ANTHROPIC_API_KEY"] = userdata.get('HACKATHON_ANTHROPIC_API_KEY')


NUM_TRIES = 3

FORMAT ="""
{
    "nodes": [
        {
            "name": str,
            "description": str,
            "input_names": [str],
            "output_names": [str]
        }
    ],
    "edges": [
        {
            "source_node": str,
            "destination_node": str
        }
    ],
    "cond_edges": [
        {
            "source_node": str,
            "destinations": [str]
        }
    ]
}
"""
NEW_PLAN_PROMPT_TEMPLATE = PromptTemplate.from_template(
"""In LangGraph, nodes are typically python functions where the first positional argument is the state.
The START Node is a special node that represents the node sends user input to the graph. The main purpose for referencing this node is to determine which nodes should be called first.
The END Node is a special node that represents a terminal node. This node is referenced when you want to denote which edges have no actions after they are done.

Edges define how the logic is routed and how the graph decides to stop. This is a big part of how your agents work and how different nodes communicate with each other. There are a few key types of edges:
- Normal Edges: Go directly from one node to the next.
- Conditional Edges: Call a function to determine which node(s) to go to next.
- Entry Point: Which node to call first when user input arrives.

Describe the LangGraph nodes and edges you would use to accomplishes this task:
{task}

Give your response as JSON in the following format, don't include an explanation or anything other than JSON:
{format}
"""
)
REVISION_PROMPT_TEMPLATE = PromptTemplate.from_template(
"""In LangGraph, nodes are typically python functions where the first positional argument is the state.
The START Node is a special node that represents the node sends user input to the graph. The main purpose for referencing this node is to determine which nodes should be called first.
The END Node is a special node that represents a terminal node. This node is referenced when you want to denote which edges have no actions after they are done.

Edges define how the logic is routed and how the graph decides to stop. This is a big part of how your agents work and how different nodes communicate with each other. There are a few key types of edges:
- Normal Edges: Go directly from one node to the next.
- Conditional Edges: Call a function to determine which node(s) to go to next.
- Entry Point: Which node to call first when user input arrives.

Describe the LangGraph nodes and edges you would use to accomplishes this task:
{task}

A previous iteration attempted this, and output this graph:
{graph}

But, when we evaluated the result this is what we found:
{eval_result}

Try again, fixing the previous attempt. Respond with JSON in the following format, don't include an explanation or anything other than JSON:
{format}
"""
)

In [14]:

def planner(state: HLAState) -> HLAState:
    assert state.task

    llm = ChatAnthropic(model="claude-3-haiku-20240307", api_key=ANTHROPIC_API_KEY)

    if len(state.eval_results) == 0:
        chain = NEW_PLAN_PROMPT_TEMPLATE | llm
        response = chain.invoke({"task": state.task, "format": FORMAT})
        # TODO: handle invalid json
        state.lla_graph = json.loads(response.content)
        return state
    else:
        chain = REVISION_PROMPT_TEMPLATE | llm
        response = chain.invoke({"task": state.task, "format": FORMAT, "graph": state.lla_graph, "eval_result": state.eval_results[-1].details})
        # TODO: handle invalid json
        state.lla_graph = json.loads(response.content)
        return state


In [11]:
def check_start_again(state: HLAState) -> str:
    if len(state.eval_results) > NUM_TRIES:
        return "__end__"
    result = state.eval_results[-1]
    if result.passed:
        return "__end__"
    return "planner"


In [15]:
# test script

def mock_evaluator(state: HLAState, does_pass: bool) -> HLAState:
    result = EvalResult(does_pass, "cond_edges should have >1 destinations")
    state.eval_results.append(result)
    return state

task = "Tell me about the weather"
state = HLAState("", [], task, [])
state = planner(state)
assert state.lla_code == ""
assert state.lla_graph != ""
assert state.task == task
assert len(state.eval_results) == 0

state = mock_evaluator(state, False)
assert state.lla_graph != ""
assert state.task == task
assert len(state.eval_results) == 1
assert not state.eval_results[-1].passed

next_node = check_start_again(state)
assert next_node == "planner"

old_graph = state.lla_graph
state = planner(state)
assert state.lla_graph != ""
assert state.lla_graph != old_graph
assert state.task == task
assert len(state.eval_results) == 1

state = mock_evaluator(state, True)
assert state.lla_graph != ""
assert state.task == task
assert len(state.eval_results) == 2
assert state.eval_results[-1].passed

next_node = check_start_again(state)
assert next_node == "__end__"