---
layout: post
title:  Batching with SBERT and LanceDB
description: Retrieval with SBERT and LanceDB
author: "Mahamadi NIKIEMA"
thumbnail-img: profile.jpg
tags: [Reranking, Python, Embeddings]
date:   2025-05-14 21:55:51 +0200
categories: scraping
draft: true
---

I start using lancedb for many projects because of the easy setup and simplicity. I used it with open-source models available on HuggingFace and noticed that it is sometimes slow while running some retrieval tasks.

[Sentence Transformers v4.1](https://github.com/UKPLab/sentence-transformers/releases/tag/v4.1.0) release bring multiple backend support of the SBERT model.
Now we can use some backend such as O``NNX`` and ``OpenVINO`` to speed up the inference. As shown in the [benchmark](https://sbert.net/docs/cross_encoder/usage/efficiency.html) the speed-up gains is ``*1.73x*`` on the GPU wile preserving 99.61% of the accuracy. We can now use the SBERT model with the LanceDB backend to speed up the retrieval process.

The native integration of the SBERT model with LanceDB is available but the backend support is not available and they provide a [documentation](https://lancedb.github.io/lancedb/reranking/custom_reranker/#example-of-a-custom-reranker) to write a custom reranker.

I will show you how to write a custom reranker using SBERT and LanceDB for ``ONNX`` backend.

As the ``cross-encoder`` model is already implemented, I took inspiration from the [cross-encoder](https://lancedb.github.io/lancedb/reranking/cross_encoder/) example to implement the ``ONNX`` backend.

In [39]:
import warnings
from typing import Optional
import pyarrow as pa
import lancedb
from lancedb.rerankers import Reranker
from functools import cached_property
from sentence_transformers import CrossEncoder

warnings.filterwarnings("ignore")



class ONNXCrossEncoderReranker(Reranker):
    """
    A custom reranker for LanceDB that uses an ONNX backend for cross-encoder models.

    This reranker provides:
    1. Increased performance through ONNX runtime
    2. Flexibility to filter results based on criteria
    3. Support for various cross-encoder models
    4. Batch processing for efficiency
    """

    def __init__(
        self,
        model_name: str,
        max_length: int = 256,
        batch_size: int = 32,
        device: Optional[str] = None,
        model_kwargs: Optional[dict] = None,
        trust_remote_code: bool = True,
        column: str = "text",
        **kwargs,
    ):
        """
        Initialize the ONNX Cross-Encoder Reranker.

        Args:
            model_name_or_path: Original model name or path for tokenization
            onnx_model_path: Path to the ONNX model file
            max_length: Maximum sequence length for tokenization
            filters: String or list of strings to filter out from results
            batch_size: Number of examples to process at once
            score_threshold: Minimum score threshold for results
            device: Device to run inference on ('cpu', 'cuda', etc.)
        """
        super().__init__(**kwargs)

        self.model_name = model_name
        self.model_kwargs = model_kwargs if model_kwargs is not None else {}
        self.max_length = max_length
        self.column = column
        self.device = device
        self.batch_size = batch_size
        self.trust_remote_code = trust_remote_code
        if self.device is None:
            self.device = "cpu"

    @cached_property
    def model(self):
        # Allows overriding the automatically selected device
        cross_encoder = CrossEncoder(
            model_name_or_path=self.model_name,
            backend="onnx",
            device=self.device,
            model_kwargs=self.model_kwargs,
        )

        return cross_encoder

    def _rerank(self, result_set: pa.Table, query: str):
        result_set = self._handle_empty_results(result_set)
        if len(result_set) == 0:
            return result_set
        passages = result_set[self.column].to_pylist()
        cross_inp = [[query, passage] for passage in passages]
        cross_scores = self.model.predict(cross_inp)
        result_set = result_set.append_column(
            "_relevance_score", pa.array(cross_scores, type=pa.float32())
        )

        return result_set

    def rerank_hybrid(
        self,
        query: str,
        vector_results: pa.Table,
        fts_results: pa.Table,
    ):
        combined_results = self.merge_results(vector_results, fts_results)
        combined_results = self._rerank(combined_results, query)
        # sort the results by _score
        if self.score == "relevance":
            combined_results = self._keep_relevance_score(combined_results)
        elif self.score == "all":
            raise NotImplementedError("return_score='all' not implemented for CrossEncoderReranker")
        combined_results = combined_results.sort_by([("_relevance_score", "descending")])

        return combined_results

    def rerank_vector(self, query: str, vector_results: pa.Table):
        vector_results = self._rerank(vector_results, query)
        if self.score == "relevance":
            vector_results = vector_results.drop_columns(["_distance"])

        vector_results = vector_results.sort_by([("_relevance_score", "descending")])
        return vector_results

    def rerank_fts(self, query: str, fts_results: pa.Table):
        fts_results = self._rerank(fts_results, query)
        if self.score == "relevance":
            fts_results = fts_results.drop_columns(["_score"])

        fts_results = fts_results.sort_by([("_relevance_score", "descending")])
        return fts_results

Let us try it out with a simple example

In [40]:
db = lancedb.connect("./lancedb")

In [53]:
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import get_registry


func = get_registry().get("sentence-transformers").create(name="sentence-transformers/all-MiniLM-L6-v2",
                                                          device="cpu")

# Define a Schema
class Words(LanceModel):
    # This is the source field to compute the embeddings and index
    text: str = func.SourceField()

    # This is the vector field that will store the output of the embeddings
    vector: Vector(func.ndims()) = func.VectorField()

In [54]:
data = [
    {"text": "This guy is happy"},
    {"text": "This person is not happy"},
    {"text": "That is a very happy person"},
    {"text": "This is a good guy"},
]

In [55]:
table = db.create_table("testing", schema=Words, mode="overwrite")
table.add(data)

In [56]:
cross_encoder = ONNXCrossEncoderReranker(model_name="Alibaba-NLP/gte-reranker-modernbert-base",
                                         model_kwargs={"file_name": "onnx/model_int8.onnx"})
question = "This is a happy person"
results = table.search(question, query_type="vector").limit(4)

In [57]:
print(results.to_pandas().drop("vector", axis=1))

                          text  _distance
0  That is a very happy person   0.394280
1            This guy is happy   0.515812
2     This person is not happy   0.548586
3           This is a good guy   0.982840


In [58]:
print(results.rerank(reranker=cross_encoder).to_pandas().drop("vector", axis=1))

Too many ONNX model files were found in onnx/model.onnx ,onnx/model_bnb4.onnx ,onnx/model_fp16.onnx ,onnx/model_int8.onnx ,onnx/model_q4.onnx ,onnx/model_q4f16.onnx ,onnx/model_quantized.onnx ,onnx/model_uint8.onnx. specify which one to load by using the `file_name` and/or the `subfolder` arguments. Loading the file model_int8.onnx in the subfolder onnx.


model_int8.onnx:   0%|          | 0.00/151M [00:00<?, ?B/s]

                          text  _relevance_score
0  That is a very happy person          0.938593
1            This guy is happy          0.916949
2           This is a good guy          0.885104
3     This person is not happy          0.819631
