diff --git a/prometheus/graph/knowledge_graph.py b/prometheus/graph/knowledge_graph.py index df612e76..a2fd5dec 100644 --- a/prometheus/graph/knowledge_graph.py +++ b/prometheus/graph/knowledge_graph.py @@ -1,13 +1,24 @@ from collections import deque import logging from pathlib import Path +from typing import Sequence from prometheus.graph.file_graph_builder import FileGraphBuilder from prometheus.graph.graph_types import ( + ASTNode, FileNode, KnowledgeGraphEdge, KnowledgeGraphEdgeType, KnowledgeGraphNode, + Neo4jASTNode, + Neo4jFileNode, + Neo4jHasASTEdge, + Neo4jHasFileEdge, + Neo4jHasTextEdge, + Neo4jNextChunkEdge, + Neo4jParentOfEdge, + Neo4jTextNode, + TextNode, ) @@ -64,3 +75,83 @@ def _build_graph(self, root_dir: Path): continue self._logger.info(f"Skip parsing {file} because it is not supported.") + + def get_file_nodes(self) -> Sequence[KnowledgeGraphNode]: + return [ + kg_node + for kg_node in self._knowledge_graph_nodes + if isinstance(kg_node.node, FileNode) + ] + + def get_ast_nodes(self) -> Sequence[KnowledgeGraphNode]: + return [ + kg_node + for kg_node in self._knowledge_graph_nodes + if isinstance(kg_node.node, ASTNode) + ] + + def get_text_nodes(self) -> Sequence[KnowledgeGraphNode]: + return [ + kg_node + for kg_node in self._knowledge_graph_nodes + if isinstance(kg_node.node, TextNode) + ] + + def get_has_ast_edges(self) -> Sequence[KnowledgeGraphEdge]: + return [ + kg_edge + for kg_edge in self._knowledge_graph_edges + if kg_edge.type == KnowledgeGraphEdgeType.has_ast + ] + + def get_has_file_edges(self) -> Sequence[KnowledgeGraphEdge]: + return [ + kg_edge + for kg_edge in self._knowledge_graph_edges + if kg_edge.type == KnowledgeGraphEdgeType.has_file + ] + + def get_has_text_edges(self) -> Sequence[KnowledgeGraphEdge]: + return [ + kg_edge + for kg_edge in self._knowledge_graph_edges + if kg_edge.type == KnowledgeGraphEdgeType.has_text + ] + + def get_next_chunk_edges(self) -> Sequence[KnowledgeGraphEdge]: + return [ + kg_edge + for kg_edge in self._knowledge_graph_edges + if kg_edge.type == KnowledgeGraphEdgeType.next_chunk + ] + + def get_parent_of_edges(self) -> Sequence[KnowledgeGraphEdge]: + return [ + kg_edge + for kg_edge in self._knowledge_graph_edges + if kg_edge.type == KnowledgeGraphEdgeType.parent_of + ] + + def get_neo4j_file_nodes(self) -> Sequence[Neo4jFileNode]: + return [kg_node.to_neo4j_node() for kg_node in self.get_file_nodes()] + + def get_neo4j_ast_nodes(self) -> Sequence[Neo4jASTNode]: + return [kg_node.to_neo4j_node() for kg_node in self.get_ast_nodes()] + + def get_neo4j_text_nodes(self) -> Sequence[Neo4jTextNode]: + return [kg_node.to_neo4j_node() for kg_node in self.get_text_nodes()] + + def get_neo4j_has_ast_edges(self) -> Sequence[Neo4jHasASTEdge]: + return [kg_edge.to_neo4j_edge() for kg_edge in self.get_has_ast_edges()] + + def get_neo4j_has_file_edges(self) -> Sequence[Neo4jHasFileEdge]: + return [kg_edge.to_neo4j_edge() for kg_edge in self.get_has_file_edges()] + + def get_neo4j_has_text_edges(self) -> Sequence[Neo4jHasTextEdge]: + return [kg_edge.to_neo4j_edge() for kg_edge in self.get_has_text_edges()] + + def get_neo4j_next_chunk_edges(self) -> Sequence[Neo4jNextChunkEdge]: + return [kg_edge.to_neo4j_edge() for kg_edge in self.get_next_chunk_edges()] + + def get_neo4j_parent_of_edges(self) -> Sequence[Neo4jParentOfEdge]: + return [kg_edge.to_neo4j_edge() for kg_edge in self.get_parent_of_edges()] diff --git a/prometheus/neo4j/__init__.py b/prometheus/neo4j/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/prometheus/neo4j/handler.py b/prometheus/neo4j/handler.py new file mode 100644 index 00000000..f281b855 --- /dev/null +++ b/prometheus/neo4j/handler.py @@ -0,0 +1,152 @@ +from typing import Sequence +from neo4j import GraphDatabase, ManagedTransaction + +from prometheus.graph.graph_types import ( + Neo4jASTNode, + Neo4jFileNode, + Neo4jHasASTEdge, + Neo4jHasFileEdge, + Neo4jHasTextEdge, + Neo4jNextChunkEdge, + Neo4jParentOfEdge, + Neo4jTextNode, +) +from prometheus.graph.knowledge_graph import KnowledgeGraph + + +class Handler: + """The handler to writing the Knowledge graph to neo4j.""" + + def __init__( + self, uri: str, user: str, password: str, database: str, batch_size: int + ): + self.driver = GraphDatabase.driver(uri, auth=(user, password)) + self.database = database + self.batch_size = batch_size + + def _init_database(self, tx: ManagedTransaction): + queries = [ + "CREATE CONSTRAINT unique_file_node_id IF NOT EXISTS " + "FOR (n:FileNode) REQUIRE n.node_id IS UNIQUE", + "CREATE CONSTRAINT unique_ast_node_id IF NOT EXISTS " + "FOR (n:ASTNode) REQUIRE n.node_id IS UNIQUE", + "CREATE CONSTRAINT unique_text_node_id IF NOT EXISTS " + "FOR (n:TextNode) REQUIRE n.node_id IS UNIQUE", + ] + for query in queries: + tx.run(query) + + def _write_file_nodes( + self, tx: ManagedTransaction, file_nodes: Sequence[Neo4jFileNode] + ): + query = """ + UNWIND $file_nodes AS file_node + CREATE (a:FileNode {node_id: file_node.node_id, basename: file_node.basename, relative_path: file_node.relative_path}) + """ + for i in range(0, len(file_nodes), self.batch_size): + file_nodes_batch = file_nodes[i : i + self.batch_size] + tx.run(query, file_nodes=file_nodes_batch) + + def _write_ast_nodes(self, tx: ManagedTransaction, ast_nodes: Sequence[Neo4jASTNode]): + query = """ + UNWIND $ast_nodes AS ast_node + CREATE (a:ASTNode {node_id: ast_node.node_id, start_line: ast_node.start_line, end_line: ast_node.end_line, type: ast_node.type, text: ast_node.text}) + """ + for i in range(0, len(ast_nodes), self.batch_size): + ast_nodes_batch = ast_nodes[i : i + self.batch_size] + tx.run(query, ast_nodes=ast_nodes_batch) + + def _write_text_nodes( + self, tx: ManagedTransaction, text_nodes: Sequence[Neo4jTextNode] + ): + query = """ + UNWIND $text_nodes AS text_node + CREATE (a:TextNode {node_id: text_node.node_id, text: text_node.text, metadata: text_node.metadata}) + """ + for i in range(0, len(text_nodes), self.batch_size): + text_nodes_batch = text_nodes[i : i + self.batch_size] + tx.run(query, text_nodes=text_nodes_batch) + + def _write_has_file_edges( + self, tx: ManagedTransaction, has_file_edges: Sequence[Neo4jHasFileEdge] + ): + query = """ + UNWIND $edges AS edge + MATCH (source:FileNode), (target:FileNode) + WHERE source.node_id = edge.source.node_id AND target.node_id = edge.target.node_id + CREATE (source) -[:HAS_FILE]-> (target) + """ + for i in range(0, len(has_file_edges), self.batch_size): + has_file_edges_batch = has_file_edges[i : i + self.batch_size] + tx.run(query, edges=has_file_edges_batch) + + def _write_has_ast_edges( + self, tx: ManagedTransaction, has_ast_edges: Sequence[Neo4jHasASTEdge] + ): + query = """ + UNWIND $edges AS edge + MATCH (source:FileNode), (target:ASTNode) + WHERE source.node_id = edge.source.node_id AND target.node_id = edge.target.node_id + CREATE (source) -[:HAS_AST]-> (target) + """ + for i in range(0, len(has_ast_edges), self.batch_size): + has_ast_edges_batch = has_ast_edges[i : i + self.batch_size] + tx.run(query, edges=has_ast_edges_batch) + + def _write_has_text_edges( + self, tx: ManagedTransaction, has_text_edges: Sequence[Neo4jHasTextEdge] + ): + query = """ + UNWIND $edges AS edge + MATCH (source:FileNode), (target:TextNode) + WHERE source.node_id = edge.source.node_id AND target.node_id = edge.target.node_id + CREATE (source) -[:HAS_TEXT]-> (target) + """ + for i in range(0, len(has_text_edges), self.batch_size): + has_text_edges_batch = has_text_edges[i : i + self.batch_size] + tx.run(query, edges=has_text_edges_batch) + + def _write_parent_of_edges( + self, tx: ManagedTransaction, parent_of_edges: Sequence[Neo4jParentOfEdge] + ): + query = """ + UNWIND $edges AS edge + MATCH (source:ASTNode), (target:ASTNode) + WHERE source.node_id = edge.source.node_id AND target.node_id = edge.target.node_id + CREATE (source) -[:PARENT_OF]-> (target) + """ + for i in range(0, len(parent_of_edges), self.batch_size): + parent_of_edges_batch = parent_of_edges[i : i + self.batch_size] + tx.run(query, edges=parent_of_edges_batch) + + def _write_next_chunk_edges( + self, tx: ManagedTransaction, next_chunk_edges: Sequence[Neo4jNextChunkEdge] + ): + query = """ + UNWIND $edges AS edge + MATCH (source:TextNode), (target:TextNode) + WHERE source.node_id = edge.source.node_id AND target.node_id = edge.target.node_id + CREATE (source) -[:NEXT_CHUNK]-> (target) + """ + for i in range(0, len(next_chunk_edges), self.batch_size): + next_chunk_edges_batch = next_chunk_edges[i : i + self.batch_size] + tx.run(query, edges=next_chunk_edges_batch) + + def write_knowledge_graph(self, kg: KnowledgeGraph): + with self.driver.session() as session: + session.execute_write(self._init_database) + + session.execute_write(self._write_file_nodes, kg.get_neo4j_file_nodes()) + session.execute_write(self._write_ast_nodes, kg.get_neo4j_ast_nodes()) + session.execute_write(self._write_text_nodes, kg.get_neo4j_text_nodes()) + + session.execute_write(self._write_has_ast_edges, kg.get_neo4j_has_ast_edges()) + session.execute_write(self._write_has_file_edges, kg.get_neo4j_has_file_edges()) + session.execute_write(self._write_has_text_edges, kg.get_neo4j_has_text_edges()) + session.execute_write( + self._write_next_chunk_edges, kg.get_neo4j_next_chunk_edges() + ) + session.execute_write(self._write_parent_of_edges, kg.get_neo4j_parent_of_edges()) + + def close(self): + self.driver.close() diff --git a/pyproject.toml b/pyproject.toml index e14f0ef2..64574a6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ test = [ "ruff==0.6.9", "pytest==8.3.3", "pytest-cov==5.0.0", + "testcontainers==4.8.2", ] [tool.ruff] diff --git a/tests/graph/test_knowledge_graph.py b/tests/graph/test_knowledge_graph.py index 0f010682..86db20d0 100644 --- a/tests/graph/test_knowledge_graph.py +++ b/tests/graph/test_knowledge_graph.py @@ -1,9 +1,3 @@ -from prometheus.graph.graph_types import ( - ASTNode, - FileNode, - KnowledgeGraphEdgeType, - TextNode, -) from prometheus.graph.knowledge_graph import KnowledgeGraph from tests.test_utils import test_project_paths @@ -18,58 +12,11 @@ def test_build_graph(): assert len(knowledge_graph._knowledge_graph_nodes) == 97 assert len(knowledge_graph._knowledge_graph_edges) == 99 - file_nodes = [ - kg_node - for kg_node in knowledge_graph._knowledge_graph_nodes - if isinstance(kg_node.node, FileNode) - ] - assert len(file_nodes) == 8 - - ast_nodes = [ - kg_node - for kg_node in knowledge_graph._knowledge_graph_nodes - if isinstance(kg_node.node, ASTNode) - ] - assert len(ast_nodes) == 85 - - text_nodes = [ - kg_node - for kg_node in knowledge_graph._knowledge_graph_nodes - if isinstance(kg_node.node, TextNode) - ] - assert len(text_nodes) == 4 - - parent_of_edges = [ - kg_edge - for kg_edge in knowledge_graph._knowledge_graph_edges - if kg_edge.type == KnowledgeGraphEdgeType.parent_of - ] - assert len(parent_of_edges) == 82 - - has_file_edges = [ - kg_edge - for kg_edge in knowledge_graph._knowledge_graph_edges - if kg_edge.type == KnowledgeGraphEdgeType.has_file - ] - assert len(has_file_edges) == 7 - - has_ast_edges = [ - kg_edge - for kg_edge in knowledge_graph._knowledge_graph_edges - if kg_edge.type == KnowledgeGraphEdgeType.has_ast - ] - assert len(has_ast_edges) == 3 - - has_text_edges = [ - kg_edge - for kg_edge in knowledge_graph._knowledge_graph_edges - if kg_edge.type == KnowledgeGraphEdgeType.has_text - ] - assert len(has_text_edges) == 4 - - next_chunk_edges = [ - kg_edge - for kg_edge in knowledge_graph._knowledge_graph_edges - if kg_edge.type == KnowledgeGraphEdgeType.next_chunk - ] - assert len(next_chunk_edges) == 3 + assert len(knowledge_graph.get_file_nodes()) == 8 + assert len(knowledge_graph.get_ast_nodes()) == 85 + assert len(knowledge_graph.get_text_nodes()) == 4 + assert len(knowledge_graph.get_parent_of_edges()) == 82 + assert len(knowledge_graph.get_has_file_edges()) == 7 + assert len(knowledge_graph.get_has_ast_edges()) == 3 + assert len(knowledge_graph.get_has_text_edges()) == 4 + assert len(knowledge_graph.get_next_chunk_edges()) == 3 diff --git a/tests/neo4j/__init__.py b/tests/neo4j/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/neo4j/test_handler.py b/tests/neo4j/test_handler.py new file mode 100644 index 00000000..8a0b3bea --- /dev/null +++ b/tests/neo4j/test_handler.py @@ -0,0 +1,169 @@ +from neo4j import GraphDatabase +import pytest + +from testcontainers.neo4j import Neo4jContainer + +from prometheus.graph.knowledge_graph import KnowledgeGraph +from prometheus.neo4j.handler import Handler +from tests.test_utils import test_project_paths + +NEO4J_IMAGE = "neo4j:5.20.0" +NEO4J_USERNAME = "neo4j" +NEO4J_PASSWORD = "password" + + +@pytest.fixture(scope="session") +def setup_neo4j_container(): + kg = KnowledgeGraph(test_project_paths.TEST_PROJECT_PATH, 1000) + with Neo4jContainer( + image=NEO4J_IMAGE, username=NEO4J_USERNAME, password=NEO4J_PASSWORD + ) as neo4j_container: + uri = neo4j_container.get_connection_url() + handler = Handler(uri, NEO4J_USERNAME, NEO4J_PASSWORD, "neo4j", 100) + handler.write_knowledge_graph(kg) + handler.close() + yield neo4j_container + + +def test_num_ast_nodes(setup_neo4j_container): + neo4j_container = setup_neo4j_container + uri = neo4j_container.get_connection_url() + + def _count_num_ast_nodes(tx): + result = tx.run(""" + MATCH (n:ASTNode) + RETURN n + """) + values = [record.values() for record in result] + return values + + with GraphDatabase.driver(uri, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver: + with driver.session() as session: + read_ast_nodes = session.execute_read(_count_num_ast_nodes) + assert len(read_ast_nodes) == 85 + + +def test_num_file_nodes(setup_neo4j_container): + neo4j_container = setup_neo4j_container + uri = neo4j_container.get_connection_url() + + def _count_num_file_nodes(tx): + result = tx.run(""" + MATCH (n:FileNode) + RETURN n + """) + values = [record.values() for record in result] + return values + + with GraphDatabase.driver(uri, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver: + with driver.session() as session: + read_file_nodes = session.execute_read(_count_num_file_nodes) + assert len(read_file_nodes) == 8 + + +def test_num_text_nodes(setup_neo4j_container): + neo4j_container = setup_neo4j_container + uri = neo4j_container.get_connection_url() + + def _count_num_text_nodes(tx): + result = tx.run(""" + MATCH (n:TextNode) + RETURN n + """) + values = [record.values() for record in result] + return values + + with GraphDatabase.driver(uri, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver: + with driver.session() as session: + read_text_nodes = session.execute_read(_count_num_text_nodes) + assert len(read_text_nodes) == 4 + + +def test_num_parent_of_edges(setup_neo4j_container): + neo4j_container = setup_neo4j_container + uri = neo4j_container.get_connection_url() + + def _count_num_parent_of_edges(tx): + result = tx.run(""" + MATCH () -[r:PARENT_OF]-> () + RETURN r + """) + values = [record.values() for record in result] + return values + + with GraphDatabase.driver(uri, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver: + with driver.session() as session: + read_parent_of_edges = session.execute_read(_count_num_parent_of_edges) + assert len(read_parent_of_edges) == 82 + + +def test_num_has_file_edges(setup_neo4j_container): + neo4j_container = setup_neo4j_container + uri = neo4j_container.get_connection_url() + + def _count_num_has_file_edges(tx): + result = tx.run(""" + MATCH () -[r:HAS_FILE]-> () + RETURN r + """) + values = [record.values() for record in result] + return values + + with GraphDatabase.driver(uri, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver: + with driver.session() as session: + read_has_file_edges = session.execute_read(_count_num_has_file_edges) + assert len(read_has_file_edges) == 7 + + +def test_num_has_ast_edges(setup_neo4j_container): + neo4j_container = setup_neo4j_container + uri = neo4j_container.get_connection_url() + + def _count_num_has_ast_edges(tx): + result = tx.run(""" + MATCH () -[r:HAS_AST]-> () + RETURN r + """) + values = [record.values() for record in result] + return values + + with GraphDatabase.driver(uri, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver: + with driver.session() as session: + read_has_ast_edges = session.execute_read(_count_num_has_ast_edges) + assert len(read_has_ast_edges) == 3 + + +def test_num_has_text_edges(setup_neo4j_container): + neo4j_container = setup_neo4j_container + uri = neo4j_container.get_connection_url() + + def _count_num_has_text_edges(tx): + result = tx.run(""" + MATCH () -[r:HAS_TEXT]-> () + RETURN r + """) + values = [record.values() for record in result] + return values + + with GraphDatabase.driver(uri, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver: + with driver.session() as session: + read_has_text_edges = session.execute_read(_count_num_has_text_edges) + assert len(read_has_text_edges) == 4 + + +def test_num_next_chunk_edges(setup_neo4j_container): + neo4j_container = setup_neo4j_container + uri = neo4j_container.get_connection_url() + + def _count_num_next_chunk_edges(tx): + result = tx.run(""" + MATCH () -[r:NEXT_CHUNK]-> () + RETURN r + """) + values = [record.values() for record in result] + return values + + with GraphDatabase.driver(uri, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) as driver: + with driver.session() as session: + read_next_chunk_edges = session.execute_read(_count_num_next_chunk_edges) + assert len(read_next_chunk_edges) == 3