In [1]:
%load_ext autoreload
%autoreload 2

import faiss
import pandas as pd
from datasets import load_dataset

from rag.embeddings import create_embedder

  from .autonotebook import tqdm as notebook_tqdm


# Load RAG bioasq dataset

In [13]:
ds = load_dataset("rag-datasets/rag-mini-bioasq", "text-corpus")['passages']
ds

Dataset({
    features: ['passage', 'id'],
    num_rows: 40221
})

In [7]:
print(ds[0]['passage'])

New data on viruses isolated from patients with subacute thyroiditis de Quervain 
are reported. Characteristic morphological, cytological, some physico-chemical 
and biological features of the isolated viruses are described. A possible role 
of these viruses in human and animal health disorders is discussed. The isolated 
viruses remain unclassified so far.


In [14]:
test_ds = load_dataset("rag-datasets/rag-mini-bioasq", "question-answer-passages")['test']
test_ds

Dataset({
    features: ['question', 'answer', 'relevant_passage_ids', 'id'],
    num_rows: 4719
})

In [9]:
test_ds[0]

{'question': 'Is Hirschsprung disease a mendelian or a multifactorial disorder?',
 'answer': "Coding sequence mutations in RET, GDNF, EDNRB, EDN3, and SOX10 are involved in the development of Hirschsprung disease. The majority of these genes was shown to be related to Mendelian syndromic forms of Hirschsprung's disease, whereas the non-Mendelian inheritance of sporadic non-syndromic Hirschsprung disease proved to be complex; involvement of multiple loci was demonstrated in a multiplicative model.",
 'relevant_passage_ids': '[20598273, 6650562, 15829955, 15617541, 23001136, 8896569, 21995290, 12239580, 15858239]',
 'id': 0}

# Embed with Local Embedderm

In [201]:
from rag.embeddings import create_embedder
from rag.embeddings import LocalEmbedder
from rag.config import settings


# embedder = LocalEmbedder("pritamdeka/S-BioBert-snli-multinli-stsb")
embedder = LocalEmbedder("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")
# embedder = create_embedder(settings)
embedder

LocalEmbedder("pritamdeka/S-BioBert-snli-multinli-stsb", dim=768)

In [205]:
from rag.config import PROJECT_ROOT

def embed(batch, column):
    embs = embedder.embed_batch(batch[column])
    return {'embedding': embs}

ds = ds.map(
    embed,
    batched=True,
    fn_kwargs={'column': 'passage'},
    # load_from_cache_file=True,
    # cache_file_name= str(PROJECT_ROOT / ".cache/rag_bioasq_mini.arrow")
)
ds

Map: 100%|██████████| 40221/40221 [00:15<00:00, 2657.03 examples/s]


Dataset({
    features: ['passage', 'id', 'embedding'],
    num_rows: 40221
})

In [206]:
test_ds = test_ds.map(
    embed,
    batched=True,
    fn_kwargs={'column': 'question'},
    # load_from_cache_file=True,
    # cache_file_name= str(PROJECT_ROOT / ".cache/rag_bioasq_mini_test.arrow")
)
test_ds

Map: 100%|██████████| 4719/4719 [00:01<00:00, 4573.82 examples/s]


Dataset({
    features: ['question', 'answer', 'relevant_passage_ids', 'id', 'embedding'],
    num_rows: 4719
})

# Add native faiss index by arrow

In [207]:
ds = ds.add_faiss_index(
    column='embedding',
    string_factory='Flat',
    metric_type=faiss.METRIC_INNER_PRODUCT,
    batch_size=128,
)
ds

100%|██████████| 315/315 [00:00<00:00, 3417.69it/s]


Dataset({
    features: ['passage', 'id', 'embedding'],
    num_rows: 40221
})

In [25]:
ds.save_faiss_index("embedding", "bioasq-mini-arrow.index")
ds.drop_index("embedding")
ds.save_to_disk("bioasq-mini-arrow.docs")

Saving the dataset (1/1 shards): 100%|██████████| 40221/40221 [00:00<00:00, 1121341.50 examples/s]


In [72]:
test_ds.save_to_disk("bioasq-mini-arrow.qrels")

Saving the dataset (1/1 shards): 100%|██████████| 4719/4719 [00:00<00:00, 665935.02 examples/s]


# Check native arrow searchm

In [208]:
import numpy as np

res = ds.get_index('embedding').search_batch(np.array(test_ds['embedding']), k=5)
res

BatchedSearchResults(total_scores=array([[0.76342267, 0.7587638 , 0.7465499 , 0.7461885 , 0.7456034 ],
       [0.7370676 , 0.7298986 , 0.71763325, 0.6868064 , 0.68301857],
       [0.47025472, 0.45580637, 0.4535181 , 0.45100364, 0.44160718],
       ...,
       [0.45328766, 0.42834824, 0.41921014, 0.41430444, 0.41389298],
       [0.8394382 , 0.81101155, 0.757771  , 0.7550446 , 0.7212372 ],
       [0.7403879 , 0.696525  , 0.68430775, 0.6641858 , 0.6601447 ]],
      shape=(4719, 5), dtype=float32), total_indices=array([[  691,  5660, 22260, 18975,  8872],
       [23647, 18918, 28563, 23167,  6232],
       [39440,  7902,  8430,  4422, 35272],
       ...,
       [23197, 24967, 35449, 24662, 28307],
       [39704, 12552, 37768,  9760, 12296],
       [37276,  3970, 23266, 11042,  9249]], shape=(4719, 5)))

In [62]:
res.total_indices

array([[22260,   435,  3307, 11166, 21024],
       [23647, 32793, 21714, 40093, 24595],
       [ 7017,  1072,  4422, 39902,  6083],
       ...,
       [38617, 34598, 12269,  8809, 14798],
       [34585, 12552, 39704,  1676, 12296],
       [37276, 26956,  2099, 35832,  5886]], shape=(4719, 5))

# Index using faiss lib

In [43]:
import numpy as np
import faiss

index = faiss.IndexFlatIP(embedder.dimension)
# index = faiss.IndexFlatL2(embedder.dimension)

In [44]:
index.add(np.array(ds['embedding']))

In [38]:
index

<faiss.swigfaiss_avx512.IndexFlatIP; proxy of <Swig Object of type 'faiss::IndexFlatIP *' at 0x77d380582d30> >

In [39]:
index.ntotal

40221

In [48]:
k = 5

distances, pred_ids = index.search(np.array(test_ds['embedding']), k=k)

In [41]:
distances.dtype

dtype('float32')

In [46]:
pred_ids

array([[22260,   435,  3307, 11166, 21024],
       [23647, 32793, 21714, 40093, 24595],
       [ 7017,  1072,  4422, 39902,  6083],
       ...,
       [38617, 34598, 12269,  8809, 14798],
       [34585, 12552, 39704,  1676, 12296],
       [37276, 26956,  2099, 35832,  5886]], shape=(4719, 5))

# Evaluate with RAGAS

In [67]:
# load dot env for OPENAI API keys
from dotenv import load_dotenv

load_dotenv(PROJECT_ROOT / '.env')

True

In [68]:
res = ds.get_index('embedding').search_batch(np.array(test_ds['embedding']), k=5)
res

BatchedSearchResults(total_scores=array([[0.6631991 , 0.6363485 , 0.61760235, 0.60579264, 0.6037774 ],
       [0.5931436 , 0.54778117, 0.5428606 , 0.5413512 , 0.5342575 ],
       [0.61819255, 0.5982094 , 0.5524874 , 0.47593287, 0.45271417],
       ...,
       [0.3502987 , 0.35011643, 0.32357284, 0.31465802, 0.2996544 ],
       [0.7555245 , 0.72723746, 0.71668005, 0.6510484 , 0.6401484 ],
       [0.626945  , 0.5422195 , 0.53387046, 0.51882434, 0.51590806]],
      shape=(4719, 5), dtype=float32), total_indices=array([[22260,   435,  3307, 11166, 21024],
       [23647, 32793, 21714, 40093, 24595],
       [ 7017,  1072,  4422, 39902,  6083],
       ...,
       [38617, 34598, 12269,  8809, 14798],
       [34585, 12552, 39704,  1676, 12296],
       [37276, 26956,  2099, 35832,  5886]], shape=(4719, 5)))

In [70]:
res.total_indices

array([[22260,   435,  3307, 11166, 21024],
       [23647, 32793, 21714, 40093, 24595],
       [ 7017,  1072,  4422, 39902,  6083],
       ...,
       [38617, 34598, 12269,  8809, 14798],
       [34585, 12552, 39704,  1676, 12296],
       [37276, 26956,  2099, 35832,  5886]], shape=(4719, 5))

In [77]:
index_to_passage_id = np.array(ds['id'])

retrieved_ids = index_to_passage_id[res.total_indices]
retrieved_ids

array([[23001136,  1785632,  9727738, 17965226, 22584707],
       [23382875, 27426127, 22829865, 34667080, 23637683],
       [15094122,  3320045, 11076767, 34489718, 12666201],
       ...,
       [32529410, 28624872, 18637493, 16394582, 20007317],
       [28614408, 18824533, 34169075,  7676521, 18654798],
       [30580288, 24310308,  8275569, 29383495, 12497758]],
      shape=(4719, 5))

In [78]:
test_ds

Dataset({
    features: ['question', 'answer', 'relevant_passage_ids', 'id', 'embedding'],
    num_rows: 4719
})

In [105]:
gold_sets = [np.array(eval(gold)) for gold in test_ds['relevant_passage_ids']]
len(gold_sets)

4719

In [150]:
num_q, k = retrieved_ids.shape
hit_flags = np.zeros((num_q, k), dtype=int)
hit_flags

array([[0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       ...,
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0]], shape=(4719, 5))

In [151]:
retrieved_ids[0]

array([23001136,  1785632,  9727738, 17965226, 22584707])

In [152]:
gold_sets[0]

array([20598273,  6650562, 15829955, 15617541, 23001136,  8896569,
       21995290, 12239580, 15858239])

In [153]:
np.isin(retrieved_ids[0], gold_sets[0])

array([ True, False, False, False, False])

In [154]:
for i in range(num_q):
    hit_flags[i] = np.isin(retrieved_ids[i], gold_sets[i])
hit_flags

array([[1, 0, 0, 0, 0],
       [1, 0, 0, 0, 0],
       [1, 1, 1, 0, 1],
       ...,
       [0, 0, 0, 0, 0],
       [1, 1, 1, 0, 1],
       [1, 0, 0, 0, 0]], shape=(4719, 5))

In [155]:
gold_counts = np.array([len(s) for s in gold_sets])

In [156]:
def precision_at_k(hit_flags):
    return float(hit_flags.mean(axis=1).mean())

def recall_at_k(hit_flags, gold_counts):
    retrieved = hit_flags.sum(axis=1)
    recall = retrieved / gold_counts
    return float(recall.mean())


precision_at_k(hit_flags)
recall_at_k(hit_flags, gold_counts)

0.2926906106714452

In [157]:
# argmax to get first
ranks = np.argmax(hit_flags, axis=1) + 1
ranks[:100]

array([1, 1, 1, 2, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       2, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 2, 5, 1, 3, 2, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 5, 1, 1,
       1, 1, 1, 1, 4, 5, 2, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1])

In [158]:
has_hit = hit_flags.any(axis=1)
reciprocal = np.zeros_like(ranks, dtype=np.float32)
reciprocal[has_hit] = 1.0 / ranks[has_hit]
reciprocal[:100]

array([1.        , 1.        , 1.        , 0.5       , 0.        ,
       0.        , 1.        , 0.5       , 0.5       , 0.        ,
       1.        , 1.        , 1.        , 0.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 0.5       , 1.        , 1.        ,
       0.33333334, 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       0.        , 1.        , 1.        , 1.        , 1.        ,
       0.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 0.5       , 0.2       ,
       1.        , 0.33333334, 0.5       , 0.5       , 0.5       ,
       1.        , 1.        , 1.        , 0.        , 0.        ,
       0.        , 0.5       , 1.        , 0.2       , 1.        ,
       1.        , 0.        , 1.        , 1.        , 0.        ,
       0.25      , 0.2       , 0.5       , 1.        , 0.5    

In [159]:
float(reciprocal.mean())

0.6261213421821594

In [160]:
def mrr_at_k(hit_flags):
    has_hit = hit_flags.any(axis=1)
    ranks = np.argmax(hit_flags, axis=1) + 1
    reciprocal = np.zeros_like(ranks, dtype=np.float32)
    reciprocal[has_hit] = 1.0 / ranks[has_hit]
    return float(reciprocal.mean())

mrr_at_k(hit_flags)

0.6261213421821594

In [187]:
def ndcg_at_k(hit_flags: np.ndarray) -> float:
    # Binary relevance (1 if retrieved id is in gold set)
    # DCG = sum_{i=1..k} (rel_i / log2(i+1)), since rel∈{0,1} we can simplify
    k = hit_flags.shape[1]
    discounts = 1.0 / np.log2(np.arange(2, k+2))  # [1/log2(2), 1/log2(3), ...]
    dcg = (hit_flags * discounts).sum(axis=1)
    # Ideal DCG for binary relevance = sum of top min(k, |gold|) discounts
    # We need |gold| per query:
    ndcg = dcg / np.array([discounts[:min(k, c)].sum() if c > 0 else 1.0
                           for c in gold_counts])
    return float(ndcg.mean())

ndcg_at_k(hit_flags)

0.5596524687433778

# Embed

In [2]:
from datasets import load_dataset
from

ds = load_dataset("rag-datasets/rag-mini-bioasq", "text-corpus")['passages']
test_df = load_dataset("rag-datasets/rag-mini-bioasq", "question-answer-passages")['test']

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# embedder = LocalEmbedder("BAAI/bge-large-en-v1.5", device='cpu')
# embedder = LocalEmbedder("BAAI/bge-small-en-v1.5", device='cuda')
# embedder = LocalEmbedder("Qwen/Qwen3-Embedding-8B", device='cuda')
embedder = LocalEmbedder("Qwen/Qwen3-Embedding-0.6B", device='cuda')

def embed(batch, column):
    embs = embedder.embed_batch(batch[column])
    return {'embedding': embs}

ds = ds.map(
    embed,
    batched=True,
    fn_kwargs={'column': 'passage'},
)
test_ds = test_ds.map(
    embed,
    batched=True,
    fn_kwargs={'column': 'question'},
)

NameError: name 'LocalEmbedder' is not defined

In [252]:
ds = ds.add_faiss_index(
    column='embedding',
    string_factory='Flat',
    metric_type=faiss.METRIC_INNER_PRODUCT,
    batch_size=128,
)

100%|██████████| 315/315 [00:00<00:00, 5794.48it/s]


# Precompute,m

In [275]:
passage_id_to_text = ds.select_columns(['id', 'passage']).to_pandas().set_index('id')['passage'].to_dict()

In [276]:
index_to_passage_id = np.array(ds['id'])
queries = np.array(test_ds['question'])

gold_sets = [np.array(eval(gold)) for gold in test_ds['relevant_passage_ids']]
gold_counts = [len(s) for s in gold_sets]

# Actual Search

In [308]:
k0 = 100
res = ds.get_index('embedding').search_batch(np.array(test_ds['embedding']), k=k0)
retrieved_ids = index_to_passage_id[res.total_indices]
retrieved_ids

array([[15617541, 23001136,  2309705, ..., 27714920,  8862623, 18091433],
       [23382875, 34667080, 29680500, ..., 16730855, 16049312, 24307346],
       [15094122,  3320045, 11076767, ..., 20130175, 24022122, 22751350],
       ...,
       [33826820, 32176765, 30559259, ..., 24298040, 25457975, 23647909],
       [34169075, 28614408, 17042799, ..., 19929788, 28381231, 25470471],
       [30580288, 22318908, 21834047, ..., 24840526, 16937455, 19822006]],
      shape=(4719, 100))

# Get Metrics

In [298]:
def get_hit_flags(retrieved_ids):
    hit_flags = np.zeros_like(retrieved_ids, dtype=np.bool)
    for i in range(len(retrieved_ids)):
        hit_flags[i] = np.isin(retrieved_ids[i], gold_sets[i])
    return hit_flags

# hit_flags = get_hit_flags(retrieved_ids)

def print_metrics(retrieved_ids, gold_counts, k):
    hit_flags = get_hit_flags(retrieved_ids)

    print(f'Embedding model: {embedder.model_name}')
    print(f"P@{k}:    {precision_at_k(hit_flags):.3f}")
    print(f"R@{k}:    {recall_at_k(hit_flags, gold_counts):.3f}")
    print(f"MRR@{k}:  {mrr_at_k(hit_flags):.3f}")
    print(f"nDCG@{k}: {ndcg_at_k(hit_flags):.3f}")

print_metrics(retrieved_ids, gold_counts, k0)

Embedding model: BAAI/bge-small-en-v1.5
P@100:    0.459
R@100:    0.363
MRR@100:  0.734
nDCG@100: 0.598


In [299]:
retrieved_ids

array([[15617541, 23001136,  2309705, 16816022, 12239580],
       [23382875, 34667080, 29680500, 16159418, 15871762],
       [15094122,  3320045, 11076767, 12666201, 11419941],
       ...,
       [33826820, 32176765, 30559259, 30462303, 19351152],
       [34169075, 28614408, 17042799, 18824533, 18654798],
       [30580288, 22318908, 21834047, 10620111,  2300390]],
      shape=(4719, 5))

# Reranker

In [311]:
from tqdm import trange
from sentence_transformers import CrossEncoder
from FlagEmbedding import FlagReranker
import numpy as np

k = 5

ce = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=512, device='cuda')
# fr = FlagReranker("BAAI/bge-reranker-base", max_length=512, device='cuda')
num_q, k0 = retrieved_ids.shape
pairs, ptr = [], []

for i in trange(num_q):
    pids = index_to_passage_id[res.total_indices[i]]
    ctxs = [passage_id_to_text[pid] for pid in pids]
    for j, ctx in enumerate(ctxs):
        pairs.append((queries[i], ctx))
        ptr.append((i, j))

# scores_flat = fr.compute_score(pairs, batch_size=128)  # GPU if available
scores_flat = ce.predict(pairs, show_progress_bar=True, batch_size=128)  # GPU if available
scores = np.zeros((num_q, k0), dtype=np.float32)

for s, (i, j) in zip(scores_flat, ptr):
    scores[i, j] = s

order = np.argsort(-scores, axis=1)
topk = order[:, :k]  # final K
reranked_ids = np.take_along_axis(retrieved_ids, topk, axis=1)

print_metrics(reranked_ids, gold_counts, k)

100%|██████████| 4719/4719 [00:00<00:00, 19540.25it/s]
pre tokenize:   0%|          | 0/3687 [00:00<?, ?it/s]You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
pre tokenize: 100%|██████████| 3687/3687 [00:58<00:00, 63.39it/s]
Compute Scores: 100%|██████████| 3687/3687 [12:52<00:00,  4.78it/s]


Embedding model: BAAI/bge-small-en-v1.5
P@5:    0.406
R@5:    0.331
MRR@5:  0.649
nDCG@5: 0.524
