In [None]:
import numpy as np
import faiss
import pytrec_eval
from sentence_transformers import SentenceTransformer
import torch
from datasets import load_dataset

# --------------------------
# Our Evaluation Class
# --------------------------
import logging

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

class Evaluator:
    @staticmethod
    def evaluate(
        qrels: dict[str, dict[str, int]],
        results: dict[str, dict[str, float]],
        k_values: list[int],
        ignore_identical_ids: bool = True,
    ) -> tuple[dict[str, float], dict[str, float], dict[str, float], dict[str, float]]:
        if ignore_identical_ids:
            logger.info(
                "For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to not ignore them."
            )
            for qid, rels in results.items():
                # Remove any document id that is identical to the query id
                for pid in list(rels):
                    if qid == pid:
                        results[qid].pop(pid)

        ndcg = {}
        _map = {}
        recall = {}
        precision = {}

        for k in k_values:
            ndcg[f"NDCG@{k}"] = 0.0
            _map[f"MAP@{k}"] = 0.0
            recall[f"Recall@{k}"] = 0.0
            precision[f"P@{k}"] = 0.0

        map_string = "map_cut." + ",".join([str(k) for k in k_values])
        ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values])
        recall_string = "recall." + ",".join([str(k) for k in k_values])
        precision_string = "P." + ",".join([str(k) for k in k_values])
        evaluator = pytrec_eval.RelevanceEvaluator(
            qrels, {map_string, ndcg_string, recall_string, precision_string}
        )
        scores = evaluator.evaluate(results)

        for query_id in scores.keys():
            for k in k_values:
                ndcg[f"NDCG@{k}"] += scores[query_id].get("ndcg_cut_" + str(k), 0)
                _map[f"MAP@{k}"] += scores[query_id].get("map_cut_" + str(k), 0)
                recall[f"Recall@{k}"] += scores[query_id].get("recall_" + str(k), 0)
                precision[f"P@{k}"] += scores[query_id].get("P_" + str(k), 0)

        num_queries = len(scores)
        for k in k_values:
            ndcg[f"NDCG@{k}"] = round(ndcg[f"NDCG@{k}"] / num_queries, 5)
            _map[f"MAP@{k}"] = round(_map[f"MAP@{k}"] / num_queries, 5)
            recall[f"Recall@{k}"] = round(recall[f"Recall@{k}"] / num_queries, 5)
            precision[f"P@{k}"] = round(precision[f"P@{k}"] / num_queries, 5)

        for metric in [ndcg, _map, recall, precision]:
            logger.info("\nEvaluation metrics:")
            for key in metric.keys():
                logger.info(f"{key}: {metric[key]:.4f}")

        return ndcg, _map, recall, precision

# --------------------------
# Retrieval and Evaluation
# --------------------------
def get_embedding(texts):
    """
    Generate embeddings using the loaded model.
    """
    return model.encode(texts, convert_to_numpy=True, device=device)

def preprocess_dataset(dataset, num_queries=None):
    data = dataset["train"]  # Use the appropriate split
    if num_queries is None:
        num_queries = len(data["title"])
    
    queries = data["title"][:num_queries]
    query_ids = data["docid"][:num_queries]  # Assuming each query has a unique docid
    corpus = data["text"]
    
    # Assume that for each query, the positive document is the corresponding corpus text
    query_to_positive = {qid: {doc} for qid, doc in zip(query_ids, data["text"][:num_queries])}
    return queries, query_ids, corpus, query_to_positive

def build_faiss_index(corpus_embeddings):
    """
    Builds a FAISS index for efficient retrieval.
    """
    dim = corpus_embeddings.shape[1]
    index = faiss.IndexFlatL2(dim)
    index.add(corpus_embeddings)
    return index

def evaluate_retrieval(dataset, lang, num_queries=None):
    """
    Evaluate the retrieval model and compute NDCG@10 and Recall@100.
    """
    print(f"\nEvaluating language: {lang}...")
    
    # Preprocess dataset
    queries, query_ids, corpus, query_to_positive = preprocess_dataset(dataset, num_queries)
    if not queries or not corpus:
        print(f"No valid data found for {lang}")
        return 0, 0

    # Embed queries and corpus
    query_embeddings = get_embedding(queries)
    corpus_embeddings = get_embedding(corpus)
    
    # Build FAISS index
    index = build_faiss_index(corpus_embeddings)
    
    # Determine the maximum number of retrieved items needed
    max_k = max(10, 100)
    distances, indices = index.search(query_embeddings, max_k)
    
    # Format the results for pytrec_eval:
    # Use corpus texts as document ids. (In your qrels, you use the actual corpus text.)
    results = {}
    for i, qid in enumerate(query_ids):
        results[qid] = {
            corpus[idx]: float(-distances[i][j])  # Negative distance as score
            for j, idx in enumerate(indices[i][:max_k])
        }
    
    # Format the qrels. Each query has a set of relevant documents (here, the positive document).
    qrels = {
        qid: {doc: 1 for doc in query_to_positive.get(qid, [])} 
        for qid in query_ids
    }
    
    # Evaluate using our Evaluator class with k_values of 10 and 100.
    k_values = [10, 100]
    ndcg, _map, recall, precision = Evaluator.evaluate(qrels, results, k_values)
    
    # Extract the desired metrics
    ndcg_at_10 = ndcg.get("NDCG@10", 0)
    recall_at_100 = recall.get("Recall@100", 0)
    print(f"\n{lang} - NDCG@10: {ndcg_at_10:.4f}, Recall@100: {recall_at_100:.4f}")
    
    return recall_at_100, ndcg_at_10

# --------------------------
# Main Code: Load Model and Dataset
# --------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

model_name = "facebook/mcontriever-msmarco"
model = SentenceTransformer(model_name, device=device)

dataset_name = "miracl/miracl-corpus"
# For example, use the Yoruba version. Adjust the configuration if needed.
dataset = load_dataset(dataset_name, "yo", trust_remote_code=True)
lang = "Yoruba (yo)"

# Run evaluation
recall, ndcg = evaluate_retrieval(dataset, lang)
print(f"\nFinal Results:\nRecall@100: {recall:.4f}\nNDCG@10: {ndcg:.4f}")

Evaluating language: Yoruba (yo)...
Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:13<00:00,  1.96s/it]
Batches: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 1533/1533 [01:09<00:00, 22.07it/s]
INFO:__main__:For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to not ignore them.
INFO:__main__:
Evaluation metrics:
INFO:__main__:NDCG@10: 0.0282
INFO:__main__:NDCG@100: 0.0295
INFO:__main__:
Evaluation metrics:
INFO:__main__:MAP@10: 0.0275
INFO:__main__:MAP@100: 0.0279
INFO:__main__:
Evaluation metrics:
INFO:__main__:Recall@10: 0.0300
INFO:__main__:Recall@100: 0.0350
INFO:__main__:
Evaluation metrics:
INFO:__main__:P@10: 0.0030
INFO:__main__:P@100: 0.0003

Yoruba (yo) - NDCG@10: 0.0282, Recall@100: 0.0350

Final Results:
Recall@100: 0.0350
NDCG@10: 0.0282
l78gao@cdr2644 ~/scratch/Evaluate $ python test.py
Using device: cuda
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: facebook/mcontriever-msmarco
WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name facebook/mcontriever-msmarco. Creating a new one with mean pooling.
Some weights of BertModel were not initialized from the model checkpoint at facebook/mcontriever-msmarco and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

Evaluating language: Yoruba (yo)...
Batches: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 1533/1533 [00:33<00:00, 45.61it/s]
Batches: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 1533/1533 [01:09<00:00, 22.16it/s]
INFO:__main__:For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to not ignore them.
INFO:__main__:
Evaluation metrics:
INFO:__main__:NDCG@10: 0.4726
INFO:__main__:NDCG@100: 0.4822
INFO:__main__:
Evaluation metrics:
INFO:__main__:MAP@10: 0.4621
INFO:__main__:MAP@100: 0.4639
INFO:__main__:
Evaluation metrics:
INFO:__main__:Recall@10: 0.5053
INFO:__main__:Recall@100: 0.5531
INFO:__main__:
Evaluation metrics:
INFO:__main__:P@10: 0.0505
INFO:__main__:P@100: 0.0055

Yoruba (yo) - NDCG@10: 0.4726, Recall@100: 0.5531

Final Results:
Recall@100: 0.5531
NDCG@10: 0.4726