## Demo

### Imports

In [None]:
%load_ext autoreload
%autoreload 2

from decouple import config
from fastembed import (
    TextEmbedding, 
    SparseTextEmbedding, 
    LateInteractionTextEmbedding
)
from qdrant_client import QdrantClient
from qdrant_client.models import (
    Distance,
    Fusion,
    VectorParams, 
    SparseVectorParams,
    SparseIndexParams,
    MultiVectorConfig,
    MultiVectorComparator,
    Modifier,
    HnswConfigDiff
)
from ranx import Qrels, Run, evaluate

from rag.models import (
    DenseModelConfig, 
    SparseModelConfig,
    RerankingModelConfig,
    DenseSearchManager,
    SparseSearchManager,
    HybridFusionSearchManager, 
    HybridRerankingSearchManager, 
    Metadata
)
from rag.utils import load_datasets


QDRANT_URL = config('QDRANT_URL')
QDRANT_API_KEY = config('QDRANT_API_KEY')
NVIDIA_API_KEY = config('NVIDIA_API_KEY')
CUDA=config('CUDA', cast=bool)
PROVIDER = 'CUDAExecutionProvider' if CUDA else 'CPUExecutionProvider'

### Client

In [2]:
qdrant_client = QdrantClient(
    url=QDRANT_URL, 
    api_key=QDRANT_API_KEY,
)

print(qdrant_client.get_collections())

collections=[CollectionDescription(name='dense_collection'), CollectionDescription(name='fusion_collection2'), CollectionDescription(name='fusion_collection'), CollectionDescription(name='dense_collection2'), CollectionDescription(name='reranking_collection2'), CollectionDescription(name='sparse_collection2'), CollectionDescription(name='reranking_collection'), CollectionDescription(name='sparse_collection')]


### Models

In [3]:
dense_model_name = 'BAAI/bge-small-en-v1.5'
dense_model = TextEmbedding(
    model_name=dense_model_name,
    providers=[PROVIDER]
)

sparse_model_name = 'Qdrant/bm25'
sparse_model = SparseTextEmbedding(
    model_name=sparse_model_name,
    providers=[PROVIDER],
    k1=1.5,
    b=0.75
)

reranking_model_name = 'colbert-ir/colbertv2.0'
reranking_model = LateInteractionTextEmbedding(
    model_name=reranking_model_name,
    providers=[PROVIDER]
)

dense_model_config = DenseModelConfig(
    name=dense_model_name,
    vector_params=VectorParams(
        size=384,
        distance=Distance.COSINE,
        hnsw_config=HnswConfigDiff(
            m=16,
            ef_construct=128,
            on_disk=True
        ),
        on_disk=True
    )
)

sparse_model_config = SparseModelConfig(
    name=sparse_model_name,
    sparse_vector_params=SparseVectorParams(
        index=SparseIndexParams(
            on_disk=True
        ),
        modifier=Modifier.IDF
    )
)

reranking_model_config = RerankingModelConfig(
    name=reranking_model_name,
    vector_params=VectorParams(
        size=128,
        distance=Distance.COSINE,
        hnsw_config=HnswConfigDiff(
            m=0     # disable HNSW
        ),
        on_disk=True,
        multivector_config=MultiVectorConfig(
            comparator=MultiVectorComparator.MAX_SIM,
        )
    )
)



### Documents

In [4]:
corpus_df, queries_df, qrels_df = load_datasets(5_000, 5_000)

print(len(corpus_df), len(queries_df), len(qrels_df))

104 109 111


In [18]:
len(corpus_df)

104

In [10]:
corpus_df[:3]

Unnamed: 0,_id,title,text
0,307,Abraham Lincoln,"Abraham Lincoln ( ; February 12, 1809 – April ..."
1,628,Aldous Huxley,Aldous Leonard Huxley ( ; 26 July 1894 – 22 No...
2,844,Amsterdam,Amsterdam ( ; ] ) is the capital and most popu...


In [11]:
queries_df[:3]

Unnamed: 0,_id,title,text
0,5ac3b95755429939154138e6,,What language family is the language of the tr...
1,5abee3d95542994516f4546c,,"Which of the following is acclaimed for his ""l..."
2,5a8c4c8e554299585d9e3652,,Filipino sitcom Iskul Bukol had a theme song t...


In [12]:
qrels_df[:3]

Unnamed: 0,query-id,corpus-id,score
101,5ac3b95755429939154138e6,7222,1
137,5abee3d95542994516f4546c,2310,1
203,5a8c4c8e554299585d9e3652,9288,1


In [34]:
corpus_texts: list[str] = corpus_df['text'].values.tolist()
metadatas = [
    Metadata(
        id=row['_id'],
        text=row['text']
    )
    for _, row in corpus_df.iterrows()
]

dense_embeddings = list(dense_model.embed(corpus_texts))
sparse_embeddings = list(sparse_model.embed(corpus_texts))
reranking_embeddings = list(reranking_model.embed(corpus_texts))

### Indexing

In [32]:
dense_collection_name = 'dense_collection2'
sparse_collection_name = 'sparse_collection2'
fusion_collection_name = 'fusion_collection2'
reranking_collection_name = 'reranking_collection2'

#### Dense

In [None]:
dense_search = DenseSearchManager(
    qdrant_client=qdrant_client, 
    dense_model_config=dense_model_config
)

dense_search.create_collection(dense_collection_name)
dense_search.upsert(
    dense_collection_name,
    dense_embeddings,
    metadatas
)

#### Sparse

In [20]:
sparse_search = SparseSearchManager(
    qdrant_client=qdrant_client, 
    sparse_model_config=sparse_model_config
)

sparse_search.create_collection(sparse_collection_name)
sparse_search.upsert(
    sparse_collection_name,
    sparse_embeddings,
    metadatas
)

'completed'

#### Hybrid - Fusion

In [21]:
fusion_search = HybridFusionSearchManager(
    qdrant_client=qdrant_client, 
    dense_model_config=dense_model_config,
    sparse_model_config=sparse_model_config
)

fusion_search.create_collection(fusion_collection_name)
fusion_search.upsert(
    fusion_collection_name, 
    dense_embeddings, 
    sparse_embeddings,
    metadatas
)

'completed'

#### Hybrid - Reranking

In [48]:
qdrant_client.create_collection(
    collection_name="movies",
    vectors_config=VectorParams(
        size=128, #size of each vector produced by ColBERT
        distance=Distance.COSINE, #similarity metric between each vector
        multivector_config=MultiVectorConfig(
            comparator=MultiVectorComparator.MAX_SIM #similarity metric between multivectors (matrices)
        ),
    ),
)



True

In [63]:
import numpy as np

duplicated_flattened = reranking_embeddings * 8

In [67]:
from qdrant_client.models import PointStruct

qdrant_client.upload_points(
    collection_name="movies",
    points=[
        PointStruct(
            id=idx,
            payload={
                'text': 'bla',
            },
            vector=vector
        )
        for idx, vector in enumerate(duplicated_flattened)
    ],
)

In [None]:
reranking_search = HybridRerankingSearchManager(
    qdrant_client=qdrant_client, 
    dense_model_config=dense_model_config,
    sparse_model_config=sparse_model_config,
    reranking_model_config=reranking_model_config
)

reranking_search.create_collection(reranking_collection_name)
reranking_search.upsert(
    reranking_collection_name, 
    dense_embeddings, 
    sparse_embeddings,
    reranking_embeddings,
    metadatas
)

### Query

In [7]:
query_texts: list[str] = queries_df['text'].values.tolist()

query_dense_embeddings = list(dense_model.embed(query_texts))
query_sparse_embeddings = list(sparse_model.embed(query_texts))
query_reranking_embeddings = list(reranking_model.embed(query_texts))

In [None]:
# TODO how to evaluate and handle multiple points for each question

### Search

In [8]:
top_k = 5

#### Dense

In [14]:
dense_scored_points_list = [
    dense_search.search(
        dense_collection_name,
        query_dense_embedding,
        top_k
    )
    for query_dense_embedding in query_dense_embeddings
]

print(*dense_scored_points_list[0][:3], sep='\n')

id='abe66f13-361d-4205-a4b0-92006f8b5efd' version=0 score=0.6729106 payload={'id': 7222, 'text': 'The Choctaw (In the Choctaw language, Chahta) are a Native American people originally occupying what is now the Southeastern United States (modern-day Alabama, Florida, Mississippi, and Louisiana). Their Choctaw language belongs to the Muskogean language family group.'} vector=None shard_key=None order_value=None
id='bba3ee6b-8818-4335-a61f-d6b3c5fb178e' version=0 score=0.63985586 payload={'id': 2303, 'text': 'Aramaic (אַרָמָיָא "Arāmāyā", Syriac: ܐܪܡܝܐ\u200e , Arabic: آرامية\u200e \u200e ) is a language or group of languages belonging to the Semitic subfamily of the Afroasiatic language family. More specifically, it is part of the Northwest Semitic group, which also includes the Canaanite languages such as Hebrew and Phoenician. The Aramaic alphabet was widely adopted for other languages and is ancestral to the Hebrew, Syriac and Arabic alphabets.'} vector=None shard_key=None order_value=

#### Sparse

In [13]:
sparse_scored_points = sparse_search.search(
    sparse_collection_name,
    query_sparse_embedding,
    1
)

print(*sparse_scored_points, sep='\n')

id='2bd53b67-23b4-4a84-8f47-12782ef07e2a' version=0 score=4.1631217 payload={'id': 1, 'text': 'FastEmbed is lighter than Transformers & Sentence-Transformers.'} vector=None shard_key=None order_value=None


#### Hybrid - Fusion

In [14]:
fusion_scored_points = fusion_search.search(
    fusion_collection_name,
    query_dense_embedding,
    query_sparse_embedding,
    Fusion.RRF,
    1
)

print(*fusion_scored_points, sep='\n')

id='2c0b2e24-1435-451c-8dc0-298e1c710e53' version=0 score=1.0 payload={'id': 1, 'text': 'FastEmbed is lighter than Transformers & Sentence-Transformers.'} vector=None shard_key=None order_value=None


#### Hybrid - Reranking

In [15]:
reranking_scored_points = reranking_search.search(
    reranking_collection_name,
    query_dense_embedding,
    query_sparse_embedding,
    query_reranking_embedding,
    2,
    1
)

print(*reranking_scored_points, sep='\n')

id='2d78e6a2-e0a4-4380-afe8-7d054612d469' version=0 score=5.8587246 payload={'id': 1, 'text': 'FastEmbed is lighter than Transformers & Sentence-Transformers.'} vector=None shard_key=None order_value=None


### Evaluate

In [15]:
qrels_dict = {}

for _, row in qrels_df.iterrows():
    query_id = row['query-id']
    corpus_id = str(row['corpus-id'])
    relevance = int(row['score'])
    
    if query_id not in qrels_dict:
        qrels_dict[query_id] = {}

    qrels_dict[query_id][corpus_id] = relevance

In [31]:
runs_dict = {}

for i, query_id in enumerate(queries_df['_id'].values):
    runs_dict[query_id] = {}
    
    for scored_point in dense_scored_points_list[i]:
        doc_id = str(scored_point.payload['id'])
        runs_dict[query_id][doc_id] = float(scored_point.score)

qrels_ranx = Qrels(qrels_dict)
run_ranx = Run(runs_dict)

results = evaluate(qrels_ranx, run_ranx, metrics=[f'ndcg@{top_k}'])

print(results)

0.9481393103951706


In [None]:
metrics = [
    'mrr',
    'map',
    f'precision@{top_k}',
    f'recall@{top_k}',
    f'ndcg@{top_k}'
]