---
layout: post
title:  Batching with SBERT and LanceDB
description: Retrieval with SBERT and LanceDB
author: "Mahamadi NIKIEMA"
thumbnail-img: profile.jpg
tags: [Python, Scraping, Podcast]
date:   2024-12-29 21:55:51 +0200
categories: scraping
draft: true
---

In [None]:
from sentence_transformers import  SentenceTransformer
from lancedb.embeddings import get_registry
from lancedb.db import DBConnection
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import EmbeddingFunction
from lancedb.table import Table
from typing import Literal

In [2]:
def get_embeddings(batch, model: SentenceTransformer, column: str = "chunk"):
    """Get embeddings for a batch of text using the specified model
    """
    embeddings = model.encode(batch[column])
    return {"vector": embeddings}

In [None]:
def get_or_create_lancedb_table(
    db: DBConnection,
    table_name: str,
    all_docs,
    embedding_model: str = "sentence-transformers",
    model_name: str = "all-MiniLM-L6-v2",
):
    embd_func: EmbeddingFunction = get_registry().get(embedding_model)
    func = embd_func.create(name=model_name)

    class Chunk(LanceModel):
        id: str
        chunk: str = func.SourceField()
        vector: Vector = func.VectorField()  # type: ignore
        vector: Vector(func.ndims()) = func.VectorField()  # type: ignore

    if table_name in db.table_names() and db.open_table(table_name).count_rows() > 0:
        print(f"Table {table_name} already exists")
        table = db.open_table(table_name)
        table.create_fts_index("chunk", replace=True)
        return table

    table = db.create_table(table_name, schema=Chunk, mode="overwrite")
    table.add(all_docs)
    print(f"Table {table_name} created with {len(all_docs)} chunks")
    table.create_fts_index("chunk", replace=True)
    print(f"{table.count_rows()} chunks ingested into the database")
    return table

In [None]:
def retrieve(
    question: str,
    table: Table,
    max_k=25,
    mode: Literal["vector", "fts", "hybrid"] = "vector",
    ):
    try:
        if mode == "fts" or mode == "hybrid":
            results = table.search(
                query=question, vector_column_name=None, query_type=mode
            ).limit(max_k)
        else:
            results = table.search(question, query_type=mode).limit(max_k)

        return [
            {"id": result["id"], "chunk": result["chunk"]}
            for result in results.to_list()
        ]
    except Exception as e:
        print(f"Error: {e}")
        return []

In [3]:
def calculate_recall(predictions: list[str], gt: list[str]):
    # Calculate the proportion of relevant items that were retrieved
    return len([label for label in gt if label in predictions]) / len(gt)

In [None]:
from datasets import load_dataset
dataset = load_dataset("trec", split="train[:100]")
dataset = dataset.map(lambda x: {"chunk": x["text"], "question": x["text"]})