In [1]:
%load_ext autoreload
%autoreload 2


import faiss
import numpy as np
import pandas as pd
from datasets import load_dataset

from rag.embeddings import LocalEmbedder
from rag.utils import embed_biorag_datasets, precision_at_k, recall_at_k, mrr_at_k, ndcg_at_k, get_hit_flags, \
    get_metrics

doc_ds = load_dataset("rag-datasets/rag-mini-bioasq", "text-corpus")['passages']
query_ds = load_dataset("rag-datasets/rag-mini-bioasq", "question-answer-passages")['test']

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
faiss_metric = 'IP' #  L2 or IP

embedder = LocalEmbedder("all-MiniLM-L6-v2", device="cuda")
# embedder = LocalEmbedder("BAAI/bge-large-en-v1.5", device='cuda')
# embedder = LocalEmbedder("BAAI/bge-small-en-v1.5", device='cpu')

In [3]:
from tqdm import tqdm

# assumes you already have your embedder
tokenizer = embedder.model.tokenizer

def get_seq_lengths(ds, column="passage"):
    lengths = []
    for text in tqdm(ds[column], desc="Tokenizing"):
        tokens = tokenizer(text, truncation=False, padding=False, return_length=True)
        lengths.append(tokens["length"][0])
    return pd.Series(lengths)

# Example usage
ds_lengths = get_seq_lengths(doc_ds, column="passage")

Tokenizing:   0%|          | 0/40221 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (340 > 256). Running this sequence through the model will result in indexing errors
Tokenizing: 100%|██████████| 40221/40221 [00:08<00:00, 4517.34it/s]


In [4]:
ds_lengths.describe()

count    40221.000000
mean       238.106586
std        218.792328
min          3.000000
25%          3.000000
50%        258.000000
75%        371.000000
max       8099.000000
dtype: float64

In [5]:
doc_ds, query_ds = embed_biorag_datasets(doc_ds, query_ds, embedder)

Map: 100%|██████████| 40221/40221 [00:22<00:00, 1810.40 examples/s]
Map: 100%|██████████| 4719/4719 [00:01<00:00, 3122.70 examples/s]


In [6]:
doc_ds.add_faiss_index(
    column='embedding',
    string_factory='Flat',
    metric_type=faiss.METRIC_L2 if faiss_metric == 'L2' else faiss.METRIC_INNER_PRODUCT,
    batch_size=128,
)

100%|██████████| 315/315 [00:00<00:00, 4561.99it/s]


Dataset({
    features: ['passage', 'id', 'embedding'],
    num_rows: 40221
})

# Precompute vals

In [7]:
doc_id_to_text = doc_ds.select_columns(['id', 'passage']).to_pandas().set_index('id')['passage'].to_dict()
index_to_doc_id = np.array(doc_ds['id'])
queries = np.array(query_ds['question'])

qrels = [np.array(eval(gold)) for gold in query_ds['relevant_passage_ids']]
qrels_counts = [len(s) for s in qrels]

# Search

In [8]:
k = 5
res = doc_ds.get_index('embedding').search_batch(np.array(query_ds['embedding']), k=k)
retrieved_ids = index_to_doc_id[res.total_indices]
retrieved_ids

array([[23001136,  1785632,  9727738, 17965226, 22584707],
       [23382875, 27426127, 22829865, 34667080, 23637683],
       [15094122,  3320045, 11076767, 34489718, 12666201],
       ...,
       [32529410, 28624872, 18637493, 16394582, 20007317],
       [28614408, 18824533, 34169075,  7676521, 18654798],
       [30580288, 24310308,  8275569, 29383495, 12497758]],
      shape=(4719, 5))

# Metrics

In [9]:
metrics = get_metrics(retrieved_ids, query_ds, k)

print(embedder.model_name, faiss_metric)
for k, v in metrics.items():
    print(f"{k:6s}", f"{v:.3f}")

all-MiniLM-L6-v2 IP
P@5   0.380
R@5   0.293
MRR@5 0.626
nDCG@5 0.489
