# Retrieval

## Libraries

In [None]:
import sys
import warnings

import lancedb
from lancedb.db import DBConnection
from lancedb.rerankers import (
    ColbertReranker,
    CrossEncoderReranker,
    LinearCombinationReranker,
)
from lancedb.table import Table


# isort: off
sys.path.append("..")  # include repository-root to load modules from src folder
from src.constants import LANCEDB_URI, get_rag_config  # noqa: E402
from src.embeddings import create_local_emb_func  # noqa: E402

# ignore some warnings
warnings.simplefilter(action="ignore", category=FutureWarning)

## Parameters

In [None]:
table_name: str = get_rag_config()["knowledge_base"]["table_name"]
emb_model_name: str = get_rag_config()["embeddings"]["model_name"]

## Code

## Setup

### Text Embedding Setup

In [None]:
emb_func = create_local_emb_func(emb_model_name)
# measure the dimension of the embedding (serves as test)
n_dim_vec = len(emb_func(["foo"])[0])

### Access Knowledge Base

In [None]:
# load database
# Testing
db: DBConnection = lancedb.connect(uri=LANCEDB_URI)
print(f"List of all tables in the LanceDB database: {db.table_names()}")

In [None]:
k_base: Table = db.open_table(table_name)
print(f"Number of entries in the table '{table_name}': {k_base.count_rows()}")

## Testing Retrieval

In [None]:
query_text: str = "How to reduce heart Disease Risk"
query_vec: list[float] = emb_func([query_text])[0]

n_best: int = 10

In [None]:
def show_results(response: list[dict]) -> None:
    print(f"{len(response)} results for query")

    # unique URLs
    urls: set = {hit["url"] for hit in response}
    print(f"{len(urls)} unique URL(s)")

    # unique Titles
    titles: set = {hit["title"] for hit in response}
    print(f"{len(titles)} unique Title(s)")

    print("Paragraphs:")
    for i, hit in enumerate(response):
        score: float = hit.get(
            "_distance", hit.get("score", hit.get("_relevance_score", 0))
        )

        print(f" {i+1:2}. ({score:.3f}) '{hit['title']}' : '{hit['text']}'")

    # # tiles
    # print("Titles:")
    # for i, hit in enumerate(response):
    #     print(f" {i+1:2}. {hit['title']}")

### Vector search
- [Approximate nearest neighbor (ANN) search](https://lancedb.github.io/lancedb/search/#exhaustive-search-knn)

In [None]:
# Create and train the index - you need to have enough data in the table for an effective training step
# (takes ~one minute)
k_base.create_index(metric="cosine", replace=True, accelerator="cuda")

In [None]:
response_vec = k_base.search(query_vec, query_type="vector").limit(n_best).to_list()

In [None]:
show_results(response_vec)

### Full-text search (aka keyword-based search)
- https://lancedb.github.io/lancedb/fts/

In [None]:
k_base.create_fts_index(["text", "title"], replace=True)

In [None]:
response_fts = k_base.search(query_text, query_type="fts").limit(n_best).to_list()

In [None]:
show_results(response_fts)

### Hybrid search
- https://lancedb.github.io/lancedb/hybrid_search/hybrid_search
- only possible if embeddings is handle by database

In [None]:
# default: linear combination, where vector search has 70% weight
response_hy = k_base.search(query_text, query_type="hybrid").limit(n_best).to_list()

In [None]:
show_results(response_hy)

### Hybrid search with reranker (linear combination)

In [None]:
# Use 0.3 as the weight for vector search (instead of 0.7)
reranker_lc = LinearCombinationReranker(weight=0.1)

In [None]:
response_rr = (
    k_base.search(query_text, query_type="hybrid")
    .rerank(reranker=reranker_lc)
    .limit(n_best)
    .to_list()
)

In [None]:
show_results(response_rr)

### Hybrid search with reranker (Cross Encoder)

In [None]:
ce_reranker = CrossEncoderReranker(column="text")

results_ce = (
    k_base.search(query_text, query_type="hybrid")
    .rerank(reranker=ce_reranker)
    .limit(n_best)
    .to_list()
)

In [None]:
show_results(results_ce)

### Hybrid search with reranker (ColBERT )

In [None]:
reranker_co = ColbertReranker()

results_co = (
    k_base.search(query_text, query_type="hybrid")
    .rerank(reranker=reranker_co)
    .limit(n_best)
    .to_list()
)

In [None]:
show_results(results_co)