Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions kuzu_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from __future__ import annotations

import json
import logging
import os
import threading
from dataclasses import asdict, dataclass
Expand All @@ -19,6 +20,8 @@

from ast_java import ONTOLOGY_VERSION as _ONTOLOGY_VERSION

log = logging.getLogger(__name__)

__all__ = [
"KuzuGraph",
"resolve_kuzu_path",
Expand Down Expand Up @@ -178,6 +181,18 @@ def _scope_filters(
"lombok.",
)

_EDGE_TYPES: tuple[str, ...] = (
"EXTENDS",
"IMPLEMENTS",
"INJECTS",
"DECLARES",
"CALLS",
"EXPOSES",
"DECLARES_CLIENT",
"HTTP_CALLS",
"ASYNC_CALLS",
)


def _type_part_fqn(sym_fqn: str) -> str:
return sym_fqn.split("#", 1)[0]
Expand Down Expand Up @@ -517,6 +532,20 @@ def meta(self) -> dict[str, Any]:
cross_service_resolution = (
str(raw_csr) if raw_csr not in (None, "") else None
)
edge_counts = {edge: 0 for edge in _EDGE_TYPES}
failed_edges: list[str] = []
for edge_type in _EDGE_TYPES:
try:
edge_rows = self._rows(
f"MATCH ()-[e:{edge_type}]->() RETURN count(e) AS n"
)
edge_counts[edge_type] = int(edge_rows[0].get("n") or 0) if edge_rows else 0
except Exception as exc:
failed_edges.append(edge_type)
log.warning("edge count query failed for %s: %s", edge_type, exc)
if len(failed_edges) == len(_EDGE_TYPES):
log.warning("edge count queries failed for all edge types; returning zeroed edge_counts")

return {
"ontology_version": int(row.get("ontology_version") or 0),
"built_at": int(row.get("built_at") or 0),
Expand All @@ -543,9 +572,33 @@ def meta(self) -> dict[str, Any]:
"pass3_skipped_cross_service": pass3_skipped_cross_service,
"pass4_exposes_suppressed_feign": pass4_exposes_suppressed_feign,
"cross_service_resolution": cross_service_resolution,
"edge_counts": edge_counts,
"db_path": self.db_path,
}

def edge_counts_for(self, node_id: str) -> dict[str, dict[str, int]]:
rows = self._rows(
"MATCH (n {id: $id})-[e]->() "
"RETURN label(e) AS edge_type, 'out' AS direction, count(e) AS n "
"UNION ALL "
"MATCH (n {id: $id})<-[e]-() "
"RETURN label(e) AS edge_type, 'in' AS direction, count(e) AS n",
{"id": node_id},
)
out: dict[str, dict[str, int]] = {}
for row in rows:
edge_type = str(row.get("edge_type") or "")
direction = str(row.get("direction") or "")
if edge_type == "" or direction not in ("in", "out"):
continue
out.setdefault(edge_type, {"in": 0, "out": 0})
out[edge_type][direction] = int(row.get("n") or 0)
return {
edge_type: dirs
for edge_type, dirs in out.items()
if int(dirs.get("in", 0)) > 0 or int(dirs.get("out", 0)) > 0
}

def _scope_counts(self, column: str) -> dict[str, int]:
"""Generic helper: count resolved type symbols grouped by `column`.

Expand Down
53 changes: 51 additions & 2 deletions mcp_v2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
import os
from pathlib import Path
import threading
Expand Down Expand Up @@ -143,7 +144,7 @@ def _row_to_search_hit(row: dict[str, Any]) -> SearchHit:
score = float(row.get("_rrf_score") or row.get("_score") or 0.0)
return SearchHit(
chunk_id=_chunk_id_from_row(row),
symbol_id=str(row.get("symbol_id")) if row.get("symbol_id") else None,
symbol_id=_chunk_to_symbol_id(row),
fqn=str(row.get("primary_type_fqn")) if row.get("primary_type_fqn") else None,
score=score,
snippet=str(row.get("text") or ""),
Expand All @@ -153,6 +154,25 @@ def _row_to_search_hit(row: dict[str, Any]) -> SearchHit:
)


def _chunk_to_symbol_id(chunk_row: dict[str, Any]) -> str | None:
symbol_id = chunk_row.get("symbol_id")
if symbol_id:
return str(symbol_id)
meta = chunk_row.get("metadata")
if isinstance(meta, str):
try:
parsed = json.loads(meta)
if isinstance(parsed, dict):
meta = parsed
except Exception:
meta = None
if isinstance(meta, dict):
nested = meta.get("symbol_id")
if nested:
return str(nested)
return None


def _symbol_where_from_filter(f: NodeFilter) -> tuple[str, dict[str, Any]]:
preds: list[str] = []
params: dict[str, Any] = {}
Expand Down Expand Up @@ -242,6 +262,32 @@ def _load_node_record(graph: KuzuGraph, node_id: str, kind: Literal["symbol", "r
return rows[0]


def _edge_summary_for_node(graph: Any, node_id: str) -> dict[str, dict[str, int]]:
if hasattr(graph, "edge_counts_for"):
return graph.edge_counts_for(node_id)
rows = graph._rows( # noqa: SLF001
"MATCH (n {id: $id})-[e]->() "
"RETURN label(e) AS edge_type, 'out' AS direction, count(e) AS n "
"UNION ALL "
"MATCH (n {id: $id})<-[e]-() "
"RETURN label(e) AS edge_type, 'in' AS direction, count(e) AS n",
{"id": node_id},
)
out: dict[str, dict[str, int]] = {}
for row in rows:
edge_type = str(row.get("edge_type") or "")
direction = str(row.get("direction") or "")
if edge_type == "" or direction not in ("in", "out"):
continue
out.setdefault(edge_type, {"in": 0, "out": 0})
out[edge_type][direction] = int(row.get("n") or 0)
return {
edge_type: dirs
for edge_type, dirs in out.items()
if int(dirs.get("in", 0)) > 0 or int(dirs.get("out", 0)) > 0
}


def _node_matches_filter(kind: Literal["symbol", "route", "client"], row: dict[str, Any], f: NodeFilter | None) -> bool:
if f is None:
return True
Expand Down Expand Up @@ -386,9 +432,10 @@ def describe_v2(id: str, graph: KuzuGraph | None = None) -> DescribeOutput:
if row is None:
return DescribeOutput(success=False, message=f"No node found for `{id}`")
ref = _node_ref_from_row(kind, row)
edge_summary = _edge_summary_for_node(g, id)
return DescribeOutput(
success=True,
record=NodeRecord(id=ref.id, kind=kind, fqn=ref.fqn, data=row, edge_summary=None),
record=NodeRecord(id=ref.id, kind=kind, fqn=ref.fqn, data=row, edge_summary=edge_summary),
)
except ValueError as exc:
return DescribeOutput(success=False, message=str(exc))
Expand All @@ -399,6 +446,8 @@ def describe_v2(id: str, graph: KuzuGraph | None = None) -> DescribeOutput:
@validate_call(config={"arbitrary_types_allowed": True})
def neighbors_v2(
ids: str | list[str],
# Required fields are intentional: direct Python calls and MCP-bound calls
# share the same validation contract through @validate_call.
direction: Literal["in", "out"] = Field(...),
edge_types: list[str] = Field(...),
limit: int = 25,
Expand Down
2 changes: 2 additions & 0 deletions search_lancedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
"role",
"annotations_on_type",
"symbols",
"symbol_id",
"metadata",
"ontology_version",
"capabilities",
)
Expand Down
2 changes: 2 additions & 0 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ class GraphMetaOutput(BaseModel):
routes_resolved_pct: float = 0.0
routes_from_brownfield_pct: float = 0.0
routes_by_layer: dict[str, int] = Field(default_factory=dict)
edge_counts: dict[str, int] = Field(default_factory=dict)
http_calls_match_breakdown: dict[str, int] = Field(default_factory=dict)
async_calls_match_breakdown: dict[str, int] = Field(default_factory=dict)
cross_service_calls_total: int = 0
Expand Down Expand Up @@ -547,6 +548,7 @@ def _graph_meta_output() -> GraphMetaOutput:
routes_resolved_pct=float(meta.get("routes_resolved_pct") or 0.0),
routes_from_brownfield_pct=float(meta.get("routes_from_brownfield_pct") or 0.0),
routes_by_layer=routes_by_layer,
edge_counts={str(k): int(v) for k, v in (meta.get("edge_counts") or {}).items()},
http_calls_match_breakdown={str(k): int(v) for k, v in (meta.get("http_calls_match_breakdown") or {}).items()},
async_calls_match_breakdown={str(k): int(v) for k, v in (meta.get("async_calls_match_breakdown") or {}).items()},
cross_service_calls_total=int(meta.get("cross_service_calls_total") or 0),
Expand Down
187 changes: 187 additions & 0 deletions tests/test_mcp_v2_compose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from __future__ import annotations

from typing import Any

from mcp_v2 import describe_v2, neighbors_v2, search_v2
from server import _graph_meta_output


_EDGE_TYPES = (
"ASYNC_CALLS",
"CALLS",
"DECLARES",
"DECLARES_CLIENT",
"EXPOSES",
"EXTENDS",
"HTTP_CALLS",
"IMPLEMENTS",
"INJECTS",
)


def _controller_method_with_calls(kuzu_graph) -> tuple[str, str]:
rows = kuzu_graph._rows( # noqa: SLF001
"MATCH (t:Symbol)-[:DECLARES]->(m:Symbol) "
"WHERE t.role = 'CONTROLLER' AND m.kind IN ['method', 'constructor'] "
"AND (EXISTS { MATCH (:Symbol)-[:CALLS]->(m) } OR EXISTS { MATCH (m)-[:CALLS]->(:Symbol) }) "
"RETURN m.id AS id, m.fqn AS fqn "
"ORDER BY m.fqn LIMIT 1"
)
assert rows
return str(rows[0]["id"]), str(rows[0]["fqn"])


def _method_with_incoming_calls(kuzu_graph) -> tuple[str, str]:
rows = kuzu_graph._rows( # noqa: SLF001
"MATCH (src:Symbol)-[:CALLS]->(dst:Symbol) "
"RETURN dst.id AS id LIMIT 1"
)
assert rows
node_id = str(rows[0]["id"])
fqn_rows = kuzu_graph._rows( # noqa: SLF001
"MATCH (s:Symbol) WHERE s.id = $id RETURN s.fqn AS fqn LIMIT 1",
{"id": node_id},
)
assert fqn_rows
return node_id, str(fqn_rows[0]["fqn"])


def _route_with_handler(kuzu_graph) -> str:
rows = kuzu_graph._rows( # noqa: SLF001
"MATCH (:Symbol)-[:EXPOSES]->(r:Route) RETURN r.id AS id ORDER BY r.id LIMIT 1"
)
assert rows
return str(rows[0]["id"])


def test_describe_edge_summary_for_controller(kuzu_graph) -> None:
node_id, fqn = _controller_method_with_calls(kuzu_graph)
out = describe_v2(node_id, graph=kuzu_graph)
assert out.success is True
assert out.record is not None
assert out.record.edge_summary is not None
calls = out.record.edge_summary.get("CALLS", {"in": 0, "out": 0})
callers = kuzu_graph.find_callers(fqn, limit=1000, exclude_external=False)
callees = kuzu_graph.find_callees(fqn, limit=1000, exclude_external=False)
assert int(calls.get("in", 0)) == len(callers)
assert int(calls.get("out", 0)) == len(callees)


def test_describe_edge_summary_omits_zero_count_types(kuzu_graph) -> None:
node_id, _ = _controller_method_with_calls(kuzu_graph)
out = describe_v2(node_id, graph=kuzu_graph)
assert out.success is True
assert out.record is not None
assert out.record.edge_summary is not None
for edge_type in _EDGE_TYPES:
if edge_type in out.record.edge_summary:
continue
rows = kuzu_graph._rows( # noqa: SLF001
f"MATCH (n {{id: $id}})-[e:{edge_type}]->() RETURN count(e) AS n "
f"UNION ALL "
f"MATCH (n {{id: $id}})<-[e:{edge_type}]-() RETURN count(e) AS n",
{"id": node_id}
)
assert sum(int(r.get("n") or 0) for r in rows) == 0


def test_describe_edge_summary_for_route(kuzu_graph) -> None:
route_id = _route_with_handler(kuzu_graph)
out = describe_v2(route_id, graph=kuzu_graph)
assert out.success is True
assert out.record is not None
assert out.record.kind == "route"
assert out.record.edge_summary is not None
exposes = out.record.edge_summary.get("EXPOSES", {"in": 0, "out": 0})
assert int(exposes.get("in", 0)) >= 1


def test_search_populates_symbol_id_when_chunk_rooted_in_symbol(monkeypatch, kuzu_graph) -> None:
rows: list[dict[str, Any]] = [
{
"filename": "A.java",
"start": {"byte_offset": 0},
"end": {"byte_offset": 10},
"symbol_id": "sym:one",
"primary_type_fqn": "com.example.A",
"_rrf_score": 0.9,
"text": "A",
},
{
"filename": "B.java",
"start": {"byte_offset": 10},
"end": {"byte_offset": 20},
"metadata": {"symbol_id": "sym:two"},
"primary_type_fqn": "com.example.B",
"_rrf_score": 0.8,
"text": "B",
},
{
"filename": "C.java",
"start": {"byte_offset": 30},
"end": {"byte_offset": 40},
"metadata": '{"symbol_id":"sym:three"}',
"primary_type_fqn": "com.example.C",
"_rrf_score": 0.75,
"text": "C",
},
{
"filename": "raw.txt",
"start": {"byte_offset": 20},
"end": {"byte_offset": 30},
"_rrf_score": 0.7,
"text": "raw",
},
]
monkeypatch.setattr("mcp_v2.run_search", lambda *args, **kwargs: rows)
out = search_v2("query", graph=kuzu_graph)
assert out.success is True
rooted = [hit for hit in out.results if hit.fqn is not None]
assert rooted
assert all(hit.symbol_id is not None for hit in rooted)


def test_meta_returns_per_edge_type_counts() -> None:
out = _graph_meta_output()
assert out.success is True
assert set(out.edge_counts.keys()) == set(_EDGE_TYPES)
assert all(int(v) >= 0 for v in out.edge_counts.values())


def test_search_describe_neighbors_chain_end_to_end(kuzu_graph, monkeypatch) -> None:
node_id, _ = _method_with_incoming_calls(kuzu_graph)
rows = kuzu_graph._rows( # noqa: SLF001
"MATCH (m:Symbol {id: $id}) RETURN m.fqn AS fqn, m.role AS role, m.module AS module, "
"m.microservice AS microservice, m.filename AS filename",
{"id": node_id},
)
assert rows
row = rows[0]
monkeypatch.setattr(
"mcp_v2.run_search",
lambda *args, **kwargs: [
{
"filename": str(row.get("filename") or "x.java"),
"start": {"byte_offset": 0},
"end": {"byte_offset": 1},
"symbol_id": node_id,
"primary_type_fqn": str(row.get("fqn") or ""),
"role": str(row.get("role") or ""),
"module": str(row.get("module") or ""),
"microservice": str(row.get("microservice") or ""),
"_rrf_score": 0.95,
"text": "match",
}
],
)
search_out = search_v2("assign", graph=kuzu_graph)
assert search_out.success is True
assert search_out.results
top_symbol_id = search_out.results[0].symbol_id
assert top_symbol_id is not None
describe_out = describe_v2(top_symbol_id, graph=kuzu_graph)
assert describe_out.success is True
assert describe_out.record is not None
neighbors_out = neighbors_v2(top_symbol_id, direction="in", edge_types=["CALLS"], graph=kuzu_graph)
assert neighbors_out.success is True
assert neighbors_out.results
Loading