# Chunking + Reranking Experiments

This notebook combines chunking strategies with reranking to find the optimal RAG pipeline:

## Experimental Design:
- **Chunkers**: Recursive, Semantic
- **Chunk sizes**: 64, 128 tokens
- **Overlaps**: 0, 50 tokens
- **Semantic thresholds**: 0.1, 0.2, 0.3
- **Embedders**: BGE-small, BGE-base
- **Rerankers**: MiniLM-L6, MiniLM-L12, BGE-reranker

All experiments tracked in MLflow for comparison.

In [1]:
import warnings
import logging
import numpy as np
import faiss
import torch
import gc
from time import time
from datasets import load_dataset, disable_caching
from sentence_transformers import CrossEncoder
from FlagEmbedding import FlagReranker

from rag.tracking import ExperimentTracker
from rag.utils import embed_dataset, get_metrics
from rag.embeddings import LocalEmbedder
from rag.ingestion.chunker import RecursiveChunker, SemanticChunker

# Suppress warnings and logs
warnings.filterwarnings('ignore', message='.*Model.*was trained with spaCy.*')
warnings.filterwarnings('ignore', message='.*rule-based lemmatizer did not find POS annotation.*')
logging.getLogger('langchain_text_splitters').setLevel(logging.ERROR)

print("✓ Setup complete")

  from .autonotebook import tqdm as notebook_tqdm


✓ Setup complete


In [2]:
# Load datasets
print("Loading BioASQ dataset...")
doc_ds = load_dataset("rag-datasets/rag-mini-bioasq", "text-corpus", split="passages")
doc_ds = doc_ds.filter(lambda row: row['passage'] != 'nan')
query_ds = load_dataset("rag-datasets/rag-mini-bioasq", "question-answer-passages", split="test")

print(f"✓ Loaded {len(doc_ds):,} documents, {len(query_ds):,} queries")

# Precompute
queries = np.array(query_ds['question'])
qrels = [np.array(eval(gold)) for gold in query_ds['relevant_passage_ids']]
qrels_counts = [len(s) for s in qrels]

disable_caching()

Loading BioASQ dataset...
✓ Loaded 28,001 documents, 4,719 queries


In [3]:
def deduplicate_retrieved_docs(retrieved_ids_all, k):
    """
    Deduplicate document IDs per query, keeping only first (highest-ranked) occurrence.
    
    Args:
        retrieved_ids_all: (n_queries, n_retrieved) array of document IDs
        k: Number of unique documents to keep per query
    
    Returns:
        (n_queries, k) array of unique document IDs
    """
    deduped = []
    for query_results in retrieved_ids_all:
        seen = set()
        unique_docs = []
        for doc_id in query_results:
            if doc_id not in seen and doc_id != 0:
                unique_docs.append(doc_id)
                seen.add(doc_id)
            if len(unique_docs) == k:
                break
        # Pad if needed
        while len(unique_docs) < k:
            unique_docs.append(0)
        deduped.append(unique_docs)
    return np.array(deduped)


def rerank_results(queries, retrieved_passages, retrieved_ids, reranker, batch_size=256):
    """
    Rerank retrieved passages using a reranker model (GPU optimized).
    
    Args:
        queries: List of query strings
        retrieved_passages: 2D array of passage texts [n_queries, k]
        retrieved_ids: 2D array of passage IDs [n_queries, k]
        reranker: Reranker model (CrossEncoder or FlagReranker)
        batch_size: Batch size for GPU processing
    
    Returns:
        reranked_ids: 2D array of reranked passage IDs [n_queries, k]
    """
    n_queries = len(queries)
    k = retrieved_passages.shape[1]
    
    # Flatten all query-passage pairs
    all_pairs = []
    for i in range(n_queries):
        query = queries[i]
        passages = retrieved_passages[i]
        pairs = [[query, passage] for passage in passages]
        all_pairs.extend(pairs)
    
    # Score all pairs
    if isinstance(reranker, FlagReranker):
        all_scores = reranker.compute_score(all_pairs, batch_size=batch_size, normalize=True)
        all_scores = np.array(all_scores)
    else:
        all_scores = reranker.predict(all_pairs, batch_size=batch_size, show_progress_bar=False)
    
    # Reshape scores back to [n_queries, k]
    scores_2d = all_scores.reshape(n_queries, k)
    
    # Sort by scores for each query
    reranked_ids = np.zeros_like(retrieved_ids)
    for i in range(n_queries):
        sorted_indices = np.argsort(scores_2d[i])[::-1]
        reranked_ids[i] = retrieved_ids[i][sorted_indices]
    
    return reranked_ids

print("✓ Helper functions defined")

✓ Helper functions defined


In [None]:
# Initialize tracker
tracker = ExperimentTracker('chunking-reranking-bioasq')

# Experiment configuration
embedder_models = [
    "sentence-transformers/all-MiniLM-L6-v2",
    "BAAI/bge-small-en-v1.5",
    # "BAAI/bge-base-en-v1.5",
]

reranker_models = [
    "cross-encoder/ms-marco-MiniLM-L-6-v2",
    # "cross-encoder/ms-marco-MiniLM-L-12-v2",
    "BAAI/bge-reranker-base",
]

chunkers = ['recursive', 'semantic']
chunk_sizes = [64, 128]
chunk_overlaps = [0, 50, 100]
chunk_thresholds = [0.01, 0.02, 0.05,]  # For semantic only

initial_retrieve_k = 100
rerank_batch_size = 256
faiss_metric = 'IP'

print(f"""\n{'='*80}
EXPERIMENT CONFIGURATION
{'='*80}
Embedders: {len(embedder_models)}
Rerankers: {len(reranker_models)}
Chunkers: {chunkers}
Chunk sizes: {chunk_sizes}
Overlaps: {chunk_overlaps}
Semantic thresholds: {chunk_thresholds}
Initial retrieve: top-{initial_retrieve_k}
{'='*80}\n""")

In [None]:
import itertools

# Run experiments
for embedder_name in embedder_models:
    embedder_name_short = embedder_name.split('/')[-1]
    
    for chunker_type in chunkers:
        for chunk_size, chunk_overlap in itertools.product(chunk_sizes, chunk_overlaps):
            # Skip invalid combinations
            if chunk_size <= chunk_overlap:
                continue
            
            # Determine thresholds to test
            thresholds = chunk_thresholds if chunker_type == 'semantic' else [None]
            
            for threshold in thresholds:
                print(f"\n{'='*80}")
                print(f"Embedder: {embedder_name_short} | Chunker: {chunker_type} | "
                      f"Size: {chunk_size} | Overlap: {chunk_overlap} | Threshold: {threshold}")
                print(f"{'='*80}")
                
                try:
                    # ========== STEP 1: CHUNKING ==========
                    print("[1/5] Chunking documents...")
                    if chunker_type == 'recursive':
                        chunker = RecursiveChunker(
                            chunk_size=chunk_size,
                            chunk_overlap=chunk_overlap,
                            embedder_model=embedder_name,
                        )
                    else:  # semantic
                        chunker = SemanticChunker(
                            chunk_size=chunk_size,
                            chunk_overlap=chunk_overlap,
                            embedder_model=embedder_name,
                            threshold=threshold,
                        )
                    
                    chunked_ds = chunker.chunk_dataset(doc_ds, text_col='passage', id_col='id')
                    chunked_ds = chunked_ds.rename_column('doc_id', 'parent_id')
                    chunked_ds = chunked_ds.rename_column('text', 'passage')
                    print(f"  ✓ Created {len(chunked_ds):,} chunks from {len(doc_ds):,} documents")
                    
                    # ========== STEP 2: EMBEDDING ==========
                    print("[2/5] Embedding...")
                    embedder = LocalEmbedder(embedder_name, device="cuda")
                    
                    embed_start = time()
                    chunked_ds = embed_dataset(chunked_ds, embedder, column="passage")
                    query_ds_embedded = embed_dataset(query_ds, embedder, column="question")
                    embed_time = time() - embed_start
                    print(f"  ✓ Embedded in {embed_time:.1f}s")
                    
                    # ========== STEP 3: FAISS RETRIEVAL ==========
                    print(f"[3/5] Retrieving top-{initial_retrieve_k}...")
                    chunked_ds.add_faiss_index(
                        column='embedding',
                        string_factory='Flat',
                        metric_type=faiss.METRIC_INNER_PRODUCT,
                        batch_size=128,
                    )
                    
                    res = chunked_ds.get_index('embedding').search_batch(
                        np.array(query_ds_embedded['embedding']),
                        k=initial_retrieve_k
                    )
                    
                    # Map chunks to parent docs
                    index_to_parent_id = np.array(chunked_ds['parent_id'])
                    chunk_passages = np.array(chunked_ds['passage'])
                    
                    retrieved_chunk_ids = res.total_indices
                    retrieved_parent_ids = index_to_parent_id[retrieved_chunk_ids]
                    retrieved_passages = chunk_passages[retrieved_chunk_ids]
                    print(f"  ✓ Retrieved {initial_retrieve_k} chunks per query")
                    
                    # Test each reranker
                    for reranker_name in reranker_models:
                        print(f"\n  [4/5] Reranking with {reranker_name.split('/')[-1]}...")
                        
                        try:
                            # Load reranker
                            if "BAAI" in reranker_name or "bge-reranker" in reranker_name:
                                reranker = FlagReranker(reranker_name, use_fp16=True)
                            else:
                                reranker = CrossEncoder(reranker_name, device="cuda")
                            
                            # Rerank
                            rerank_start = time()
                            reranked_parent_ids = rerank_results(
                                queries, retrieved_passages, retrieved_parent_ids, 
                                reranker, batch_size=rerank_batch_size
                            )
                            rerank_time = time() - rerank_start
                            
                            # Deduplicate
                            reranked_parent_ids_dedup = deduplicate_retrieved_docs(
                                reranked_parent_ids, max_k=initial_retrieve_k
                            )
                            print(f"    ✓ Reranked + deduplicated in {rerank_time:.1f}s")
                            
                            # ========== STEP 5: EVALUATION ==========
                            print("  [5/5] Calculating metrics...")
                            metrics = {}
                            for k in [1, 3, 5, 10]:
                                reranked_top_k = reranked_parent_ids_dedup[:, :k]
                                metrics = {
                                    **metrics,
                                    **get_metrics(reranked_top_k, query_ds, k),
                                }
                            
                            total_time = embed_time + rerank_time
                            metrics = {
                                **{k: round(v, 4) for k, v in metrics.items()},
                                "embed_time": round(embed_time, 1),
                                "rerank_time": round(rerank_time, 1),
                                "total_time": round(total_time, 1),
                                "num_chunks": len(chunked_ds),
                            }
                            
                            # Log to MLflow
                            params = {
                                'embedder': embedder_name,
                                'reranker': reranker_name,
                                'chunker': chunker_type,
                                'chunk_size': chunk_size,
                                'chunk_overlap': chunk_overlap,
                                'chunk_threshold': threshold if threshold else 'none',
                                'faiss_metric': faiss_metric,
                                'initial_k': initial_retrieve_k,
                            }
                            
                            run_name = (
                                f"{embedder_name_short}_{chunker_type}_"
                                f"cs{chunk_size}_ov{chunk_overlap}_"
                                f"thr{threshold if threshold else 'none'}_"
                                f"{reranker_name.split('/')[-1]}"
                            )
                            
                            tags = {
                                'experiment_type': 'chunking+reranking',
                                'phase': 'exploration',
                                'dataset': 'bioasq-mini',
                                'embedder': embedder_name_short,
                                'reranker': reranker_name.split('/')[-1],
                                'chunker': chunker_type,
                            }
                            
                            with tracker.start_run(run_name=run_name, tags=tags):
                                tracker.log_params(params)
                                tracker.log_metrics(metrics)
                            
                            print(f"\n  RESULTS:")
                            print(f"    P@10:    {metrics['P@10']:.4f}")
                            print(f"    R@10:    {metrics['R@10']:.4f}")
                            print(f"    MRR@10:  {metrics['MRR@10']:.4f}")
                            print(f"    nDCG@10: {metrics['nDCG@10']:.4f}")
                            print(f"    Time:    {total_time:.1f}s")
                            
                            # Cleanup reranker
                            del reranker
                            gc.collect()
                            torch.cuda.empty_cache()
                            
                        except Exception as e:
                            print(f"    ✗ Reranker failed: {e}")
                            if 'reranker' in locals():
                                del reranker
                            gc.collect()
                            torch.cuda.empty_cache()
                            continue
                    
                    # Cleanup
                    chunked_ds.drop_index('embedding')
                    del embedder
                    del chunked_ds
                    del query_ds_embedded
                    gc.collect()
                    torch.cuda.empty_cache()
                    
                except Exception as e:
                    print(f"\n✗ FAILED: {e}")
                    import traceback
                    traceback.print_exc()
                    gc.collect()
                    torch.cuda.empty_cache()
                    continue

print(f"\n{'='*80}")
print("✓ ALL EXPERIMENTS COMPLETED!")
print(f"{'='*80}")

## Analysis

View all results in MLflow UI at: http://localhost:5000

### Key Questions:
1. Does semantic chunking outperform recursive chunking with reranking?
2. What is the optimal chunk size when reranking is applied?
3. Does semantic clustering threshold affect reranking performance?
4. Which embedder + chunker + reranker combination works best?
5. How much does reranking improve over retrieval-only (compare with notebook 15_2)?

### Expected Insights:
- Reranking should improve MRR@10 significantly (better at finding the best match)
- Smaller chunks might benefit more from reranking (more candidates to reorder)
- Semantic chunking might reduce reranking time (fewer chunks = fewer pairs to score)