In [None]:
import os
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
import torch
from datasets import load_dataset
import logging

# For BM25 hard negatives
from rank_bm25 import BM25Okapi

# Set up logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

# Preprocess dataset for hard negative mining
def preprocess_dataset(dataset):
    """
    Process the MIRACL dev split to create queries, positives, and corpus.
    
    Returns:
        queries (dict): {query_id: query_text}
        positives (dict): {query_id: positive_docid}
        corpus (dict): {passage_id: passage_text}
    """
    data = dataset["dev"]
    
    # Create queries dictionary: {query_id: query_text}
    queries = {str(item["query_id"]): item["query"] for item in data}
    
    # Create positives dictionary: {query_id: positive_docid}
    positives = {}
    relevant_docids = set()
    for item in data:
        query_id = str(item["query_id"])
        if item["positive_passages"]:
            positives[query_id] = str(item["positive_passages"][0]["docid"])
        for passage in item["negative_passages"]:
            relevant_docids.add(str(passage["docid"]))
    
    # Load corpus from miracl/miracl-corpus for Yoruba
    corpus_dataset = load_dataset("miracl/miracl-corpus", "yo", trust_remote_code=True)["train"]
    corpus = {str(item["docid"]): item["text"] for item in corpus_dataset}
    
    return queries, positives, corpus

def bm25_retrieve(query_text, corpus_texts, corpus_ids, top_k=20):
    """
    Retrieve hard negatives using BM25 and normalize scores to [0, 1].
    """
    tokenize = lambda text: text.lower().split()
    tokenized_corpus = [tokenize(text) for text in corpus_texts]
    bm25 = BM25Okapi(tokenized_corpus)
    tokenized_query = tokenize(query_text)
    scores = bm25.get_scores(tokenized_query)
    
    # Get indices sorted by BM25 score (highest first)
    sorted_indices = np.argsort(scores)[::-1][:top_k]
    # Define results with raw scores first
    results = [(corpus_ids[idx], scores[idx], idx) for idx in sorted_indices]
    
    # Normalize BM25 scores to [0, 1]
    bm25_scores = [item[1] for item in results]  # Extract raw scores from results
    min_score, max_score = min(bm25_scores), max(bm25_scores)
    if max_score > min_score:  # Avoid division by zero
        results = [(docid, (score - min_score) / (max_score - min_score), idx) 
                   for docid, score, idx in results]
    else:
        results = [(docid, 1.0, idx) for docid, score, idx in results]  # All scores equal
    logger.info("BM25 finish 1 iterate")
    return results

def kalm_retrieve(query_embedding, corpus_embeddings, corpus_ids, top_k=20):
    """
    Retrieve hard negatives using pre-computed KaLM embeddings with FAISS.
    """
    # Normalize embeddings
    faiss.normalize_L2(query_embedding)
    faiss.normalize_L2(corpus_embeddings)
    
    dim = corpus_embeddings.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(corpus_embeddings)
    distances, indices = index.search(query_embedding, top_k)
    
    results = []
    for score, idx in zip(distances[0], indices[0]):
        results.append((corpus_ids[idx], score, idx))
    logger.info("KaLM finish 1 iterate")
    return results

# Hard negative mining combining dense, BM25, and KaLM strategies
def mine_hard_negatives(model, queries, positives, corpus_dict, bm25_top_k=30, kalm_top_k=30, negatives_to_mine=15):
    """
    Mine hard negatives for each query using a combination of BM25 and KaLM-based scoring.
    
    Args:
        model: Preloaded SentenceTransformer model for KaLM retrieval.
    
    Returns:
        training_batches (list): List of dicts with query, positive, and hard negative info.
    """
    # Prepare corpus lists
    corpus_ids = list(corpus_dict.keys())
    corpus_texts = list(corpus_dict.values())
    query_ids = list(queries.keys())

    # Pre-compute corpus embeddings once to save time
    corpus_embeddings = model.encode(corpus_texts, convert_to_numpy=True)
    faiss.normalize_L2(corpus_embeddings)

    training_batches = []
    for qid in query_ids:
        gold_docid = positives.get(qid)
        if gold_docid not in corpus_ids:
            # Skip if the positive document is not in the corpus.
            continue
        
        # BM25 negatives
        logger.info("BM25")
        bm25_results = bm25_retrieve(queries[qid], corpus_texts, corpus_ids, top_k=bm25_top_k)
        bm25_results = [item for item in bm25_results if item[0] != gold_docid]
        
        # KaLM negatives (use pre-computed corpus embeddings)
        logger.info("KaLM")
        query_embedding = model.encode([queries[qid]], convert_to_numpy=True)
        faiss.normalize_L2(query_embedding)
        kalm_results = kalm_retrieve(query_embedding, corpus_embeddings, corpus_ids, top_k=kalm_top_k)
        kalm_results = [item for item in kalm_results if item[0] != gold_docid]
        logger.info('Finish retrieval')
        
        # Merge negatives from BM25 and KaLM using a dictionary to deduplicate scores
        neg_dict = {}
        for source in [bm25_results, kalm_results]:
            for docid, score, idx in source:
                # If a document appears from both methods, take the higher score
                neg_dict[docid] = max(neg_dict.get(docid, float('-inf')), score)
        
        # Sort negatives by score (descending) and take the top negatives_to_mine
        sorted_negatives = sorted(neg_dict.items(), key=lambda x: x[1], reverse=True)[:negatives_to_mine]
        hard_negative_ids = [docid for docid, _ in sorted_negatives]
        hard_negative_scores = [neg_dict[docid] for docid in hard_negative_ids]

        batch = {
            "query_id": qid,
            "query_text": queries[qid],
            "positive_id": gold_docid,
            "hard_negative_ids": hard_negative_ids,
            "hard_negative_scores": hard_negative_scores
        }
        training_batches.append(batch)
    
    return training_batches

# Main execution
def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using device:", device)
    
    logger.info("Load model")
    
    # Load model for dense retrieval
    model_name = "HIT-TMG/KaLM-embedding-multilingual-mini-v1"
    model = SentenceTransformer(model_name, device=device)

    logger.info("Load dataset")

    # Load MIRACL dataset (Yoruba dev split)
    dataset_name = "miracl/miracl"
    dataset = load_dataset(dataset_name, "yo", trust_remote_code=True)
    
    logger.info("Preprocess dataset")
    # Preprocess dataset
    queries, positives, corpus_dict = preprocess_dataset(dataset)
    
    logger.info("Hard negative mining")
    
    # Mine hard negatives using dense, BM25, and KaLM methods
    training_batches = mine_hard_negatives(model, queries, positives, corpus_dict, 
                                          bm25_top_k=30, kalm_top_k=30, negatives_to_mine=15)
    
    # Print sample results
    print(f"\nMined {len(training_batches)} training batches.")
    for batch in training_batches[:2]:  # Show first 2 batches
        print(f"\nQuery ID: {batch['query_id']}")
        print(f"Query Text: {batch['query_text']}")
        print(f"Positive ID: {batch['positive_id']}")
        print(f"Hard Negatives Count: {len(batch['hard_negative_ids'])}")
        print(f"Hard Negative IDs: {batch['hard_negative_ids']}")
        print(f"Hard Negative Scores: {batch['hard_negative_scores']}")
    
    # Save batches for training
    np.save("miracl_yo_training_batches.npy", training_batches, allow_pickle=True)
    print("\nSaved training batches to 'miracl_yo_training_batches.npy'.")

if __name__ == "__main__":
    main()

Mined 119 training batches.

```Query ID: 10020#0
Query Text: Odun wo ni wọn ṣe idije Olympiiki akọkọ?
Positive ID: 10020#1
Hard Negatives Count: 15
Hard Negative IDs: ['61510#0', '67605#2', '10020#0', '54580#0', '53588#0', '54579#0', '54582#0', '54581#0', '10019#0', '54589#0', '54588#0', '54590#0', '54583#0', '54591#0', '54597#0']
Hard Negative Scores: [1.0, 0.8826476173871683, 0.7056303, 0.6600299, 0.6508212, 0.64164484, 0.63624036, 0.63371813, 0.632797, 0.6303463, 0.6296907, 0.62422645, 0.62160224, 0.620029, 0.6162815]
```
```Query ID: 10118#0
Query Text: Orile ede wo ni Washington DC wà?
Positive ID: 55716#0
Hard Negatives Count: 15
Hard Negative IDs: ['68969#3', '3280#0', '66970#0', '20051#0', '13619#0', '70132#0', '67978#0', '18364#0', '66836#0', '71410#0', '22522#0', '1835#5', '16271#0', '67976#0', '1835#0']
Hard Negative Scores: [1.0, 0.8165482099800115, 0.7664398104753086, 0.7495139429683063, 0.7296947, 0.7291269635579585, 0.6989294197861712, 0.6839287908370626, 0.6579214304185333, 0.6550035706510758, 0.6037975, 0.58926904, 0.57326627, 0.5725800646653812, 0.5608194]
```

In [1]:
import numpy as np

# Specify the path to your .npy file
file_path = "./miracl_yo_training_batches.npy"

# Load the .npy file
data = np.load(file_path, allow_pickle=True)

# Print the loaded data
print(data)

# If the data is a list or dictionary, you can access it like this:
if isinstance(data, (list, np.ndarray)):
    for item in data:
        print(item)
elif isinstance(data, dict):
    for key, value in data.items():
        print(f"{key}: {value}")

[{'query_id': '10020#0', 'query_text': 'Odun wo ni wọn ṣe idije Olympiiki akọkọ?', 'positive_id': '10020#1', 'hard_negative_ids': ['61510#0', '67605#2', '10020#0', '54580#0', '53588#0', '54579#0', '54582#0', '54581#0', '10019#0', '54589#0', '54588#0', '54590#0', '54583#0', '54591#0', '54597#0'], 'hard_negative_scores': [1.0, 0.8826476173871683, 0.7056303, 0.6600299, 0.6508212, 0.64164484, 0.63624036, 0.63371813, 0.632797, 0.6303463, 0.6296907, 0.62422645, 0.62160224, 0.620029, 0.6162815]}
 {'query_id': '10118#0', 'query_text': 'Orile ede wo ni Washington DC wà?', 'positive_id': '55716#0', 'hard_negative_ids': ['68969#3', '3280#0', '66970#0', '20051#0', '13619#0', '70132#0', '67978#0', '18364#0', '66836#0', '71410#0', '22522#0', '1835#5', '16271#0', '67976#0', '1835#0'], 'hard_negative_scores': [1.0, 0.8165482099800115, 0.7664398104753086, 0.7495139429683063, 0.7296947, 0.7291269635579585, 0.6989294197861712, 0.6839287908370626, 0.6579214304185333, 0.6550035706510758, 0.6037975, 0.58926