diff --git a/example.env b/example.env index 8895d9bd..8c5f321a 100644 --- a/example.env +++ b/example.env @@ -17,6 +17,7 @@ PROMETHEUS_MAX_TOKEN_PER_NEO4J_RESULT=10000 # LLM API keys and model settings PROMETHEUS_ADVANCED_MODEL=gpt-4o PROMETHEUS_BASE_MODEL=gpt-4o +PROMETHEUS_MAX_TOKENS=64000 PROMETHEUS_ANTHROPIC_API_KEY=anthropic_api_key PROMETHEUS_GEMINI_API_KEY=gemini_api_key PROMETHEUS_OPENAI_API_KEY=openai_api_key diff --git a/prometheus/app/main.py b/prometheus/app/main.py index c20afbb9..554efdca 100644 --- a/prometheus/app/main.py +++ b/prometheus/app/main.py @@ -26,6 +26,7 @@ logger.info(f"KNOWLEDGE_GRAPH_CHUNK_SIZE={settings.KNOWLEDGE_GRAPH_CHUNK_SIZE}") logger.info(f"KNOWLEDGE_GRAPH_CHUNK_OVERLAP={settings.KNOWLEDGE_GRAPH_CHUNK_OVERLAP}") logger.info(f"MAX_TOKEN_PER_NEO4J_RESULT={settings.MAX_TOKEN_PER_NEO4J_RESULT}") +logger.info(f"MAX_TOKENS={settings.MAX_TOKENS}") @asynccontextmanager diff --git a/prometheus/app/services/issue_service.py b/prometheus/app/services/issue_service.py index e980706b..cebf7a87 100644 --- a/prometheus/app/services/issue_service.py +++ b/prometheus/app/services/issue_service.py @@ -40,6 +40,31 @@ def answer_issue( build_commands: Optional[Sequence[str]] = None, test_commands: Optional[Sequence[str]] = None, ): + """ + Processes an issue, generates patches if needed, runs optional builds and tests, and returning the results. + + Args: + issue_title (str): The title of the issue. + issue_body (str): The body of the issue. + issue_comments (Sequence[Mapping[str, str]]): Comments on the issue. + issue_type (IssueType): The type of the issue (BUG or QUESTION). + run_build (bool): Whether to run the build commands. + run_existing_test (bool): Whether to run existing tests. + number_of_candidate_patch (int): Number of candidate patches to generate. + dockerfile_content (Optional[str]): Content of the Dockerfile for user-defined environments. + image_name (Optional[str]): Name of the Docker image. + workdir (Optional[str]): Working directory for the container. + build_commands (Optional[Sequence[str]]): Commands to build the project. + test_commands (Optional[Sequence[str]]): Commands to test the project. + Returns: + Tuple containing: + - edit_patch (str): The generated patch for the issue. + - passed_reproducing_test (bool): Whether the reproducing test passed. + - passed_build (bool): Whether the build passed. + - passed_existing_test (bool): Whether the existing tests passed. + - issue_response (str): Response generated for the issue. + """ + # Construct the working directory if dockerfile_content or image_name: container = UserDefinedContainer( self.kg_service.kg.get_local_path(), @@ -51,7 +76,7 @@ def answer_issue( ) else: container = GeneralContainer(self.kg_service.kg.get_local_path()) - + # Initialize the issue graph with the necessary services and parameters issue_graph = IssueGraph( advanced_model=self.llm_service.advanced_model, base_model=self.llm_service.base_model, @@ -63,7 +88,7 @@ def answer_issue( build_commands=build_commands, test_commands=test_commands, ) - + # Invoke the issue graph with the provided parameters output_state = issue_graph.invoke( issue_title, issue_body, @@ -84,11 +109,13 @@ def answer_issue( ) elif output_state["issue_type"] == IssueType.QUESTION: return ( - "", + None, False, False, False, output_state["issue_response"], ) - return "", False, False, False, "" + raise ValueError( + f"Unknown issue type: {output_state['issue_type']}. Expected BUG or QUESTION." + ) diff --git a/prometheus/app/services/service_coordinator.py b/prometheus/app/services/service_coordinator.py index 3f4467e0..693f317e 100644 --- a/prometheus/app/services/service_coordinator.py +++ b/prometheus/app/services/service_coordinator.py @@ -90,6 +90,34 @@ def answer_issue( test_commands: Optional[Sequence[str]] = None, push_to_remote: Optional[bool] = None, ): + """ + Processes an issue, generates patches if needed, runs optional builds and tests, + and can push changes to a remote branch. + + Args: + issue_number: The issue number to process. + issue_title: Title of the issue. + issue_body: Body of the issue. + issue_comments: Comments on the issue. + issue_type: Type of the issue (e.g., bug, feature). + run_build: Whether to run a build after applying the patch. + run_existing_test: Whether to run existing tests after applying the patch. + number_of_candidate_patch: Number of candidate patches to generate. + dockerfile_content: Optional Dockerfile content for user-defined environment. + image_name: Optional name for the Docker image. + workdir: Working directory for the container. + build_commands: Commands to build the project. + test_commands: Commands to test the project. + push_to_remote: Whether to push changes to a remote branch. + Returns: + A tuple containing: + - remote_branch_name: Name of the remote branch if changes were pushed. + - patch: The generated patch for the issue. + - passed_reproducing_test: Whether the reproducing test passed. + - passed_build: Whether the build passed. + - passed_existing_test: Whether existing tests passed. + - issue_response: Response from the issue service after processing. + """ logger = logging.getLogger("prometheus") formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") @@ -99,6 +127,7 @@ def answer_issue( logger.addHandler(file_handler) try: + # fix issue patch, passed_reproducing_test, passed_build, passed_existing_test, issue_response = ( self.issue_service.answer_issue( issue_title=issue_title, @@ -115,7 +144,7 @@ def answer_issue( test_commands=test_commands, ) ) - + # push to remote if requested remote_branch_name = None if patch and push_to_remote: remote_branch_name = self.repository_service.push_change_to_remote( diff --git a/prometheus/lang_graph/graphs/issue_graph.py b/prometheus/lang_graph/graphs/issue_graph.py index 834be6be..3e45b213 100644 --- a/prometheus/lang_graph/graphs/issue_graph.py +++ b/prometheus/lang_graph/graphs/issue_graph.py @@ -16,6 +16,13 @@ class IssueGraph: + """ + A LangGraph-based workflow to handle and triage GitHub issues with LLM assistance. + Attributes: + git_repo (GitRepository): The Git repository to work with. + graph (StateGraph): The state graph representing the issue handling workflow. + """ + def __init__( self, advanced_model: BaseChatModel, @@ -30,7 +37,9 @@ def __init__( ): self.git_repo = git_repo + # Entrance point for the issue handling workflow issue_type_branch_node = NoopNode() + # Subgraph nodes for issue classification and bug handling issue_classification_subgraph_node = IssueClassificationSubgraphNode( model=base_model, kg=kg, @@ -48,14 +57,16 @@ def __init__( build_commands=build_commands, test_commands=test_commands, ) - + # Create the state graph for the issue handling workflow workflow = StateGraph(IssueState) - + # Add nodes to the workflow workflow.add_node("issue_type_branch_node", issue_type_branch_node) workflow.add_node("issue_classification_subgraph_node", issue_classification_subgraph_node) workflow.add_node("issue_bug_subgraph_node", issue_bug_subgraph_node) - + # Set the entry point for the workflow workflow.set_entry_point("issue_type_branch_node") + # Define the edges and conditions for the workflow + # Classify the issue type if not provided workflow.add_conditional_edges( "issue_type_branch_node", lambda state: state["issue_type"], @@ -67,6 +78,7 @@ def __init__( IssueType.QUESTION: END, }, ) + # Add edges for the issue classification subgraph workflow.add_conditional_edges( "issue_classification_subgraph_node", lambda state: state["issue_type"], @@ -77,6 +89,7 @@ def __init__( IssueType.QUESTION: END, }, ) + # Add edges for ending the workflow workflow.add_edge("issue_bug_subgraph_node", END) self.graph = workflow.compile() @@ -91,6 +104,9 @@ def invoke( run_existing_test: bool, number_of_candidate_patch: int, ): + """ + Invoke the issue handling workflow with the provided parameters. + """ config = None input_state = { @@ -105,6 +121,7 @@ def invoke( output_state = self.graph.invoke(input_state, config) + # Reset the git repository to its original state self.git_repo.reset_repository() return output_state diff --git a/prometheus/lang_graph/nodes/context_provider_node.py b/prometheus/lang_graph/nodes/context_provider_node.py index d9637757..790f02c9 100644 --- a/prometheus/lang_graph/nodes/context_provider_node.py +++ b/prometheus/lang_graph/nodes/context_provider_node.py @@ -120,14 +120,18 @@ def __init__( self._logger = logging.getLogger("prometheus.lang_graph.nodes.context_provider_node") def _init_tools(self): - """Initializes KnowledgeGraph traversal tools. - + """ + Initializes KnowledgeGraph traversal tools. Returns: List of StructuredTool instances configured for KnowledgeGraph traversal. """ tools = [] + # === FILE SEARCH TOOLS === + + # Tool: Find file node by filename (basename) + # Used when only the filename (not full path) is known find_file_node_with_basename_fn = functools.partial( graph_traversal.find_file_node_with_basename, driver=self.neo4j_driver, @@ -142,6 +146,8 @@ def _init_tools(self): ) tools.append(find_file_node_with_basename_tool) + # Tool: Find file node by relative path + # Preferred method when the exact file path is known find_file_node_with_relative_path_fn = functools.partial( graph_traversal.find_file_node_with_relative_path, driver=self.neo4j_driver, @@ -156,6 +162,10 @@ def _init_tools(self): ) tools.append(find_file_node_with_relative_path_tool) + # === AST NODE SEARCH TOOLS === + + # Tool: Find AST node by text match in file (by basename) + # Useful for searching specific snippets or patterns in unknown locations find_ast_node_with_text_in_file_with_basename_fn = functools.partial( graph_traversal.find_ast_node_with_text_in_file_with_basename, driver=self.neo4j_driver, @@ -170,6 +180,7 @@ def _init_tools(self): ) tools.append(find_ast_node_with_text_in_file_with_basename_tool) + # Tool: Find AST node by text match in file (by relative path) find_ast_node_with_text_in_file_with_relative_path_fn = functools.partial( graph_traversal.find_ast_node_with_text_in_file_with_relative_path, driver=self.neo4j_driver, @@ -184,6 +195,8 @@ def _init_tools(self): ) tools.append(find_ast_node_with_text_in_file_with_relative_path_tool) + # Tool: Find AST node by type in file (by basename) + # Example types: FunctionDef, ClassDef, Assign, etc. find_ast_node_with_type_in_file_with_basename_fn = functools.partial( graph_traversal.find_ast_node_with_type_in_file_with_basename, driver=self.neo4j_driver, @@ -198,6 +211,7 @@ def _init_tools(self): ) tools.append(find_ast_node_with_type_in_file_with_basename_tool) + # Tool: Find AST node by type in file (by relative path) find_ast_node_with_type_in_file_with_relative_path_fn = functools.partial( graph_traversal.find_ast_node_with_type_in_file_with_relative_path, driver=self.neo4j_driver, @@ -212,6 +226,9 @@ def _init_tools(self): ) tools.append(find_ast_node_with_type_in_file_with_relative_path_tool) + # === TEXT/DOCUMENT SEARCH TOOLS === + + # Tool: Find text node globally by keyword find_text_node_with_text_fn = functools.partial( graph_traversal.find_text_node_with_text, driver=self.neo4j_driver, @@ -226,6 +243,7 @@ def _init_tools(self): ) tools.append(find_text_node_with_text_tool) + # Tool: Find text node by keyword in specific file find_text_node_with_text_in_file_fn = functools.partial( graph_traversal.find_text_node_with_text_in_file, driver=self.neo4j_driver, @@ -240,6 +258,7 @@ def _init_tools(self): ) tools.append(find_text_node_with_text_in_file_tool) + # Tool: Fetch the next text node chunk in a chain (used for long docs/comments) get_next_text_node_with_node_id_fn = functools.partial( graph_traversal.get_next_text_node_with_node_id, driver=self.neo4j_driver, @@ -254,6 +273,9 @@ def _init_tools(self): ) tools.append(get_next_text_node_with_node_id_tool) + # === FILE PREVIEW & READING TOOLS === + + # Tool: Preview contents of file by basename preview_file_content_with_basename_fn = functools.partial( graph_traversal.preview_file_content_with_basename, driver=self.neo4j_driver, @@ -268,6 +290,7 @@ def _init_tools(self): ) tools.append(preview_file_content_with_basename_tool) + # Tool: Preview contents of file by relative path preview_file_content_with_relative_path_fn = functools.partial( graph_traversal.preview_file_content_with_relative_path, driver=self.neo4j_driver, @@ -282,6 +305,7 @@ def _init_tools(self): ) tools.append(preview_file_content_with_relative_path_tool) + # Tool: Read entire code file by basename read_code_with_basename_fn = functools.partial( graph_traversal.read_code_with_basename, driver=self.neo4j_driver, @@ -296,6 +320,7 @@ def _init_tools(self): ) tools.append(read_code_with_basename_tool) + # Tool: Read entire code file by relative path read_code_with_relative_path_fn = functools.partial( graph_traversal.read_code_with_relative_path, driver=self.neo4j_driver, @@ -316,13 +341,15 @@ def __call__(self, state: Dict): """Processes the current state and traverse the knowledge graph to retrieve context. Args: - state: Current state containing the human query and preivous context_messages. + state: Current state containing the human query and previous context_messages. Returns: Dictionary that will update the state with the model's response messages. """ + self._logger.debug(f"Context provider messages: {state['context_provider_messages']}") message_history = [self.system_prompt] + state["context_provider_messages"] truncated_message_history = truncate_messages(message_history) response = self.model_with_tools.invoke(truncated_message_history) self._logger.debug(response) + # The response will be added to the bottom of the list return {"context_provider_messages": [response]} diff --git a/prometheus/lang_graph/nodes/context_query_message_node.py b/prometheus/lang_graph/nodes/context_query_message_node.py index d77eec34..ba43a5b1 100644 --- a/prometheus/lang_graph/nodes/context_query_message_node.py +++ b/prometheus/lang_graph/nodes/context_query_message_node.py @@ -12,4 +12,5 @@ def __init__(self): def __call__(self, state: ContextRetrievalState): human_message = HumanMessage(state["query"]) self._logger.debug(f"Sending query to ContextProviderNode:\n{human_message}") + # The message will be added to the end of the context provider messages return {"context_provider_messages": [human_message]} diff --git a/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py b/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py index 847ddd4d..5c2f27b0 100644 --- a/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py +++ b/prometheus/lang_graph/nodes/context_retrieval_subgraph_node.py @@ -21,7 +21,7 @@ def __init__( self._logger = logging.getLogger( "prometheus.lang_graph.nodes.context_retrieval_subgraph_node" ) - self.context_retrevial_subgraph = ContextRetrievalSubgraph( + self.context_retrieval_subgraph = ContextRetrievalSubgraph( model=model, kg=kg, neo4j_driver=neo4j_driver, @@ -32,8 +32,9 @@ def __init__( def __call__(self, state: Dict): self._logger.info("Enter context retrieval subgraph") - output_state = self.context_retrevial_subgraph.invoke( + output_state = self.context_retrieval_subgraph.invoke( state[self.query_key_name], state["max_refined_query_loop"] ) + self._logger.info(f"Context retrieved: {output_state['context']}") return {self.context_key_name: output_state["context"]} diff --git a/prometheus/lang_graph/nodes/context_selection_node.py b/prometheus/lang_graph/nodes/context_selection_node.py index 815c8d9e..d7f0920d 100644 --- a/prometheus/lang_graph/nodes/context_selection_node.py +++ b/prometheus/lang_graph/nodes/context_selection_node.py @@ -96,6 +96,7 @@ def format_human_prompt(self, state: ContextRetrievalState, search_result: str) return context_info def __call__(self, state: ContextRetrievalState): + self._logger.info("Starting context selection process") context_list = state.get("context", []) for tool_message in extract_last_tool_messages(state["context_provider_messages"]): for search_result in neo4j_data_for_context_generator(tool_message.artifact): diff --git a/prometheus/lang_graph/nodes/issue_bug_context_message_node.py b/prometheus/lang_graph/nodes/issue_bug_context_message_node.py index 238d84fb..acf6ccd8 100644 --- a/prometheus/lang_graph/nodes/issue_bug_context_message_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_context_message_node.py @@ -5,8 +5,7 @@ class IssueBugContextMessageNode: - BUG_FIX_QUERY = ( - """\ + BUG_FIX_QUERY = """\ {issue_info} Find all relevant source code context and documentation needed to understand and fix this issue. @@ -17,10 +16,7 @@ class IssueBugContextMessageNode: 4. Follow imports to find dependent code that directly impacts the issue Skip any test files -""".replace("{", "{{") - .replace("}", "}}") - .replace("{{issue_info}}", "{issue_info}") - ) +""" def __init__(self): self._logger = logging.getLogger( diff --git a/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py index 49ee14cf..89415443 100644 --- a/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_bug_subgraph_node.py @@ -13,6 +13,10 @@ class IssueBugSubgraphNode: + """ + A LangGraph node that handles the issue bug subgraph, which is responsible for solving bugs in a GitHub issue. + """ + def __init__( self, advanced_model: BaseChatModel, @@ -40,6 +44,7 @@ def __init__( ) def __call__(self, state: IssueState): + # Ensure the container is built and started self.container.build_docker_image() self.container.start_container() @@ -70,11 +75,11 @@ def __call__(self, state: IssueState): except GraphRecursionError: self._logger.critical("Please increase the recursion limit of IssueBugSubgraph") return { - "edit_patch": "", + "edit_patch": None, "passed_reproducing_test": False, "passed_build": False, "passed_existing_test": False, - "issue_response": "", + "issue_response": None, } finally: self.container.cleanup() diff --git a/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py b/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py index ead733dd..43ace6f2 100644 --- a/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py +++ b/prometheus/lang_graph/nodes/issue_verified_bug_subgraph_node.py @@ -12,6 +12,10 @@ class IssueVerifiedBugSubgraphNode: + """ + A LangGraph node that handles the verified issue bug, which is responsible for solving bugs. + """ + def __init__( self, advanced_model: BaseChatModel, @@ -56,14 +60,16 @@ def __call__(self, state: Dict): self._logger.info("Recursion limit reached") self.git_repo.reset_repository() return { - "edit_patch": "", + "edit_patch": None, "passed_reproducing_test": False, "passed_build": False, "passed_existing_test": False, } - + # if all the tests passed passed_reproducing_test = not bool(output_state["reproducing_test_fail_log"]) + # if the build passed passed_build = state["run_build"] and not output_state["build_fail_log"] + # if the existing tests passed passed_existing_test = ( state["run_existing_test"] and not output_state["existing_test_fail_log"] ) diff --git a/prometheus/lang_graph/subgraphs/context_retrieval_subgraph.py b/prometheus/lang_graph/subgraphs/context_retrieval_subgraph.py index a44c07b3..20ddecb8 100644 --- a/prometheus/lang_graph/subgraphs/context_retrieval_subgraph.py +++ b/prometheus/lang_graph/subgraphs/context_retrieval_subgraph.py @@ -16,6 +16,26 @@ class ContextRetrievalSubgraph: + """ + A LangGraph-based subgraph for retrieving relevant contextual information + (e.g., code, documentation, definitions) from a knowledge graph based on a query. + + This subgraph performs an iterative retrieval process: + 1. Constructs a context query message from the user prompt + 2. Uses tool-based retrieval (Neo4j-backed) to gather candidate context snippets + 3. Selects relevant context with LLM assistance + 4. Optionally refines the query and retries if necessary + 5. Outputs the final selected context + + Nodes: + - ContextQueryMessageNode: Converts user query to internal query prompt + - ContextProviderNode: Queries knowledge graph using structured tools + - ToolNode: Dynamically invokes retrieval tools based on tool condition + - ContextSelectionNode: Uses LLM to select useful context snippets + - ResetMessagesNode: Clears previous context messages + - ContextRefineNode: Decides whether to refine the query and retry + """ + def __init__( self, model: BaseChatModel, @@ -23,21 +43,44 @@ def __init__( neo4j_driver: neo4j.Driver, max_token_per_neo4j_result: int, ): + """ + Initializes the context retrieval subgraph. + + Args: + model (BaseChatModel): The LLM used for context selection and refinement. + kg (KnowledgeGraph): The graph-based semantic index for code/docs retrieval. + neo4j_driver (neo4j.Driver): Driver for executing Cypher queries in Neo4j. + max_token_per_neo4j_result (int): Token limit for responses from graph tools. + """ + # Step 1: Generate an initial query from the user's input context_query_message_node = ContextQueryMessageNode() + + # Step 2: Provide candidate context snippets using knowledge graph tools context_provider_node = ContextProviderNode( model, kg, neo4j_driver, max_token_per_neo4j_result ) + + # Step 3: Add tool node to handle tool-based retrieval invocation dynamically + # The tool message will be added to the end of the context provider messages context_provider_tools = ToolNode( tools=context_provider_node.tools, name="context_provider_tools", messages_key="context_provider_messages", ) + + # Step 4: Select relevant context snippets from the candidates context_selection_node = ContextSelectionNode(model) + + # Step 5: Reset tool messages to prepare for the next iteration (if needed) reset_context_provider_messages_node = ResetMessagesNode("context_provider_messages") + + # Step 6: Refine the query if needed and loop back context_refine_node = ContextRefineNode(model, kg) + # Construct the LangGraph workflow workflow = StateGraph(ContextRetrievalState) + # Add all nodes to the graph workflow.add_node("context_query_message_node", context_query_message_node) workflow.add_node("context_provider_node", context_provider_node) workflow.add_node("context_provider_tools", context_provider_tools) @@ -47,8 +90,12 @@ def __init__( ) workflow.add_node("context_refine_node", context_refine_node) + # Set the entry point for the workflow workflow.set_entry_point("context_query_message_node") + # Define edges between nodes workflow.add_edge("context_query_message_node", "context_provider_node") + + # Conditional: Use tool node if tools_condition is satisfied workflow.add_conditional_edges( "context_provider_node", functools.partial(tools_condition, messages_key="context_provider_messages"), @@ -57,20 +104,39 @@ def __init__( workflow.add_edge("context_provider_tools", "context_provider_node") workflow.add_edge("context_selection_node", "reset_context_provider_messages_node") workflow.add_edge("reset_context_provider_messages_node", "context_refine_node") + + # If refined_query is non-empty, loop back to provider; else terminate workflow.add_conditional_edges( "context_refine_node", lambda state: bool(state["refined_query"]), {True: "context_provider_node", False: END}, ) + # Compile and store the subgraph self.subgraph = workflow.compile() def invoke( self, query: str, max_refined_query_loop: int, recursion_limit: int = 999 ) -> Sequence[str]: + """ + Executes the context retrieval subgraph given an initial query. + + Args: + query (str): The natural language query representing the information need. + max_refined_query_loop (int): Maximum number of times the system can refine and retry the query. + recursion_limit (int, optional): Global recursion limit for LangGraph. Default is 999. + + Returns: + Dict with a single key: + - "context" (Sequence[str]): A list of selected context snippets relevant to the query. + """ config = {"recursion_limit": recursion_limit} - input_state = {"query": query, "max_refined_query_loop": max_refined_query_loop} + input_state = { + "query": query, + "max_refined_query_loop": max_refined_query_loop, + } output_state = self.subgraph.invoke(input_state, config) + return {"context": output_state["context"]} diff --git a/prometheus/lang_graph/subgraphs/issue_bug_subgraph.py b/prometheus/lang_graph/subgraphs/issue_bug_subgraph.py index 086cd8af..7b2dd83d 100644 --- a/prometheus/lang_graph/subgraphs/issue_bug_subgraph.py +++ b/prometheus/lang_graph/subgraphs/issue_bug_subgraph.py @@ -31,6 +31,7 @@ def __init__( build_commands: Optional[Sequence[str]] = None, test_commands: Optional[Sequence[str]] = None, ): + # Construct bug reproduction node bug_reproduction_subgraph_node = BugReproductionSubgraphNode( advanced_model=advanced_model, base_model=base_model, @@ -41,7 +42,7 @@ def __init__( max_token_per_neo4j_result=max_token_per_neo4j_result, test_commands=test_commands, ) - + # Construct issue bug verified subgraph nodes issue_verified_bug_subgraph_node = IssueVerifiedBugSubgraphNode( advanced_model=advanced_model, base_model=base_model, @@ -53,6 +54,7 @@ def __init__( build_commands=build_commands, test_commands=test_commands, ) + # Construct issue not verified bug subgraph node issue_not_verified_bug_subgraph_node = IssueNotVerifiedBugSubgraphNode( advanced_model=advanced_model, base_model=base_model, @@ -61,11 +63,12 @@ def __init__( neo4j_driver=neo4j_driver, max_token_per_neo4j_result=max_token_per_neo4j_result, ) - + # Construct issue bug responder node issue_bug_responder_node = IssueBugResponderNode(base_model) + # Create the state graph for the issue bug subgraph workflow = StateGraph(IssueBugState) - + # Add nodes to the workflow workflow.add_node("bug_reproduction_subgraph_node", bug_reproduction_subgraph_node) workflow.add_node("issue_verified_bug_subgraph_node", issue_verified_bug_subgraph_node) @@ -74,9 +77,9 @@ def __init__( ) workflow.add_node("issue_bug_responder_node", issue_bug_responder_node) - + # Set the entry point for the workflow workflow.set_entry_point("bug_reproduction_subgraph_node") - + # Go to verified bug subgraph if the bug is verified, otherwise go to not verified bug subgraph workflow.add_conditional_edges( "bug_reproduction_subgraph_node", lambda state: state["reproduced_bug"] @@ -87,11 +90,13 @@ def __init__( False: "issue_not_verified_bug_subgraph_node", }, ) + # Go to issue bug responder node if the bug is solved, otherwise go to not verified bug subgraph workflow.add_conditional_edges( "issue_verified_bug_subgraph_node", lambda state: bool(state["edit_patch"]), {True: "issue_bug_responder_node", False: "issue_not_verified_bug_subgraph_node"}, ) + # Add edges for the issue bug responder node workflow.add_edge("issue_not_verified_bug_subgraph_node", "issue_bug_responder_node") workflow.add_edge("issue_bug_responder_node", END) diff --git a/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py b/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py index 8f777848..e151116a 100644 --- a/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py +++ b/prometheus/lang_graph/subgraphs/issue_verified_bug_subgraph.py @@ -26,6 +26,22 @@ class IssueVerifiedBugSubgraph: + """ + A LangGraph-based subgraph that handles verified bug issues by generating, + applying, and validating patch candidates. + + This subgraph executes the following phases: + 1. Context construction and retrieval from knowledge graph and codebase + 2. Semantic analysis of the bug using advanced LLM + 3. Patch generation via LLM and optional tool invocations + 4. Patch application with Git diff visualization + 5. Build and test the modified code in a containerized environment + 6. Iterative refinement if verification fails + + Attributes: + subgraph (StateGraph): The compiled LangGraph workflow to handle verified bugs. + """ + def __init__( self, advanced_model: BaseChatModel, @@ -38,6 +54,22 @@ def __init__( build_commands: Optional[Sequence[str]] = None, test_commands: Optional[Sequence[str]] = None, ): + """ + Initialize the verified bug fix subgraph. + + Args: + advanced_model (BaseChatModel): A strong LLM used for bug understanding and patch generation. + base_model (BaseChatModel): A smaller, less expensive LLM used for context retrieval and test verification. + container (BaseContainer): A build/test container to run code validations. + kg (KnowledgeGraph): A knowledge graph used for context-aware retrieval of relevant code entities. + git_repo (GitRepository): Git interface to apply patches and get diffs. + neo4j_driver (neo4j.Driver): Neo4j driver for executing graph-based semantic queries. + max_token_per_neo4j_result (int): Maximum tokens to limit output from Neo4j query results. + build_commands (Optional[Sequence[str]]): Commands to build the project inside the container. + test_commands (Optional[Sequence[str]]): Commands to test the project inside the container. + """ + + # Phase 1: Retrieve context related to the bug issue_bug_context_message_node = IssueBugContextMessageNode() context_retrieval_subgraph_node = ContextRetrievalSubgraphNode( model=base_model, @@ -48,9 +80,11 @@ def __init__( context_key_name="bug_fix_context", ) + # Phase 2: Analyze the bug and generate hypotheses issue_bug_analyzer_message_node = IssueBugAnalyzerMessageNode() issue_bug_analyzer_node = IssueBugAnalyzerNode(advanced_model) + # Phase 3: Generate code edits and optionally apply toolchains edit_message_node = EditMessageNode() edit_node = EditNode(advanced_model, kg) edit_tools = ToolNode( @@ -58,13 +92,18 @@ def __init__( name="edit_tools", messages_key="edit_messages", ) + + # Phase 4: Apply patch, diff changes, and update the container git_diff_node = GitDiffNode(git_repo, "edit_patch", "reproduced_bug_file") update_container_node = UpdateContainerNode(container, git_repo) + # Phase 5: Re-run test case that reproduces the bug bug_fix_verification_subgraph_node = BugFixVerificationSubgraphNode( base_model, container, ) + + # Phase 6: Optionally run full build and test after fix build_or_test_branch_node = NoopNode() build_and_test_subgraph_node = BuildAndTestSubgraphNode( container, @@ -74,8 +113,10 @@ def __init__( test_commands, ) + # Build the LangGraph workflow workflow = StateGraph(IssueVerifiedBugState) + # Add nodes to graph workflow.add_node("issue_bug_context_message_node", issue_bug_context_message_node) workflow.add_node("context_retrieval_subgraph_node", context_retrieval_subgraph_node) @@ -92,39 +133,47 @@ def __init__( workflow.add_node("build_or_test_branch_node", build_or_test_branch_node) workflow.add_node("build_and_test_subgraph_node", build_and_test_subgraph_node) + # Define edges for full flow workflow.set_entry_point("issue_bug_context_message_node") workflow.add_edge("issue_bug_context_message_node", "context_retrieval_subgraph_node") workflow.add_edge("context_retrieval_subgraph_node", "issue_bug_analyzer_message_node") - workflow.add_edge("issue_bug_analyzer_message_node", "issue_bug_analyzer_node") workflow.add_edge("issue_bug_analyzer_node", "edit_message_node") - workflow.add_edge("edit_message_node", "edit_node") + + # Conditionally invoke tools or continue to diffing workflow.add_conditional_edges( "edit_node", functools.partial(tools_condition, messages_key="edit_messages"), {"tools": "edit_tools", END: "git_diff_node"}, ) + workflow.add_edge("edit_tools", "edit_node") workflow.add_edge("git_diff_node", "update_container_node") workflow.add_edge("update_container_node", "bug_fix_verification_subgraph_node") + # If test still fails, loop back to reanalyze the bug workflow.add_conditional_edges( "bug_fix_verification_subgraph_node", lambda state: bool(state["reproducing_test_fail_log"]), {True: "issue_bug_analyzer_message_node", False: "build_or_test_branch_node"}, ) + + # Optionally run full build/test suite workflow.add_conditional_edges( "build_or_test_branch_node", lambda state: state["run_build"] or state["run_existing_test"], {True: "build_and_test_subgraph_node", False: END}, ) + + # If build/test fail, go back to reanalyze and patch workflow.add_conditional_edges( "build_and_test_subgraph_node", lambda state: bool(state["build_fail_log"]) or bool(state["existing_test_fail_log"]), {True: "issue_bug_analyzer_message_node", False: END}, ) + # Compile and assign the subgraph self.subgraph = workflow.compile() def invoke( diff --git a/prometheus/utils/lang_graph_util.py b/prometheus/utils/lang_graph_util.py index b5b4192c..21cb306f 100644 --- a/prometheus/utils/lang_graph_util.py +++ b/prometheus/utils/lang_graph_util.py @@ -1,6 +1,7 @@ from typing import Callable, Dict, Sequence import tiktoken +from langchain_core.language_models import BaseChatModel from langchain_core.messages import ( AIMessage, BaseMessage, @@ -11,6 +12,8 @@ ) from langchain_core.output_parsers import StrOutputParser +from prometheus.configuration.config import settings + def check_remaining_steps( state: Dict, @@ -57,9 +60,32 @@ def tiktoken_counter(messages: Sequence[BaseMessage]) -> int: return num_tokens +def compress_messages( + model: BaseChatModel, messages: Sequence[BaseMessage], max_tokens: int = settings.MAX_TOKENS +) -> Sequence[BaseMessage]: + """Compress messages if it exceeds the max token limit.""" + if tiktoken_counter(messages) <= max_tokens: + return messages + prompt = """ + You are a helpful assistant for software engineering. + Your task is to compress the following conversation messages while preserving their meaning as most as you can. + All these messages are context information for a software engineering issues, which including debugging, + feature, document and question. + The compressed messages should be concise and clear. + Avoid adding any new information or changing the meaning of the original messages. + Any redundant or irrelevant information should be removed. + """ + filtered_messages = [msg for msg in messages if not isinstance(msg, SystemMessage)] + system_messages = [msg for msg in messages if isinstance(msg, SystemMessage)] + messages = [SystemMessage(content=prompt)] + filtered_messages + response = model.invoke(messages) + return system_messages + [response] + + def truncate_messages( - messages: Sequence[BaseMessage], max_tokens: int = 100000 + messages: Sequence[BaseMessage], max_tokens: int = settings.MAX_TOKENS ) -> Sequence[BaseMessage]: + """TODO: Instead of truncating, we should use a better strategy to keep the most relevant messages.""" return trim_messages( messages, token_counter=tiktoken_counter,