In [None]:
import numpy as np
import faiss
import pytrec_eval
from sentence_transformers import SentenceTransformer
import torch
from datasets import load_dataset

def get_embedding(texts):
    """
    Function to generate embeddings
    """
    return model.encode(texts, convert_to_numpy=True, device=device)

def preprocess_dataset(dataset, num_queries=200):
    """
    Extracts queries, query IDs, and corpus from the dataset.
    """
    queries = dataset["query"][:num_queries]
    query_ids = dataset["query_id"][:num_queries]
    
    all_positive_passages = [p[0] if p else None for p in dataset["positive_passages"]]
    all_negative_passages = [n[0] if n else None for n in dataset["negative_passages"]]
    
    corpus = list(set(all_positive_passages + all_negative_passages))  # Remove duplicates
    corpus = [p for p in corpus if p is not None]  # Remove None values
    
    query_to_positive = {
        qid: set(p) for qid, p in zip(query_ids, dataset["positive_passages"]) if p
    }
    
    return queries, query_ids, corpus, query_to_positive

def build_faiss_index(corpus_embeddings):
    """
    Builds a FAISS index for efficient retrieval.
    """
    dim = corpus_embeddings.shape[1]
    index = faiss.IndexFlatL2(dim)
    index.add(corpus_embeddings)
    return index

def evaluate_retrieval(dataset, lang, k_recall=100, k_ndcg=10, num_queries=200):
    """
    Standardized function for evaluating retrieval models on a dataset using pytrec_eval.
    """
    print(f"\nEvaluating language: {lang}...")
    
    # Preprocess dataset
    queries, query_ids, corpus, query_to_positive = preprocess_dataset(dataset, num_queries)
    
    if not queries or not corpus:
        print(f"No valid data found for {lang}")
        return 0, 0
    
    # Embed queries and corpus
    query_embeddings = get_embedding(queries)
    corpus_embeddings = get_embedding(corpus)
    
    # Build FAISS index
    index = build_faiss_index(corpus_embeddings)
    
    # Retrieve top-k results
    D, I = index.search(query_embeddings, max(k_recall, k_ndcg))
    
    # Format results for evaluation
    results = {}
    for i, qid in enumerate(query_ids):
        results[qid] = {corpus[idx]: float(-D[i][j]) for j, idx in enumerate(I[i][:max(k_recall, k_ndcg)])}
    
    # Format qrels for evaluation
    qrels = {qid: {doc: 1 for doc in query_to_positive.get(qid, [])} for qid in query_ids}
    
    # Use pytrec_eval to compute Recall@100 and NDCG@10
    evaluator = pytrec_eval.RelevanceEvaluator(qrels, {f"recall.{k_recall}", f"ndcg_cut.{k_ndcg}"})
    scores = evaluator.evaluate(results)
    
    recall_100 = np.mean([scores[qid].get(f"recall_{k_recall}", 0) for qid in query_ids])
    ndcg_10 = np.mean([scores[qid].get(f"ndcg_cut_{k_ndcg}", 0) for qid in query_ids])
    
    print(f"{lang} - Recall@100: {recall_100:.4f}, NDCG@10: {ndcg_10:.4f}")
    return recall_100, ndcg_10


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "HIT-TMG/KaLM-embedding-multilingual-mini-v1"
model = SentenceTransformer(model_name).to(device)

# Define the   languages to evaluate
languages = ["ar", "bn", "en", "es", "fa", "fi", "fr", "hi", "id", "ja", "ko", "ru", "sw", "te", "th", "zh"]
# Dictionary to store results
results = {}
# Iterate over each language and evaluate
for lang in languages:
    print(f"\nProcessing language: {lang}")
    dataset = load_dataset("miracl/miracl", lang, split="dev")  # Load dataset
    recall, ndcg = evaluate_retrieval(dataset)  # Evaluate
    results[lang] = {"Recall@100": recall, "NDCG@10": ndcg}

# Load mContriever model from Hugging Face
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "facebook/mcontriever-msmarco"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
dataset_name = "miracl/miracl-corpus"
dataset = load_dataset(dataset_name, "yo")  # "yo" for Yoruba language
lang = "Yoruba (yo)"
evaluate_retrieval(dataset, lang)