Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 56 additions & 18 deletions src/semble/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,52 @@
from model2vec.utils import get_package_extras

from semble.index import SembleIndex
from semble.index.file_walker import FILE_TYPES
from semble.stats import format_savings_report
from semble.utils import _format_results, _is_git_url, _resolve_chunk

_CLAUDE_FILE_PATH = Path(".claude") / "agents" / "semble-search.md"
_CLI_DISPATCH_ARGS = frozenset({"search", "find-related", "init", "savings", "-h", "--help"})


def _parse_extensions(value: str) -> frozenset[str]:
"""Split a comma-separated string of extensions into a frozenset."""
return frozenset(ext.strip() for ext in value.split(",") if ext.strip())


def _resolve_extensions(args: argparse.Namespace) -> frozenset[str] | None:
"""Return the effective extension set based on parsed CLI arguments."""
if args.extensions is not None:
return _parse_extensions(args.extensions)
if args.add_extensions is not None:
return frozenset(FILE_TYPES.keys()) | _parse_extensions(args.add_extensions)
return None


def _add_extension_args(parser: argparse.ArgumentParser) -> None:
"""Add mutually exclusive --extensions and --add-extension flags to *parser*."""
group = parser.add_mutually_exclusive_group()
group.add_argument(
"--extensions",
default=None,
help="Comma-separated list of file extensions to index instead of the default set (e.g. '.py,.js,.ext').",
)
group.add_argument(
"--add-extensions",
default=None,
help="Comma-separated list of file extensions to add to the default set (e.g. '.ext,.lua').",
)


def _add_include_text_files_arg(parser: argparse.ArgumentParser) -> None:
"""Add the --include-text-files flag to *parser*."""
parser.add_argument(
"--include-text-files",
action="store_true",
help="Also index non-code text files (.md, .yaml, .json, etc.).",
)


def main() -> None:
"""Entry point for the semble command-line tool."""
if len(sys.argv) > 1 and sys.argv[1] in _CLI_DISPATCH_ARGS:
Expand All @@ -35,18 +74,22 @@ def _mcp_main() -> None:
help="Local directory or git URL to pre-index at startup (optional).",
)
parser.add_argument("--ref", default=None, help="Branch or tag to check out (git URLs only).")
parser.add_argument(
"--include-text-files",
action="store_true",
help="Also index non-code text files (.md, .yaml, .json, etc.).",
)
_add_extension_args(parser)
_add_include_text_files_arg(parser)
args = parser.parse_args()
if any(find_spec(dep) is None for dep in get_package_extras("semble", "mcp")):
print("MCP dependencies are not installed. Run: pip install 'semble[mcp]'", file=sys.stderr)
raise SystemExit(1)
from semble.mcp import serve

asyncio.run(serve(args.path, ref=args.ref, include_text_files=args.include_text_files))
asyncio.run(
serve(
args.path,
ref=args.ref,
include_text_files=args.include_text_files,
extensions=_resolve_extensions(args),
)
)


def _run_init(*, force: bool = False) -> None:
Expand All @@ -72,22 +115,16 @@ def _cli_main() -> None:
search_p.add_argument(
"-m", "--mode", default="hybrid", choices=["hybrid", "semantic", "bm25"], help="Search mode (default: hybrid)."
)
search_p.add_argument(
"--include-text-files",
action="store_true",
help="Also index non-code text files (.md, .yaml, .json, etc.).",
)
_add_extension_args(search_p)
_add_include_text_files_arg(search_p)

related_p = sub.add_parser("find-related", help="Find code similar to a specific location.")
related_p.add_argument("file_path", help="File path as shown in search results.")
related_p.add_argument("line", type=int, help="Line number (1-indexed).")
related_p.add_argument("path", nargs="?", default=".", help="Local path or git URL (default: current directory).")
related_p.add_argument("-k", "--top-k", type=int, default=5, help="Number of results (default: 5).")
related_p.add_argument(
"--include-text-files",
action="store_true",
help="Also index non-code text files (.md, .yaml, .json, etc.).",
)
_add_extension_args(related_p)
_add_include_text_files_arg(related_p)

init_p = sub.add_parser("init", help="Write .claude/agents/semble-search.md for Claude Code sub-agent support.")
init_p.add_argument("--force", action="store_true", help="Overwrite if the file already exists.")
Expand All @@ -106,10 +143,11 @@ def _cli_main() -> None:
return

include_text = args.include_text_files
extensions = _resolve_extensions(args)
index = (
SembleIndex.from_git(args.path, include_text_files=include_text)
SembleIndex.from_git(args.path, include_text_files=include_text, extensions=extensions)
if _is_git_url(args.path)
else SembleIndex.from_path(args.path, include_text_files=include_text)
else SembleIndex.from_path(args.path, include_text_files=include_text, extensions=extensions)
)

if args.command == "search":
Expand Down
21 changes: 17 additions & 4 deletions src/semble/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,15 @@ async def find_related(
return server


async def serve(path: str | None = None, ref: str | None = None, include_text_files: bool = False) -> None:
async def serve(
path: str | None = None,
ref: str | None = None,
include_text_files: bool = False,
extensions: frozenset[str] | None = None,
) -> None:
"""Start an MCP stdio server, optionally pre-indexing a default source."""
model = await asyncio.to_thread(load_model)
cache = _IndexCache(model=model, include_text_files=include_text_files)
cache = _IndexCache(model=model, include_text_files=include_text_files, extensions=extensions)
if path:
await cache.get(path, ref=ref)
if not _is_git_url(path):
Expand All @@ -132,10 +137,13 @@ async def serve(path: str | None = None, ref: str | None = None, include_text_fi
class _IndexCache:
"""Cache of indexed repos and local paths for the lifetime of the MCP server process."""

def __init__(self, model: Encoder, include_text_files: bool = False) -> None:
def __init__(
self, model: Encoder, include_text_files: bool = False, extensions: frozenset[str] | None = None
) -> None:
"""Initialise an empty cache with a shared embedding model."""
self._model = model
self._include_text_files = include_text_files
self._extensions = extensions
self._tasks: OrderedDict[str, asyncio.Task[SembleIndex]] = OrderedDict() # ordered for LRU eviction
self._watcher_task: asyncio.Task[None] | None = None

Expand Down Expand Up @@ -180,12 +188,17 @@ async def get(self, source: str, ref: str | None = None) -> SembleIndex:
ref=ref,
model=self._model,
include_text_files=self._include_text_files,
extensions=self._extensions,
)
)
else:
self._tasks[cache_key] = asyncio.create_task(
asyncio.to_thread(
SembleIndex.from_path, cache_key, model=self._model, include_text_files=self._include_text_files
SembleIndex.from_path,
cache_key,
model=self._model,
include_text_files=self._include_text_files,
extensions=self._extensions,
)
)
task = self._tasks[cache_key]
Expand Down
39 changes: 39 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from semble.cli import _CLAUDE_FILE_PATH, _cli_main, _run_init, main
from semble.index.file_walker import FILE_TYPES
from semble.types import SearchMode, SearchResult
from tests.conftest import make_chunk

Expand Down Expand Up @@ -202,3 +203,41 @@ def test_agent_file_tools_are_bash_only() -> None:
tools = [t.strip() for t in tools_line.removeprefix("tools:").split(",")]
assert set(tools) == {"Bash", "Read"}, f"Unexpected tools in agent file: {tools}"
assert not any("mcp__" in t for t in tools)


@pytest.mark.parametrize(
"flag,value,expected",
[
("--extensions", ".ext", frozenset({".ext"})),
("--add-extensions", ".ext", frozenset(FILE_TYPES.keys()) | {".ext"}),
],
)
def test_cli_search_extensions(
flag: str,
value: str,
expected: frozenset[str],
monkeypatch: pytest.MonkeyPatch,
capsys: pytest.CaptureFixture[str],
) -> None:
"""_cli_main forwards --extensions and --add-extensions to from_path."""
chunk = make_chunk("def foo(): pass", "src/foo.py")
fake_index = MagicMock()
fake_index.search.return_value = [SearchResult(chunk=chunk, score=0.9, source=SearchMode.HYBRID)]
monkeypatch.setattr(sys, "argv", ["semble", "search", "query", "/some/path", flag, value])
with patch("semble.cli.SembleIndex.from_path", return_value=fake_index) as mock_from_path:
_cli_main()
mock_from_path.assert_called_once_with("/some/path", include_text_files=False, extensions=expected)


def test_cli_extensions_mutual_exclusion(
monkeypatch: pytest.MonkeyPatch,
capsys: pytest.CaptureFixture[str],
) -> None:
"""Providing both --extensions and --add-extensions causes argparse to exit with an error."""
monkeypatch.setattr(
sys, "argv", ["semble", "search", "query", "/some/path", "--extensions", ".ext", "--add-extensions", ".lua"]
)
with pytest.raises(SystemExit) as exc_info:
_cli_main()
assert exc_info.value.code == 2
assert "not allowed with argument" in capsys.readouterr().err
39 changes: 37 additions & 2 deletions tests/test_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,30 @@ async def test_index_cache_builds_and_caches(
assert first is fake_index
assert second is fake_index
mock_build.assert_called_once()
_, kwargs = mock_build.call_args
assert kwargs.get("extensions") is None


@pytest.mark.anyio
@pytest.mark.parametrize(
("source", "patch_target"),
[
("local_tmp_path", "from_path"),
("https://github.com/org/repo", "from_git"),
],
ids=["local_path", "git_url"],
)
async def test_index_cache_forwards_extensions(tmp_path: Path, source: str, patch_target: str) -> None:
"""_IndexCache passes custom extensions through to the index builder."""
resolved_source = str(tmp_path) if source == "local_tmp_path" else source
extensions = frozenset({".ext", ".lua"})
cache = _IndexCache(model=MagicMock(spec=Encoder), extensions=extensions)
fake_index = MagicMock()
with patch(f"semble.mcp.SembleIndex.{patch_target}", return_value=fake_index) as mock_build:
result = await cache.get(resolved_source)
assert result is fake_index
_, kwargs = mock_build.call_args
assert kwargs.get("extensions") == extensions


@pytest.mark.anyio
Expand Down Expand Up @@ -243,17 +267,28 @@ async def test_tool_output(

@pytest.mark.anyio
@pytest.mark.parametrize("with_path", [True, False], ids=["pre_index", "no_path"])
async def test_serve_runs_stdio(tmp_path: Path, with_path: bool) -> None:
@pytest.mark.parametrize("extensions", [None, frozenset({".ext"})], ids=["default", "custom"])
async def test_serve_runs_stdio(
tmp_path: Path, with_path: bool, extensions: frozenset[str] | None
) -> None:
"""serve() loads the model, runs stdio, and optionally pre-indexes when a path is given."""
original_init = _IndexCache.__init__
with (
patch("semble.mcp.load_model", return_value=MagicMock(spec=Encoder)),
patch("semble.mcp.SembleIndex.from_path", return_value=MagicMock()),
patch.object(_IndexCache, "start_watcher", new_callable=AsyncMock),
patch.object(_IndexCache, "__init__", side_effect=original_init, autospec=True) as mock_init,
patch("mcp.server.fastmcp.FastMCP.run_stdio_async", new_callable=AsyncMock) as mock_run,
):
await (serve(str(tmp_path)) if with_path else serve())
await (
serve(str(tmp_path), extensions=extensions)
if with_path
else serve(extensions=extensions)
)

mock_run.assert_called_once()
_, kwargs = mock_init.call_args
assert kwargs.get("extensions") == extensions


@pytest.mark.anyio
Expand Down