In [5]:
import pandas as pd
import torch
import faiss
import numpy as np
import pandas as pd
from datasets import load_dataset
from rag.embeddings import LocalEmbedder

ds = load_dataset("rag-datasets/rag-mini-bioasq", "text-corpus")['passages']
test_ds = load_dataset("rag-datasets/rag-mini-bioasq", "question-answer-passages")['test']

In [6]:
# embedder = LocalEmbedder("BAAI/bge-large-en-v1.5", device='cuda')
# embedder = LocalEmbedder("BAAI/bge-small-en-v1.5", device='cuda')
# embedder = LocalEmbedder("Qwen/Qwen3-Embedding-8B", device='cuda')
embedder = LocalEmbedder("Qwen/Qwen3-Embedding-4B", device='cuda', model_kwargs={"dtype": torch.bfloat16})
# embedder = LocalEmbedder("Qwen/Qwen3-Embedding-0.6B", device='cuda', model_kwargs={"dtype": torch.bfloat16})

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 85.51it/s]


In [7]:
from tqdm import tqdm

# assumes you already have your embedder
tokenizer = embedder.model.tokenizer

def get_seq_lengths(ds, column="passage"):
    lengths = []
    for text in tqdm(ds[column], desc="Tokenizing"):
        tokens = tokenizer(text, truncation=False, padding=False, return_length=True)
        lengths.append(tokens["length"][0])
    return pd.Series(lengths)

# Example usage
ds_lengths = get_seq_lengths(ds, column="passage")

Tokenizing: 100%|██████████| 40221/40221 [00:10<00:00, 3989.56it/s]


In [8]:
ds_lengths.describe()

count    40221.000000
mean       249.147460
std        233.892122
min          2.000000
25%          2.000000
50%        267.000000
75%        387.000000
max       9579.000000
dtype: float64

In [2]:
def embed(batch, column):
    embs = embedder.embed_batch(batch[column])
    return {'embedding': embs}

ds = ds.map(
    embed,
    batch_size=8,
    batched=True,
    fn_kwargs={'column': 'passage'},
)
test_ds = test_ds.map(
    embed,
    batch_size=8,
    batched=True,
    fn_kwargs={'column': 'question'},
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 100.62it/s]
Map:  15%|█▌        | 6160/40221 [02:10<11:58, 47.38 examples/s] 

KeyboardInterrupt



In [None]:
ds = ds.add_faiss_index(
    column='embedding',
    string_factory='Flat',
    metric_type=faiss.METRIC_INNER_PRODUCT,
    batch_size=128,
)

# Precompute vals

In [None]:
passage_id_to_text = ds.select_columns(['id', 'passage']).to_pandas().set_index('id')['passage'].to_dict()
index_to_passage_id = np.array(ds['id'])
queries = np.array(test_ds['question'])

gold_sets = [np.array(eval(gold)) for gold in test_ds['relevant_passage_ids']]
gold_counts = [len(s) for s in gold_sets]

# Search

In [None]:
k = 5
res = ds.get_index('embedding').search_batch(np.array(test_ds['embedding']), k=k)
retrieved_ids = index_to_passage_id[res.total_indices]
retrieved_ids

# Metrics

In [None]:
def precision_at_k(hit_flags):
    return float(hit_flags.mean(axis=1).mean())

def recall_at_k(hit_flags, gold_counts):
    retrieved = hit_flags.sum(axis=1)
    recall = retrieved / gold_counts
    return float(recall.mean())

def mrr_at_k(hit_flags):
    has_hit = hit_flags.any(axis=1)
    ranks = np.argmax(hit_flags, axis=1) + 1
    reciprocal = np.zeros_like(ranks, dtype=np.float32)
    reciprocal[has_hit] = 1.0 / ranks[has_hit]
    return float(reciprocal.mean())

def ndcg_at_k(hit_flags: np.ndarray) -> float:
    # Binary relevance (1 if retrieved id is in gold set)
    # DCG = sum_{i=1..k} (rel_i / log2(i+1)), since rel∈{0,1} we can simplify
    k = hit_flags.shape[1]
    discounts = 1.0 / np.log2(np.arange(2, k+2))  # [1/log2(2), 1/log2(3), ...]
    dcg = (hit_flags * discounts).sum(axis=1)
    # Ideal DCG for binary relevance = sum of top min(k, |gold|) discounts
    # We need |gold| per query:
    ndcg = dcg / np.array([discounts[:min(k, c)].sum() if c > 0 else 1.0
                           for c in gold_counts])
    return float(ndcg.mean())

def get_hit_flags(retrieved_ids):
    hit_flags = np.zeros_like(retrieved_ids, dtype=np.bool)
    for i in range(len(retrieved_ids)):
        hit_flags[i] = np.isin(retrieved_ids[i], gold_sets[i])
    return hit_flags

# hit_flags = get_hit_flags(retrieved_ids)

def print_metrics(retrieved_ids, gold_counts, k):
    hit_flags = get_hit_flags(retrieved_ids)

    print(f'Embedding model: {embedder.model_name}')
    print(f"P@{k}:    {precision_at_k(hit_flags):.3f}")
    print(f"R@{k}:    {recall_at_k(hit_flags, gold_counts):.3f}")
    print(f"MRR@{k}:  {mrr_at_k(hit_flags):.3f}")
    print(f"nDCG@{k}: {ndcg_at_k(hit_flags):.3f}")

print_metrics(retrieved_ids, gold_counts, k)