diff --git a/kuzu_queries.py b/kuzu_queries.py index 6eba924..4b0dde5 100644 --- a/kuzu_queries.py +++ b/kuzu_queries.py @@ -9,6 +9,7 @@ from __future__ import annotations import json +import logging import os import threading from dataclasses import asdict, dataclass @@ -19,6 +20,8 @@ from ast_java import ONTOLOGY_VERSION as _ONTOLOGY_VERSION +log = logging.getLogger(__name__) + __all__ = [ "KuzuGraph", "resolve_kuzu_path", @@ -178,6 +181,18 @@ def _scope_filters( "lombok.", ) +_EDGE_TYPES: tuple[str, ...] = ( + "EXTENDS", + "IMPLEMENTS", + "INJECTS", + "DECLARES", + "CALLS", + "EXPOSES", + "DECLARES_CLIENT", + "HTTP_CALLS", + "ASYNC_CALLS", +) + def _type_part_fqn(sym_fqn: str) -> str: return sym_fqn.split("#", 1)[0] @@ -517,6 +532,20 @@ def meta(self) -> dict[str, Any]: cross_service_resolution = ( str(raw_csr) if raw_csr not in (None, "") else None ) + edge_counts = {edge: 0 for edge in _EDGE_TYPES} + failed_edges: list[str] = [] + for edge_type in _EDGE_TYPES: + try: + edge_rows = self._rows( + f"MATCH ()-[e:{edge_type}]->() RETURN count(e) AS n" + ) + edge_counts[edge_type] = int(edge_rows[0].get("n") or 0) if edge_rows else 0 + except Exception as exc: + failed_edges.append(edge_type) + log.warning("edge count query failed for %s: %s", edge_type, exc) + if len(failed_edges) == len(_EDGE_TYPES): + log.warning("edge count queries failed for all edge types; returning zeroed edge_counts") + return { "ontology_version": int(row.get("ontology_version") or 0), "built_at": int(row.get("built_at") or 0), @@ -543,9 +572,33 @@ def meta(self) -> dict[str, Any]: "pass3_skipped_cross_service": pass3_skipped_cross_service, "pass4_exposes_suppressed_feign": pass4_exposes_suppressed_feign, "cross_service_resolution": cross_service_resolution, + "edge_counts": edge_counts, "db_path": self.db_path, } + def edge_counts_for(self, node_id: str) -> dict[str, dict[str, int]]: + rows = self._rows( + "MATCH (n {id: $id})-[e]->() " + "RETURN label(e) AS edge_type, 'out' AS direction, count(e) AS n " + "UNION ALL " + "MATCH (n {id: $id})<-[e]-() " + "RETURN label(e) AS edge_type, 'in' AS direction, count(e) AS n", + {"id": node_id}, + ) + out: dict[str, dict[str, int]] = {} + for row in rows: + edge_type = str(row.get("edge_type") or "") + direction = str(row.get("direction") or "") + if edge_type == "" or direction not in ("in", "out"): + continue + out.setdefault(edge_type, {"in": 0, "out": 0}) + out[edge_type][direction] = int(row.get("n") or 0) + return { + edge_type: dirs + for edge_type, dirs in out.items() + if int(dirs.get("in", 0)) > 0 or int(dirs.get("out", 0)) > 0 + } + def _scope_counts(self, column: str) -> dict[str, int]: """Generic helper: count resolved type symbols grouped by `column`. diff --git a/mcp_v2.py b/mcp_v2.py index c2c3c14..f7e18f2 100644 --- a/mcp_v2.py +++ b/mcp_v2.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import os from pathlib import Path import threading @@ -143,7 +144,7 @@ def _row_to_search_hit(row: dict[str, Any]) -> SearchHit: score = float(row.get("_rrf_score") or row.get("_score") or 0.0) return SearchHit( chunk_id=_chunk_id_from_row(row), - symbol_id=str(row.get("symbol_id")) if row.get("symbol_id") else None, + symbol_id=_chunk_to_symbol_id(row), fqn=str(row.get("primary_type_fqn")) if row.get("primary_type_fqn") else None, score=score, snippet=str(row.get("text") or ""), @@ -153,6 +154,25 @@ def _row_to_search_hit(row: dict[str, Any]) -> SearchHit: ) +def _chunk_to_symbol_id(chunk_row: dict[str, Any]) -> str | None: + symbol_id = chunk_row.get("symbol_id") + if symbol_id: + return str(symbol_id) + meta = chunk_row.get("metadata") + if isinstance(meta, str): + try: + parsed = json.loads(meta) + if isinstance(parsed, dict): + meta = parsed + except Exception: + meta = None + if isinstance(meta, dict): + nested = meta.get("symbol_id") + if nested: + return str(nested) + return None + + def _symbol_where_from_filter(f: NodeFilter) -> tuple[str, dict[str, Any]]: preds: list[str] = [] params: dict[str, Any] = {} @@ -242,6 +262,32 @@ def _load_node_record(graph: KuzuGraph, node_id: str, kind: Literal["symbol", "r return rows[0] +def _edge_summary_for_node(graph: Any, node_id: str) -> dict[str, dict[str, int]]: + if hasattr(graph, "edge_counts_for"): + return graph.edge_counts_for(node_id) + rows = graph._rows( # noqa: SLF001 + "MATCH (n {id: $id})-[e]->() " + "RETURN label(e) AS edge_type, 'out' AS direction, count(e) AS n " + "UNION ALL " + "MATCH (n {id: $id})<-[e]-() " + "RETURN label(e) AS edge_type, 'in' AS direction, count(e) AS n", + {"id": node_id}, + ) + out: dict[str, dict[str, int]] = {} + for row in rows: + edge_type = str(row.get("edge_type") or "") + direction = str(row.get("direction") or "") + if edge_type == "" or direction not in ("in", "out"): + continue + out.setdefault(edge_type, {"in": 0, "out": 0}) + out[edge_type][direction] = int(row.get("n") or 0) + return { + edge_type: dirs + for edge_type, dirs in out.items() + if int(dirs.get("in", 0)) > 0 or int(dirs.get("out", 0)) > 0 + } + + def _node_matches_filter(kind: Literal["symbol", "route", "client"], row: dict[str, Any], f: NodeFilter | None) -> bool: if f is None: return True @@ -386,9 +432,10 @@ def describe_v2(id: str, graph: KuzuGraph | None = None) -> DescribeOutput: if row is None: return DescribeOutput(success=False, message=f"No node found for `{id}`") ref = _node_ref_from_row(kind, row) + edge_summary = _edge_summary_for_node(g, id) return DescribeOutput( success=True, - record=NodeRecord(id=ref.id, kind=kind, fqn=ref.fqn, data=row, edge_summary=None), + record=NodeRecord(id=ref.id, kind=kind, fqn=ref.fqn, data=row, edge_summary=edge_summary), ) except ValueError as exc: return DescribeOutput(success=False, message=str(exc)) @@ -399,6 +446,8 @@ def describe_v2(id: str, graph: KuzuGraph | None = None) -> DescribeOutput: @validate_call(config={"arbitrary_types_allowed": True}) def neighbors_v2( ids: str | list[str], + # Required fields are intentional: direct Python calls and MCP-bound calls + # share the same validation contract through @validate_call. direction: Literal["in", "out"] = Field(...), edge_types: list[str] = Field(...), limit: int = 25, diff --git a/search_lancedb.py b/search_lancedb.py index f8ec55e..228981c 100644 --- a/search_lancedb.py +++ b/search_lancedb.py @@ -34,6 +34,8 @@ "role", "annotations_on_type", "symbols", + "symbol_id", + "metadata", "ontology_version", "capabilities", ) diff --git a/server.py b/server.py index 36e6d92..55accca 100644 --- a/server.py +++ b/server.py @@ -366,6 +366,7 @@ class GraphMetaOutput(BaseModel): routes_resolved_pct: float = 0.0 routes_from_brownfield_pct: float = 0.0 routes_by_layer: dict[str, int] = Field(default_factory=dict) + edge_counts: dict[str, int] = Field(default_factory=dict) http_calls_match_breakdown: dict[str, int] = Field(default_factory=dict) async_calls_match_breakdown: dict[str, int] = Field(default_factory=dict) cross_service_calls_total: int = 0 @@ -547,6 +548,7 @@ def _graph_meta_output() -> GraphMetaOutput: routes_resolved_pct=float(meta.get("routes_resolved_pct") or 0.0), routes_from_brownfield_pct=float(meta.get("routes_from_brownfield_pct") or 0.0), routes_by_layer=routes_by_layer, + edge_counts={str(k): int(v) for k, v in (meta.get("edge_counts") or {}).items()}, http_calls_match_breakdown={str(k): int(v) for k, v in (meta.get("http_calls_match_breakdown") or {}).items()}, async_calls_match_breakdown={str(k): int(v) for k, v in (meta.get("async_calls_match_breakdown") or {}).items()}, cross_service_calls_total=int(meta.get("cross_service_calls_total") or 0), diff --git a/tests/test_mcp_v2_compose.py b/tests/test_mcp_v2_compose.py new file mode 100644 index 0000000..cde4c76 --- /dev/null +++ b/tests/test_mcp_v2_compose.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +from typing import Any + +from mcp_v2 import describe_v2, neighbors_v2, search_v2 +from server import _graph_meta_output + + +_EDGE_TYPES = ( + "ASYNC_CALLS", + "CALLS", + "DECLARES", + "DECLARES_CLIENT", + "EXPOSES", + "EXTENDS", + "HTTP_CALLS", + "IMPLEMENTS", + "INJECTS", +) + + +def _controller_method_with_calls(kuzu_graph) -> tuple[str, str]: + rows = kuzu_graph._rows( # noqa: SLF001 + "MATCH (t:Symbol)-[:DECLARES]->(m:Symbol) " + "WHERE t.role = 'CONTROLLER' AND m.kind IN ['method', 'constructor'] " + "AND (EXISTS { MATCH (:Symbol)-[:CALLS]->(m) } OR EXISTS { MATCH (m)-[:CALLS]->(:Symbol) }) " + "RETURN m.id AS id, m.fqn AS fqn " + "ORDER BY m.fqn LIMIT 1" + ) + assert rows + return str(rows[0]["id"]), str(rows[0]["fqn"]) + + +def _method_with_incoming_calls(kuzu_graph) -> tuple[str, str]: + rows = kuzu_graph._rows( # noqa: SLF001 + "MATCH (src:Symbol)-[:CALLS]->(dst:Symbol) " + "RETURN dst.id AS id LIMIT 1" + ) + assert rows + node_id = str(rows[0]["id"]) + fqn_rows = kuzu_graph._rows( # noqa: SLF001 + "MATCH (s:Symbol) WHERE s.id = $id RETURN s.fqn AS fqn LIMIT 1", + {"id": node_id}, + ) + assert fqn_rows + return node_id, str(fqn_rows[0]["fqn"]) + + +def _route_with_handler(kuzu_graph) -> str: + rows = kuzu_graph._rows( # noqa: SLF001 + "MATCH (:Symbol)-[:EXPOSES]->(r:Route) RETURN r.id AS id ORDER BY r.id LIMIT 1" + ) + assert rows + return str(rows[0]["id"]) + + +def test_describe_edge_summary_for_controller(kuzu_graph) -> None: + node_id, fqn = _controller_method_with_calls(kuzu_graph) + out = describe_v2(node_id, graph=kuzu_graph) + assert out.success is True + assert out.record is not None + assert out.record.edge_summary is not None + calls = out.record.edge_summary.get("CALLS", {"in": 0, "out": 0}) + callers = kuzu_graph.find_callers(fqn, limit=1000, exclude_external=False) + callees = kuzu_graph.find_callees(fqn, limit=1000, exclude_external=False) + assert int(calls.get("in", 0)) == len(callers) + assert int(calls.get("out", 0)) == len(callees) + + +def test_describe_edge_summary_omits_zero_count_types(kuzu_graph) -> None: + node_id, _ = _controller_method_with_calls(kuzu_graph) + out = describe_v2(node_id, graph=kuzu_graph) + assert out.success is True + assert out.record is not None + assert out.record.edge_summary is not None + for edge_type in _EDGE_TYPES: + if edge_type in out.record.edge_summary: + continue + rows = kuzu_graph._rows( # noqa: SLF001 + f"MATCH (n {{id: $id}})-[e:{edge_type}]->() RETURN count(e) AS n " + f"UNION ALL " + f"MATCH (n {{id: $id}})<-[e:{edge_type}]-() RETURN count(e) AS n", + {"id": node_id} + ) + assert sum(int(r.get("n") or 0) for r in rows) == 0 + + +def test_describe_edge_summary_for_route(kuzu_graph) -> None: + route_id = _route_with_handler(kuzu_graph) + out = describe_v2(route_id, graph=kuzu_graph) + assert out.success is True + assert out.record is not None + assert out.record.kind == "route" + assert out.record.edge_summary is not None + exposes = out.record.edge_summary.get("EXPOSES", {"in": 0, "out": 0}) + assert int(exposes.get("in", 0)) >= 1 + + +def test_search_populates_symbol_id_when_chunk_rooted_in_symbol(monkeypatch, kuzu_graph) -> None: + rows: list[dict[str, Any]] = [ + { + "filename": "A.java", + "start": {"byte_offset": 0}, + "end": {"byte_offset": 10}, + "symbol_id": "sym:one", + "primary_type_fqn": "com.example.A", + "_rrf_score": 0.9, + "text": "A", + }, + { + "filename": "B.java", + "start": {"byte_offset": 10}, + "end": {"byte_offset": 20}, + "metadata": {"symbol_id": "sym:two"}, + "primary_type_fqn": "com.example.B", + "_rrf_score": 0.8, + "text": "B", + }, + { + "filename": "C.java", + "start": {"byte_offset": 30}, + "end": {"byte_offset": 40}, + "metadata": '{"symbol_id":"sym:three"}', + "primary_type_fqn": "com.example.C", + "_rrf_score": 0.75, + "text": "C", + }, + { + "filename": "raw.txt", + "start": {"byte_offset": 20}, + "end": {"byte_offset": 30}, + "_rrf_score": 0.7, + "text": "raw", + }, + ] + monkeypatch.setattr("mcp_v2.run_search", lambda *args, **kwargs: rows) + out = search_v2("query", graph=kuzu_graph) + assert out.success is True + rooted = [hit for hit in out.results if hit.fqn is not None] + assert rooted + assert all(hit.symbol_id is not None for hit in rooted) + + +def test_meta_returns_per_edge_type_counts() -> None: + out = _graph_meta_output() + assert out.success is True + assert set(out.edge_counts.keys()) == set(_EDGE_TYPES) + assert all(int(v) >= 0 for v in out.edge_counts.values()) + + +def test_search_describe_neighbors_chain_end_to_end(kuzu_graph, monkeypatch) -> None: + node_id, _ = _method_with_incoming_calls(kuzu_graph) + rows = kuzu_graph._rows( # noqa: SLF001 + "MATCH (m:Symbol {id: $id}) RETURN m.fqn AS fqn, m.role AS role, m.module AS module, " + "m.microservice AS microservice, m.filename AS filename", + {"id": node_id}, + ) + assert rows + row = rows[0] + monkeypatch.setattr( + "mcp_v2.run_search", + lambda *args, **kwargs: [ + { + "filename": str(row.get("filename") or "x.java"), + "start": {"byte_offset": 0}, + "end": {"byte_offset": 1}, + "symbol_id": node_id, + "primary_type_fqn": str(row.get("fqn") or ""), + "role": str(row.get("role") or ""), + "module": str(row.get("module") or ""), + "microservice": str(row.get("microservice") or ""), + "_rrf_score": 0.95, + "text": "match", + } + ], + ) + search_out = search_v2("assign", graph=kuzu_graph) + assert search_out.success is True + assert search_out.results + top_symbol_id = search_out.results[0].symbol_id + assert top_symbol_id is not None + describe_out = describe_v2(top_symbol_id, graph=kuzu_graph) + assert describe_out.success is True + assert describe_out.record is not None + neighbors_out = neighbors_v2(top_symbol_id, direction="in", edge_types=["CALLS"], graph=kuzu_graph) + assert neighbors_out.success is True + assert neighbors_out.results diff --git a/tests/test_search_lancedb.py b/tests/test_search_lancedb.py index 4918ac9..bae58b7 100644 --- a/tests/test_search_lancedb.py +++ b/tests/test_search_lancedb.py @@ -2,7 +2,10 @@ from __future__ import annotations -from search_lancedb import _rrf_merge +import numpy as np + +import search_lancedb +from search_lancedb import JAVA_ENRICHED_COLUMNS, _rrf_merge def test_rrf_merge_weights_second_list_by_row() -> None: @@ -32,3 +35,61 @@ def test_rrf_merge_weights_second_list_by_row() -> None: by_file = {m["filename"]: float(m["_rrf_score"]) for m in merged} # Rank 0 in graph list (weight 1.0) should contribute more than rank 1 (weight 0.5). assert by_file["b.java"] > by_file["c.java"] + + +def test_java_enriched_columns_include_symbol_identity_fields() -> None: + assert "symbol_id" in JAVA_ENRICHED_COLUMNS + assert "metadata" in JAVA_ENRICHED_COLUMNS + + +def test_search_one_table_selects_symbol_identity_columns_when_schema_has_them(monkeypatch) -> None: + selected: list[str] = [] + + class _FakeQuery: + def select(self, cols): + selected[:] = list(cols) + return self + + def limit(self, _n): + return self + + def to_list(self): + return [] + + class _FakeTable: + def search(self, *_args, **_kwargs): + return _FakeQuery() + + class _FakeDb: + def open_table(self, _name): + return _FakeTable() + + monkeypatch.setattr( + search_lancedb, + "_table_columns", + lambda *_args, **_kwargs: { + "filename", + "text", + "start", + "end", + "language", + "package", + "primary_type_fqn", + "symbol_id", + "metadata", + }, + ) + search_lancedb._search_one_table( + "javacodeindex_java_code", + uri="mem://", + db=_FakeDb(), + query_vec=np.zeros((3,), dtype=np.float32), + limit=5, + path_predicate=None, + kind="java", + hybrid=False, + fts_text=None, + extra_predicates=None, + ) + assert "symbol_id" in selected + assert "metadata" in selected