In [1]:
import os
from datasets import load_from_disk, disable_caching, load_dataset
from rag.config import PROJECT_ROOT

doc_ds = load_dataset("rag-datasets/rag-mini-bioasq", "text-corpus", split="passages")
doc_ds = doc_ds.filter(lambda row: row['passage'] != 'nan')
query_ds = load_dataset("rag-datasets/rag-mini-bioasq", "question-answer-passages", split="test")

disable_caching()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from pathlib import Path
from rag.tracking import ExperimentTracker

tracker = ExperimentTracker('test_experiment')

with tracker.start_run(run_name='test'):
    tracker.log_dvc_data(PROJECT_ROOT / 'data' / 'rag-mini-bioasq.dvc')
    tracker.log_params({"chunk_size": 512})
    tracker.log_metrics({"P@10": 0.25})


[2025-10-26 08:41:34] [rag.tracking] [INFO] Tracking to: http://localhost:5000
[2025-10-26 08:41:34] [rag.tracking] [INFO] Experiment: test_experiment
[2025-10-26 08:41:34] [rag.tracking] [INFO] Logged DVC hash for rag-mini-bioasq.dvc: 0b9d04ec...
🏃 View run test at: http://localhost:5000/#/experiments/521976983953332643/runs/8c53a1d7929d46afa6ddbfd53bd19094
🧪 View experiment at: http://localhost:5000/#/experiments/521976983953332643


In [3]:
tracker = ExperimentTracker('embedder-comparison-bioasq')

[2025-10-26 08:41:34] [rag.tracking] [INFO] Tracking to: http://localhost:5000
[2025-10-26 08:41:34] [rag.tracking] [INFO] Experiment: embedder-comparison-bioasq


In [4]:
import numpy as np

# Precompute
doc_id_to_text = doc_ds.select_columns(['id', 'passage']).to_pandas().set_index('id')['passage'].to_dict()
index_to_doc_id = np.array(doc_ds['id'])
queries = np.array(query_ds['question'])

qrels = [np.array(eval(gold)) for gold in query_ds['relevant_passage_ids']]
qrels_counts = [len(s) for s in qrels]

In [5]:
import faiss
import torch
import gc
from time import time
from rag.utils import embed_dataset, get_metrics
from rag.embeddings import LocalEmbedder

embedder_models = [
        # "all-MiniLM-L6-v2",
        # "all-MiniLM-L12-v2",
        # "all-mpnet-base-v2",
        # "BAAI/bge-small-en-v1.5",
        # "BAAI/bge-base-en-v1.5",
        # "BAAI/bge-large-en-v1.5",
        # "Snowflake/snowflake-arctic-embed-l-v2.0",
        "jinaai/jina-embeddings-v3",
        "intfloat/e5-base-v2",
        "BAAI/bge-m3",
        "Lajavaness/bilingual-embedding-base",
]

faiss_metric = 'IP'
chunk_size = None
chunk_overlap = None
rerank_model = None


for embedder_name in embedder_models:
    embedder_name_short = embedder_name.split('/')[-1]
    try:
        embedder = LocalEmbedder(embedder_name, device="cuda")
        start_time = time()
        doc_ds = embed_dataset(doc_ds, embedder, column="passage")
        query_ds = embed_dataset(query_ds, embedder, column="question")
        elapsed_time = time() - start_time
    except Exception as e:
        print(f"Failed to embed {embedder_name}: {e}")
        gc.collect()
        torch.cuda.empty_cache()
        continue

    doc_ds.add_faiss_index(
        column='embedding',
        string_factory='Flat',
        metric_type=faiss.METRIC_INNER_PRODUCT,
        batch_size=128,
    )

    metrics = {}
    res = doc_ds.get_index('embedding').search_batch(np.array(query_ds['embedding']), k=11)
    retrieved_ids_all = index_to_doc_id[res.total_indices]

    for k in [1, 3, 5, 10]:
        retrieved_ids = retrieved_ids_all[:, :k]
        metrics = {
            **metrics,
            **get_metrics(retrieved_ids, query_ds, k),
        }
    metrics = {
        **{k: round(v, 4) for k,v in metrics.items()},
        "elapsed_time": round(elapsed_time, 1),
    }

    params = {
        'embed_model': embedder_name,
        'rerank_model': rerank_model,
        'chunked': False,
        'chunk_size': chunk_size,
        'chunk_overlap': chunk_overlap,
        'faiss_metric': faiss_metric,
    }

    res_dict = {
        **params,
        **metrics,
    }

    run_name = f"{embedder_name_short}"
    tags = {
        'experiment_type': 'embedder',
        'phase': 'exploration',
        'dataset': 'bioasq-mini',
        'embedder':  embedder_name_short,
    }
    with tracker.start_run(run_name=run_name, tags=tags):
        tracker.log_params(params)
        tracker.log_metrics(metrics)
        tracker.log_dvc_data(PROJECT_ROOT / 'data' / 'rag-mini-bioasq.dvc')
        tracker.log_dvc_data(PROJECT_ROOT / 'data' / 'rag-mini-bioasq-qrels.dvc')

    # FIX 3: Drop FAISS index
    doc_ds.drop_index('embedding')

    # Cleanup
    del embedder
    gc.collect()
    torch.cuda.empty_cache()

Failed to embed jinaai/jina-embeddings-v3: No module named 'custom_st'


Map: 100%|██████████| 40221/40221 [01:34<00:00, 427.61 examples/s]
Map: 100%|██████████| 4719/4719 [00:03<00:00, 1527.60 examples/s]
100%|██████████| 315/315 [00:00<00:00, 2616.95it/s]


[2025-10-26 08:43:40] [rag.tracking] [INFO] Logged DVC hash for rag-mini-bioasq.dvc: 0b9d04ec...
[2025-10-26 08:43:40] [rag.tracking] [INFO] Logged DVC hash for rag-mini-bioasq-qrels.dvc: c39757e3...
🏃 View run e5-base-v2 at: http://localhost:5000/#/experiments/803843227128346260/runs/1b231e46bc5e43e8aaa820a90e0a2e8a
🧪 View experiment at: http://localhost:5000/#/experiments/803843227128346260


Map: 100%|██████████| 40221/40221 [04:55<00:00, 136.17 examples/s]
Map: 100%|██████████| 4719/4719 [00:06<00:00, 708.12 examples/s]
100%|██████████| 315/315 [00:00<00:00, 4196.89it/s]


[2025-10-26 08:49:13] [rag.tracking] [INFO] Logged DVC hash for rag-mini-bioasq.dvc: 0b9d04ec...
[2025-10-26 08:49:13] [rag.tracking] [INFO] Logged DVC hash for rag-mini-bioasq-qrels.dvc: c39757e3...
🏃 View run bge-m3 at: http://localhost:5000/#/experiments/803843227128346260/runs/1b64799b81034a9ab23371700720d5e1
🧪 View experiment at: http://localhost:5000/#/experiments/803843227128346260
Failed to embed Lajavaness/bilingual-embedding-base: dangvantuan/bilingual_impl You can inspect the repository content at https://hf.co/Lajavaness/bilingual-embedding-base.
Please pass the argument `trust_remote_code=True` to allow custom code to be run.
