In [1]:
from sentence_transformers import SentenceTransformer
from fastembed import SparseTextEmbedding
import pandas as pd
from qdrant_client import QdrantClient, models
from tqdm.notebook import tqdm
import numpy as np
#evaluate(qrels, dense_run, metrics=["precision@10", "mrr@10"], make_comparable=True)
from fastembed import TextEmbedding, SparseTextEmbedding, LateInteractionTextEmbedding


In [40]:
dataset = pd.read_parquet(
            'hf://datasets/neural-bridge/rag-dataset-12000/data/train-00000-of-00001-9df3a936e1f63191.parquet'
        ).dropna().reset_index(drop=True)

In [34]:
client = QdrantClient()
dataset = pd.read_parquet(
            'hf://datasets/neural-bridge/rag-dataset-12000/data/train-00000-of-00001-9df3a936e1f63191.parquet'
        ).dropna().reset_index(drop=True)
load_batch_size = 16,
ef_construct = 100,
m = 16,
full_scan_threshold = 10,
rerank_model_name="colbert-ir/colbertv2.0",
dense_model_name="sentence-transformers/all-MiniLM-L6-v2",
sparse_model_name = 'Qdrant/bm25',
device="cpu",
dense_model = TextEmbedding("sentence-transformers/all-MiniLM-L6-v2")
sparse_model = SparseTextEmbedding('Qdrant/bm25')
late_interaction_embedding_model = LateInteractionTextEmbedding("colbert-ir/colbertv2.0")

In [3]:
len(list(late_interaction_embedding_model.passage_embed(dataset["context"][0:1]))[0])

512

In [35]:
dataset=dataset[:1000]

In [5]:
dataset["context"][0:1]

0    Caption: Tasmanian berry grower Nic Hansen sho...
Name: context, dtype: object

In [43]:
if client.collection_exists(collection_name="test"):
    client.delete_collection(collection_name="test")
client.create_collection("test",
            vectors_config={
                "all-MiniLM-L6-v2": models.VectorParams(
                    size=384,
                    distance=models.Distance.COSINE,
                ),
                "colbertv2.0": models.VectorParams(
                    size=128,
                    distance=models.Distance.COSINE,
                    multivector_config=models.MultiVectorConfig(
                        comparator=models.MultiVectorComparator.MAX_SIM,
                    )
                ),
            },
            sparse_vectors_config={
                "bm25": models.SparseVectorParams(
                    modifier=models.Modifier.IDF,
                )
            }
        )


load_batch_size = 16  

for i in tqdm(range(len(dataset) // load_batch_size + 1)):
    start_idx = i * load_batch_size
    end_idx = (i + 1) * load_batch_size
    row = dataset.iloc[start_idx:end_idx] 

    dense_embeddings = list(dense_model.passage_embed(row["context"].values))
    bm25_embeddings = list(sparse_model.passage_embed(row["context"].values))
    late_interaction_embeddings = list(late_interaction_embedding_model.passage_embed(row["context"].values))

    client.upload_points(
        "test",
        points=[
            models.PointStruct(
                id=int(id_),
                vector={
                    "all-MiniLM-L6-v2": dense_embeddings[i],
                    "bm25": bm25_embeddings[i].as_object(),
                    "colbertv2.0": late_interaction_embeddings[i],
                },
                payload={
                    "_id": i,
                    "text": row["context"][id_],
                }
            ) for i, id_ in enumerate(row.index)  
        ],
        batch_size=load_batch_size,
    )

  0%|          | 0/600 [00:00<?, ?it/s]

In [None]:
dense_queries = list(dense_model.passage_embed(dataset['question'], show_progress_bar=True, batch_size=load_batch_size))
sparse_queries = list(sparse_model.passage_embed(dataset['question'], show_progress_bar=True, batch_size=load_batch_size))
rerank_queries = list(late_interaction_embedding_model.passage_embed(dataset['question'], show_progress_bar=True, batch_size=load_batch_size))
search_results = []

dense_queries = list(dense_model.passage_embed(
    dataset['question'], 
    show_progress_bar=True, 
    batch_size=load_batch_size
))
sparse_queries = list(sparse_model.passage_embed(
    dataset['question'], 
    show_progress_bar=True, 
    batch_size=load_batch_size
))
rerank_queries = list(late_interaction_embedding_model.passage_embed(
    dataset['question'], 
    show_progress_bar=True, 
    batch_size=load_batch_size
))

run_dict = {}

for query_idx in tqdm(range(len(dense_queries))):
    query_id = str(query_idx)
    
    dense_query_vector = dense_queries[query_idx]
    sparse_query_vector = sparse_queries[query_idx]
    late_query_vector = rerank_queries[query_idx]
    prefetch = [
        models.Prefetch(
            query=dense_query_vector,
            using="all-MiniLM-L6-v2", 
            limit=20
        ),
        models.Prefetch(
            query=models.SparseVector(**sparse_query_vector.as_object()),
            using="bm25",            
            limit=20
        )
    ]
    results = client.query_points(
        "test",              
        prefetch=prefetch,      
        query=late_query_vector, 
        using="colbertv2.0",     
        with_payload=False,
        limit=20                 
    )
    
    run_dict[query_id] = {
        str(point.id): point.score
        for point in results.points
    }

  0%|          | 0/9598 [00:00<?, ?it/s]

In [45]:
run_dict

{'0': {'0': 9.09739,
  '7347': 6.6065974,
  '1161': 6.435562,
  '9225': 5.969359,
  '2109': 5.891042,
  '8609': 5.859406,
  '5236': 5.8293796,
  '4116': 5.687734,
  '5618': 5.6869073,
  '6956': 5.6410103,
  '6842': 5.4884815,
  '6400': 5.354558,
  '7143': 5.3535113,
  '6451': 5.3401933,
  '2290': 5.306905,
  '6839': 5.2700114,
  '843': 5.22953,
  '8007': 5.1873846,
  '258': 5.0651,
  '8390': 5.0574837},
 '1': {'1': 12.034189,
  '4253': 9.978613,
  '1965': 9.165308,
  '1316': 9.093838,
  '2380': 9.085963,
  '4826': 8.98652,
  '1750': 8.97985,
  '6282': 8.685312,
  '5507': 8.531012,
  '2203': 8.444836,
  '467': 8.442221,
  '2630': 8.407343,
  '5596': 8.271476,
  '162': 8.183465,
  '8768': 8.068176,
  '4168': 8.041452,
  '8303': 8.024167,
  '2108': 8.022286,
  '331': 7.651526,
  '3356': 7.6453066},
 '2': {'2': 11.32105,
  '8907': 8.161923,
  '404': 8.038729,
  '3456': 7.61776,
  '9405': 7.611228,
  '4588': 7.4964156,
  '7591': 7.45772,
  '699': 7.381824,
  '5868': 7.359888,
  '89': 7.2628

In [None]:
import numpy as np

results = run_dict

correct_at_1 = 0
correct_at_5 = 0
correct_at_2=0
correct_at_3=0
correct_at_10 = 0
correct_at_20=0
total_queries = len(results)

for query_id, retrieved_items in results.items():
    retrieved_ids = list(retrieved_items.keys())
    
    if query_id in retrieved_ids[:1]:
        correct_at_1 += 1
    if query_id in retrieved_ids[:2]:
        correct_at_2 += 1
    
    if query_id in retrieved_ids[:3]:
        correct_at_3 += 1
    if query_id in retrieved_ids[:5]:
        correct_at_5 += 1
    
    if query_id in retrieved_ids[:10]:
        correct_at_10 += 1
    if query_id in retrieved_ids[:20]:
        correct_at_20 += 1


recall_at_1 = correct_at_1 / total_queries
recall_at_2 = correct_at_2 / total_queries
recall_at_3 = correct_at_3 / total_queries
recall_at_5 = correct_at_5 / total_queries
recall_at_10 = correct_at_10 / total_queries
recall_at_20 = correct_at_20 / total_queries

print(f"Recall@1: {recall_at_1:.4f}")
print(f"Recall@2: {recall_at_2:.4f}")
print(f"Recall@3: {recall_at_3:.4f}")
print(f"Recall@5: {recall_at_5:.4f}")
print(f"Recall@10: {recall_at_10:.4f}")
print(f"Recall@20: {recall_at_20:.4f}")

Recall@1: 0.9061
Recall@2: 0.9331
Recall@3: 0.9437
Recall@5: 0.9543
Recall@10: 0.9642
Recall@20: 0.9709


In [64]:
metrics_dict = {
            'recall@1': recall_at_1,
            'recall@2': recall_at_2,
            'recall@3': recall_at_3,
            'recall@5': recall_at_5,
            'recall@10': recall_at_10,
            'recall@20': recall_at_20
        }
metric_df = pd.DataFrame([metrics_dict])

In [65]:
metric_df

Unnamed: 0,recall@1,recall@2,recall@3,recall@5,recall@10,recall@20
0,0.906126,0.933111,0.943738,0.954261,0.964159,0.970931
