In [1]:
%load_ext autoreload
%autoreload 2

import os
import gc
import torch
import faiss
import numpy as np
import pandas as pd
from datasets import load_dataset, Dataset
from langchain_text_splitters import RecursiveCharacterTextSplitter
from time import time
from tqdm import tqdm

from rag.embeddings import LocalEmbedder
from rag.utils import get_metrics, embed_dataset

# Load datasets
doc_ds = load_dataset("rag-datasets/rag-mini-bioasq", "text-corpus")['passages']
query_ds = load_dataset("rag-datasets/rag-mini-bioasq", "question-answer-passages")['test']

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Precompute query information
queries = np.array(query_ds['question'])
qrels = [np.array(eval(gold)) for gold in query_ds['relevant_passage_ids']]
qrels_counts = [len(s) for s in qrels]

In [3]:
# Define chunking function
def chunk_documents(dataset, chunker, text_col='passage', id_col='id'):
    chunked_docs = []
    pbar = tqdm(total=len(dataset), desc='Chunking')
    for doc in dataset:
        text = doc[text_col]
        parent_id = doc[id_col]
        chunks = chunker.split_text(text)
        for i, chunk in enumerate(chunks):
            chunked_docs.append({
                'passage': chunk,
                'parent_id': parent_id,
                'chunk_id': i,
            })
        pbar.update(1)
    pbar.close()
    return Dataset.from_list(chunked_docs)


In [4]:
# Embedder models to compare
embedder_models = [
    "all-MiniLM-L6-v2",
    "all-MiniLM-L12-v2",
    "all-mpnet-base-v2",
    "nomic-ai/nomic-embed-text-v1.5",
    "BAAI/bge-small-en-v1.5",
    "BAAI/bge-base-en-v1.5",
    "BAAI/bge-large-en-v1.5",
    "Alibaba-NLP/gte-multilingual-base",
    "Snowflake/snowflake-arctic-embed-l-v2.0",
    "jinaai/jina-embeddings-v3",
    "intfloat/e5-base-v2",
    "BAAI/bge-m3",
    "Lajavaness/bilingual-embedding-base",
    "Qwen/Qwen3-Embedding-0.6B",
]

# Chunking parameters
chunk_size = 256
chunk_overlap = 50

for i, model_name in enumerate(embedder_models):
    print("=" * 20, f"[{i + 1}/{len(embedder_models)}]", "=" * 20)
    print(f"Model: {model_name}")
    
    try:
        # Create embedder
        embedder = LocalEmbedder(model_name, device="cuda")
        
        # Create tokenizer-aware chunker
        tokenizer = embedder.model.tokenizer
        def count_tokens(text):
            return len(tokenizer.encode(text))
        
        chunker = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            length_function=count_tokens,
            separators=["\n\n", "\n", ". ", " ", ""],
        )
        
        # Chunk documents
        chunked_ds = chunk_documents(doc_ds, chunker)
        print(f"Created {len(chunked_ds)} chunks from {len(doc_ds)} documents")
        
        start_time = time()
        # Embed chunked documents and queries
        chunked_ds = embed_dataset(chunked_ds, embedder, column='passage')
        query_ds = embed_dataset(query_ds, embedder, column='question')
        elapsed_time = time() - start_time
        
        # Create mapping from chunk index to parent document ID
        index_to_parent_id = np.array(chunked_ds['parent_id'])
        
    except Exception as e:
        print(f"Failed to embed {model_name}: {e}")
        if 'embedder' in locals():
            del embedder
        gc.collect()
        torch.cuda.empty_cache()
        continue
    
    # Test with both FAISS metrics
    for faiss_metric in ["IP", "L2"]:
        # Add FAISS index
        chunked_ds.add_faiss_index(
            column='embedding',
            string_factory='Flat',
            metric_type=faiss.METRIC_L2 if faiss_metric == 'L2' else faiss.METRIC_INNER_PRODUCT,
            batch_size=128,
        )
        
        metrics = {}
        
        # Compute metrics for different k values
        for k in [1, 3, 5, 10]:
            res = chunked_ds.get_index('embedding').search_batch(
                np.array(query_ds['embedding']), k=k
            )
            # Map chunk indices to parent document IDs
            retrieved_ids = index_to_parent_id[res.total_indices]
            
            metrics = {
                **metrics,
                **get_metrics(retrieved_ids, query_ds, k),
            }
        
        # Save results
        res_dict = {
            'model': model_name,
            'faiss_metric': faiss_metric,
            'chunked': True,
            'chunk_size': None,
            'chunk_overlap': None,
            'rerank_model': None,
            **{k: round(v, 3) for k, v in metrics.items()},
            "elapsed_time": round(elapsed_time, 1),
        }
        
        res_df = pd.DataFrame([res_dict])
        csv_path = "results.csv"
        append = os.path.exists(csv_path) and os.path.getsize(csv_path) > 0
        res_df.to_csv(csv_path, mode='a', header=not append, index=False)
        
        # Remove FAISS index for next metric
        chunked_ds.drop_index('embedding')
    
    # Print summary
    print(f"P@10    {metrics['P@10']:.3f}")
    print(f"R@10    {metrics['R@10']:.3f}")
    print(f"MRR@10  {metrics['MRR@10']:.3f}")
    print(f"nDCG@10 {metrics['nDCG@10']:.3f}")
    print(f"Time: {elapsed_time:.1f}s")
    print()
    
# Clean up
gc.collect()
torch.cuda.empty_cache()

print("Comparison complete! Results saved to results.csv")

Model: all-MiniLM-L6-v2


Chunking: 100%|██████████| 40221/40221 [00:36<00:00, 1094.15it/s]


Created 76262 chunks from 40221 documents


Map: 100%|██████████| 76262/76262 [00:40<00:00, 1893.54 examples/s]
Map: 100%|██████████| 4719/4719 [00:01<00:00, 3093.77 examples/s]
100%|██████████| 596/596 [00:00<00:00, 4391.82it/s]
100%|██████████| 596/596 [00:00<00:00, 4627.39it/s]


P@10    0.356
R@10    0.538
MRR@10  0.651
nDCG@10 0.573
Time: 56.6s

Model: all-MiniLM-L12-v2


Chunking: 100%|██████████| 40221/40221 [00:37<00:00, 1075.73it/s]


Created 76262 chunks from 40221 documents


Map: 100%|██████████| 76262/76262 [00:51<00:00, 1495.24 examples/s]
Map: 100%|██████████| 4719/4719 [00:02<00:00, 2068.20 examples/s]
100%|██████████| 596/596 [00:00<00:00, 4358.89it/s]
100%|██████████| 596/596 [00:00<00:00, 4408.07it/s]


P@10    0.331
R@10    0.496
MRR@10  0.626
nDCG@10 0.535
Time: 68.9s

Model: all-mpnet-base-v2


Chunking: 100%|██████████| 40221/40221 [00:35<00:00, 1139.28it/s]


Created 76262 chunks from 40221 documents


Map: 100%|██████████| 76262/76262 [01:47<00:00, 710.80 examples/s]
Map: 100%|██████████| 4719/4719 [00:03<00:00, 1458.66 examples/s]
100%|██████████| 596/596 [00:00<00:00, 2645.10it/s]
100%|██████████| 596/596 [00:00<00:00, 2828.89it/s]


P@10    0.336
R@10    0.508
MRR@10  0.609
nDCG@10 0.538
Time: 127.0s

Model: nomic-ai/nomic-embed-text-v1.5
Failed to embed nomic-ai/nomic-embed-text-v1.5: nomic-ai/nomic-bert-2048 You can inspect the repository content at https://hf.co/nomic-ai/nomic-embed-text-v1.5.
Please pass the argument `trust_remote_code=True` to allow custom code to be run.
Model: BAAI/bge-small-en-v1.5


Chunking: 100%|██████████| 40221/40221 [00:39<00:00, 1013.21it/s]


Created 76262 chunks from 40221 documents


Map: 100%|██████████| 76262/76262 [01:03<00:00, 1195.04 examples/s]
Map: 100%|██████████| 4719/4719 [00:02<00:00, 2089.07 examples/s]
100%|██████████| 596/596 [00:00<00:00, 4142.60it/s]
100%|██████████| 596/596 [00:00<00:00, 4265.76it/s]


P@10    0.419
R@10    0.652
MRR@10  0.750
nDCG@10 0.694
Time: 81.7s

Model: BAAI/bge-base-en-v1.5


Chunking: 100%|██████████| 40221/40221 [00:37<00:00, 1082.70it/s]


Created 76262 chunks from 40221 documents


Map: 100%|██████████| 76262/76262 [01:39<00:00, 765.44 examples/s]
Map: 100%|██████████| 4719/4719 [00:03<00:00, 1557.03 examples/s]
100%|██████████| 596/596 [00:00<00:00, 2768.41it/s]
100%|██████████| 596/596 [00:00<00:00, 2839.20it/s]


P@10    0.437
R@10    0.687
MRR@10  0.757
nDCG@10 0.721
Time: 119.3s

Model: BAAI/bge-large-en-v1.5


Chunking: 100%|██████████| 40221/40221 [00:37<00:00, 1075.24it/s]


Created 76262 chunks from 40221 documents


Map: 100%|██████████| 76262/76262 [04:04<00:00, 311.76 examples/s]
Map: 100%|██████████| 4719/4719 [00:06<00:00, 747.42 examples/s]
100%|██████████| 596/596 [00:00<00:00, 2398.06it/s]
100%|██████████| 596/596 [00:00<00:00, 2454.25it/s]


P@10    0.442
R@10    0.699
MRR@10  0.760
nDCG@10 0.731
Time: 267.8s

Model: Alibaba-NLP/gte-multilingual-base
Failed to embed Alibaba-NLP/gte-multilingual-base: Alibaba-NLP/new-impl You can inspect the repository content at https://hf.co/Alibaba-NLP/gte-multilingual-base.
Please pass the argument `trust_remote_code=True` to allow custom code to be run.
Model: Snowflake/snowflake-arctic-embed-l-v2.0


Chunking: 100%|██████████| 40221/40221 [00:38<00:00, 1033.92it/s]


Created 82827 chunks from 40221 documents


Map: 100%|██████████| 82827/82827 [04:29<00:00, 307.75 examples/s]
Map: 100%|██████████| 4719/4719 [00:06<00:00, 705.74 examples/s]
100%|██████████| 648/648 [00:00<00:00, 2440.32it/s]
100%|██████████| 648/648 [00:00<00:00, 2466.21it/s]


P@10    0.402
R@10    0.644
MRR@10  0.711
nDCG@10 0.667
Time: 293.9s

Model: jinaai/jina-embeddings-v3
Failed to embed jinaai/jina-embeddings-v3: No module named 'custom_st'
Model: intfloat/e5-base-v2


Chunking: 100%|██████████| 40221/40221 [00:45<00:00, 883.09it/s] 


Created 76262 chunks from 40221 documents


Map: 100%|██████████| 76262/76262 [01:40<00:00, 757.66 examples/s]
Map: 100%|██████████| 4719/4719 [00:03<00:00, 1533.29 examples/s]
100%|██████████| 596/596 [00:00<00:00, 3084.96it/s]
100%|██████████| 596/596 [00:00<00:00, 3764.10it/s]


P@10    0.422
R@10    0.662
MRR@10  0.752
nDCG@10 0.700
Time: 120.5s

Model: BAAI/bge-m3


Chunking: 100%|██████████| 40221/40221 [00:38<00:00, 1055.09it/s]


Created 82827 chunks from 40221 documents


Map: 100%|██████████| 82827/82827 [04:24<00:00, 312.73 examples/s]
Map: 100%|██████████| 4719/4719 [00:06<00:00, 693.38 examples/s]
100%|██████████| 648/648 [00:00<00:00, 3180.72it/s]
100%|██████████| 648/648 [00:00<00:00, 3337.51it/s]


P@10    0.415
R@10    0.663
MRR@10  0.745
nDCG@10 0.694
Time: 290.1s

Model: Lajavaness/bilingual-embedding-base
Failed to embed Lajavaness/bilingual-embedding-base: dangvantuan/bilingual_impl You can inspect the repository content at https://hf.co/Lajavaness/bilingual-embedding-base.
Please pass the argument `trust_remote_code=True` to allow custom code to be run.
Model: Qwen/Qwen3-Embedding-0.6B


Chunking: 100%|██████████| 40221/40221 [00:46<00:00, 871.63it/s] 


Created 75905 chunks from 40221 documents


Map: 100%|██████████| 75905/75905 [06:38<00:00, 190.68 examples/s]
Map: 100%|██████████| 4719/4719 [00:10<00:00, 461.84 examples/s]
100%|██████████| 594/594 [00:00<00:00, 1337.34it/s]
100%|██████████| 594/594 [00:00<00:00, 2221.51it/s]


P@10    0.405
R@10    0.639
MRR@10  0.715
nDCG@10 0.670
Time: 426.3s

Comparison complete! Results saved to results.csv
