In [5]:
from pylate import evaluation, indexes, models, retrieve

# Step 1: Initialize the ColBERT model

dataset = "scifact" # Choose the dataset you want to evaluate
model = models.ColBERT(
    model_name_or_path="lightonai/GTE-ModernColBERT-v1",
    device="cpu" # "cpu" or "cuda" or "mps"
)

# Step 2: Create a Voyager index
index = indexes.Voyager(
    index_folder="pylate-index",
    index_name=dataset,
    override=True,  # Overwrite any existing index
)

# Step 3: Load the documents, queries, and relevance judgments (qrels)
documents, queries, qrels = evaluation.load_beir(
    dataset,  # Specify the dataset (e.g., "scifact")
    split="test",  # Specify the split (e.g., "test")
)

# Step 4: Encode the documents
documents_embeddings = model.encode(
    [document["text"] for document in documents],
    batch_size=32,
    is_query=False,  # Indicate that these are documents
    show_progress_bar=True,
)

# Step 5: Add document embeddings to the index
index.add_documents(
    documents_ids=[document["id"] for document in documents],
    documents_embeddings=documents_embeddings,
)

# Step 6: Encode the queries
queries_embeddings = model.encode(
    queries,
    batch_size=32,
    is_query=True,  # Indicate that these are queries
    show_progress_bar=True,
)

# Step 7: Retrieve top-k documents
retriever = retrieve.ColBERT(index=index)
scores = retriever.retrieve(
    queries_embeddings=queries_embeddings,
    k=100,  # Retrieve the top 100 matches for each query
)

# Step 8: Evaluate the retrieval results
results = evaluation.evaluate(
    scores=scores,
    qrels=qrels,
    queries=queries,
    metrics=[f"ndcg@{k}" for k in [1, 3, 5, 10]] # NDCG for different k values                                     # Mean Average Precision (MAP)
    + ["recall@10", "recall@100"]                     # Recall at k
    + ["precision@10", "precision@100"],              # Precision at k
)

print(results)

./evaluation_datasets/scifact.zip: 100%|██████████| 2.69M/2.69M [00:00<00:00, 2.89MiB/s]
100%|██████████| 5183/5183 [00:00<00:00, 209729.36it/s]
Encoding documents (bs=32): 100%|██████████| 162/162 [03:58<00:00,  1.47s/it]
Adding documents to the index (bs=2000): 100%|██████████| 3/3 [00:22<00:00,  7.66s/it]
Encoding queries (bs=32): 100%|██████████| 10/10 [00:01<00:00,  6.07it/s]
Retrieving documents (bs=50):  86%|████████▌ | 6/7 [01:09<00:11, 11.57s/it]

{'ndcg@1': np.float64(0.6066666666666667), 'ndcg@3': np.float64(0.6942010170870506), 'ndcg@5': np.float64(0.716427828679468), 'ndcg@10': np.float64(0.7324129787189755), 'recall@10': np.float64(0.8562222222222222), 'recall@100': np.float64(0.9610000000000001), 'precision@10': np.float64(0.09633333333333333), 'precision@100': np.float64(0.010866666666666667)}



