From 0c2620678a93268dda21800321d48bba4b3d1a1e Mon Sep 17 00:00:00 2001 From: stephantul Date: Thu, 16 Apr 2026 08:10:03 +0200 Subject: [PATCH] feat: add language and document selector --- src/semble/index/create.py | 9 +++--- src/semble/index/dense.py | 52 +++++++++++++++++++++++++++++++++++ src/semble/index/index.py | 56 ++++++++++++++++++++++++++++++++------ src/semble/index/sparse.py | 12 ++++++++ src/semble/search.py | 30 ++++++++++++-------- 5 files changed, 135 insertions(+), 24 deletions(-) diff --git a/src/semble/index/create.py b/src/semble/index/create.py index 76ac6a3..4620888 100644 --- a/src/semble/index/create.py +++ b/src/semble/index/create.py @@ -4,11 +4,11 @@ from pathlib import Path import bm25s -from vicinity import Metric, Vicinity +from vicinity.backends.basic import BasicArgs from semble.chunker import chunk_source from semble.file_walker import filter_extensions, language_for_path, walk_files -from semble.index.dense import embed_chunks +from semble.index.dense import SelectableBasicBackend, embed_chunks from semble.index.sparse import enrich_for_bm25 from semble.tokens import tokenize from semble.types import Chunk, Encoder @@ -21,7 +21,7 @@ def create_index_from_path( ignore: frozenset[str] | None = None, include_docs: bool = False, display_root: Path | None = None, -) -> tuple[bm25s.BM25, Vicinity, list[Chunk]]: +) -> tuple[bm25s.BM25, SelectableBasicBackend, list[Chunk]]: """Create an index from a resolved directory, optionally storing chunk paths relative to display_root. :param path: Resolved absolute path to index. @@ -51,7 +51,8 @@ def create_index_from_path( [tokenize(enrich_for_bm25(chunk, display_root or path)) for chunk in chunks], show_progress=False, ) - semantic_index = Vicinity.from_vectors_and_items(embeddings, chunks, metric=Metric.COSINE) + args = BasicArgs() + semantic_index = SelectableBasicBackend(embeddings, args) else: raise ValueError("Unable to create index.") diff --git a/src/semble/index/dense.py b/src/semble/index/dense.py index 2f864ee..fe3b12c 100644 --- a/src/semble/index/dense.py +++ b/src/semble/index/dense.py @@ -4,6 +4,10 @@ import numpy.typing as npt from huggingface_hub.utils.tqdm import disable_progress_bars from model2vec import StaticModel +from numpy import typing as npt +from vicinity.backends.basic import CosineBasicBackend +from vicinity.datatypes import QueryResult +from vicinity.utils import normalize from semble.types import Chunk, Encoder @@ -28,3 +32,51 @@ def embed_chunks(model: Encoder, chunks: list[Chunk]) -> npt.NDArray[np.float32] if not chunks: return np.empty((0, 256), dtype=np.float32) return np.array(model.encode([c.content for c in chunks]), dtype=np.float32) + + +class SelectableBasicBackend(CosineBasicBackend): + def _selector_dist(self, x: npt.NDArray, selector: npt.NDArray[np.int_]) -> npt.NDArray: + """Compute cosine distance.""" + x_norm = normalize(x) + sim = x_norm.dot(self._vectors[selector].T) + return 1 - sim + + def query(self, vectors: npt.NDArray, k: int, selector: npt.NDArray[np.int_] | None = None) -> QueryResult: + """Batched distance query. + + :param vectors: The vectors to query. + :param k: The number of nearest neighbors to return. + :param selector: Optional array of chunk indices to filter results by. + :return: A list of tuples with the indices and distances. + :raises ValueError: If k is less than 1. + """ + if k < 1: + raise ValueError(f"k should be >= 1, is now {k}") + + out: QueryResult = [] + num_vectors = len(self.vectors) + effective_k = min(k, num_vectors) + if selector is not None: + effective_k = min(effective_k, len(selector)) + + # Batch the queries + for index in range(0, len(vectors), 1024): + batch = vectors[index : index + 1024] + if selector is not None: + distances = self._selector_dist(batch, selector) + else: + distances = self._dist(batch) + + # Efficiently get the k smallest distances + indices = np.argpartition(distances, kth=effective_k - 1, axis=1)[:, :effective_k] + sorted_indices = np.take_along_axis( + indices, np.argsort(np.take_along_axis(distances, indices, axis=1)), axis=1 + ) + sorted_distances = np.take_along_axis(distances, sorted_indices, axis=1) + + # Extend the output with tuples of (indices, distances) + if selector is not None: + sorted_indices = selector[sorted_indices] + out.extend(zip(sorted_indices, sorted_distances)) + + return out diff --git a/src/semble/index/index.py b/src/semble/index/index.py index 0ef66fe..17f6788 100644 --- a/src/semble/index/index.py +++ b/src/semble/index/index.py @@ -5,11 +5,13 @@ from collections import defaultdict from pathlib import Path +import numpy as np +import numpy.typing as npt from bm25s import BM25 -from vicinity import Vicinity from semble.index.create import create_index_from_path -from semble.index.dense import load_model +from semble.index.dense import SelectableBasicBackend, load_model +from semble.index.sparse import selector_to_mask from semble.search import search_bm25, search_hybrid, search_semantic from semble.types import Chunk, Encoder, IndexStats, SearchMode, SearchResult @@ -21,7 +23,7 @@ def __init__( self, model: Encoder, bm25_index: BM25, - semantic_index: Vicinity, + semantic_index: SelectableBasicBackend, chunks: list[Chunk], index_root: Path, ) -> None: @@ -36,8 +38,21 @@ def __init__( self.model: Encoder = model self.chunks: list[Chunk] = chunks self._bm25_index: BM25 = bm25_index - self._semantic_index: Vicinity = semantic_index + self._semantic_index: SelectableBasicBackend = semantic_index self._index_root: Path = index_root + self.file_mapping, self.language_mapping = self._populate_mapping() + + def _populate_mapping(self) -> tuple[dict[str, list[int]], dict[str, list[int]]]: + """Creates two mappings, one from language to chunk, and one from file to chunk.""" + language_to_id = defaultdict(list) + file_to_id = defaultdict(list) + for i, chunk in enumerate(self.chunks): + language = chunk.language + if language: + language_to_id[language].append(i) + file_to_id[chunk.file_path].append(i) + + return dict(file_to_id), dict(language_to_id) @property def stats(self) -> IndexStats: @@ -143,15 +158,33 @@ def find_related(self, file_path: str, line: int, top_k: int = 5) -> list[Search ) if target is None: return [] - results = search_semantic(target.content, self.model, self._semantic_index, top_k + 1) + if target.language: + selector = self._get_selector_vector(select_language=[target.language]) + else: + selector = None + results = search_semantic(target.content, self.model, self._semantic_index, self.chunks, top_k + 1, selector) return [r for r in results if r.chunk != target][:top_k] + def _get_selector_vector( + self, select_language: list[str] | None = None, select_document: list[str] | None = None + ) -> npt.NDArray[np.int_] | None: + """Create a vector of integers corresponding to the items that should be retrieved.""" + selector = [] + for language in select_language or []: + selector.extend(self.language_mapping.get(language, [])) + for filename in select_document or []: + selector.extend(self.file_mapping.get(filename, [])) + + return np.asarray(selector) if selector else None + def search( self, query: str, top_k: int = 10, mode: SearchMode | str = SearchMode.HYBRID, alpha: float | None = None, + select_language: list[str] | None = None, + select_document: list[str] | None = None, ) -> list[SearchResult]: """Search the index and return the top-k most relevant chunks. @@ -161,6 +194,8 @@ def search( :param alpha: Blend weight for hybrid score combination; 1.0 = full semantic weight, 0.0 = full BM25 weight. File-path penalties and diversity reranking are applied regardless. ``None`` auto-detects from query type. + :param select_language: Optional list of language codes to filter results by. + :param select_document: Optional list of document paths to filter results by. :return: Ranked list of :class:`SearchResult` objects, best match first. :raises ValueError: If `mode` is not a recognised search strategy. """ @@ -168,11 +203,14 @@ def search( if not self.chunks: return [] - if mode == SearchMode.BM25: - return search_bm25(query, bm25_index, self.chunks, top_k) + selector = self._get_selector_vector(select_language, select_document) + if mode == SearchMode.BM25: + return search_bm25(query, bm25_index, self.chunks, top_k, selector=selector) if mode == SearchMode.SEMANTIC: - return search_semantic(query, self.model, semantic_index, top_k) + return search_semantic(query, self.model, semantic_index, self.chunks, top_k, selector=selector) if mode == SearchMode.HYBRID: - return search_hybrid(query, self.model, semantic_index, bm25_index, self.chunks, top_k, alpha=alpha) + return search_hybrid( + query, self.model, semantic_index, bm25_index, self.chunks, top_k, alpha=alpha, selector=selector + ) raise ValueError(f"Unknown search mode: {mode!r}") diff --git a/src/semble/index/sparse.py b/src/semble/index/sparse.py index 9fd839d..c064376 100644 --- a/src/semble/index/sparse.py +++ b/src/semble/index/sparse.py @@ -3,9 +3,21 @@ import contextlib from pathlib import Path +import numpy as np +import numpy.typing as npt + from semble.types import Chunk +def selector_to_mask(selector: npt.NDArray[np.int_] | None) -> npt.NDArray[np.bool_] | None: + """Convert a selector array to a boolean mask.""" + if selector is None: + return None + mask = np.zeros(len(selector), dtype=bool) + mask[selector] = True + return mask + + def enrich_for_bm25(chunk: Chunk, root: Path | None) -> str: """Append file path components to BM25 content to boost path-based queries. diff --git a/src/semble/search.py b/src/semble/search.py index b5338a6..e501e6f 100644 --- a/src/semble/search.py +++ b/src/semble/search.py @@ -1,8 +1,9 @@ import bm25s import numpy as np import numpy.typing as npt -from vicinity import Vicinity +from semble.index.dense import SelectableBasicBackend +from semble.index.sparse import selector_to_mask from semble.ranking import apply_query_boost, rerank_topk, resolve_alpha from semble.tokens import tokenize from semble.types import Chunk, Encoder, SearchMode, SearchResult @@ -21,15 +22,18 @@ def _rrf_scores(scores: dict[Chunk, float]) -> dict[Chunk, float]: def search_semantic( query: str, model: Encoder, - semantic_index: Vicinity[Chunk], + semantic_index: SelectableBasicBackend, + chunks: list[Chunk], top_k: int, + selector: npt.NDArray[np.int_] | None, ) -> list[SearchResult]: """Run semantic search for a query.""" query_embedding = model.encode([query]) - hits = semantic_index.query(query_embedding, k=top_k)[0] + indices, scores = semantic_index.query(query_embedding, k=top_k, selector=selector)[0] # Vicinity returns cosine distance; convert to similarity so higher = better. return [ - SearchResult(chunk=chunk, score=1.0 - float(distance), source=SearchMode.SEMANTIC) for chunk, distance in hits + SearchResult(chunk=chunks[index], score=1.0 - float(distance), source=SearchMode.SEMANTIC) + for index, distance in zip(indices, scores) ] @@ -47,9 +51,11 @@ def search_bm25( bm25_index: bm25s.BM25, chunks: list[Chunk], top_k: int, + selector: npt.NDArray[np.int_] | None, ) -> list[SearchResult]: """Return chunks ranked by BM25 score, excluding zero-score results.""" - scores: npt.NDArray[np.float32] = bm25_index.get_scores(tokenize(query)) + mask = selector_to_mask(selector) + scores: npt.NDArray[np.float32] = bm25_index.get_scores(tokenize(query), weight_mask=mask) indices = _sort_top_k(scores, top_k) # Exclude chunks with zero score, no query tokens matched. @@ -61,11 +67,12 @@ def search_bm25( def search_hybrid( query: str, model: Encoder, - semantic_index: Vicinity[Chunk], + semantic_index: SelectableBasicBackend, bm25_index: bm25s.BM25, chunks: list[Chunk], top_k: int, alpha: float | None = None, + selector: npt.NDArray[np.int_] | None = None, ) -> list[SearchResult]: """Hybrid search: alpha-weighted combination of semantic and BM25 scores. @@ -79,6 +86,7 @@ def search_hybrid( :param chunks: All indexed chunks (parallel to BM25 index). :param top_k: Number of results to return. :param alpha: Weight for semantic score (1-alpha goes to BM25). None = auto-detect based on query type. + :param selector: Optional array of chunk indices to filter results by. :return: List of search results sorted by combined score descending. """ alpha_weight = resolve_alpha(query, alpha) @@ -87,15 +95,15 @@ def search_hybrid( # 5x is sufficient; latency difference vs larger multipliers is negligible. candidate_count = top_k * 5 - semantic = search_semantic(query, model, semantic_index, candidate_count) + semantic = search_semantic(query, model, semantic_index, chunks, candidate_count, selector) semantic_scores: dict[Chunk, float] = {result.chunk: result.score for result in semantic} - bm25_result_scores = {} - for result in search_bm25(query, bm25_index, chunks, candidate_count): + bm25_scores = {} + for result in search_bm25(query, bm25_index, chunks, candidate_count, selector): if result.score: - bm25_result_scores[result.chunk] = result.score + bm25_scores[result.chunk] = result.score normalized_semantic = _rrf_scores(semantic_scores) - normalized_bm25 = _rrf_scores(bm25_result_scores) + normalized_bm25 = _rrf_scores(bm25_scores) combined_scores: dict[Chunk, float] = { chunk: alpha_weight * normalized_semantic.get(chunk, 0.0)