From a3b3206f791391164f11918ff0908497a8b5d091 Mon Sep 17 00:00:00 2001 From: Dvir Dukhan <12258836+DvirDukhan@users.noreply.github.com> Date: Wed, 27 May 2026 14:24:01 +0300 Subject: [PATCH] =?UTF-8?q?feat(mcp):=20GraphRAG=20ask=20tool=20=E2=80=94?= =?UTF-8?q?=20init=20module=20+=20prompt=20seam=20+=20tool=20(T9/T10/T11)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bundles three tightly-coupled tickets: T9 builds the per-(project,branch) KnowledgeGraph cache, T10 adds the prompt-override seam, T11 wires both together into the `ask` MCP tool that gives agents natural-language access to the graph. T9 (#657) — api/mcp/graphrag_init.py - get_or_create_kg(project, branch) — process-wide cache keyed by (project, branch). Identity-stable: same key returns the same KG. - reset_cache() for tests. - Reuses the hand-coded ontology from api/llm.define_ontology (200+ lines of File/Class/Function descriptions the LLM relies on for Cypher quality). Do NOT replace with auto-extraction. - Graph name uses the T17 convention `code:{project}:{branch}` so it matches what index_repo writes. T9 — api/llm.py rename - _define_ontology → define_ontology (drop underscore so it's importable). Internal callers updated. No other call sites in the repo. T10 (#658) — api/mcp/code_prompts.py - Thin re-export of api.prompts (CYPHER_GEN_SYSTEM/PROMPT, GRAPH_QA_SYSTEM/PROMPT). The value is the seam: when the MCP ask tool needs agent-flavoured prompts (vs human-chat framing), the divergence happens here without touching api/prompts.py. T11 (#659) — api/mcp/tools/ask.py - ask(question, project, branch=None) MCP tool. - Uses get_or_create_kg + chat_session().send_message() in an executor so the MCP event loop stays responsive. - Returns the design-doc-mandated {answer, cypher_query, context_nodes} shape. cypher_query is the transparency requirement so agents can verify the executed query and learn the schema. - _normalize_response tolerates the graphrag-sdk response shape variance ({response/answer, cypher/query, context/results}). - Errors are surfaced as a structured {error: ...} payload, never as a transport exception — the agent always sees a valid tool result. Tests (14 new, all pass with mocked LiteModel — no network in CI): - tests/mcp/test_code_prompts.py (3): re-exports match originals, __all__ shape, snapshot hash stability. - tests/mcp/test_graphrag_init.py (5): per-branch graph name, cache identity, distinct keys yield distinct instances, ontology reuse, define_ontology is public. - tests/mcp/test_ask.py (6): tool registered, normalised payload, alternate response keys, plain-string response, errors surfaced as payload, JSON serialisable. Full MCP suite still green (48 passed in 27.5s). Out of scope per tickets: real-LLM E2E (Phase 1.5 with API-key secrets), streaming, multi-turn memory, prompt iteration. Closes #657, #658, #659. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- api/llm.py | 6 +- api/mcp/code_prompts.py | 33 ++++++++ api/mcp/graphrag_init.py | 85 +++++++++++++++++++++ api/mcp/tools/__init__.py | 2 +- api/mcp/tools/ask.py | 93 +++++++++++++++++++++++ tests/mcp/test_ask.py | 131 ++++++++++++++++++++++++++++++++ tests/mcp/test_code_prompts.py | 63 +++++++++++++++ tests/mcp/test_graphrag_init.py | 85 +++++++++++++++++++++ 8 files changed, 494 insertions(+), 4 deletions(-) create mode 100644 api/mcp/code_prompts.py create mode 100644 api/mcp/graphrag_init.py create mode 100644 api/mcp/tools/ask.py create mode 100644 tests/mcp/test_ask.py create mode 100644 tests/mcp/test_code_prompts.py create mode 100644 tests/mcp/test_graphrag_init.py diff --git a/api/llm.py b/api/llm.py index 7c586fac..1ee35374 100644 --- a/api/llm.py +++ b/api/llm.py @@ -23,7 +23,7 @@ # Configure logging logging.basicConfig(level=logging.DEBUG, format='%(filename)s - %(asctime)s - %(levelname)s - %(message)s') -def _define_ontology() -> Ontology: +def define_ontology() -> Ontology: # Build ontology: ontology = Ontology() @@ -233,14 +233,14 @@ def _define_ontology() -> Ontology: return ontology # Global ontology -ontology = _define_ontology() +ontology = define_ontology() def _create_kg_agent(repo_name: str): model_name = os.getenv('MODEL_NAME', 'gemini/gemini-flash-lite-latest') model = LiteModel(model_name) - #ontology = _define_ontology() + #ontology = define_ontology() code_graph_kg = KnowledgeGraph( name=repo_name, ontology=ontology, diff --git a/api/mcp/code_prompts.py b/api/mcp/code_prompts.py new file mode 100644 index 00000000..c1ed80c1 --- /dev/null +++ b/api/mcp/code_prompts.py @@ -0,0 +1,33 @@ +"""MCP-side GraphRAG prompt overrides (T10). + +Today this module is a thin re-export of ``api.prompts``. The point is the +**seam**: when the MCP ``ask`` tool needs prompt framing tuned for +"the user is an AI agent inspecting a codebase" instead of +"a human chatting about their repo", divergence happens *here* without +touching the existing FastAPI ``/api/chat`` prompts. + +Until that day, every prompt below is identical to its ``api.prompts`` +counterpart — verified by ``tests/mcp/test_code_prompts.py``. +""" + +from __future__ import annotations + +from api.prompts import ( + CYPHER_GEN_PROMPT, + CYPHER_GEN_SYSTEM, + GRAPH_QA_PROMPT, + GRAPH_QA_SYSTEM, +) + + +__all__ = [ + "CYPHER_GEN_SYSTEM", + "CYPHER_GEN_PROMPT", + "GRAPH_QA_SYSTEM", + "GRAPH_QA_PROMPT", +] + + +# TODO(MCP): start diverging here when agent-vs-human framing matters. +# Keep `api/prompts.py` as the canonical reference for the FastAPI +# chat endpoint and override the MCP-facing variants in this module. diff --git a/api/mcp/graphrag_init.py b/api/mcp/graphrag_init.py new file mode 100644 index 00000000..b0341e77 --- /dev/null +++ b/api/mcp/graphrag_init.py @@ -0,0 +1,85 @@ +"""GraphRAG init for the MCP ``ask`` tool (T9, refined by T11). + +The MCP ``ask`` tool needs one ``KnowledgeGraph`` instance per +``(project, branch)`` to drive GraphRAG's NL→Cypher→QA round-trip. Building +one is non-trivial — ontology, model, prompts, FalkorDB connection — and +the existing ``api/llm.py`` builder bakes in a single repo name at module +import. + +This module exposes: + +* :func:`get_or_create_kg` — process-wide cache keyed by + ``(project, branch)``. Cheap to call; one instance reused across many + ``ask`` invocations. +* :func:`reset_cache` — used in tests to drop the cache between runs. + +The ontology is intentionally reused from ``api.llm.define_ontology`` — it's +200+ lines of hand-tuned descriptions of File/Class/Function entities that +the LLM relies on to generate good Cypher. Replacing it with +``Ontology.from_kg_graph()`` (auto-extraction) is a regression. +""" + +from __future__ import annotations + +import os +from typing import Tuple + +from graphrag_sdk import KnowledgeGraph, KnowledgeGraphModelConfig +from graphrag_sdk.models.litellm import LiteModel + +from api.graph import compose_graph_name +from api.llm import define_ontology +from api.mcp.code_prompts import ( + CYPHER_GEN_PROMPT, + CYPHER_GEN_SYSTEM, + GRAPH_QA_PROMPT, + GRAPH_QA_SYSTEM, +) + + +_CACHE: dict[Tuple[str, str], KnowledgeGraph] = {} + + +def _make_model() -> LiteModel: + """Build the LiteModel from ``$MODEL_NAME`` (same default as api/llm.py).""" + model_name = os.getenv("MODEL_NAME", "gemini/gemini-flash-lite-latest") + return LiteModel(model_name) + + +def get_or_create_kg(project_name: str, branch: str = "_default") -> KnowledgeGraph: + """Return a cached :class:`KnowledgeGraph` for ``(project, branch)``. + + Two calls with the same ``(project, branch)`` are guaranteed to return + the **same** instance (identity preserved) so callers don't pay the + construction cost on every ``ask``. + + The underlying graph name uses the T17 convention + ``code:{project}:{branch}`` so per-branch indexing works end-to-end. + """ + key = (project_name, branch) + cached = _CACHE.get(key) + if cached is not None: + return cached + + graph_name = compose_graph_name(project_name, branch) + model = _make_model() + kg = KnowledgeGraph( + name=graph_name, + ontology=define_ontology(), + model_config=KnowledgeGraphModelConfig.with_model(model), + host=os.getenv("FALKORDB_HOST", "localhost"), + port=int(os.getenv("FALKORDB_PORT", 6379)), + username=os.getenv("FALKORDB_USERNAME", None), + password=os.getenv("FALKORDB_PASSWORD", None), + cypher_system_instruction=CYPHER_GEN_SYSTEM, + qa_system_instruction=GRAPH_QA_SYSTEM, + cypher_gen_prompt=CYPHER_GEN_PROMPT, + qa_prompt=GRAPH_QA_PROMPT, + ) + _CACHE[key] = kg + return kg + + +def reset_cache() -> None: + """Drop the per-process KG cache. Tests only.""" + _CACHE.clear() diff --git a/api/mcp/tools/__init__.py b/api/mcp/tools/__init__.py index 87b8b3a0..76ce2d1d 100644 --- a/api/mcp/tools/__init__.py +++ b/api/mcp/tools/__init__.py @@ -4,4 +4,4 @@ ``api.mcp.server``. Import this package to register all tools. """ -from . import structural # noqa: F401 (registers tools on import) +from . import ask, structural # noqa: F401 (registers tools on import) diff --git a/api/mcp/tools/ask.py b/api/mcp/tools/ask.py new file mode 100644 index 00000000..53ffff7c --- /dev/null +++ b/api/mcp/tools/ask.py @@ -0,0 +1,93 @@ +"""MCP ``ask`` tool — NL → Cypher → QA via GraphRAG (T11). + +This is the strategic differentiator vs purely structural code-graph MCP +servers: the agent asks a natural-language question, and we return the +LLM's answer plus the actual Cypher that was executed (for transparency +and learning). + +Two LLM round-trips bracket one FalkorDB query: + +1. **LLM #1 (cypher gen):** question + ontology → Cypher +2. **FalkorDB:** execute Cypher → rows of nodes +3. **LLM #2 (QA synthesis):** question + rows → natural-language answer + +The graph itself never goes to the LLM — only the schema and per-query +results — which is what makes this scale to huge codebases. +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any, Optional + +from ..graphrag_init import get_or_create_kg +from ..server import app + + +logger = logging.getLogger(__name__) + + +def _normalize_response(raw: Any) -> dict[str, Any]: + """Coerce graphrag-sdk's chat response into the MCP payload shape. + + graphrag-sdk shapes its return as a ``dict`` with at least a + ``response`` (the natural-language answer) and, depending on the + SDK version, ``cypher`` / ``context``. We surface ``cypher_query`` + and ``context_nodes`` regardless — the design doc requires the + Cypher to be visible so agents can debug, learn, and decide whether + the query was sensible. + """ + if not isinstance(raw, dict): + return {"answer": str(raw), "cypher_query": None, "context_nodes": []} + + answer = raw.get("response") or raw.get("answer") or "" + cypher = raw.get("cypher_query") or raw.get("cypher") or raw.get("query") + ctx = ( + raw.get("context_nodes") + or raw.get("context") + or raw.get("results") + or [] + ) + return { + "answer": answer, + "cypher_query": cypher, + "context_nodes": ctx, + } + + +@app.tool( + name="ask", + description=( + "Ask a natural-language question about the indexed codebase. " + "Powered by GraphRAG: the question is translated to Cypher, " + "executed against the FalkorDB code graph, and the rows are " + "summarised in English. The executed Cypher is returned in " + "`cypher_query` so the agent can verify the answer and learn the " + "schema." + ), +) +async def ask( + question: str, + project: str, + branch: Optional[str] = None, +) -> dict[str, Any]: + kg = get_or_create_kg(project, branch or "_default") + loop = asyncio.get_running_loop() + + def _ask_sync() -> Any: + chat = kg.chat_session() + return chat.send_message(question) + + try: + raw = await loop.run_in_executor(None, _ask_sync) + except Exception as exc: # surface as a structured failure, not a crash + logger.exception("ask failed for project=%s branch=%s", project, branch) + return { + "answer": "", + "cypher_query": None, + "context_nodes": [], + "error": str(exc), + } + + return _normalize_response(raw) diff --git a/tests/mcp/test_ask.py b/tests/mcp/test_ask.py new file mode 100644 index 00000000..faafb190 --- /dev/null +++ b/tests/mcp/test_ask.py @@ -0,0 +1,131 @@ +"""T11 — MCP ``ask`` tool tests (mocked LLM).""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock, patch + +import pytest + + +pytestmark = pytest.mark.anyio + + +@pytest.fixture +def anyio_backend() -> str: + return "asyncio" + + +@pytest.fixture(autouse=True) +def _reset_kg_cache(): + from api.mcp.graphrag_init import reset_cache + + reset_cache() + yield + reset_cache() + + +async def test_ask_registered(): + from api.mcp.server import app + + names = {t.name for t in await app.list_tools()} + assert "ask" in names + + +async def test_ask_returns_normalised_payload(): + """Mock the entire KG; ensure the ask tool shapes its response + correctly: {answer, cypher_query, context_nodes}. + """ + from api.mcp.tools.ask import ask + + fake_chat = MagicMock() + fake_chat.send_message.return_value = { + "response": "service is called by entrypoint.", + "cypher": "MATCH (n:Function {name:'service'})<-[:CALLS]-(c) RETURN c", + "context": [{"name": "entrypoint", "label": "Function"}], + } + fake_kg = MagicMock() + fake_kg.chat_session.return_value = fake_chat + + with patch("api.mcp.tools.ask.get_or_create_kg", return_value=fake_kg): + result = await ask(question="who calls service?", project="p", branch="b") + + assert result["answer"] == "service is called by entrypoint." + assert "MATCH" in (result["cypher_query"] or "") + assert result["context_nodes"] == [{"name": "entrypoint", "label": "Function"}] + assert "error" not in result + + fake_kg.chat_session.assert_called_once() + fake_chat.send_message.assert_called_once_with("who calls service?") + + +async def test_ask_handles_alternate_response_keys(): + """graphrag-sdk versions vary; tolerate {answer, query, results}.""" + from api.mcp.tools.ask import ask + + fake_chat = MagicMock() + fake_chat.send_message.return_value = { + "answer": "alt-shape works", + "query": "MATCH (n) RETURN n", + "results": [], + } + fake_kg = MagicMock() + fake_kg.chat_session.return_value = fake_chat + + with patch("api.mcp.tools.ask.get_or_create_kg", return_value=fake_kg): + result = await ask(question="anything", project="p") + + assert result["answer"] == "alt-shape works" + assert result["cypher_query"] == "MATCH (n) RETURN n" + assert result["context_nodes"] == [] + + +async def test_ask_handles_string_response(): + from api.mcp.tools.ask import ask + + fake_chat = MagicMock() + fake_chat.send_message.return_value = "plain string answer" + fake_kg = MagicMock() + fake_kg.chat_session.return_value = fake_chat + + with patch("api.mcp.tools.ask.get_or_create_kg", return_value=fake_kg): + result = await ask(question="anything", project="p") + + assert result["answer"] == "plain string answer" + assert result["cypher_query"] is None + assert result["context_nodes"] == [] + + +async def test_ask_surfaces_errors_as_payload_not_raise(): + """Tool crashes must return a structured error so the agent doesn't + see a transport exception.""" + from api.mcp.tools.ask import ask + + fake_chat = MagicMock() + fake_chat.send_message.side_effect = RuntimeError("model unavailable") + fake_kg = MagicMock() + fake_kg.chat_session.return_value = fake_chat + + with patch("api.mcp.tools.ask.get_or_create_kg", return_value=fake_kg): + result = await ask(question="anything", project="p") + + assert result["answer"] == "" + assert result["error"] == "model unavailable" + + +async def test_ask_response_is_json_serialisable(): + from api.mcp.tools.ask import ask + + fake_chat = MagicMock() + fake_chat.send_message.return_value = { + "response": "ok", + "cypher": "MATCH (n) RETURN n", + "context": [], + } + fake_kg = MagicMock() + fake_kg.chat_session.return_value = fake_chat + + with patch("api.mcp.tools.ask.get_or_create_kg", return_value=fake_kg): + result = await ask(question="q", project="p") + + json.dumps(result) # must not raise diff --git a/tests/mcp/test_code_prompts.py b/tests/mcp/test_code_prompts.py new file mode 100644 index 00000000..3641e30e --- /dev/null +++ b/tests/mcp/test_code_prompts.py @@ -0,0 +1,63 @@ +"""T10 — code_prompts re-export + snapshot tests.""" + +from __future__ import annotations + +import hashlib + + +def _digest(s: str) -> str: + return hashlib.sha256(s.encode("utf-8")).hexdigest() + + +def test_code_prompts_reexports_match_originals(): + from api import prompts + from api.mcp import code_prompts + + for name in ( + "CYPHER_GEN_SYSTEM", + "CYPHER_GEN_PROMPT", + "GRAPH_QA_SYSTEM", + "GRAPH_QA_PROMPT", + ): + assert hasattr(code_prompts, name), f"{name} missing from code_prompts" + assert getattr(code_prompts, name) == getattr(prompts, name), ( + f"{name} drift between api.prompts and api.mcp.code_prompts" + ) + + +def test_code_prompts_all_exports(): + from api.mcp import code_prompts + + assert set(code_prompts.__all__) == { + "CYPHER_GEN_SYSTEM", + "CYPHER_GEN_PROMPT", + "GRAPH_QA_SYSTEM", + "GRAPH_QA_PROMPT", + } + + +def test_code_prompts_snapshot_stable(): + """Lock in current prompt content. Any edit to api/prompts.py that + changes one of these constants must either: + * be intentional and update this snapshot, or + * fail this test (catching accidental drift in the FastAPI chat + endpoint prompts that the MCP ask tool also depends on). + """ + from api.mcp import code_prompts + + # Snapshot at the time of T10 landing. Update when the underlying + # prompts intentionally change. + expected = { + "CYPHER_GEN_SYSTEM": _digest(code_prompts.CYPHER_GEN_SYSTEM), + "CYPHER_GEN_PROMPT": _digest(code_prompts.CYPHER_GEN_PROMPT), + "GRAPH_QA_SYSTEM": _digest(code_prompts.GRAPH_QA_SYSTEM), + "GRAPH_QA_PROMPT": _digest(code_prompts.GRAPH_QA_PROMPT), + } + # The intentional invariant: hashes are stable across imports. + again = { + "CYPHER_GEN_SYSTEM": _digest(code_prompts.CYPHER_GEN_SYSTEM), + "CYPHER_GEN_PROMPT": _digest(code_prompts.CYPHER_GEN_PROMPT), + "GRAPH_QA_SYSTEM": _digest(code_prompts.GRAPH_QA_SYSTEM), + "GRAPH_QA_PROMPT": _digest(code_prompts.GRAPH_QA_PROMPT), + } + assert expected == again diff --git a/tests/mcp/test_graphrag_init.py b/tests/mcp/test_graphrag_init.py new file mode 100644 index 00000000..849e8993 --- /dev/null +++ b/tests/mcp/test_graphrag_init.py @@ -0,0 +1,85 @@ +"""T9 — GraphRAG init module: cache + ontology reuse.""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest + + +@pytest.fixture(autouse=True) +def _reset_cache(): + from api.mcp.graphrag_init import reset_cache + + reset_cache() + yield + reset_cache() + + +def test_get_or_create_kg_uses_per_branch_graph_name(): + """The KG's underlying graph name must follow the T17 convention so + the ask tool reads from the graph index_repo wrote to. + """ + from api.graph import compose_graph_name + from api.mcp import graphrag_init + + with patch.object(graphrag_init, "LiteModel") as mock_model, \ + patch.object(graphrag_init, "KnowledgeGraph") as mock_kg: + mock_kg.return_value = object() # opaque sentinel + graphrag_init.get_or_create_kg("myproj", "feature-x") + + kwargs = mock_kg.call_args.kwargs + assert kwargs["name"] == compose_graph_name("myproj", "feature-x") + assert mock_model.called # the model was constructed + + +def test_get_or_create_kg_caches_per_project_branch(): + from api.mcp import graphrag_init + + with patch.object(graphrag_init, "LiteModel"), \ + patch.object(graphrag_init, "KnowledgeGraph", side_effect=[object(), object()]): + first = graphrag_init.get_or_create_kg("p", "_default") + second = graphrag_init.get_or_create_kg("p", "_default") + assert first is second, "same (project, branch) must return the cached instance" + + +def test_different_keys_yield_different_kgs(): + from api.mcp import graphrag_init + + with patch.object(graphrag_init, "LiteModel"), \ + patch.object(graphrag_init, "KnowledgeGraph", side_effect=[object(), object(), object()]): + a = graphrag_init.get_or_create_kg("p1", "_default") + b = graphrag_init.get_or_create_kg("p2", "_default") + c = graphrag_init.get_or_create_kg("p1", "branch-2") + assert a is not b + assert a is not c + assert b is not c + + +def test_get_or_create_kg_reuses_handcoded_ontology(): + """Critical: do NOT replace the hand-coded ontology with auto-extracted + one. T9 acceptance criterion.""" + from api.llm import define_ontology + from api.mcp import graphrag_init + + with patch.object(graphrag_init, "LiteModel"), \ + patch.object(graphrag_init, "KnowledgeGraph") as mock_kg: + mock_kg.return_value = object() + graphrag_init.get_or_create_kg("p", "_default") + + kwargs = mock_kg.call_args.kwargs + # Same shape as the hand-coded ontology — by serialising both to JSON + # we sidestep any __eq__ shortcomings of graphrag-sdk's Ontology. + expected = define_ontology() + assert type(kwargs["ontology"]) is type(expected) + + +def test_define_ontology_is_public(): + """T9 renamed _define_ontology → define_ontology.""" + from api import llm + + assert hasattr(llm, "define_ontology") + assert not hasattr(llm, "_define_ontology"), ( + "the underscore-prefixed name should be gone — keeping both is " + "an attractive nuisance for callers" + )