In [1]:
from rag.config import Settings
%load_ext autoreload
%autoreload 2

import os
import gc
import torch
import faiss
import numpy as np
import pandas as pd
from datasets import load_dataset, Dataset
from langchain_text_splitters import RecursiveCharacterTextSplitter
from sentence_transformers import CrossEncoder
from FlagEmbedding import FlagReranker
from time import time
from tqdm import tqdm

from rag.embeddings import LocalEmbedder
from rag.utils import get_metrics, embed_dataset

# Load datasets
doc_ds = load_dataset("rag-datasets/rag-mini-bioasq", "text-corpus")['passages']
query_ds = load_dataset("rag-datasets/rag-mini-bioasq", "question-answer-passages")['test']

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
doc_ds = doc_ds.filter(lambda row: row['passage'] != 'nan')

In [3]:
# Precompute query information
queries = np.array(query_ds['question'])
qrels = [np.array(eval(gold)) for gold in query_ds['relevant_passage_ids']]
qrels_counts = [len(s) for s in qrels]

In [4]:
# Define chunking function
def chunk_documents(dataset, chunker, text_col='passage', id_col='id'):
    chunked_docs = []
    pbar = tqdm(total=len(dataset), desc='Chunking')
    for doc in dataset:
        text = doc[text_col]
        parent_id = doc[id_col]
        chunks = chunker.split_text(text)
        for i, chunk in enumerate(chunks):
            chunked_docs.append({
                'passage': chunk,
                'parent_id': parent_id,
                'chunk_id': i,
            })
        pbar.update(1)
    pbar.close()
    return Dataset.from_list(chunked_docs)

In [5]:
# Define reranking function (optimized for GPU)
def rerank_results(queries, retrieved_passages, retrieved_ids, reranker, batch_size=256):
    """
    Rerank retrieved passages using a reranker model (optimized for GPU).
    
    Args:
        queries: List of query strings
        retrieved_passages: 2D array of retrieved passage texts [n_queries, k]
        retrieved_ids: 2D array of retrieved passage IDs [n_queries, k]
        reranker: Reranker model (CrossEncoder or FlagReranker)
        batch_size: Batch size for reranking (larger is better for GPU utilization)
    
    Returns:
        reranked_ids: 2D array of reranked passage IDs [n_queries, k]
    """
    n_queries = len(queries)
    k = retrieved_passages.shape[1]
    
    # Flatten all query-passage pairs to maximize GPU utilization
    all_pairs = []
    for i in range(n_queries):
        query = queries[i]
        passages = retrieved_passages[i]
        pairs = [[query, passage] for passage in passages]
        all_pairs.extend(pairs)
    
    # Score all pairs in one batch (much faster!)
    print(f"    Scoring {len(all_pairs):,} pairs with batch_size={batch_size}...")
    
    # Use appropriate scoring method based on reranker type
    if isinstance(reranker, FlagReranker):
        # FlagReranker.compute_score expects list of [query, doc] pairs
        all_scores = reranker.compute_score(all_pairs, batch_size=batch_size, normalize=True)
        all_scores = np.array(all_scores)
    else:
        # CrossEncoder from sentence-transformers
        all_scores = reranker.predict(all_pairs, batch_size=batch_size, show_progress_bar=True)
    
    # Reshape scores back to [n_queries, k]
    scores_2d = all_scores.reshape(n_queries, k)
    
    # Sort by scores for each query
    reranked_ids = np.zeros_like(retrieved_ids)
    for i in range(n_queries):
        sorted_indices = np.argsort(scores_2d[i])[::-1]
        reranked_ids[i] = retrieved_ids[i][sorted_indices]
    
    return reranked_ids


def deduplicate_retrieved_ids(retrieved_ids, max_k=100):
    """
    Deduplicate document IDs while preserving rank order.
    
    Args:
        retrieved_ids: 2D array of retrieved IDs [n_queries, k] (may contain duplicates)
        max_k: Maximum number of unique IDs to keep per query
    
    Returns:
        deduplicated_ids: 2D array [n_queries, max_k] with duplicates removed, 
                         padded with zeros if fewer than max_k unique IDs
    """
    n_queries = retrieved_ids.shape[0]
    deduplicated = np.zeros((n_queries, max_k), dtype=retrieved_ids.dtype)
    
    for i in range(n_queries):
        seen = set()
        unique_ids = []
        for doc_id in retrieved_ids[i]:
            if doc_id not in seen and doc_id != 0:  # Skip padding zeros
                seen.add(doc_id)
                unique_ids.append(doc_id)
                if len(unique_ids) >= max_k:
                    break
        
        # Fill deduplicated array (remaining positions stay 0)
        deduplicated[i, :len(unique_ids)] = unique_ids
    
    return deduplicated

In [6]:
# Embedder models to test
embedder_models = [
    "all-MiniLM-L6-v2",
    "all-MiniLM-L12-v2",
    "BAAI/bge-small-en-v1.5",
    "BAAI/bge-base-en-v1.5",
    "BAAI/bge-large-en-v1.5",
]

# Reranker models to test
reranker_models = [
    "cross-encoder/ms-marco-MiniLM-L-6-v2",
    "cross-encoder/ms-marco-MiniLM-L-12-v2",
    # "BAAI/bge-reranker-base",
    # "BAAI/bge-reranker-large",
]

# Chunking parameters
# chunk_size = 256
chunk_overlap = 50

# Initial retrieval depth
initial_k = 100

for emb_idx, embedder_name in enumerate(embedder_models):
    print("=" * 20, f"[{emb_idx + 1}/{len(embedder_models)}] Embedder: {embedder_name}", "=" * 20)

    for chunk_size in [128, 256]:
        try:
            # Create embedder
            embedder = LocalEmbedder(embedder_name, device="cuda")

            # Create tokenizer-aware chunker
            tokenizer = embedder.model.tokenizer
            def count_tokens(text):
                return len(tokenizer.encode(text))

            chunker = RecursiveCharacterTextSplitter(
                chunk_size=chunk_size,
                chunk_overlap=chunk_overlap,
                length_function=count_tokens,
                separators=["\n\n", "\n", ". ", " ", ""],
            )

            # Chunk documents
            chunked_ds = chunk_documents(doc_ds, chunker)
            print(f"Created {len(chunked_ds)} chunks from {len(doc_ds)} documents")

            # Embed chunked documents and queries
            embed_start = time()
            chunked_ds = embed_dataset(chunked_ds, embedder, column='passage')
            query_ds_embedded = embed_dataset(query_ds, embedder, column='question')
            embed_time = time() - embed_start

            # Create mapping from chunk index to parent document ID
            index_to_parent_id = np.array(chunked_ds['parent_id'])
            chunk_passages = np.array(chunked_ds['passage'])

        except Exception as e:
            print(f"Failed to embed with {embedder_name}: {e}")
            if 'embedder' in locals():
                del embedder
            gc.collect()
            torch.cuda.empty_cache()
            continue

        # Test with both FAISS metrics
        faiss_metric = 'IP'


        # Add FAISS index
        chunked_ds.add_faiss_index(
            column='embedding',
            string_factory='Flat',
            metric_type=faiss.METRIC_L2 if faiss_metric == 'L2' else faiss.METRIC_INNER_PRODUCT,
            batch_size=128,
        )

        # Retrieve top-k candidates
        res = chunked_ds.get_index('embedding').search_batch(
            np.array(query_ds_embedded['embedding']), k=initial_k
        )

        # Map chunk indices to parent document IDs
        retrieved_chunk_ids = res.total_indices
        retrieved_parent_ids = index_to_parent_id[retrieved_chunk_ids]
        retrieved_passages = chunk_passages[retrieved_chunk_ids]

        # Test each reranker
        for reranker_idx, reranker_name in enumerate(reranker_models):
            print(f"    [{reranker_idx + 1}/{len(reranker_models)}] Reranker: {reranker_name}")

            try:
                # Load reranker with correct library
                # BAAI rerankers use FlagEmbedding, others use sentence-transformers CrossEncoder
                if "BAAI" in reranker_name or "bge-reranker" in reranker_name:
                    reranker = FlagReranker(reranker_name, use_fp16=True)
                else:
                    reranker = CrossEncoder(reranker_name, device="cuda")

                # Rerank results
                rerank_start = time()
                reranked_parent_ids = rerank_results(
                    queries, retrieved_passages, retrieved_parent_ids, reranker
                )

                # DEDUPLICATE: Remove duplicate doc IDs while preserving rank order
                # This is critical for accurate IR metrics
                reranked_parent_ids_dedup = deduplicate_retrieved_ids(reranked_parent_ids, max_k=initial_k)
                print(f"    Deduplicated: {reranked_parent_ids.shape} -> unique docs per query")

                rerank_time = time() - rerank_start

                total_time = embed_time + rerank_time

                # Compute metrics for different k values (using deduplicated results)
                metrics = {}
                for k in [1, 3, 5, 10]:
                    reranked_top_k = reranked_parent_ids_dedup[:, :k]
                    metrics = {
                        **metrics,
                        **get_metrics(reranked_top_k, query_ds, k),
                    }

                # Save results
                res_dict = {
                    'model': embedder_name,
                    'faiss_metric': faiss_metric,
                    'chunked': True,
                    'chunk_size': chunk_size,
                    'chunk_overlap': chunk_overlap,
                    'rerank_model': reranker_name,
                    **{k: round(v, 3) for k, v in metrics.items()},
                    "elapsed_time": round(total_time, 1),
                }

                res_df = pd.DataFrame([res_dict])
                csv_path = "results.csv"
                append = os.path.exists(csv_path) and os.path.getsize(csv_path) > 0
                res_df.to_csv(csv_path, mode='a', header=not append, index=False)

                # Print summary
                print(f"      P@10: {metrics['P@10']:.3f}, R@10: {metrics['R@10']:.3f}, "
                      f"MRR@10: {metrics['MRR@10']:.3f}, nDCG@10: {metrics['nDCG@10']:.3f}, "
                      f"Time: {total_time:.1f}s")

                # Clean up reranker
                del reranker
                gc.collect()
                torch.cuda.empty_cache()

            except Exception as e:
                print(f"      Failed to rerank with {reranker_name}: {e}")
                if 'reranker' in locals():
                    del reranker
                gc.collect()
                torch.cuda.empty_cache()
                continue

        # Remove FAISS index for next metric
        chunked_ds.drop_index('embedding')

        # Clean up embedder
        del embedder
        gc.collect()
        torch.cuda.empty_cache()
        print()

print("\nReranking comparison complete! Results saved to results.csv")



Chunking: 100%|██████████| 28001/28001 [00:42<00:00, 657.77it/s]


Created 146055 chunks from 28001 documents


Map: 100%|██████████| 146055/146055 [02:27<00:00, 991.61 examples/s] 
Map: 100%|██████████| 4719/4719 [00:04<00:00, 948.96 examples/s] 
100%|██████████| 1142/1142 [00:00<00:00, 4200.31it/s]


    [1/2] Reranker: cross-encoder/ms-marco-MiniLM-L-6-v2
    Scoring 471,900 pairs with batch_size=256...


Batches: 100%|██████████| 1844/1844 [01:12<00:00, 25.37it/s]


    Deduplicated: (4719, 100) -> unique docs per query
      P@10: 0.361, R@10: 0.475, MRR@10: 0.771, nDCG@10: 0.600, Time: 249.1s
    [2/2] Reranker: cross-encoder/ms-marco-MiniLM-L-12-v2
    Scoring 471,900 pairs with batch_size=256...


Batches: 100%|██████████| 1844/1844 [02:19<00:00, 13.26it/s]


    Deduplicated: (4719, 100) -> unique docs per query
      P@10: 0.361, R@10: 0.476, MRR@10: 0.773, nDCG@10: 0.602, Time: 315.1s



Chunking: 100%|██████████| 28001/28001 [00:42<00:00, 659.03it/s]


Created 64042 chunks from 28001 documents


Map: 100%|██████████| 64042/64042 [01:40<00:00, 636.12 examples/s] 
100%|██████████| 501/501 [00:00<00:00, 4901.90it/s]


    [1/2] Reranker: cross-encoder/ms-marco-MiniLM-L-6-v2
    Scoring 471,900 pairs with batch_size=256...


Batches: 100%|██████████| 1844/1844 [02:12<00:00, 13.91it/s]


    Deduplicated: (4719, 100) -> unique docs per query
      P@10: 0.355, R@10: 0.467, MRR@10: 0.764, nDCG@10: 0.593, Time: 257.3s
    [2/2] Reranker: cross-encoder/ms-marco-MiniLM-L-12-v2
    Scoring 471,900 pairs with batch_size=256...


Batches: 100%|██████████| 1844/1844 [04:19<00:00,  7.10it/s]


    Deduplicated: (4719, 100) -> unique docs per query
      P@10: 0.356, R@10: 0.467, MRR@10: 0.767, nDCG@10: 0.594, Time: 384.8s



Chunking: 100%|██████████| 28001/28001 [00:44<00:00, 633.32it/s]


Created 146055 chunks from 28001 documents


Map: 100%|██████████| 146055/146055 [04:11<00:00, 581.40 examples/s] 
Map: 100%|██████████| 4719/4719 [00:06<00:00, 759.26 examples/s] 
100%|██████████| 1142/1142 [00:00<00:00, 4550.05it/s]


    [1/2] Reranker: cross-encoder/ms-marco-MiniLM-L-6-v2
    Scoring 471,900 pairs with batch_size=256...


Batches: 100%|██████████| 1844/1844 [01:12<00:00, 25.53it/s]


    Deduplicated: (4719, 100) -> unique docs per query
      P@10: 0.351, R@10: 0.463, MRR@10: 0.764, nDCG@10: 0.588, Time: 355.7s
    [2/2] Reranker: cross-encoder/ms-marco-MiniLM-L-12-v2
    Scoring 471,900 pairs with batch_size=256...


Batches: 100%|██████████| 1844/1844 [02:18<00:00, 13.29it/s]


    Deduplicated: (4719, 100) -> unique docs per query
      P@10: 0.352, R@10: 0.464, MRR@10: 0.767, nDCG@10: 0.590, Time: 422.3s



Chunking: 100%|██████████| 28001/28001 [00:43<00:00, 650.83it/s]


Created 64042 chunks from 28001 documents


Map: 100%|██████████| 64042/64042 [01:52<00:00, 570.37 examples/s] 
100%|██████████| 501/501 [00:00<00:00, 4719.14it/s]


    [1/2] Reranker: cross-encoder/ms-marco-MiniLM-L-6-v2
    Scoring 471,900 pairs with batch_size=256...


Batches: 100%|██████████| 1844/1844 [02:13<00:00, 13.82it/s]


    Deduplicated: (4719, 100) -> unique docs per query
      P@10: 0.348, R@10: 0.456, MRR@10: 0.760, nDCG@10: 0.583, Time: 270.7s
    [2/2] Reranker: cross-encoder/ms-marco-MiniLM-L-12-v2
    Scoring 471,900 pairs with batch_size=256...


Batches: 100%|██████████| 1844/1844 [04:18<00:00,  7.12it/s]


    Deduplicated: (4719, 100) -> unique docs per query
      P@10: 0.348, R@10: 0.457, MRR@10: 0.765, nDCG@10: 0.585, Time: 396.5s



Chunking: 100%|██████████| 28001/28001 [00:46<00:00, 597.71it/s]


Created 146055 chunks from 28001 documents


Map: 100%|██████████| 146055/146055 [04:12<00:00, 578.84 examples/s] 
Map: 100%|██████████| 4719/4719 [00:05<00:00, 810.77 examples/s] 
100%|██████████| 1142/1142 [00:00<00:00, 4360.68it/s]


    [1/2] Reranker: cross-encoder/ms-marco-MiniLM-L-6-v2
    Scoring 471,900 pairs with batch_size=256...


Batches: 100%|██████████| 1844/1844 [01:13<00:00, 25.09it/s]


    Deduplicated: (4719, 100) -> unique docs per query
      P@10: 0.377, R@10: 0.501, MRR@10: 0.784, nDCG@10: 0.623, Time: 357.4s
    [2/2] Reranker: cross-encoder/ms-marco-MiniLM-L-12-v2
    Scoring 471,900 pairs with batch_size=256...


Batches: 100%|██████████| 1844/1844 [02:19<00:00, 13.18it/s]


    Deduplicated: (4719, 100) -> unique docs per query
      P@10: 0.377, R@10: 0.501, MRR@10: 0.787, nDCG@10: 0.625, Time: 423.2s



Chunking: 100%|██████████| 28001/28001 [00:43<00:00, 643.73it/s]


Created 64042 chunks from 28001 documents


Map: 100%|██████████| 64042/64042 [02:10<00:00, 490.34 examples/s]
100%|██████████| 501/501 [00:00<00:00, 4534.74it/s]


    [1/2] Reranker: cross-encoder/ms-marco-MiniLM-L-6-v2
    Scoring 471,900 pairs with batch_size=256...


Batches: 100%|██████████| 1844/1844 [02:13<00:00, 13.79it/s]


    Deduplicated: (4719, 100) -> unique docs per query
      P@10: 0.377, R@10: 0.499, MRR@10: 0.784, nDCG@10: 0.624, Time: 290.4s
    [2/2] Reranker: cross-encoder/ms-marco-MiniLM-L-12-v2
    Scoring 471,900 pairs with batch_size=256...


Batches: 100%|██████████| 1844/1844 [04:21<00:00,  7.05it/s]


    Deduplicated: (4719, 100) -> unique docs per query
      P@10: 0.377, R@10: 0.500, MRR@10: 0.789, nDCG@10: 0.625, Time: 418.4s



Chunking: 100%|██████████| 28001/28001 [00:48<00:00, 573.78it/s]


Created 146055 chunks from 28001 documents


Map: 100%|██████████| 146055/146055 [05:23<00:00, 450.97 examples/s]
Map: 100%|██████████| 4719/4719 [00:08<00:00, 568.10 examples/s]
100%|██████████| 1142/1142 [00:00<00:00, 2753.14it/s]


    [1/2] Reranker: cross-encoder/ms-marco-MiniLM-L-6-v2
    Scoring 471,900 pairs with batch_size=256...


Batches: 100%|██████████| 1844/1844 [01:13<00:00, 25.01it/s]


    Deduplicated: (4719, 100) -> unique docs per query
      P@10: 0.383, R@10: 0.511, MRR@10: 0.789, nDCG@10: 0.632, Time: 432.7s
    [2/2] Reranker: cross-encoder/ms-marco-MiniLM-L-12-v2
    Scoring 471,900 pairs with batch_size=256...


Batches: 100%|██████████| 1844/1844 [02:17<00:00, 13.41it/s]


    Deduplicated: (4719, 100) -> unique docs per query
      P@10: 0.384, R@10: 0.512, MRR@10: 0.791, nDCG@10: 0.634, Time: 496.1s



Chunking: 100%|██████████| 28001/28001 [00:42<00:00, 655.28it/s]


Created 64042 chunks from 28001 documents


Map: 100%|██████████| 64042/64042 [02:54<00:00, 367.12 examples/s]
100%|██████████| 501/501 [00:00<00:00, 3315.45it/s]


    [1/2] Reranker: cross-encoder/ms-marco-MiniLM-L-6-v2
    Scoring 471,900 pairs with batch_size=256...


Batches: 100%|██████████| 1844/1844 [02:15<00:00, 13.62it/s]


    Deduplicated: (4719, 100) -> unique docs per query
      P@10: 0.381, R@10: 0.508, MRR@10: 0.789, nDCG@10: 0.631, Time: 337.5s
    [2/2] Reranker: cross-encoder/ms-marco-MiniLM-L-12-v2
    Scoring 471,900 pairs with batch_size=256...


Batches: 100%|██████████| 1844/1844 [04:23<00:00,  6.99it/s]


    Deduplicated: (4719, 100) -> unique docs per query
      P@10: 0.382, R@10: 0.509, MRR@10: 0.793, nDCG@10: 0.632, Time: 465.6s



Chunking: 100%|██████████| 28001/28001 [00:48<00:00, 578.23it/s]


Created 146055 chunks from 28001 documents


Map: 100%|██████████| 146055/146055 [10:40<00:00, 227.98 examples/s]
Map: 100%|██████████| 4719/4719 [00:12<00:00, 378.79 examples/s]
100%|██████████| 1142/1142 [00:00<00:00, 2471.12it/s]


    [1/2] Reranker: cross-encoder/ms-marco-MiniLM-L-6-v2
    Scoring 471,900 pairs with batch_size=256...


Batches: 100%|██████████| 1844/1844 [01:12<00:00, 25.41it/s]


    Deduplicated: (4719, 100) -> unique docs per query
      P@10: 0.383, R@10: 0.509, MRR@10: 0.787, nDCG@10: 0.630, Time: 754.3s
    [2/2] Reranker: cross-encoder/ms-marco-MiniLM-L-12-v2
    Scoring 471,900 pairs with batch_size=256...


Batches: 100%|██████████| 1844/1844 [02:21<00:00, 13.07it/s]


    Deduplicated: (4719, 100) -> unique docs per query
      P@10: 0.383, R@10: 0.509, MRR@10: 0.788, nDCG@10: 0.632, Time: 822.1s



Chunking: 100%|██████████| 28001/28001 [00:43<00:00, 644.47it/s]


Created 64042 chunks from 28001 documents


Map: 100%|██████████| 64042/64042 [06:43<00:00, 158.74 examples/s]
100%|██████████| 501/501 [00:00<00:00, 2850.47it/s]


    [1/2] Reranker: cross-encoder/ms-marco-MiniLM-L-6-v2
    Scoring 471,900 pairs with batch_size=256...


Batches: 100%|██████████| 1844/1844 [02:16<00:00, 13.50it/s]


    Deduplicated: (4719, 100) -> unique docs per query
      P@10: 0.383, R@10: 0.509, MRR@10: 0.787, nDCG@10: 0.631, Time: 569.2s
    [2/2] Reranker: cross-encoder/ms-marco-MiniLM-L-12-v2
    Scoring 471,900 pairs with batch_size=256...


Batches: 100%|██████████| 1844/1844 [04:22<00:00,  7.02it/s]


    Deduplicated: (4719, 100) -> unique docs per query
      P@10: 0.383, R@10: 0.509, MRR@10: 0.791, nDCG@10: 0.632, Time: 695.0s


Reranking comparison complete! Results saved to results.csv
