Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/semble/index/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from pathlib import Path

import bm25s
from vicinity import Metric, Vicinity
from vicinity.backends.basic import BasicArgs

from semble.chunker import chunk_source
from semble.file_walker import filter_extensions, language_for_path, walk_files
from semble.index.dense import embed_chunks
from semble.index.dense import SelectableBasicBackend, embed_chunks
from semble.index.sparse import enrich_for_bm25
from semble.tokens import tokenize
from semble.types import Chunk, Encoder
Expand All @@ -21,7 +21,7 @@ def create_index_from_path(
ignore: frozenset[str] | None = None,
include_docs: bool = False,
display_root: Path | None = None,
) -> tuple[bm25s.BM25, Vicinity, list[Chunk]]:
) -> tuple[bm25s.BM25, SelectableBasicBackend, list[Chunk]]:
"""Create an index from a resolved directory, optionally storing chunk paths relative to display_root.

:param path: Resolved absolute path to index.
Expand Down Expand Up @@ -51,7 +51,8 @@ def create_index_from_path(
[tokenize(enrich_for_bm25(chunk, display_root or path)) for chunk in chunks],
show_progress=False,
)
semantic_index = Vicinity.from_vectors_and_items(embeddings, chunks, metric=Metric.COSINE)
args = BasicArgs()
semantic_index = SelectableBasicBackend(embeddings, args)
else:
raise ValueError("Unable to create index.")

Expand Down
52 changes: 52 additions & 0 deletions src/semble/index/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import numpy.typing as npt
from huggingface_hub.utils.tqdm import disable_progress_bars
from model2vec import StaticModel
from numpy import typing as npt
from vicinity.backends.basic import CosineBasicBackend
from vicinity.datatypes import QueryResult
from vicinity.utils import normalize

from semble.types import Chunk, Encoder

Expand All @@ -28,3 +32,51 @@ def embed_chunks(model: Encoder, chunks: list[Chunk]) -> npt.NDArray[np.float32]
if not chunks:
return np.empty((0, 256), dtype=np.float32)
return np.array(model.encode([c.content for c in chunks]), dtype=np.float32)


class SelectableBasicBackend(CosineBasicBackend):
def _selector_dist(self, x: npt.NDArray, selector: npt.NDArray[np.int_]) -> npt.NDArray:
"""Compute cosine distance."""
x_norm = normalize(x)
sim = x_norm.dot(self._vectors[selector].T)
return 1 - sim

def query(self, vectors: npt.NDArray, k: int, selector: npt.NDArray[np.int_] | None = None) -> QueryResult:
"""Batched distance query.

:param vectors: The vectors to query.
:param k: The number of nearest neighbors to return.
:param selector: Optional array of chunk indices to filter results by.
:return: A list of tuples with the indices and distances.
:raises ValueError: If k is less than 1.
"""
if k < 1:
raise ValueError(f"k should be >= 1, is now {k}")

out: QueryResult = []
num_vectors = len(self.vectors)
effective_k = min(k, num_vectors)
if selector is not None:
effective_k = min(effective_k, len(selector))

# Batch the queries
for index in range(0, len(vectors), 1024):
batch = vectors[index : index + 1024]
if selector is not None:
distances = self._selector_dist(batch, selector)
else:
distances = self._dist(batch)

# Efficiently get the k smallest distances
indices = np.argpartition(distances, kth=effective_k - 1, axis=1)[:, :effective_k]
sorted_indices = np.take_along_axis(
indices, np.argsort(np.take_along_axis(distances, indices, axis=1)), axis=1
)
sorted_distances = np.take_along_axis(distances, sorted_indices, axis=1)

# Extend the output with tuples of (indices, distances)
if selector is not None:
sorted_indices = selector[sorted_indices]
out.extend(zip(sorted_indices, sorted_distances))

return out
56 changes: 47 additions & 9 deletions src/semble/index/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from collections import defaultdict
from pathlib import Path

import numpy as np
import numpy.typing as npt
from bm25s import BM25
from vicinity import Vicinity

from semble.index.create import create_index_from_path
from semble.index.dense import load_model
from semble.index.dense import SelectableBasicBackend, load_model
from semble.index.sparse import selector_to_mask
from semble.search import search_bm25, search_hybrid, search_semantic
from semble.types import Chunk, Encoder, IndexStats, SearchMode, SearchResult

Expand All @@ -21,7 +23,7 @@ def __init__(
self,
model: Encoder,
bm25_index: BM25,
semantic_index: Vicinity,
semantic_index: SelectableBasicBackend,
chunks: list[Chunk],
index_root: Path,
) -> None:
Expand All @@ -36,8 +38,21 @@ def __init__(
self.model: Encoder = model
self.chunks: list[Chunk] = chunks
self._bm25_index: BM25 = bm25_index
self._semantic_index: Vicinity = semantic_index
self._semantic_index: SelectableBasicBackend = semantic_index
self._index_root: Path = index_root
self.file_mapping, self.language_mapping = self._populate_mapping()

def _populate_mapping(self) -> tuple[dict[str, list[int]], dict[str, list[int]]]:
"""Creates two mappings, one from language to chunk, and one from file to chunk."""
language_to_id = defaultdict(list)
file_to_id = defaultdict(list)
for i, chunk in enumerate(self.chunks):
language = chunk.language
if language:
language_to_id[language].append(i)
file_to_id[chunk.file_path].append(i)

return dict(file_to_id), dict(language_to_id)

@property
def stats(self) -> IndexStats:
Expand Down Expand Up @@ -143,15 +158,33 @@ def find_related(self, file_path: str, line: int, top_k: int = 5) -> list[Search
)
if target is None:
return []
results = search_semantic(target.content, self.model, self._semantic_index, top_k + 1)
if target.language:
selector = self._get_selector_vector(select_language=[target.language])
else:
selector = None
results = search_semantic(target.content, self.model, self._semantic_index, self.chunks, top_k + 1, selector)
return [r for r in results if r.chunk != target][:top_k]

def _get_selector_vector(
self, select_language: list[str] | None = None, select_document: list[str] | None = None
) -> npt.NDArray[np.int_] | None:
"""Create a vector of integers corresponding to the items that should be retrieved."""
selector = []
for language in select_language or []:
selector.extend(self.language_mapping.get(language, []))
for filename in select_document or []:
selector.extend(self.file_mapping.get(filename, []))

return np.asarray(selector) if selector else None

def search(
self,
query: str,
top_k: int = 10,
mode: SearchMode | str = SearchMode.HYBRID,
alpha: float | None = None,
select_language: list[str] | None = None,
select_document: list[str] | None = None,
) -> list[SearchResult]:
"""Search the index and return the top-k most relevant chunks.

Expand All @@ -161,18 +194,23 @@ def search(
:param alpha: Blend weight for hybrid score combination; 1.0 = full semantic
weight, 0.0 = full BM25 weight. File-path penalties and diversity reranking
are applied regardless. ``None`` auto-detects from query type.
:param select_language: Optional list of language codes to filter results by.
:param select_document: Optional list of document paths to filter results by.
:return: Ranked list of :class:`SearchResult` objects, best match first.
:raises ValueError: If `mode` is not a recognised search strategy.
"""
bm25_index, semantic_index = self._bm25_index, self._semantic_index
if not self.chunks:
return []

if mode == SearchMode.BM25:
return search_bm25(query, bm25_index, self.chunks, top_k)
selector = self._get_selector_vector(select_language, select_document)

if mode == SearchMode.BM25:
return search_bm25(query, bm25_index, self.chunks, top_k, selector=selector)
if mode == SearchMode.SEMANTIC:
return search_semantic(query, self.model, semantic_index, top_k)
return search_semantic(query, self.model, semantic_index, self.chunks, top_k, selector=selector)
if mode == SearchMode.HYBRID:
return search_hybrid(query, self.model, semantic_index, bm25_index, self.chunks, top_k, alpha=alpha)
return search_hybrid(
query, self.model, semantic_index, bm25_index, self.chunks, top_k, alpha=alpha, selector=selector
)
raise ValueError(f"Unknown search mode: {mode!r}")
12 changes: 12 additions & 0 deletions src/semble/index/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,21 @@
import contextlib
from pathlib import Path

import numpy as np
import numpy.typing as npt

from semble.types import Chunk


def selector_to_mask(selector: npt.NDArray[np.int_] | None) -> npt.NDArray[np.bool_] | None:
"""Convert a selector array to a boolean mask."""
if selector is None:
return None
mask = np.zeros(len(selector), dtype=bool)
mask[selector] = True
return mask


def enrich_for_bm25(chunk: Chunk, root: Path | None) -> str:
"""Append file path components to BM25 content to boost path-based queries.

Expand Down
30 changes: 19 additions & 11 deletions src/semble/search.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import bm25s
import numpy as np
import numpy.typing as npt
from vicinity import Vicinity

from semble.index.dense import SelectableBasicBackend
from semble.index.sparse import selector_to_mask
from semble.ranking import apply_query_boost, rerank_topk, resolve_alpha
from semble.tokens import tokenize
from semble.types import Chunk, Encoder, SearchMode, SearchResult
Expand All @@ -21,15 +22,18 @@ def _rrf_scores(scores: dict[Chunk, float]) -> dict[Chunk, float]:
def search_semantic(
query: str,
model: Encoder,
semantic_index: Vicinity[Chunk],
semantic_index: SelectableBasicBackend,
chunks: list[Chunk],
top_k: int,
selector: npt.NDArray[np.int_] | None,
) -> list[SearchResult]:
"""Run semantic search for a query."""
query_embedding = model.encode([query])
hits = semantic_index.query(query_embedding, k=top_k)[0]
indices, scores = semantic_index.query(query_embedding, k=top_k, selector=selector)[0]
# Vicinity returns cosine distance; convert to similarity so higher = better.
return [
SearchResult(chunk=chunk, score=1.0 - float(distance), source=SearchMode.SEMANTIC) for chunk, distance in hits
SearchResult(chunk=chunks[index], score=1.0 - float(distance), source=SearchMode.SEMANTIC)
for index, distance in zip(indices, scores)
]


Expand All @@ -47,9 +51,11 @@ def search_bm25(
bm25_index: bm25s.BM25,
chunks: list[Chunk],
top_k: int,
selector: npt.NDArray[np.int_] | None,
) -> list[SearchResult]:
"""Return chunks ranked by BM25 score, excluding zero-score results."""
scores: npt.NDArray[np.float32] = bm25_index.get_scores(tokenize(query))
mask = selector_to_mask(selector)
scores: npt.NDArray[np.float32] = bm25_index.get_scores(tokenize(query), weight_mask=mask)
indices = _sort_top_k(scores, top_k)

# Exclude chunks with zero score, no query tokens matched.
Expand All @@ -61,11 +67,12 @@ def search_bm25(
def search_hybrid(
query: str,
model: Encoder,
semantic_index: Vicinity[Chunk],
semantic_index: SelectableBasicBackend,
bm25_index: bm25s.BM25,
chunks: list[Chunk],
top_k: int,
alpha: float | None = None,
selector: npt.NDArray[np.int_] | None = None,
) -> list[SearchResult]:
"""Hybrid search: alpha-weighted combination of semantic and BM25 scores.

Expand All @@ -79,6 +86,7 @@ def search_hybrid(
:param chunks: All indexed chunks (parallel to BM25 index).
:param top_k: Number of results to return.
:param alpha: Weight for semantic score (1-alpha goes to BM25). None = auto-detect based on query type.
:param selector: Optional array of chunk indices to filter results by.
:return: List of search results sorted by combined score descending.
"""
alpha_weight = resolve_alpha(query, alpha)
Expand All @@ -87,15 +95,15 @@ def search_hybrid(
# 5x is sufficient; latency difference vs larger multipliers is negligible.
candidate_count = top_k * 5

semantic = search_semantic(query, model, semantic_index, candidate_count)
semantic = search_semantic(query, model, semantic_index, chunks, candidate_count, selector)
semantic_scores: dict[Chunk, float] = {result.chunk: result.score for result in semantic}
bm25_result_scores = {}
for result in search_bm25(query, bm25_index, chunks, candidate_count):
bm25_scores = {}
for result in search_bm25(query, bm25_index, chunks, candidate_count, selector):
if result.score:
bm25_result_scores[result.chunk] = result.score
bm25_scores[result.chunk] = result.score

normalized_semantic = _rrf_scores(semantic_scores)
normalized_bm25 = _rrf_scores(bm25_result_scores)
normalized_bm25 = _rrf_scores(bm25_scores)

combined_scores: dict[Chunk, float] = {
chunk: alpha_weight * normalized_semantic.get(chunk, 0.0)
Expand Down