In [3]:
import json
from langgraph.checkpoint.sqlite import SqliteSaver
from langchain_core.messages.tool import (
    ToolCall,
)
from langchain_anthropic.output_parsers import ToolsOutputParser
from typing import Annotated, Dict, List, Literal, TypedDict
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage

from langchain_core.messages import HumanMessage
from langchain_anthropic import ChatAnthropic
from langchain_core.tools import tool
from langgraph.checkpoint import MemorySaver
from langgraph.graph import END, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode
from langgraph.pregel import RetryPolicy
from langchain.load import dumps
from langgraph.graph import END, StateGraph, MessagesState


@tool
def reset_repo() -> str:
    """Resets the project repository to the initial state. Undoes all file changes."""
    return "Successful reset of repository"

@tool
def validate_diffs(diff: str) -> str:
    """Tests whether the Diff is applicable. Run this before compiling. Returns either a Diff Error or the applied file. Diff has to be wrapped in a Markdown codeblock and has to follow the file edit rules. The Diff verified here will not persist to disk."""
    return ""

class LineInfo(TypedDict):
    line_no: int
    content: str

class MavenReturn(TypedDict):
    updated_files: dict[str, str]
    compilation_has_succeeded: bool
    test_has_succeeded: bool
    error_text: str
    compile_error_details: Dict[str, Dict[int, List[LineInfo]]]

@tool
def compile_maven_stateful(diff: str) -> MavenReturn:
    """Compiles the project with the given diffs applied. Returns metadata for the run as well as the content of the changed files. The Diff applied here will persist to the disk, unless the repository is reset after. When the Diff has errors, nothing will be applied."""
    return {"updated_files": {}, "compilation_has_succeeded": False, "test_has_succeeded": False, "error_text": "", "compile_error_details": {}}

@tool
def compile_maven_stateless(diff: str) -> MavenReturn:
    """Compiles the project with the given diffs applied. Returns metadata for the run as well as the content of the changed files. The Diff applied wont persist to disk, subsequent file reads will show the old file."""
    return {"updated_files": {}, "compilation_has_succeeded": False, "test_has_succeeded": False, "error_text": "", "compile_error_details": {}}


@tool
def compile_maven_file_edit(new_file_content: str, file_path: str) -> MavenReturn:
    """Compiles the project, after replacing the file at file_path with the new_file_content. Returns metadata for the run as well as the content of the changed files. The File written here will persist to the disk, unless the repository is reset after."""
    print("[TOOL] Compiling Maven with full file edit", file_path, new_file_content)

    return {"updated_files": {}, "compilation_has_succeeded": False, "test_has_succeeded": False, "error_text": "", "compile_error_details": {}}



@tool
def read_file_lines(file_path: str, lines: list[int]) -> Dict[int, str]:
    """Reads the file lines (1-indexed) at the given path and returns it, or an error message if the file could not be read. Limit yourself to a reasonable amount of lines, otherwise do a full file read."""


    return {
        "1": "Hi!"
    }


@tool
def read_file(file_path: str) -> str:
    """Reads the file at the given path and returns it, or an error message if the file could not be read."""
    return "hi!"

@tool
def get_directory_tree_for_path(relative_directory_path: str) -> str:
    """Returns the directory tree of the given path. Make sure that the Path is a directory."""
    return json.dumps({})


base_tooling = [
    read_file,
    read_file_lines,
    get_directory_tree_for_path,
]

tools = base_tooling + [validate_diffs, reset_repo, compile_maven_stateful]


In [4]:



import random
import string


def should_continue(state: MessagesState) -> Literal["tools", "compile_agent"]:
    messages = state["messages"]
    last_message = messages[-1]
    # If the LLM makes a tool call, then we route to the "tools" node

    if last_message.tool_calls:
        print("[AGENT] Routing to tools")
        return "tools"
    print("[AGENT] Routing to compile agent")
    return "compile_agent"


def should_improve_non_test_diff(state: MessagesState) -> Literal["agent", END]:
    messages = state["messages"]
    last_message = messages[-1]

    if "Compilation and Testing successful:" in last_message.content:
        print("[AGENT] Compilation and Testing successful")
        return END
    try:
        parsed = json.loads(last_message.content)
        if parsed["compilation_has_succeeded"] and parsed["test_has_succeeded"]:
            print("[AGENT] Compilation and Testing successful")
            return END
        if parsed["compilation_has_succeeded"] and "Could not initialize class org.mockito.internal.creation.cglib.ClassImposterizer" in parsed["error_text"]:
            print("[AGENT] Compilation successful, hitting mockito issue")
            return END
    except:
        pass

    print("[AGENT] Back to Agent")
    return "agent"


def compile_agent(state: MessagesState):
        print("[AGENT] Compiling")
        messages = state["messages"]
        last_message = messages[-1]
        tool_name = "compile_maven_stateful"
        if trial_type == TrialType.STATELESS:
            tool_name = "compile_maven_stateless"
        if trial_type == TrialType.FULL_FILE_EDIT:
            tool_name = "compile_maven_file_edit"
            tool_call = ToolCall(
                name=tool_name,
                args={"diff": last_message.content},
                id="".join(random.choices(string.ascii_uppercase + string.digits, k=9)),
            )
        last_message.tool_calls = [tool_call]
        tools_by_name = {tool.name: tool for tool in tools}
        tool = tools_by_name[tool_name]
        try:
            tool_result = tool.invoke(tool_call["args"])
            messages.append(
                ToolMessage(
                    content=json.dumps(tool_result),
                    name=tool_call["name"],
                    tool_call_id=tool_call["id"],
                )
            )
        except Exception as e:
            tool_content = f"Error during compilation: {str(e)}"
            messages.append(
                ToolMessage(
                    content="",
                    name=tool_call["name"],
                    tool_call_id=tool_call["id"],
                    additional_kwargs={"error": e},
                )
            )
        return {"messages": messages}

In [5]:
workflow = StateGraph(MessagesState)

from langchain_openai import AzureChatOpenAI






llm = AzureChatOpenAI(
    # AZURE
    azure_deployment="gpt-4o-mini",
    api_version="2024-06-01",
    # END AZURE

    # model=language_model,
    temperature=0,
    max_tokens=None,
    timeout=None,
    max_retries=3,
)


llm_with_tools = llm.bind_tools(tools)

def call_model(state: MessagesState):
  messages = state["messages"]
  response = llm_with_tools.invoke(messages)
  return {"messages": [response]}


workflow.add_node("agent", call_model)
workflow.add_node("tools", ToolNode(tools))
workflow.add_node("compile_agent", compile_agent)
workflow.set_entry_point("agent")

# We now add a conditional edge
workflow.add_conditional_edges(
    "agent",
    should_continue,
)

workflow.add_conditional_edges(
    "compile_agent",
    should_improve_non_test_diff,
)

workflow.add_conditional_edges("tools", should_improve_non_test_diff)

app = workflow.compile()

In [6]:
print(app.get_graph().draw_mermaid())

%%{init: {'flowchart': {'curve': 'linear'}}}%%
graph TD;
	__start__([__start__]):::first
	agent(agent)
	tools(tools)
	compile_agent(compile_agent)
	__end__([__end__]):::last
	__start__ --> agent;
	agent -.-> tools;
	agent -.-> compile_agent;
	compile_agent -.-> agent;
	compile_agent -.-> __end__;
	tools -.-> agent;
	tools -.-> __end__;
	classDef default fill:#f2f0ff,line-height:1.2
	classDef first fill-opacity:0
	classDef last fill:#bfb6fc



In [14]:
from tabulate import tabulate
print(tabulate([(tool.name, tool.description) for tool in tools], headers=["Tool Name", "Description"], tablefmt="latex_booktabs"))

\begin{tabular}{ll}
\toprule
 Tool Name                   & Description                                                                                                                                                                                                                                                           \\
\midrule
 read\_file                   & Reads the file at the given path and returns it, or an error message if the file could not be read.                                                                                                                                                                   \\
 read\_file\_lines             & Reads the file lines (1-indexed) at the given path and returns it, or an error message if the file could not be read. Limit yourself to a reasonable amount of lines, otherwise do a full file read.                                                                  \\
 get\_directory\_tree\_for\_path & Returns the directory tree of the gi