Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions prometheus/graph/knowledge_graph.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
from collections import deque
import logging
from pathlib import Path
from typing import Sequence

from prometheus.graph.file_graph_builder import FileGraphBuilder
from prometheus.graph.graph_types import (
ASTNode,
FileNode,
KnowledgeGraphEdge,
KnowledgeGraphEdgeType,
KnowledgeGraphNode,
Neo4jASTNode,
Neo4jFileNode,
Neo4jHasASTEdge,
Neo4jHasFileEdge,
Neo4jHasTextEdge,
Neo4jNextChunkEdge,
Neo4jParentOfEdge,
Neo4jTextNode,
TextNode,
)


Expand Down Expand Up @@ -64,3 +75,83 @@ def _build_graph(self, root_dir: Path):
continue

self._logger.info(f"Skip parsing {file} because it is not supported.")

def get_file_nodes(self) -> Sequence[KnowledgeGraphNode]:
return [
kg_node
for kg_node in self._knowledge_graph_nodes
if isinstance(kg_node.node, FileNode)
]

def get_ast_nodes(self) -> Sequence[KnowledgeGraphNode]:
return [
kg_node
for kg_node in self._knowledge_graph_nodes
if isinstance(kg_node.node, ASTNode)
]

def get_text_nodes(self) -> Sequence[KnowledgeGraphNode]:
return [
kg_node
for kg_node in self._knowledge_graph_nodes
if isinstance(kg_node.node, TextNode)
]

def get_has_ast_edges(self) -> Sequence[KnowledgeGraphEdge]:
return [
kg_edge
for kg_edge in self._knowledge_graph_edges
if kg_edge.type == KnowledgeGraphEdgeType.has_ast
]

def get_has_file_edges(self) -> Sequence[KnowledgeGraphEdge]:
return [
kg_edge
for kg_edge in self._knowledge_graph_edges
if kg_edge.type == KnowledgeGraphEdgeType.has_file
]

def get_has_text_edges(self) -> Sequence[KnowledgeGraphEdge]:
return [
kg_edge
for kg_edge in self._knowledge_graph_edges
if kg_edge.type == KnowledgeGraphEdgeType.has_text
]

def get_next_chunk_edges(self) -> Sequence[KnowledgeGraphEdge]:
return [
kg_edge
for kg_edge in self._knowledge_graph_edges
if kg_edge.type == KnowledgeGraphEdgeType.next_chunk
]

def get_parent_of_edges(self) -> Sequence[KnowledgeGraphEdge]:
return [
kg_edge
for kg_edge in self._knowledge_graph_edges
if kg_edge.type == KnowledgeGraphEdgeType.parent_of
]

def get_neo4j_file_nodes(self) -> Sequence[Neo4jFileNode]:
return [kg_node.to_neo4j_node() for kg_node in self.get_file_nodes()]

def get_neo4j_ast_nodes(self) -> Sequence[Neo4jASTNode]:
return [kg_node.to_neo4j_node() for kg_node in self.get_ast_nodes()]

def get_neo4j_text_nodes(self) -> Sequence[Neo4jTextNode]:
return [kg_node.to_neo4j_node() for kg_node in self.get_text_nodes()]

def get_neo4j_has_ast_edges(self) -> Sequence[Neo4jHasASTEdge]:
return [kg_edge.to_neo4j_edge() for kg_edge in self.get_has_ast_edges()]

def get_neo4j_has_file_edges(self) -> Sequence[Neo4jHasFileEdge]:
return [kg_edge.to_neo4j_edge() for kg_edge in self.get_has_file_edges()]

def get_neo4j_has_text_edges(self) -> Sequence[Neo4jHasTextEdge]:
return [kg_edge.to_neo4j_edge() for kg_edge in self.get_has_text_edges()]

def get_neo4j_next_chunk_edges(self) -> Sequence[Neo4jNextChunkEdge]:
return [kg_edge.to_neo4j_edge() for kg_edge in self.get_next_chunk_edges()]

def get_neo4j_parent_of_edges(self) -> Sequence[Neo4jParentOfEdge]:
return [kg_edge.to_neo4j_edge() for kg_edge in self.get_parent_of_edges()]
Empty file added prometheus/neo4j/__init__.py
Empty file.
152 changes: 152 additions & 0 deletions prometheus/neo4j/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from typing import Sequence
from neo4j import GraphDatabase, ManagedTransaction

from prometheus.graph.graph_types import (
Neo4jASTNode,
Neo4jFileNode,
Neo4jHasASTEdge,
Neo4jHasFileEdge,
Neo4jHasTextEdge,
Neo4jNextChunkEdge,
Neo4jParentOfEdge,
Neo4jTextNode,
)
from prometheus.graph.knowledge_graph import KnowledgeGraph


class Handler:
"""The handler to writing the Knowledge graph to neo4j."""

def __init__(
self, uri: str, user: str, password: str, database: str, batch_size: int
):
self.driver = GraphDatabase.driver(uri, auth=(user, password))
self.database = database
self.batch_size = batch_size

def _init_database(self, tx: ManagedTransaction):
queries = [
"CREATE CONSTRAINT unique_file_node_id IF NOT EXISTS "
"FOR (n:FileNode) REQUIRE n.node_id IS UNIQUE",
"CREATE CONSTRAINT unique_ast_node_id IF NOT EXISTS "
"FOR (n:ASTNode) REQUIRE n.node_id IS UNIQUE",
"CREATE CONSTRAINT unique_text_node_id IF NOT EXISTS "
"FOR (n:TextNode) REQUIRE n.node_id IS UNIQUE",
]
for query in queries:
tx.run(query)

def _write_file_nodes(
self, tx: ManagedTransaction, file_nodes: Sequence[Neo4jFileNode]
):
query = """
UNWIND $file_nodes AS file_node
CREATE (a:FileNode {node_id: file_node.node_id, basename: file_node.basename, relative_path: file_node.relative_path})
"""
for i in range(0, len(file_nodes), self.batch_size):
file_nodes_batch = file_nodes[i : i + self.batch_size]
tx.run(query, file_nodes=file_nodes_batch)

def _write_ast_nodes(self, tx: ManagedTransaction, ast_nodes: Sequence[Neo4jASTNode]):
query = """
UNWIND $ast_nodes AS ast_node
CREATE (a:ASTNode {node_id: ast_node.node_id, start_line: ast_node.start_line, end_line: ast_node.end_line, type: ast_node.type, text: ast_node.text})
"""
for i in range(0, len(ast_nodes), self.batch_size):
ast_nodes_batch = ast_nodes[i : i + self.batch_size]
tx.run(query, ast_nodes=ast_nodes_batch)

def _write_text_nodes(
self, tx: ManagedTransaction, text_nodes: Sequence[Neo4jTextNode]
):
query = """
UNWIND $text_nodes AS text_node
CREATE (a:TextNode {node_id: text_node.node_id, text: text_node.text, metadata: text_node.metadata})
"""
for i in range(0, len(text_nodes), self.batch_size):
text_nodes_batch = text_nodes[i : i + self.batch_size]
tx.run(query, text_nodes=text_nodes_batch)

def _write_has_file_edges(
self, tx: ManagedTransaction, has_file_edges: Sequence[Neo4jHasFileEdge]
):
query = """
UNWIND $edges AS edge
MATCH (source:FileNode), (target:FileNode)
WHERE source.node_id = edge.source.node_id AND target.node_id = edge.target.node_id
CREATE (source) -[:HAS_FILE]-> (target)
"""
for i in range(0, len(has_file_edges), self.batch_size):
has_file_edges_batch = has_file_edges[i : i + self.batch_size]
tx.run(query, edges=has_file_edges_batch)

def _write_has_ast_edges(
self, tx: ManagedTransaction, has_ast_edges: Sequence[Neo4jHasASTEdge]
):
query = """
UNWIND $edges AS edge
MATCH (source:FileNode), (target:ASTNode)
WHERE source.node_id = edge.source.node_id AND target.node_id = edge.target.node_id
CREATE (source) -[:HAS_AST]-> (target)
"""
for i in range(0, len(has_ast_edges), self.batch_size):
has_ast_edges_batch = has_ast_edges[i : i + self.batch_size]
tx.run(query, edges=has_ast_edges_batch)

def _write_has_text_edges(
self, tx: ManagedTransaction, has_text_edges: Sequence[Neo4jHasTextEdge]
):
query = """
UNWIND $edges AS edge
MATCH (source:FileNode), (target:TextNode)
WHERE source.node_id = edge.source.node_id AND target.node_id = edge.target.node_id
CREATE (source) -[:HAS_TEXT]-> (target)
"""
for i in range(0, len(has_text_edges), self.batch_size):
has_text_edges_batch = has_text_edges[i : i + self.batch_size]
tx.run(query, edges=has_text_edges_batch)

def _write_parent_of_edges(
self, tx: ManagedTransaction, parent_of_edges: Sequence[Neo4jParentOfEdge]
):
query = """
UNWIND $edges AS edge
MATCH (source:ASTNode), (target:ASTNode)
WHERE source.node_id = edge.source.node_id AND target.node_id = edge.target.node_id
CREATE (source) -[:PARENT_OF]-> (target)
"""
for i in range(0, len(parent_of_edges), self.batch_size):
parent_of_edges_batch = parent_of_edges[i : i + self.batch_size]
tx.run(query, edges=parent_of_edges_batch)

def _write_next_chunk_edges(
self, tx: ManagedTransaction, next_chunk_edges: Sequence[Neo4jNextChunkEdge]
):
query = """
UNWIND $edges AS edge
MATCH (source:TextNode), (target:TextNode)
WHERE source.node_id = edge.source.node_id AND target.node_id = edge.target.node_id
CREATE (source) -[:NEXT_CHUNK]-> (target)
"""
for i in range(0, len(next_chunk_edges), self.batch_size):
next_chunk_edges_batch = next_chunk_edges[i : i + self.batch_size]
tx.run(query, edges=next_chunk_edges_batch)

def write_knowledge_graph(self, kg: KnowledgeGraph):
with self.driver.session() as session:
session.execute_write(self._init_database)

session.execute_write(self._write_file_nodes, kg.get_neo4j_file_nodes())
session.execute_write(self._write_ast_nodes, kg.get_neo4j_ast_nodes())
session.execute_write(self._write_text_nodes, kg.get_neo4j_text_nodes())

session.execute_write(self._write_has_ast_edges, kg.get_neo4j_has_ast_edges())
session.execute_write(self._write_has_file_edges, kg.get_neo4j_has_file_edges())
session.execute_write(self._write_has_text_edges, kg.get_neo4j_has_text_edges())
session.execute_write(
self._write_next_chunk_edges, kg.get_neo4j_next_chunk_edges()
)
session.execute_write(self._write_parent_of_edges, kg.get_neo4j_parent_of_edges())

def close(self):
self.driver.close()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ test = [
"ruff==0.6.9",
"pytest==8.3.3",
"pytest-cov==5.0.0",
"testcontainers==4.8.2",
]

[tool.ruff]
Expand Down
69 changes: 8 additions & 61 deletions tests/graph/test_knowledge_graph.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
from prometheus.graph.graph_types import (
ASTNode,
FileNode,
KnowledgeGraphEdgeType,
TextNode,
)
from prometheus.graph.knowledge_graph import KnowledgeGraph
from tests.test_utils import test_project_paths

Expand All @@ -18,58 +12,11 @@ def test_build_graph():
assert len(knowledge_graph._knowledge_graph_nodes) == 97
assert len(knowledge_graph._knowledge_graph_edges) == 99

file_nodes = [
kg_node
for kg_node in knowledge_graph._knowledge_graph_nodes
if isinstance(kg_node.node, FileNode)
]
assert len(file_nodes) == 8

ast_nodes = [
kg_node
for kg_node in knowledge_graph._knowledge_graph_nodes
if isinstance(kg_node.node, ASTNode)
]
assert len(ast_nodes) == 85

text_nodes = [
kg_node
for kg_node in knowledge_graph._knowledge_graph_nodes
if isinstance(kg_node.node, TextNode)
]
assert len(text_nodes) == 4

parent_of_edges = [
kg_edge
for kg_edge in knowledge_graph._knowledge_graph_edges
if kg_edge.type == KnowledgeGraphEdgeType.parent_of
]
assert len(parent_of_edges) == 82

has_file_edges = [
kg_edge
for kg_edge in knowledge_graph._knowledge_graph_edges
if kg_edge.type == KnowledgeGraphEdgeType.has_file
]
assert len(has_file_edges) == 7

has_ast_edges = [
kg_edge
for kg_edge in knowledge_graph._knowledge_graph_edges
if kg_edge.type == KnowledgeGraphEdgeType.has_ast
]
assert len(has_ast_edges) == 3

has_text_edges = [
kg_edge
for kg_edge in knowledge_graph._knowledge_graph_edges
if kg_edge.type == KnowledgeGraphEdgeType.has_text
]
assert len(has_text_edges) == 4

next_chunk_edges = [
kg_edge
for kg_edge in knowledge_graph._knowledge_graph_edges
if kg_edge.type == KnowledgeGraphEdgeType.next_chunk
]
assert len(next_chunk_edges) == 3
assert len(knowledge_graph.get_file_nodes()) == 8
assert len(knowledge_graph.get_ast_nodes()) == 85
assert len(knowledge_graph.get_text_nodes()) == 4
assert len(knowledge_graph.get_parent_of_edges()) == 82
assert len(knowledge_graph.get_has_file_edges()) == 7
assert len(knowledge_graph.get_has_ast_edges()) == 3
assert len(knowledge_graph.get_has_text_edges()) == 4
assert len(knowledge_graph.get_next_chunk_edges()) == 3
Empty file added tests/neo4j/__init__.py
Empty file.
Loading
Loading