## Obtaining BM25 top-k results

In [1]:
import json
import numpy as np
from tqdm import tqdm
import torch
import pytrec_eval
import logging
from typing import Dict, List, Tuple
import os

def get_bm25_run(top_k: int, run_file: str = "run.miracl.bm25.fr.dev.txt") -> Dict[str, List[Tuple[str, float]]]:
    """
    Reads BM25 run file and returns top-k results for each query
    
    Args:
        top_k: Number of top documents to return per query
        run_file: Path to the BM25 run file
        
    Returns:
        Dictionary mapping query IDs to list of (doc_id, score) tuples
    """

    if top_k > 5000:
        cmd = """
        python -m pyserini.search.lucene \
            --threads 16 --batch-size 128 \
            --language fr \
            --topics miracl-v1.0-fr-dev \
            --index miracl-v1.0-fr \
            --output run.miracl.bm25.fr.dev.txt \
            --bm25 --hits 5000
        """
        raise Exception(f"I've saved top 5k results in the run file. You may need to rerun the search with {cmd}")

    runs = {}
    with open(run_file, 'r') as f:
        for line in f:
            qid, _, docid, rank, score, _ = line.strip().split()
            if qid not in runs:
                runs[qid] = []
            if len(runs[qid]) < top_k:
                runs[qid].append((docid, float(score)))
    return runs

## Meat

### Loading data

In [None]:

from typing import Dict, Tuple
import datasets
from tqdm import tqdm

def _get_split_data(miracl_ds: datasets.DatasetDict, 
                   split: str) -> Tuple[Dict[str, str], Dict[str, Dict[str, int]]]:
    """
    Extracts queries and relevance judgments for a specific split
    
    Args:
        miracl_ds: MIRACL dataset dictionary
        split: Split name ('train' or 'dev')
        
    Returns:
        Tuple of (queries dict, qrels dict)
    """
    queries = {}
    qrels = {}
    
    if split in miracl_ds:
        for item in miracl_ds[split]:
            qid = item['query_id']
            queries[qid] = item['query']
            
            if 'positive_passages' in item:
                qrels[qid] = {
                    passage['docid']: 1 
                    for passage in item['positive_passages']
                }
    
    return queries, qrels

def load_miracl_split(lang: str = 'fr', 
                     split: str = 'dev',
                     cache_dir: str = "hf_datasets_cache") -> Tuple[Dict[str, str], Dict[str, str], Dict[str, Dict[str, int]]]:
    """
    Loads MIRACL data for a specific split
    
    Args:
        lang: Language code
        split: Split to load ('train' or 'dev')
        cache_dir: Cache directory
        
    Returns:
        Tuple of (documents dict, queries dict, qrels dict)
    """
    # Load datasets
    miracl_ds = datasets.load_dataset("miracl", lang, cache_dir=cache_dir)
    collection_ds = datasets.load_dataset("miracl/miracl-corpus", lang, cache_dir=cache_dir)
    
    # Convert collection to dictionary
    documents = {
        doc['docid']: (doc.get('title', '') + " " + doc['text']).strip()
        for doc in collection_ds['train']
    }
    
    # Get queries and qrels for specific split
    queries, qrels = _get_split_data(miracl_ds, split)
    
    return documents, queries, qrels

### Saving runs

In [3]:
from typing import Dict, List, Tuple
import os

def save_runs(runs: Dict[str, List[Tuple[str, float]]], 
             output_path: str,
             run_name: str = "miracl",
             max_rank: int = 1000) -> None:
    """
    Saves retrieval/reranking results in TREC format
    
    Args:
        runs: Dictionary mapping query IDs to lists of (doc_id, score) tuples
        output_path: Path to save the run file
        run_name: Name of the run (used in the output format)
        max_rank: Maximum number of results to save per query
    
    The TREC format is:
    qid Q0 docid rank score run_name
    
    Example:
    1 Q0 doc1 1 14.8989 miracl
    1 Q0 doc2 2 14.7654 miracl
    """
    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    print(f"Saving runs to {output_path}")
    
    with open(output_path, 'w', encoding='utf-8') as f:
        for qid, doc_scores in runs.items():
            # Sort by score in descending order if not already sorted
            sorted_docs = sorted(doc_scores, key=lambda x: x[1], reverse=True)
            
            # Write top max_rank results
            for rank, (doc_id, score) in enumerate(sorted_docs[:max_rank], start=1):
                # TREC format: qid Q0 docid rank score run_name
                f.write(f"{qid} Q0 {doc_id} {rank} {score:.6f} {run_name}\n")
    
    print(f"Saved results for {len(runs)} queries")

# Example usage with validation
def save_runs_with_validation(runs: Dict[str, List[Tuple[str, float]]], 
                            output_path: str,
                            run_name: str = "miracl",
                            max_rank: int = 1000) -> None:
    """
    Saves runs with additional validation checks
    """
    # Validate input
    if not runs:
        raise ValueError("Empty runs dictionary provided")
    
    # Check if all queries have results
    empty_queries = [qid for qid, docs in runs.items() if not docs]
    if empty_queries:
        print(f"Warning: {len(empty_queries)} queries have no results")
    
    # Check for reasonable score ranges
    for qid, doc_scores in runs.items():
        scores = [score for _, score in doc_scores]
        if scores:
            min_score, max_score = min(scores), max(scores)
            if max_score > 1e6 or min_score < -1e6:
                print(f"Warning: Unusual score range for query {qid}: [{min_score}, {max_score}]")
    
    # Save runs
    save_runs(runs, output_path, run_name, max_rank)
    
    # Verify file was created and is non-empty
    if not os.path.exists(output_path):
        raise RuntimeError(f"Failed to create output file: {output_path}")
    
    if os.path.getsize(output_path) == 0:
        raise RuntimeError(f"Output file is empty: {output_path}")
    
    # Print first few lines of the file for verification
    print("\nFirst few lines of the output file:")
    with open(output_path, 'r') as f:
        for _ in range(3):
            line = f.readline().strip()
            if line:
                print(line)

### Common 'Reranker' class

In [4]:
from sentence_transformers import SentenceTransformer, util, CrossEncoder

class Reranker:
    def __init__(self, model_name: str, batch_size: int = 8):
        is_mps = torch.backends.mps.is_available()
        is_cuda = torch.cuda.is_available()
        
        self.device = torch.device('mps' if is_mps else 'cuda' if is_cuda else 'cpu')
        self.batch_size = batch_size
        
        if 'crossencoder' in model_name.lower():
            logging.info("Actually using a CrossEncoder this time lmaoo")
            self.model = CrossEncoder(model_name, num_labels=1)
        else:
            self.model = SentenceTransformer(model_name)
            self.model.to(self.device)
            
    def rerank_batch(self, query: str, documents: List[str], 
                    doc_ids: List[str]) -> List[Tuple[str, float]]:
        """Rerank a batch of documents"""
        if isinstance(self.model, CrossEncoder):
            pairs = [[query, doc] for doc in documents]
            scores = self.model.predict(pairs, show_progress_bar=False)
        else:
            query_emb = self.model.encode(query, convert_to_tensor=True, show_progress_bar=False)
            doc_embs = self.model.encode(documents, convert_to_tensor=True, show_progress_bar=False)
            scores = util.pytorch_cos_sim(query_emb, doc_embs)[0].cpu().numpy()
        
        return list(zip(doc_ids, scores))

### Define a reranking pipeline

In [None]:
def a():
    dataset = datasets.load_dataset("miracl/miracl-corpus", "fr", cache_dir="hf_datasets_cache")['train']
    return dataset.num_rows
a()

In [6]:
class DocumentStore:
    """Memory-efficient document store that loads documents on demand"""
    def __init__(self, lang: str, doc_ids: list[str], cache_dir: str = "hf_datasets_cache"):
        dataset = datasets.load_dataset("miracl/miracl-corpus"), lang, cache_dir=cache_dir)['train']

        # Convert collection to dictionary
        logging.info(f"\nLoading {len(doc_ids)}/14636953 documents into memory (dict)..")
        self.documents = {
            doc['docid']: (doc.get('title', '') + " " + doc['text']).strip()
            for doc in dataset if doc['docid'] in doc_ids
        }
        logging.info(f"\n DONE loading all documents into memory (dict).")
    
    def get_documents(self, doc_ids: List[str]) -> List[str]:
        """Fetch documents by ID"""
        documents = []
        for doc_id in doc_ids:
            documents.append(self.documents[doc_id])
        return documents

class QueryStore:
    """Memory-efficient query store that loads queries on demand"""
    def __init__(self, lang: str, split: str = 'dev', cache_dir: str = "hf_datasets_cache"):
        self.dataset = datasets.load_dataset("miracl/miracl", lang, cache_dir=cache_dir)[split]
        self._query_map = {item['query_id']: item['query'] for item in self.dataset}
    
    def get_query(self, query_id: str) -> str:
        return self._query_map.get(query_id, "")

In [7]:
from typing import Iterator

def batch_iterator(items: List, batch_size: int) -> Iterator:
    """Yield items in batches"""
    for i in range(0, len(items), batch_size):
        yield items[i:i + batch_size]

def rerank_runs(initial_runs: Dict[str, List[Tuple[str, float]]],
                          reranker: Reranker,
                          query_store: QueryStore,
                          doc_store: DocumentStore) -> Dict[str, List[Tuple[str, float]]]:
    """Memory-efficient reranking of initial retrieval results"""
    reranked_runs = {}
    
    for qid in tqdm(initial_runs):
        query = query_store.get_query(qid)
        all_reranked = []
        
        # Process documents in batches
        doc_ids = [docid for docid, _ in initial_runs[qid]]
        for batch_doc_ids in batch_iterator(doc_ids, reranker.batch_size):
            doc_texts = doc_store.get_documents(batch_doc_ids)
            batch_reranked = reranker.rerank_batch(query, doc_texts, batch_doc_ids)
            all_reranked.extend(batch_reranked)
        
        # Sort by score
        all_reranked.sort(key=lambda x: x[1], reverse=True)
        reranked_runs[qid] = all_reranked
    
    return reranked_runs

initial_runs = None
doc_store = None
query_store = None

# Example usage for different models:
def run_reranking_pipeline(model_name: str, 
                          initial_run_file: str,
                          qrels_file: str,
                          output_run_file: str,
                          top_k: int = 1000):
    # Load initial BM25 results
    global initial_runs
    if not initial_runs:
        logging.info(f"Loading inital BM25 results")
        initial_runs = get_bm25_run(top_k, initial_run_file)

    # A bit of optimization
    # due to memory constraints I cannot load all the documents
    # instead I'll check which are referenced in BM25's rankings
    # and tell DocumentStore to save only those.

    doc_ids = {hit[0] for query_hits in initial_runs.values() for hit in query_hits}

    # Initialize stores and reranker
    global doc_store
    global query_store
    if not doc_store:
        logging.info("Initializing document and query stores")
        doc_store = DocumentStore("fr", doc_ids=doc_ids)
        query_store = QueryStore("fr")

    # Initialize reranker
    logging.info(f"Initializing model")
    reranker = Reranker(model_name, batch_size=32)

    # Rerank
    logging.info(f"Reranking queries...")
    reranked_runs = rerank_runs(initial_runs, reranker, query_store, doc_store)
    
    # Save results
    logging.info(f"Saving results...")
    save_runs_with_validation(reranked_runs, output_run_file)

## Evaluate all models

In [9]:
import pandas as pd
from datetime import datetime
import os
import logging
import datasets
from typing import Dict, List

def evaluate_all_models(lang: str = 'fr', 
                       cache_dir: str = "hf_datasets_cache",
                       output_dir: str = "runs"):
    """
    Evaluates all models on MIRACL dataset
    """

    topK = 1000

    # Setup logging
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(f'model_evaluation_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'),
            logging.StreamHandler()
        ]
    )

    # Models to evaluate
    models = [
        "facebook/mcontriever-msmarco",
        "castorini/mdpr-tied-pft-msmarco",
        "castorini/mdpr-tied-pft-msmarco-ft-miracl-fr",
        "antoinelouis/crossencoder-camembert-base-mmarcoFR"
    ]

    # Load initial BM25 runs
    initial_run_file = f"run.miracl.bm25.{lang}.dev.txt"
    qrels_file = f"qrels.miracl-v1.0-{lang}-dev.tsv"

    # Evaluate each model
    for model_name in models:
        logging.info(f"\nEvaluating {model_name}")
        try:
            # Generate run name
            run_name = model_name.split('/')[-1]
            output_run_file = os.path.join(output_dir, f"run.miracl.{run_name}.{lang}.dev.txt")

            # Run evaluation
            run_reranking_pipeline(
                model_name=model_name,
                initial_run_file=initial_run_file,
                qrels_file=qrels_file,
                output_run_file=output_run_file,
                top_k=topK
            )

        except Exception as e:
            logging.error(f"Error evaluating {model_name}: {str(e)}")

    # Evaluate fine-tuned Camembert if available
    try:
        ft_model_path = "output/crossencoder-camembert_miracl_fr-2024-11-11_15-55-59"
        if os.path.exists(ft_model_path):
            logging.info("\nEvaluating fine-tuned Camembert cross-encoder")
            
            run_reranking_pipeline(
                model_name=ft_model_path,
                initial_run_file=initial_run_file,
                qrels_file=qrels_file,
                output_run_file=os.path.join(output_dir, f"run.miracl.crossencoder-camembert-ft.{lang}.dev.txt"),
                top_k=topK
            )

    except Exception as e:
        logging.error(f"Error evaluating fine-tuned Camembert: {str(e)}")

In [None]:
os.makedirs("runs", exist_ok=True)

# Run evaluation
evaluate_all_models(
    lang='fr',
    cache_dir="hf_datasets_cache",
    output_dir="runs"
)

In [10]:
!ir_measures qrels.miracl-v1.0-fr-dev.tsv run.miracl.bm25.fr.dev.txt 'P@1 P@10 AP@10 nDCG@10' 

P@1	0.1370
P@10	0.0560
AP@10	0.1303
nDCG@10	0.1832


In [11]:
!ir_measures qrels.miracl-v1.0-fr-dev.tsv runs/run.miracl.mcontriever-msmarco.fr.dev.txt 'P@1 P@10 AP@10 nDCG@10' 

P@1	0.1603
P@10	0.0755
AP@10	0.1798
nDCG@10	0.2525


In [12]:
!ir_measures qrels.miracl-v1.0-fr-dev.tsv runs/run.miracl.mdpr-tied-pft-msmarco.fr.dev.txt 'P@1 P@10 AP@10 nDCG@10' 

P@1	0.1370
P@10	0.0726
AP@10	0.1630
nDCG@10	0.2341


In [13]:
!ir_measures qrels.miracl-v1.0-fr-dev.tsv runs/run.miracl.mdpr-tied-pft-msmarco-ft-miracl-fr.fr.dev.txt 'P@1 P@10 AP@10 nDCG@10' 

P@1	0.1778
P@10	0.0866
AP@10	0.2052
nDCG@10	0.2855


In [14]:
!ir_measures qrels.miracl-v1.0-fr-dev.tsv runs/run.miracl.crossencoder-camembert-base-mmarcoFR.fr.dev.txt 'P@1 P@10 AP@10 nDCG@10' 

P@1	0.4169
P@10	0.1239
AP@10	0.3728
nDCG@10	0.4688


In [15]:
!ir_measures qrels.miracl-v1.0-fr-dev.tsv runs/run.miracl.crossencoder-camembert-ft.fr.dev.txt 'P@1 P@10 AP@10 nDCG@10' 

P@1	0.4869
P@10	0.1327
AP@10	0.4349
nDCG@10	0.5245
