In [1]:
%load_ext autoreload
%autoreload 2

import faiss
import pandas as pd
from datasets import load_dataset

from rag.embeddings import create_embedder

ds = load_dataset("rag-datasets/rag-mini-bioasq", "text-corpus")['passages']
ds

  from .autonotebook import tqdm as notebook_tqdm


Dataset({
    features: ['passage', 'id'],
    num_rows: 40221
})

In [27]:
from rag.utils import chunked

for batch in chunked([1, 2, 3, 4], size=2):
    print(batch)

[1, 2]
[3, 4]


In [2]:
print(ds[0]['passage'])

New data on viruses isolated from patients with subacute thyroiditis de Quervain 
are reported. Characteristic morphological, cytological, some physico-chemical 
and biological features of the isolated viruses are described. A possible role 
of these viruses in human and animal health disorders is discussed. The isolated 
viruses remain unclassified so far.


In [3]:
test_ds = load_dataset("rag-datasets/rag-mini-bioasq", "question-answer-passages")['test']
test_ds

Dataset({
    features: ['question', 'answer', 'relevant_passage_ids', 'id'],
    num_rows: 4719
})

In [4]:
test_ds[0]

{'question': 'Is Hirschsprung disease a mendelian or a multifactorial disorder?',
 'answer': "Coding sequence mutations in RET, GDNF, EDNRB, EDN3, and SOX10 are involved in the development of Hirschsprung disease. The majority of these genes was shown to be related to Mendelian syndromic forms of Hirschsprung's disease, whereas the non-Mendelian inheritance of sporadic non-syndromic Hirschsprung disease proved to be complex; involvement of multiple loci was demonstrated in a multiplicative model.",
 'relevant_passage_ids': '[20598273, 6650562, 15829955, 15617541, 23001136, 8896569, 21995290, 12239580, 15858239]',
 'id': 0}

In [5]:
pd.DataFrame.from_dict(ds[:5])

Unnamed: 0,passage,id
0,New data on viruses isolated from patients wit...,9797
1,We describe an improved method for detecting d...,11906
2,We have studied the effects of curare on respo...,16083
3,Kinetic and electrophoretic properties of 230-...,23188
4,Male Wistar specific-pathogen-free rats aged 2...,23469


In [9]:
from rag.embeddings import create_embedder
from rag.config import settings

embedder = create_embedder(settings)
embedder

LocalEmbedder("all-MiniLM-L6-v2", dim=384)

In [10]:
def embed(batch, column):
    embs = embedder.embed_batch(batch[column])
    return {'embedding': embs}

In [11]:
ds = ds.map(embed, batched=True, fn_kwargs={'column': 'passage'}, load_from_cache_file=True)
ds

Map: 100%|██████████| 40221/40221 [03:22<00:00, 198.62 examples/s]


Dataset({
    features: ['passage', 'id', 'embedding'],
    num_rows: 40221
})

In [24]:
ds = ds.add_faiss_index(
    column='embedding',
    string_factory='Flat',
    metric_type=faiss.METRIC_INNER_PRODUCT,
)
ds

100%|██████████| 41/41 [00:00<00:00, 834.52it/s]


Dataset({
    features: ['passage', 'id', 'embedding'],
    num_rows: 40221
})

# Embed questions

In [20]:
test_ds = test_ds.map(embed, batched=True, fn_kwargs={'column': 'question'}, load_from_cache_file=True)
test_ds

Map: 100%|██████████| 4719/4719 [00:03<00:00, 1563.26 examples/s]


Dataset({
    features: ['question', 'answer', 'relevant_passage_ids', 'id', 'embedding'],
    num_rows: 4719
})

In [11]:
test_ds['embedding'][0]

[0.07215375453233719,
 -0.016066282987594604,
 -0.0032892690505832434,
 0.04249546304345131,
 -0.07963531464338303,
 0.00860125944018364,
 -0.03193581849336624,
 0.008533375337719917,
 -0.03716431185603142,
 0.007979831658303738,
 -0.03151911124587059,
 0.022167010232806206,
 -0.09924126416444778,
 0.04778265580534935,
 -0.11697079241275787,
 -0.02087186463177204,
 -0.12253177165985107,
 -0.021098876371979713,
 -0.015072673559188843,
 0.042420029640197754,
 0.0001467171823605895,
 -0.03716767206788063,
 0.0016443022759631276,
 -0.02426200732588768,
 -0.047754351049661636,
 -0.07843972742557526,
 -0.029823824763298035,
 -0.014371898025274277,
 -0.007594732567667961,
 0.030507974326610565,
 -0.0160763431340456,
 -4.933291347697377e-05,
 -0.0009512815740890801,
 0.04093099758028984,
 0.029748782515525818,
 0.04580683633685112,
 0.049364253878593445,
 -0.05096195265650749,
 -0.08573810011148453,
 0.07842497527599335,
 0.010967549867928028,
 -0.02801661565899849,
 -0.0022133474703878164,
 0

In [12]:
import numpy as np

queries = np.asarray(test_ds["embedding"], dtype="float32")
queries.shape

(4719, 384)

In [13]:
i = 2

true_ids = eval(test_ds[i]['relevant_passage_ids'])
len(true_ids)

10

In [14]:
pred_ids = ds.get_nearest_examples("embedding", np.array(test_ds['embedding'][i]), k=5).examples['id']
len(pred_ids)

5

In [15]:
recall_at_top_k = len(set(pred_ids) & set(true_ids)) / len(pred_ids)
recall_at_top_k

0.8

In [16]:
from tqdm import tqdm, trange

n = len(test_ds)
true_ids_len = []
recalls = []
precisions = []
k = 1

for i in trange(n):
    true_ids = eval(test_ds[i]['relevant_passage_ids'])
    pred_ids = ds.get_nearest_examples("embedding", np.array(test_ds['embedding'][i]), k=30).examples['id'][:k]
    recalls.append(len(set(pred_ids) & set(true_ids)) / len(true_ids))
    precisions.append(len(set(pred_ids) & set(true_ids)) / len(pred_ids))

    true_ids_len.append(len(true_ids))


print(f"Recall@{k}: {np.array(recalls).mean():.3f}")
print(f"Precision@{k}: {np.array(precisions).mean():.3f}")
# print(f"True average length: {np.mean(true_ids_len):.3f}")

100%|██████████| 4719/4719 [00:16<00:00, 281.66it/s]

Recall@1: 0.132
Precision@1: 0.560





In [18]:
import numpy as np
import faiss

# index = faiss.IndexFlatIP(embedder.dimension)
index = faiss.IndexFlatL2(embedder.dimension)
index.add(np.array(ds['embedding']))

In [14]:
index

<faiss.swigfaiss_avx512.IndexFlatIP; proxy of <Swig Object of type 'faiss::IndexFlatIP *' at 0x77ac71a8d230> >

In [15]:
index.ntotal

40221

In [21]:
distances, indices = index.search(np.array(test_ds['embedding']), k=5)

In [25]:
distances.dtype

dtype('float32')

In [24]:
indices

array([[22260,   435,  3307, 11166, 21024],
       [23647, 32793, 21714, 40093, 24595],
       [ 7017,  1072,  4422, 39902,  6083],
       ...,
       [38617, 34598, 12269,  8809, 14798],
       [34585, 12552, 39704,  1676, 12296],
       [37276, 26956,  2099, 35832,  5886]], shape=(4719, 5))