In [15]:
from rag.config import Settings
%load_ext autoreload
%autoreload 2

import os
import gc
import torch
import faiss
import numpy as np
import pandas as pd
from datasets import load_dataset, Dataset
from langchain_text_splitters import RecursiveCharacterTextSplitter
from sentence_transformers import CrossEncoder
from FlagEmbedding import FlagReranker
from time import time
from tqdm import tqdm

from rag.embeddings import LocalEmbedder
from rag.utils import get_metrics, embed_dataset

# Load datasets
doc_ds = load_dataset("rag-datasets/rag-mini-bioasq", "text-corpus")['passages']
query_ds = load_dataset("rag-datasets/rag-mini-bioasq", "question-answer-passages")['test']

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
# Precompute query information
queries = np.array(query_ds['question'])
qrels = [np.array(eval(gold)) for gold in query_ds['relevant_passage_ids']]
qrels_counts = [len(s) for s in qrels]

In [17]:
# Define chunking function
def chunk_documents(dataset, chunker, text_col='passage', id_col='id'):
    chunked_docs = []
    pbar = tqdm(total=len(dataset), desc='Chunking')
    for doc in dataset:
        text = doc[text_col]
        parent_id = doc[id_col]
        chunks = chunker.split_text(text)
        for i, chunk in enumerate(chunks):
            chunked_docs.append({
                'passage': chunk,
                'parent_id': parent_id,
                'chunk_id': i,
            })
        pbar.update(1)
    pbar.close()
    return Dataset.from_list(chunked_docs)

In [18]:
# Define reranking function (optimized for GPU)
def rerank_results(queries, retrieved_passages, retrieved_ids, reranker, batch_size=256):
    """
    Rerank retrieved passages using a reranker model (optimized for GPU).
    
    Args:
        queries: List of query strings
        retrieved_passages: 2D array of retrieved passage texts [n_queries, k]
        retrieved_ids: 2D array of retrieved passage IDs [n_queries, k]
        reranker: Reranker model (CrossEncoder or FlagReranker)
        batch_size: Batch size for reranking (larger is better for GPU utilization)
    
    Returns:
        reranked_ids: 2D array of reranked passage IDs [n_queries, k]
    """
    n_queries = len(queries)
    k = retrieved_passages.shape[1]
    
    # Flatten all query-passage pairs to maximize GPU utilization
    all_pairs = []
    for i in range(n_queries):
        query = queries[i]
        passages = retrieved_passages[i]
        pairs = [[query, passage] for passage in passages]
        all_pairs.extend(pairs)
    
    # Score all pairs in one batch (much faster!)
    print(f"    Scoring {len(all_pairs):,} pairs with batch_size={batch_size}...")
    
    # Use appropriate scoring method based on reranker type
    if isinstance(reranker, FlagReranker):
        # FlagReranker.compute_score expects list of [query, doc] pairs
        all_scores = reranker.compute_score(all_pairs, batch_size=batch_size, normalize=True)
        all_scores = np.array(all_scores)
    else:
        # CrossEncoder from sentence-transformers
        all_scores = reranker.predict(all_pairs, batch_size=batch_size, show_progress_bar=True)
    
    # Reshape scores back to [n_queries, k]
    scores_2d = all_scores.reshape(n_queries, k)
    
    # Sort by scores for each query
    reranked_ids = np.zeros_like(retrieved_ids)
    for i in range(n_queries):
        sorted_indices = np.argsort(scores_2d[i])[::-1]
        reranked_ids[i] = retrieved_ids[i][sorted_indices]
    
    return reranked_ids

In [19]:
# Embedder models to test
embedder_models = [
    # "all-MiniLM-L6-v2",
    # "all-MiniLM-L12-v2",
    "BAAI/bge-small-en-v1.5",
    # "BAAI/bge-base-en-v1.5",
    # "BAAI/bge-large-en-v1.5",
]

# Reranker models to test
reranker_models = [
    # "cross-encoder/ms-marco-MiniLM-L-6-v2",
    "cross-encoder/ms-marco-MiniLM-L-12-v2",
    # "BAAI/bge-reranker-base",
    # "BAAI/bge-reranker-large",
]

# Chunking parameters
chunk_size = 256
chunk_overlap = 50

# Initial retrieval depth
initial_k = 100

for emb_idx, embedder_name in enumerate(embedder_models):
    print("=" * 20, f"[{emb_idx + 1}/{len(embedder_models)}] Embedder: {embedder_name}", "=" * 20)
    
    try:
        # Create embedder
        embedder = LocalEmbedder(embedder_name, device="cuda")
        
        # Create tokenizer-aware chunker
        tokenizer = embedder.model.tokenizer
        def count_tokens(text):
            return len(tokenizer.encode(text))
        
        chunker = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            length_function=count_tokens,
            separators=["\n\n", "\n", ". ", " ", ""],
        )
        
        # Chunk documents
        chunked_ds = chunk_documents(doc_ds, chunker)
        print(f"Created {len(chunked_ds)} chunks from {len(doc_ds)} documents")
        
        # Embed chunked documents and queries
        embed_start = time()
        chunked_ds = embed_dataset(chunked_ds, embedder, column='passage')
        query_ds_embedded = embed_dataset(query_ds, embedder, column='question')
        embed_time = time() - embed_start
        
        # Create mapping from chunk index to parent document ID
        index_to_parent_id = np.array(chunked_ds['parent_id'])
        chunk_passages = np.array(chunked_ds['passage'])
        
    except Exception as e:
        print(f"Failed to embed with {embedder_name}: {e}")
        if 'embedder' in locals():
            del embedder
        gc.collect()
        torch.cuda.empty_cache()
        continue
    
    # Test with both FAISS metrics
    faiss_metric = 'IP'


    # Add FAISS index
    chunked_ds.add_faiss_index(
        column='embedding',
        string_factory='Flat',
        metric_type=faiss.METRIC_L2 if faiss_metric == 'L2' else faiss.METRIC_INNER_PRODUCT,
        batch_size=128,
    )

    # Retrieve top-k candidates
    res = chunked_ds.get_index('embedding').search_batch(
        np.array(query_ds_embedded['embedding']), k=initial_k
    )

    # Map chunk indices to parent document IDs
    retrieved_chunk_ids = res.total_indices
    retrieved_parent_ids = index_to_parent_id[retrieved_chunk_ids]
    retrieved_passages = chunk_passages[retrieved_chunk_ids]

    # Test each reranker
    for reranker_idx, reranker_name in enumerate(reranker_models):
        print(f"    [{reranker_idx + 1}/{len(reranker_models)}] Reranker: {reranker_name}")

        try:
            # Load reranker with correct library
            # BAAI rerankers use FlagEmbedding, others use sentence-transformers CrossEncoder
            if "BAAI" in reranker_name or "bge-reranker" in reranker_name:
                reranker = FlagReranker(reranker_name, use_fp16=True)
            else:
                reranker = CrossEncoder(reranker_name, device="cuda")

            # Rerank results
            rerank_start = time()
            reranked_parent_ids = rerank_results(
                queries, retrieved_passages, retrieved_parent_ids, reranker
            )
            rerank_time = time() - rerank_start

            total_time = embed_time + rerank_time

            # Compute metrics for different k values
            metrics = {}
            for k in [1, 3, 5, 10]:
                reranked_top_k = reranked_parent_ids[:, :k]
                metrics = {
                    **metrics,
                    **get_metrics(reranked_top_k, query_ds, k),
                }

            # Save results
            res_dict = {
                'model': embedder_name,
                'faiss_metric': faiss_metric,
                'chunked': True,
                'chunk_size': chunk_size,
                'chunk_overlap': chunk_overlap,
                'rerank_model': reranker_name,
                **{k: round(v, 3) for k, v in metrics.items()},
                "elapsed_time": round(total_time, 1),
            }

            res_df = pd.DataFrame([res_dict])
            csv_path = "results.csv"
            append = os.path.exists(csv_path) and os.path.getsize(csv_path) > 0
            res_df.to_csv(csv_path, mode='a', header=not append, index=False)

            # Print summary
            print(f"      P@10: {metrics['P@10']:.3f}, R@10: {metrics['R@10']:.3f}, "
                  f"MRR@10: {metrics['MRR@10']:.3f}, nDCG@10: {metrics['nDCG@10']:.3f}, "
                  f"Time: {total_time:.1f}s")

            # Clean up reranker
            del reranker
            gc.collect()
            torch.cuda.empty_cache()

        except Exception as e:
            print(f"      Failed to rerank with {reranker_name}: {e}")
            if 'reranker' in locals():
                del reranker
            gc.collect()
            torch.cuda.empty_cache()
            continue

    # Remove FAISS index for next metric
    chunked_ds.drop_index('embedding')
    
    # Clean up embedder
    del embedder
    gc.collect()
    torch.cuda.empty_cache()
    print()

print("\nReranking comparison complete! Results saved to results.csv")



Chunking: 100%|██████████| 40221/40221 [00:38<00:00, 1048.35it/s]


Created 76262 chunks from 40221 documents


Map: 100%|██████████| 76262/76262 [01:03<00:00, 1202.47 examples/s]
Map: 100%|██████████| 4719/4719 [00:02<00:00, 2134.04 examples/s]
100%|██████████| 596/596 [00:00<00:00, 4680.02it/s]


    [1/1] Reranker: cross-encoder/ms-marco-MiniLM-L-12-v2
    Scoring 471,900 pairs with batch_size=256...


Batches: 100%|██████████| 1844/1844 [02:37<00:00, 11.67it/s]


      P@10: 0.443, R@10: 0.681, MRR@10: 0.787, nDCG@10: 0.734, Time: 244.6s


Reranking comparison complete! Results saved to results.csv
