diff --git a/src/semble/cli.py b/src/semble/cli.py index 0f97adb..82040e4 100644 --- a/src/semble/cli.py +++ b/src/semble/cli.py @@ -8,6 +8,7 @@ from model2vec.utils import get_package_extras from semble.index import SembleIndex +from semble.index.file_walker import FILE_TYPES from semble.stats import format_savings_report from semble.utils import _format_results, _is_git_url, _resolve_chunk @@ -15,6 +16,44 @@ _CLI_DISPATCH_ARGS = frozenset({"search", "find-related", "init", "savings", "-h", "--help"}) +def _parse_extensions(value: str) -> frozenset[str]: + """Split a comma-separated string of extensions into a frozenset.""" + return frozenset(ext.strip() for ext in value.split(",") if ext.strip()) + + +def _resolve_extensions(args: argparse.Namespace) -> frozenset[str] | None: + """Return the effective extension set based on parsed CLI arguments.""" + if args.extensions is not None: + return _parse_extensions(args.extensions) + if args.add_extensions is not None: + return frozenset(FILE_TYPES.keys()) | _parse_extensions(args.add_extensions) + return None + + +def _add_extension_args(parser: argparse.ArgumentParser) -> None: + """Add mutually exclusive --extensions and --add-extension flags to *parser*.""" + group = parser.add_mutually_exclusive_group() + group.add_argument( + "--extensions", + default=None, + help="Comma-separated list of file extensions to index instead of the default set (e.g. '.py,.js,.ext').", + ) + group.add_argument( + "--add-extensions", + default=None, + help="Comma-separated list of file extensions to add to the default set (e.g. '.ext,.lua').", + ) + + +def _add_include_text_files_arg(parser: argparse.ArgumentParser) -> None: + """Add the --include-text-files flag to *parser*.""" + parser.add_argument( + "--include-text-files", + action="store_true", + help="Also index non-code text files (.md, .yaml, .json, etc.).", + ) + + def main() -> None: """Entry point for the semble command-line tool.""" if len(sys.argv) > 1 and sys.argv[1] in _CLI_DISPATCH_ARGS: @@ -35,18 +74,22 @@ def _mcp_main() -> None: help="Local directory or git URL to pre-index at startup (optional).", ) parser.add_argument("--ref", default=None, help="Branch or tag to check out (git URLs only).") - parser.add_argument( - "--include-text-files", - action="store_true", - help="Also index non-code text files (.md, .yaml, .json, etc.).", - ) + _add_extension_args(parser) + _add_include_text_files_arg(parser) args = parser.parse_args() if any(find_spec(dep) is None for dep in get_package_extras("semble", "mcp")): print("MCP dependencies are not installed. Run: pip install 'semble[mcp]'", file=sys.stderr) raise SystemExit(1) from semble.mcp import serve - asyncio.run(serve(args.path, ref=args.ref, include_text_files=args.include_text_files)) + asyncio.run( + serve( + args.path, + ref=args.ref, + include_text_files=args.include_text_files, + extensions=_resolve_extensions(args), + ) + ) def _run_init(*, force: bool = False) -> None: @@ -72,22 +115,16 @@ def _cli_main() -> None: search_p.add_argument( "-m", "--mode", default="hybrid", choices=["hybrid", "semantic", "bm25"], help="Search mode (default: hybrid)." ) - search_p.add_argument( - "--include-text-files", - action="store_true", - help="Also index non-code text files (.md, .yaml, .json, etc.).", - ) + _add_extension_args(search_p) + _add_include_text_files_arg(search_p) related_p = sub.add_parser("find-related", help="Find code similar to a specific location.") related_p.add_argument("file_path", help="File path as shown in search results.") related_p.add_argument("line", type=int, help="Line number (1-indexed).") related_p.add_argument("path", nargs="?", default=".", help="Local path or git URL (default: current directory).") related_p.add_argument("-k", "--top-k", type=int, default=5, help="Number of results (default: 5).") - related_p.add_argument( - "--include-text-files", - action="store_true", - help="Also index non-code text files (.md, .yaml, .json, etc.).", - ) + _add_extension_args(related_p) + _add_include_text_files_arg(related_p) init_p = sub.add_parser("init", help="Write .claude/agents/semble-search.md for Claude Code sub-agent support.") init_p.add_argument("--force", action="store_true", help="Overwrite if the file already exists.") @@ -106,10 +143,11 @@ def _cli_main() -> None: return include_text = args.include_text_files + extensions = _resolve_extensions(args) index = ( - SembleIndex.from_git(args.path, include_text_files=include_text) + SembleIndex.from_git(args.path, include_text_files=include_text, extensions=extensions) if _is_git_url(args.path) - else SembleIndex.from_path(args.path, include_text_files=include_text) + else SembleIndex.from_path(args.path, include_text_files=include_text, extensions=extensions) ) if args.command == "search": diff --git a/src/semble/mcp.py b/src/semble/mcp.py index 38d170c..16e528e 100644 --- a/src/semble/mcp.py +++ b/src/semble/mcp.py @@ -116,10 +116,15 @@ async def find_related( return server -async def serve(path: str | None = None, ref: str | None = None, include_text_files: bool = False) -> None: +async def serve( + path: str | None = None, + ref: str | None = None, + include_text_files: bool = False, + extensions: frozenset[str] | None = None, +) -> None: """Start an MCP stdio server, optionally pre-indexing a default source.""" model = await asyncio.to_thread(load_model) - cache = _IndexCache(model=model, include_text_files=include_text_files) + cache = _IndexCache(model=model, include_text_files=include_text_files, extensions=extensions) if path: await cache.get(path, ref=ref) if not _is_git_url(path): @@ -132,10 +137,13 @@ async def serve(path: str | None = None, ref: str | None = None, include_text_fi class _IndexCache: """Cache of indexed repos and local paths for the lifetime of the MCP server process.""" - def __init__(self, model: Encoder, include_text_files: bool = False) -> None: + def __init__( + self, model: Encoder, include_text_files: bool = False, extensions: frozenset[str] | None = None + ) -> None: """Initialise an empty cache with a shared embedding model.""" self._model = model self._include_text_files = include_text_files + self._extensions = extensions self._tasks: OrderedDict[str, asyncio.Task[SembleIndex]] = OrderedDict() # ordered for LRU eviction self._watcher_task: asyncio.Task[None] | None = None @@ -180,12 +188,17 @@ async def get(self, source: str, ref: str | None = None) -> SembleIndex: ref=ref, model=self._model, include_text_files=self._include_text_files, + extensions=self._extensions, ) ) else: self._tasks[cache_key] = asyncio.create_task( asyncio.to_thread( - SembleIndex.from_path, cache_key, model=self._model, include_text_files=self._include_text_files + SembleIndex.from_path, + cache_key, + model=self._model, + include_text_files=self._include_text_files, + extensions=self._extensions, ) ) task = self._tasks[cache_key] diff --git a/tests/test_cli.py b/tests/test_cli.py index 0520b7a..1a49377 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -6,6 +6,7 @@ import pytest from semble.cli import _CLAUDE_FILE_PATH, _cli_main, _run_init, main +from semble.index.file_walker import FILE_TYPES from semble.types import SearchMode, SearchResult from tests.conftest import make_chunk @@ -202,3 +203,41 @@ def test_agent_file_tools_are_bash_only() -> None: tools = [t.strip() for t in tools_line.removeprefix("tools:").split(",")] assert set(tools) == {"Bash", "Read"}, f"Unexpected tools in agent file: {tools}" assert not any("mcp__" in t for t in tools) + + +@pytest.mark.parametrize( + "flag,value,expected", + [ + ("--extensions", ".ext", frozenset({".ext"})), + ("--add-extensions", ".ext", frozenset(FILE_TYPES.keys()) | {".ext"}), + ], +) +def test_cli_search_extensions( + flag: str, + value: str, + expected: frozenset[str], + monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture[str], +) -> None: + """_cli_main forwards --extensions and --add-extensions to from_path.""" + chunk = make_chunk("def foo(): pass", "src/foo.py") + fake_index = MagicMock() + fake_index.search.return_value = [SearchResult(chunk=chunk, score=0.9, source=SearchMode.HYBRID)] + monkeypatch.setattr(sys, "argv", ["semble", "search", "query", "/some/path", flag, value]) + with patch("semble.cli.SembleIndex.from_path", return_value=fake_index) as mock_from_path: + _cli_main() + mock_from_path.assert_called_once_with("/some/path", include_text_files=False, extensions=expected) + + +def test_cli_extensions_mutual_exclusion( + monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture[str], +) -> None: + """Providing both --extensions and --add-extensions causes argparse to exit with an error.""" + monkeypatch.setattr( + sys, "argv", ["semble", "search", "query", "/some/path", "--extensions", ".ext", "--add-extensions", ".lua"] + ) + with pytest.raises(SystemExit) as exc_info: + _cli_main() + assert exc_info.value.code == 2 + assert "not allowed with argument" in capsys.readouterr().err diff --git a/tests/test_mcp.py b/tests/test_mcp.py index f9a1c80..33cd7da 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -120,6 +120,30 @@ async def test_index_cache_builds_and_caches( assert first is fake_index assert second is fake_index mock_build.assert_called_once() + _, kwargs = mock_build.call_args + assert kwargs.get("extensions") is None + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("source", "patch_target"), + [ + ("local_tmp_path", "from_path"), + ("https://github.com/org/repo", "from_git"), + ], + ids=["local_path", "git_url"], +) +async def test_index_cache_forwards_extensions(tmp_path: Path, source: str, patch_target: str) -> None: + """_IndexCache passes custom extensions through to the index builder.""" + resolved_source = str(tmp_path) if source == "local_tmp_path" else source + extensions = frozenset({".ext", ".lua"}) + cache = _IndexCache(model=MagicMock(spec=Encoder), extensions=extensions) + fake_index = MagicMock() + with patch(f"semble.mcp.SembleIndex.{patch_target}", return_value=fake_index) as mock_build: + result = await cache.get(resolved_source) + assert result is fake_index + _, kwargs = mock_build.call_args + assert kwargs.get("extensions") == extensions @pytest.mark.anyio @@ -243,17 +267,28 @@ async def test_tool_output( @pytest.mark.anyio @pytest.mark.parametrize("with_path", [True, False], ids=["pre_index", "no_path"]) -async def test_serve_runs_stdio(tmp_path: Path, with_path: bool) -> None: +@pytest.mark.parametrize("extensions", [None, frozenset({".ext"})], ids=["default", "custom"]) +async def test_serve_runs_stdio( + tmp_path: Path, with_path: bool, extensions: frozenset[str] | None +) -> None: """serve() loads the model, runs stdio, and optionally pre-indexes when a path is given.""" + original_init = _IndexCache.__init__ with ( patch("semble.mcp.load_model", return_value=MagicMock(spec=Encoder)), patch("semble.mcp.SembleIndex.from_path", return_value=MagicMock()), patch.object(_IndexCache, "start_watcher", new_callable=AsyncMock), + patch.object(_IndexCache, "__init__", side_effect=original_init, autospec=True) as mock_init, patch("mcp.server.fastmcp.FastMCP.run_stdio_async", new_callable=AsyncMock) as mock_run, ): - await (serve(str(tmp_path)) if with_path else serve()) + await ( + serve(str(tmp_path), extensions=extensions) + if with_path + else serve(extensions=extensions) + ) mock_run.assert_called_once() + _, kwargs = mock_init.call_args + assert kwargs.get("extensions") == extensions @pytest.mark.anyio