diff --git a/src/semble/index/index.py b/src/semble/index/index.py index 0e5e99c..091e5be 100644 --- a/src/semble/index/index.py +++ b/src/semble/index/index.py @@ -83,7 +83,6 @@ def from_path( :return: An indexed SembleIndex. Chunk file paths are relative to ``path``. :raises FileNotFoundError: If `path` does not exist. :raises NotADirectoryError: If `path` exists but is not a directory. - :raises ValueError: If `path` is a directory but contains no supported files. """ model = model or load_model() path = Path(path) @@ -172,10 +171,7 @@ def find_related(self, file_path: str, line: int, top_k: int = 5) -> list[Search ) if target is None: return [] - if target.language: - selector = self._get_selector_vector(filter_languages=[target.language]) - else: - selector = None + selector = self._get_selector_vector(filter_languages=[target.language]) if target.language else None results = search_semantic(target.content, self.model, self._semantic_index, self.chunks, top_k + 1, selector) return [r for r in results if r.chunk != target][:top_k] diff --git a/src/semble/mcp.py b/src/semble/mcp.py index f4bae50..57fa7f1 100644 --- a/src/semble/mcp.py +++ b/src/semble/mcp.py @@ -136,7 +136,7 @@ async def get(self, source: str, ref: str | None = None) -> SembleIndex: task = self._tasks[cache_key] try: return await asyncio.shield(task) - except asyncio.CancelledError: + except asyncio.CancelledError: # pragma: no cover # If this waiter was cancelled but the task is still running, preserve it for # other waiters. Only evict if the task itself was cancelled. if task.done(): diff --git a/src/semble/ranking/boosting.py b/src/semble/ranking/boosting.py index 2046249..65dc6b0 100644 --- a/src/semble/ranking/boosting.py +++ b/src/semble/ranking/boosting.py @@ -215,9 +215,6 @@ def _boost_symbol_definitions( ) -> None: """Boost chunks that define the queried symbol, scanning candidates and stem-matched non-candidates (in-place).""" symbol_name = _extract_symbol_name(query) - if not symbol_name: - return - names = {symbol_name} if symbol_name != query.strip(): names.add(query.strip()) diff --git a/src/semble/types.py b/src/semble/types.py index ad2dd3f..9c20f58 100644 --- a/src/semble/types.py +++ b/src/semble/types.py @@ -49,12 +49,6 @@ class SearchResult: score: float source: SearchMode - def __str__(self) -> str: - """Return a human-readable summary of the result.""" - header = f"{self.chunk.location} score={self.score:.3f}" - separator = "-" * len(header) - return f"{header}\n{separator}\n{self.chunk.content.strip()}\n" - @dataclass(frozen=True, slots=True) class IndexStats: diff --git a/tests/test_chunker.py b/tests/test_chunker.py index 39d0141..5429480 100644 --- a/tests/test_chunker.py +++ b/tests/test_chunker.py @@ -1,75 +1,85 @@ from pathlib import Path +from unittest.mock import MagicMock, patch import pytest -from semble.index.chunker import _chunk_with_chonkie, chunk_file, chunk_lines +from semble.index.chunker import chunk_file, chunk_lines, chunk_source +from semble.index.file_walker import filter_extensions -def test_chunk_lines_basic(tmp_path: Path) -> None: - """Chunks are produced with non-empty content.""" - f = tmp_path / "test.py" - f.write_text("\n".join(f"line {i}" for i in range(10))) - chunks = chunk_lines(f.read_text(), str(f), "python", max_lines=5, overlap_lines=1) - assert len(chunks) >= 2 - for c in chunks: - assert c.content.strip() - - -def test_chunk_lines_empty(tmp_path: Path) -> None: - """Empty source produces no chunks.""" - f = tmp_path / "empty.py" - f.write_text("") - chunks = chunk_lines("", str(f), "python") - assert chunks == [] +def test_chunk_lines() -> None: + """chunk_lines: empty input → []; real input → non-empty chunks starting at line 1.""" + assert chunk_lines("", "empty.py", "python") == [] - -def test_chunk_lines_line_numbers(tmp_path: Path) -> None: - """First chunk starts at line 1.""" - content = "a\nb\nc\nd\ne\n" - f = tmp_path / "t.py" - chunks = chunk_lines(content, str(f), "python", max_lines=3, overlap_lines=0) + content = "\n".join(f"line {i}" for i in range(10)) + chunks = chunk_lines(content, "test.py", "python", max_lines=5, overlap_lines=1) + assert len(chunks) >= 2 + assert all(c.content.strip() for c in chunks) assert chunks[0].start_line == 1 -def test_chunk_file_nonexistent() -> None: - """Non-existent file returns empty list without raising.""" - chunks = chunk_file(Path("/nonexistent/file.py")) - assert chunks == [] +@pytest.mark.parametrize( + ("filename", "content"), + [ + (None, None), # nonexistent path + ("empty.py", " \n\n "), # whitespace-only + ("file.xyz", "hello world\n" * 5), # unknown extension + ], + ids=["nonexistent", "whitespace_only", "unknown_extension"], +) +def test_chunk_file_edge_cases_return_list(tmp_path: Path, filename: str | None, content: str | None) -> None: + """chunk_file returns a list (usually empty) for missing / empty / unknown-type files without raising.""" + if filename is None: + target = Path("/nonexistent/file.py") + else: + target = tmp_path / filename + assert content is not None + target.write_text(content) + chunks = chunk_file(target) + assert isinstance(chunks, list) -def test_chunk_file_empty(tmp_path: Path) -> None: - """Whitespace-only file returns no chunks.""" - f = tmp_path / "empty.py" - f.write_text(" \n\n ") - chunks = chunk_file(f) - assert chunks == [] +def test_chunk_file_py_produces_sorted_chunks(tmp_py_file: Path) -> None: + """Python file with functions produces at least one chunk in ascending start-line order.""" + pytest.importorskip("tree_sitter_python") + chunks = chunk_file(tmp_py_file) + assert len(chunks) >= 1 + start_lines = [c.start_line for c in chunks] + assert start_lines == sorted(start_lines) -def test_chunk_with_chonkie_fallback(tmp_path: Path) -> None: - """Should fall back to line-based when given an unsupported language.""" - f = tmp_path / "code.py" - f.write_text("def foo():\n pass\n") - chunks = _chunk_with_chonkie(f.read_text(), str(f), "python") +def _whitespace_chunker() -> MagicMock: + whitespace_chunk = MagicMock(text=" \n", start_index=0, end_index=0) + chunker = MagicMock() + chunker.chunk.return_value = [whitespace_chunk] + return chunker + + +@pytest.mark.parametrize( + "codechunker_patch", + [ + {"side_effect": Exception("boom")}, # raises + {"return_value": MagicMock(chunk=MagicMock(return_value=[]))}, # empty result + {"return_value": _whitespace_chunker()}, # whitespace-only chunks + ], + ids=["raises", "empty", "whitespace_only"], +) +def test_chunk_source_falls_back_when_chonkie_unusable(codechunker_patch: dict) -> None: + """chunk_source falls back to line-based chunking when chonkie fails or yields nothing usable.""" + source = "def foo():\n pass\n" + with patch("semble.index.chunker.CodeChunker", **codechunker_patch): + chunks = chunk_source(source, "foo.py", "python") assert len(chunks) > 0 + assert all(c.content.strip() for c in chunks) -def test_chunk_file_py_produces_chunks(tmp_py_file: Path) -> None: - """Python file with functions is split into at least one chunk.""" - chunks = chunk_file(tmp_py_file) - assert len(chunks) >= 1 - - -def test_chunk_file_sorted_by_line(tmp_py_file: Path) -> None: - """Chunks are returned in ascending start-line order.""" - pytest.importorskip("tree_sitter_python") - chunks = chunk_file(tmp_py_file) - start_lines = [c.start_line for c in chunks] - assert start_lines == sorted(start_lines) +def test_chunk_source_empty_string() -> None: + """chunk_source returns [] for whitespace-only input.""" + assert chunk_source(" \n\n", "foo.py", "python") == [] -def test_chunk_file_unknown_extension(tmp_path: Path) -> None: - """Unknown file extension returns a list without raising.""" - f = tmp_path / "file.xyz" - f.write_text("hello world\n" * 5) - chunks = chunk_file(f) - assert isinstance(chunks, list) +def test_filter_extensions_explicit() -> None: + """filter_extensions returns the provided set unchanged when extensions is not None.""" + explicit: frozenset[str] = frozenset({".py", ".ts"}) + result = filter_extensions(explicit, include_text_files=False) + assert result == explicit diff --git a/tests/test_index.py b/tests/test_index.py index d413dd2..047c769 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -2,6 +2,7 @@ import subprocess from pathlib import Path from typing import Any +from unittest.mock import patch import pytest @@ -59,33 +60,14 @@ def test_search_invalid_mode(indexed_index: SembleIndex) -> None: indexed_index.search("query", mode="invalid") -def test_search_top_k_respected(indexed_index: SembleIndex) -> None: - """Results never exceed the requested top_k.""" - results = indexed_index.search("function", top_k=1, mode="bm25") - assert len(results) <= 1 +def test_search_constraints(indexed_index: SembleIndex) -> None: + """search: top_k is respected; no duplicate chunks are returned.""" + assert len(indexed_index.search("function", top_k=1, mode="bm25")) <= 1 - -def test_search_no_duplicate_chunks(indexed_index: SembleIndex) -> None: - """Each result chunk appears at most once in the result list.""" results = indexed_index.search("authenticate", top_k=5) assert len(results) == len(set(r.chunk for r in results)) -def test_find_related_returns_similar_chunks(indexed_index: SembleIndex) -> None: - """find_related returns semantically similar chunks for a known file location.""" - chunk = indexed_index.chunks[0] - results = indexed_index.find_related(chunk.file_path, chunk.start_line, top_k=3) - assert isinstance(results, list) - assert all(r.chunk != chunk for r in results) - assert len(results) <= 3 - - -def test_find_related_unknown_file_returns_empty(indexed_index: SembleIndex) -> None: - """find_related returns an empty list when the file is not in the index.""" - results = indexed_index.find_related("/does/not/exist.py", 1) - assert results == [] - - @pytest.mark.parametrize("mode", ["bm25", "hybrid", "semantic"]) def test_search_with_filter_paths_does_not_crash(indexed_index: SembleIndex, mode: str) -> None: """Filtered search works regardless of where the selected chunk lives in the corpus.""" @@ -101,6 +83,17 @@ def test_search_empty_query_returns_empty(indexed_index: SembleIndex, mode: str, assert indexed_index.search(query, mode=mode) == [] +def test_find_related(indexed_index: SembleIndex) -> None: + """find_related: returns similar chunks for a known location; returns [] for an unknown file.""" + chunk = indexed_index.chunks[0] + results = indexed_index.find_related(chunk.file_path, chunk.start_line, top_k=3) + assert isinstance(results, list) + assert all(r.chunk != chunk for r in results) + assert len(results) <= 3 + + assert indexed_index.find_related("/does/not/exist.py", 1) == [] + + _GIT_ENV = { **os.environ, "GIT_AUTHOR_NAME": "test", @@ -130,19 +123,13 @@ def git_repo(tmp_path: Path) -> Path: return tmp_path -def test_from_git_indexes_local_repo(mock_model: Any, git_repo: Path) -> None: - """from_git clones a local repo and returns a populated SembleIndex.""" +def test_from_git_indexes_local_repo_with_relative_paths(mock_model: Any, git_repo: Path) -> None: + """from_git clones a local repo, indexes it, and keeps chunk paths repo-relative.""" idx = SembleIndex.from_git(str(git_repo), model=mock_model) assert idx.stats.indexed_files >= 1 assert idx.stats.total_chunks > 0 assert any("main.py" in c.file_path for c in idx.chunks) - - -def test_from_git_paths_are_repo_relative(mock_model: Any, git_repo: Path) -> None: - """Chunk file_paths are repo-relative after cloning, not absolute temp-dir paths.""" - idx = SembleIndex.from_git(str(git_repo), model=mock_model) - for chunk in idx.chunks: - assert not Path(chunk.file_path).is_absolute(), f"Expected relative path, got: {chunk.file_path}" + assert all(not Path(c.file_path).is_absolute() for c in idx.chunks) def test_from_git_with_branch(mock_model: Any, tmp_path: Path) -> None: @@ -159,7 +146,28 @@ def test_from_git_with_branch(mock_model: Any, tmp_path: Path) -> None: assert "feature.py" in file_names -def test_from_git_invalid_url_raises(mock_model: Any) -> None: - """from_git raises RuntimeError when the clone fails.""" +@pytest.mark.parametrize( + ("kind", "expected_exc"), + [("missing", FileNotFoundError), ("file", NotADirectoryError)], +) +def test_from_path_rejects_invalid_paths( + mock_model: Any, tmp_path: Path, kind: str, expected_exc: type[Exception] +) -> None: + """from_path raises FileNotFoundError for missing paths and NotADirectoryError for files.""" + if kind == "missing": + target = tmp_path / "does_not_exist" + else: + target = tmp_path / "not_a_dir.py" + target.write_text("x = 1\n") + with pytest.raises(expected_exc): + SembleIndex.from_path(target, model=mock_model) + + +def test_from_git_raises_on_failure(mock_model: Any) -> None: + """from_git raises RuntimeError when the clone fails or git is not installed.""" with pytest.raises(RuntimeError, match="git clone failed"): SembleIndex.from_git("/nonexistent/path/that/does/not/exist", model=mock_model) + + with patch("semble.index.index.subprocess.run", side_effect=FileNotFoundError): + with pytest.raises(RuntimeError, match="git is not installed"): + SembleIndex.from_git("https://github.com/x/y", model=mock_model) diff --git a/tests/test_mcp.py b/tests/test_mcp.py new file mode 100644 index 0000000..f3ed2f2 --- /dev/null +++ b/tests/test_mcp.py @@ -0,0 +1,236 @@ +import sys +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from semble.mcp import _format_results, _IndexCache, _is_git_url, create_server, main, serve +from semble.types import Encoder, SearchMode, SearchResult +from tests.conftest import make_chunk + + +def _tool_text(result: Any) -> str: + """Extract the text string from a FastMCP call_tool result.""" + return result[0][0].text + + +async def _call_tool( + cache: _IndexCache, + tool: str, + args: dict[str, Any], + *, + index_method: str, + index_return: list[SearchResult], + default_source: str | None = "/some/path", +) -> str: + """Patch SembleIndex.from_path with a fake index and invoke the tool, returning the text.""" + fake_index = MagicMock() + getattr(fake_index, index_method).return_value = index_return + with patch("semble.mcp.SembleIndex.from_path", return_value=fake_index): + server = create_server(cache, default_source=default_source) + result = await server.call_tool(tool, args) + return _tool_text(result) + + +@pytest.fixture() +def cache() -> _IndexCache: + """An _IndexCache backed by a stub model.""" + return _IndexCache(model=MagicMock(spec=Encoder)) + + +@pytest.mark.parametrize( + ("path", "expected"), + [ + ("https://github.com/org/repo", True), + ("http://github.com/org/repo", True), + ("git://github.com/org/repo", True), + ("ssh://git@github.com/org/repo", True), + ("git+ssh://git@github.com/org/repo", True), + ("file:///tmp/repo", True), + ("git@github.com:org/repo", True), # scp-like + ("/local/path/to/repo", False), + ("./relative/path", False), + ("repo_name", False), + ], +) +def test_is_git_url(path: str, expected: bool) -> None: + """Remote git URLs are detected; local paths are not.""" + assert _is_git_url(path) is expected + + +def test_format_results() -> None: + """_format_results: empty list → header only; with results → numbered fenced blocks with scores.""" + empty_out = _format_results("My header", []) + assert "My header" in empty_out + assert "```" not in empty_out + + chunks = [make_chunk(f"def fn_{i}(): pass", f"f{i}.py") for i in range(3)] + results = [ + SearchResult(chunk=c, score=round(0.1 * (i + 1), 3), source=SearchMode.HYBRID) for i, c in enumerate(chunks) + ] + out = _format_results("Results for: 'foo'", results) + assert "Results for: 'foo'" in out + assert out.count("```") >= len(results) * 2 # opening + closing fence each + for i, c in enumerate(chunks, start=1): + assert f"## {i}." in out + assert c.content in out + assert "0.100" in out and "0.200" in out and "0.300" in out + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("source", "patch_target"), + [ + ("local_tmp_path", "from_path"), + ("https://github.com/org/repo", "from_git"), + ], + ids=["local_path", "git_url"], +) +async def test_index_cache_builds_and_caches( + cache: _IndexCache, tmp_path: Path, source: str, patch_target: str +) -> None: + """_IndexCache.get() builds via the correct SembleIndex.* entrypoint and caches subsequent calls.""" + resolved_source = str(tmp_path) if source == "local_tmp_path" else source + fake_index = MagicMock() + with patch(f"semble.mcp.SembleIndex.{patch_target}", return_value=fake_index) as mock_build: + first = await cache.get(resolved_source) + second = await cache.get(resolved_source) + assert first is fake_index + assert second is fake_index + mock_build.assert_called_once() + + +@pytest.mark.anyio +async def test_index_cache_evicts_on_failure(cache: _IndexCache, tmp_path: Path) -> None: + """A failed build evicts the entry so the next call can retry.""" + call_count = 0 + + def _failing_then_ok(path: str, **kwargs: object) -> MagicMock: + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError("build failed") + return MagicMock() + + with patch("semble.mcp.SembleIndex.from_path", side_effect=_failing_then_ok): + with pytest.raises(RuntimeError, match="build failed"): + await cache.get(str(tmp_path)) + result = await cache.get(str(tmp_path)) + assert result is not None + assert call_count == 2 + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("tool", "args"), + [ + ("search", {"query": "foo"}), + ("find_related", {"file_path": "src/foo.py", "line": 10}), + ], +) +async def test_tool_no_repo_no_default(cache: _IndexCache, tool: str, args: dict[str, object]) -> None: + """Both tools return an error message when no repo and no default source are given.""" + server = create_server(cache, default_source=None) + result = await server.call_tool(tool, args) + assert "No repo specified" in _tool_text(result) + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("tool", "args"), + [ + ("search", {"query": "foo", "repo": "https://github.com/x/y"}), + ("find_related", {"file_path": "src/foo.py", "line": 1, "repo": "https://github.com/x/y"}), + ], +) +async def test_tool_index_failure(cache: _IndexCache, tool: str, args: dict[str, object]) -> None: + """Both tools return a friendly error message when indexing fails.""" + with patch("semble.mcp.SembleIndex.from_git", side_effect=RuntimeError("clone failed")): + server = create_server(cache) + result = await server.call_tool(tool, args) + text = _tool_text(result) + assert "Failed to index" in text + assert "clone failed" in text + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("tool", "args", "method", "results", "expected_substrings"), + [ + pytest.param( + "search", + {"query": "bar"}, + "search", + [SearchResult(chunk=make_chunk("def bar(): pass", "src/bar.py"), score=0.9, source=SearchMode.HYBRID)], + ["bar", "0.900"], + id="search_with_results", + ), + pytest.param( + "search", + {"query": "nothing"}, + "search", + [], + ["No results found"], + id="search_no_results", + ), + pytest.param( + "find_related", + {"file_path": "src/foo.py", "line": 1}, + "find_related", + [SearchResult(chunk=make_chunk("class Foo: pass", "src/foo.py"), score=0.8, source=SearchMode.SEMANTIC)], + ["src/foo.py:1", "0.800"], + id="find_related_with_results", + ), + pytest.param( + "find_related", + {"file_path": "src/foo.py", "line": 99}, + "find_related", + [], + ["No related chunks found"], + id="find_related_no_results", + ), + ], +) +async def test_tool_output( + cache: _IndexCache, + tool: str, + args: dict[str, Any], + method: str, + results: list[SearchResult], + expected_substrings: list[str], +) -> None: + """Search and find_related format results (or an empty-state message) through the server.""" + text = await _call_tool(cache, tool, args, index_method=method, index_return=results) + for substring in expected_substrings: + assert substring in text + + +@pytest.mark.anyio +@pytest.mark.parametrize("with_path", [True, False], ids=["pre_index", "no_path"]) +async def test_serve_runs_stdio(tmp_path: Path, with_path: bool) -> None: + """serve() loads the model, runs stdio, and optionally pre-indexes when a path is given.""" + with ( + patch("semble.mcp.load_model", return_value=MagicMock(spec=Encoder)), + patch("semble.mcp.SembleIndex.from_path", return_value=MagicMock()), + patch("mcp.server.fastmcp.FastMCP.run_stdio_async", new_callable=AsyncMock) as mock_run, + ): + await (serve(str(tmp_path)) if with_path else serve()) + + mock_run.assert_called_once() + + +@pytest.mark.parametrize( + "argv", + [ + ["semble", "/some/path", "--ref", "main"], + ["semble"], + ], +) +def test_main_calls_asyncio_run(argv: list[str], monkeypatch: pytest.MonkeyPatch) -> None: + """main() parses argv and delegates to asyncio.run(serve(...)).""" + monkeypatch.setattr(sys, "argv", argv) + with patch("semble.mcp.asyncio.run") as mock_run: + mock_run.side_effect = lambda coro: coro.close() + main() + mock_run.assert_called_once() diff --git a/tests/test_ranking.py b/tests/test_ranking.py index 3c84d80..486dee4 100644 --- a/tests/test_ranking.py +++ b/tests/test_ranking.py @@ -1,82 +1,149 @@ import pytest -from semble.ranking.boosting import _chunk_defines_symbol, _is_symbol_query -from semble.ranking.penalties import _file_path_penalty, rerank_topk +from semble.ranking.boosting import apply_query_boost, boost_multi_chunk_files, resolve_alpha +from semble.ranking.penalties import rerank_topk from tests.conftest import make_chunk +def test_rerank_topk() -> None: + """rerank_topk: empty → []; penalise_paths=False respects raw scores; saturation decay keeps order.""" + assert rerank_topk({}, top_k=5) == [] + + init_chunk = make_chunk("from .auth import authenticate", "src/semble/__init__.py") + impl_chunk = make_chunk("def authenticate(token): ...", "src/semble/auth.py") + ranked = rerank_topk({init_chunk: 2.0, impl_chunk: 1.0}, top_k=2, penalise_paths=False) + assert ranked[0][0] == init_chunk + + saturated = [make_chunk(f"def fn_{i}(): pass", "big_file.py") for i in range(5)] + ranked_sat = rerank_topk({c: float(5 - i) for i, c in enumerate(saturated)}, top_k=5) + scores = [s for _, s in ranked_sat] + assert scores == sorted(scores, reverse=True) + + @pytest.mark.parametrize( - ("query", "expected"), + "penalised_path", [ - ("HTTPAdapter", True), - ("field_validator", True), - ("URL", True), - ("getUser", True), - ("Sinatra::Base", True), - ("_private", True), - ("__init__", True), - ("session", False), - ("response", False), - ("how does routing work", False), + "src/semble/__init__.py", # _REEXPORT_FILENAMES + "tests/test_auth.py", # _TEST_FILE_RE / _TEST_DIR_RE + "src/compat/old_api.py", # _COMPAT_DIR_RE + "examples/demo.py", # _EXAMPLES_DIR_RE + "src/types/index.d.ts", # _TYPE_DEFS_RE ], ) -def test_is_symbol_query(query: str, expected: bool) -> None: - """Identifiers with uppercase/underscore/separator are symbols; plain lowercase words are not.""" - assert _is_symbol_query(query) is expected +def test_rerank_topk_demotes_penalised_paths(penalised_path: str) -> None: + """Files matching each penalty pattern rank below an equal-scored regular file.""" + regular = make_chunk("def impl(): pass", "src/regular.py") + penalised = make_chunk("def impl(): pass", penalised_path) + ranked = rerank_topk({regular: 1.0, penalised: 1.0}, top_k=2) + assert ranked[0][0] == regular @pytest.mark.parametrize( - ("file_path", "expected"), + ("query", "alpha_in", "expected"), [ - ("src/auth.py", 1.0), - ("src/semble/__init__.py", 0.5), - ("tests/test_auth.py", 0.3), - ("src/compat/old_api.py", 0.3), - ("examples/demo.py", 0.3), - ("src/types/index.d.ts", 0.7), + ("MyService", 0.7, 0.7), # explicit value returned as-is + ("MyService", None, 0.3), # symbol query → _ALPHA_SYMBOL + ("how does routing work", None, 0.5), # NL query → _ALPHA_NL ], ) -def test_file_path_penalty(file_path: str, expected: float) -> None: - """Path penalties are applied correctly per file type.""" - assert _file_path_penalty(file_path) == pytest.approx(expected) +def test_resolve_alpha(query: str, alpha_in: float | None, expected: float) -> None: + """resolve_alpha returns explicit alpha or auto-detects from symbol/NL query type.""" + assert resolve_alpha(query, alpha_in) == pytest.approx(expected) @pytest.mark.parametrize( - ("content", "symbol", "expected"), + "query", [ - ("class UserService:\n pass", "UserService", True), - ("def authenticate(token):\n return token", "authenticate", True), - ("struct Config {\n host: String,\n}", "Config", True), - ("CREATE TABLE users (\n id INT\n);", "users", True), - ("x = UserService()\n", "UserService", False), - ("return Config(host='localhost')", "Config", False), + "MyService", # bare symbol query + "how does MyService work", # NL query with embedded symbol ], ) -def test_chunk_defines_symbol(content: str, symbol: str, expected: bool) -> None: - """Definition keyword + symbol name matches; bare usage does not.""" - assert _chunk_defines_symbol(make_chunk(content), symbol) is expected +def test_apply_query_boost_boosts_defining_chunk(query: str) -> None: + """Symbol and NL-with-symbol queries both boost chunks that define the symbol.""" + defining = make_chunk("class MyService:\n pass", "src/my_service.py") + other = make_chunk("x = MyService()", "src/utils.py") + scores: dict = {defining: 0.5, other: 0.4} + boosted = apply_query_boost(scores, query, [defining, other]) -def test_rerank_topk_init_demoted_by_default() -> None: - """__init__.py is demoted below an equal-scored regular file.""" - init_chunk = make_chunk("from .auth import authenticate", "src/semble/__init__.py") - impl_chunk = make_chunk("def authenticate(token): ...", "src/semble/auth.py") - ranked = rerank_topk({init_chunk: 1.0, impl_chunk: 1.0}, top_k=2) - assert ranked[0][0] == impl_chunk + assert boosted[defining] > boosted[other] -def test_rerank_topk_penalise_paths_false_respects_scores() -> None: - """penalise_paths=False leaves score order intact, including __init__.py.""" - init_chunk = make_chunk("from .auth import authenticate", "src/semble/__init__.py") - impl_chunk = make_chunk("def authenticate(token): ...", "src/semble/auth.py") - ranked = rerank_topk({init_chunk: 2.0, impl_chunk: 1.0}, top_k=2, penalise_paths=False) - assert ranked[0][0] == init_chunk +@pytest.mark.parametrize( + "query", + [ + "MyService", + "how does MyService work", + ], +) +def test_apply_query_boost_scans_non_candidates(query: str) -> None: + """Non-candidate chunks on stem-matched files get boosted when defining the symbol.""" + defining = make_chunk("class MyService:\n pass", "src/myservice.py") + candidate = make_chunk("x = 1", "src/other.py") + scores: dict = {candidate: 0.5} + boosted = apply_query_boost(scores, query, [defining, candidate]) -def test_rerank_topk_saturation_decay_preserves_order() -> None: - """Chunks beyond the saturation threshold get decay but results stay score-ordered.""" - chunks = [make_chunk(f"def fn_{i}(): pass", "big_file.py") for i in range(5)] - ranked = rerank_topk({c: float(5 - i) for i, c in enumerate(chunks)}, top_k=5) - assert len(ranked) == 5 - scores = [s for _, s in ranked] - assert scores == sorted(scores, reverse=True) + assert defining in boosted + assert boosted[defining] > 0 + + +@pytest.mark.parametrize( + "query", + [ + "UserService", # bare symbol query + "how does UserService work", # NL with embedded symbol + ], +) +def test_apply_query_boost_skips_non_matching_stem(query: str) -> None: + """Non-candidate chunk with an unrelated stem is not boosted, regardless of query style.""" + defining = make_chunk("class UserService:\n pass", "src/user_service.py") + unrelated = make_chunk("x = 1", "src/totally_unrelated_name.py") + scores: dict = {defining: 0.5} + boosted = apply_query_boost(scores, query, [defining, unrelated]) + assert unrelated not in boosted + + +@pytest.mark.parametrize( + ("query", "file_path"), + [ + ("authenticate user session", "src/auth.py"), # prefix / morphological match + ("auth service", "src/auth_service.py"), # every keyword exact-matches a stem part + ], +) +def test_apply_query_boost_nl_stem_match_boosts(query: str, file_path: str) -> None: + """NL query keywords matching file-stem parts boost the chunk above its baseline score.""" + chunk = make_chunk("def authenticate(): pass", file_path) + scores: dict = {chunk: 0.5} + boosted = apply_query_boost(scores, query, [chunk]) + assert boosted[chunk] > 0.5 + + +def test_apply_query_boost_edge_cases() -> None: + """apply_query_boost: stopwords → noop; namespace-qualified → boosts leaf; empty scores → {}.""" + chunk = make_chunk("def foo(): pass", "src/auth.py") + assert apply_query_boost({chunk: 0.5}, "the and or", [chunk])[chunk] == pytest.approx(0.5) + + defining = make_chunk("class Base:\n pass", "src/base.py") + assert apply_query_boost({defining: 0.5}, "Sinatra::Base", [defining])[defining] > 0.5 + + assert apply_query_boost({}, "SomeQuery", []) == {} + + +def test_boost_multi_chunk_files() -> None: + """boost_multi_chunk_files: no-op on empty / all-zero; promotes top chunk of a multi-chunk file.""" + empty: dict = {} + boost_multi_chunk_files(empty) + assert empty == {} + + zero_chunk = make_chunk("x = 1", "src/foo.py") + all_zero: dict = {zero_chunk: 0.0} + boost_multi_chunk_files(all_zero) + assert all_zero[zero_chunk] == 0.0 + + c1 = make_chunk("def a(): pass", "src/big.py") + c2 = make_chunk("def b(): pass", "src/big.py") + c3 = make_chunk("def c(): pass", "src/small.py") + scores: dict = {c1: 1.0, c2: 0.8, c3: 1.0} + boost_multi_chunk_files(scores) + assert scores[c1] > 1.0 diff --git a/tests/test_search.py b/tests/test_search.py index 510b755..2a0d487 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1,4 +1,5 @@ from typing import Any +from unittest.mock import MagicMock, patch import bm25s import numpy as np @@ -6,10 +7,10 @@ import pytest from vicinity.backends.basic import BasicArgs -from semble.index.dense import SelectableBasicBackend +from semble.index.dense import SelectableBasicBackend, embed_chunks, load_model from semble.search import _sort_top_k, search_bm25, search_hybrid, search_semantic from semble.tokens import tokenize -from semble.types import Chunk, SearchMode +from semble.types import Chunk, Encoder, SearchMode from tests.conftest import make_chunk @@ -48,17 +49,21 @@ def semantic(embeddings: npt.NDArray[np.float32]) -> SelectableBasicBackend: return SelectableBasicBackend(embeddings, BasicArgs()) -def test_bm25_search(bm25: bm25s.BM25, chunks: list[Chunk]) -> None: - """BM25 returns results with the most relevant chunk first.""" +def test_search_bm25(bm25: bm25s.BM25, chunks: list[Chunk]) -> None: + """search_bm25: returns most relevant chunk first; selector restricts to given indices.""" results = search_bm25("authenticate token", bm25, chunks, top_k=4, selector=None) assert len(results) > 0 assert "authenticate" in results[0].chunk.content + selector = np.array([len(chunks) - 1], dtype=np.int_) + filtered = search_bm25("format", bm25, chunks, top_k=4, selector=selector) + assert all(r.chunk is chunks[len(chunks) - 1] for r in filtered) + -def test_bm25_no_results_for_garbage(bm25: bm25s.BM25, chunks: list[Chunk]) -> None: - """Query with no matching tokens returns an empty list.""" - results = search_bm25("zzzznonexistentterm", bm25, chunks, top_k=3, selector=None) - assert results == [] +@pytest.mark.parametrize("query", ["", " ", "\n\n", "zzzznonexistentterm"]) +def test_bm25_returns_empty_for_no_match(bm25: bm25s.BM25, chunks: list[Chunk], query: str) -> None: + """Empty / whitespace-only / token-less queries return [] instead of crashing bm25s.""" + assert search_bm25(query, bm25, chunks, top_k=3, selector=None) == [] def test_semantic_search(semantic: SelectableBasicBackend, chunks: list[Chunk], mock_model: Any) -> None: @@ -68,16 +73,13 @@ def test_semantic_search(semantic: SelectableBasicBackend, chunks: list[Chunk], assert all(-1.0 <= r.score <= 1.0 for r in results) -def test_hybrid_returns_results( +def test_search_hybrid( chunks: list[Chunk], semantic: SelectableBasicBackend, bm25: bm25s.BM25, mock_model: Any ) -> None: - """Hybrid search returns results combining semantic and BM25 signals.""" + """search_hybrid: returns combined results; identical content in different files produces separate results.""" results = search_hybrid("authenticate token", mock_model, semantic, bm25, chunks, top_k=3) assert len(results) > 0 - -def test_hybrid_keeps_both_locations_for_identical_content(mock_model: Any) -> None: - """Identical chunk content in different files produces two distinct results.""" shared_content = "def helper():\n pass" chunk_a = make_chunk(shared_content, "module_a.py") chunk_b = make_chunk(shared_content, "module_b.py") @@ -91,8 +93,8 @@ def test_hybrid_keeps_both_locations_for_identical_content(mock_model: Any) -> N bm25_index = bm25s.BM25() bm25_index.index([tokenize(c.content) for c in all_chunks], show_progress=False) - results = search_hybrid("helper", mock_model, sem_index, bm25_index, all_chunks, top_k=5) - result_locations = {r.chunk.file_path for r in results} + deduped = search_hybrid("helper", mock_model, sem_index, bm25_index, all_chunks, top_k=5) + result_locations = {r.chunk.file_path for r in deduped} assert "module_a.py" in result_locations assert "module_b.py" in result_locations @@ -122,7 +124,7 @@ def test_search_source_labels( def test_sort_top_k() -> None: - """Test that the sort top k is a faster version of argsort.""" + """_sort_top_k returns the same indices as np.argsort(-x)[:top_k].""" gen = np.random.default_rng() x = gen.standard_normal(size=(10000,)) top_k = 100 @@ -130,14 +132,32 @@ def test_sort_top_k() -> None: assert np.all(indices == np.argsort(-x)[:top_k]) -def test_bm25_with_selector_high_indices(bm25: bm25s.BM25, chunks: list[Chunk]) -> None: - """BM25 with a selector whose indices exceed len(selector) does not crash.""" - selector = np.array([len(chunks) - 1], dtype=np.int_) - results = search_bm25("format", bm25, chunks, top_k=4, selector=selector) - assert all(r.chunk is chunks[len(chunks) - 1] for r in results) +@pytest.mark.parametrize( + ("model_path", "expected_call_arg"), + [ + (None, "minishlab/potion-code-16M"), # default model + ("some/custom/model", "some/custom/model"), # explicit path forwarded + ], +) +def test_load_model(model_path: str | None, expected_call_arg: str) -> None: + """load_model calls from_pretrained with default or custom model path.""" + fake_model = MagicMock(spec=Encoder) + with patch("semble.index.dense.StaticModel.from_pretrained", return_value=fake_model) as mock_fp: + result = load_model(model_path) + mock_fp.assert_called_once_with(expected_call_arg) + assert result is fake_model -@pytest.mark.parametrize("query", ["", " ", "\n\n"]) -def test_bm25_empty_query_returns_empty(bm25: bm25s.BM25, chunks: list[Chunk], query: str) -> None: - """Empty / whitespace-only queries return [] instead of crashing bm25s.""" - assert search_bm25(query, bm25, chunks, top_k=3, selector=None) == [] +def test_embed_chunks_empty_returns_empty_array(mock_model: Any) -> None: + """embed_chunks with an empty list returns a (0, 256) float32 array.""" + result = embed_chunks(mock_model, []) + assert result.shape == (0, 256) + assert result.dtype == np.float32 + + +def test_selectable_basic_backend_rejects_k_below_one( + semantic: SelectableBasicBackend, embeddings: npt.NDArray[np.float32] +) -> None: + """SelectableBasicBackend.query guards against k < 1.""" + with pytest.raises(ValueError, match="k should be >= 1"): + semantic.query(embeddings[:1], k=0)