In [1]:
!pip install -qU langchain-community langchain-openai rank_bm25 tiktoken faiss-gpu datasets sentence-transformers

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.5 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━[0m [32m2.1/2.5 MB[0m [31m63.5 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m45.8 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/54.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.2/54.2 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m66.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m32.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain.schema import Document

from langchain.vectorstores import Chroma, FAISS
from datasets import load_dataset, tqdm

In [3]:
from sentence_transformers import SentenceTransformer
from tqdm import tqdm


class MiniLM:
    def __init__(self, model_download_path):
        self.model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=model_download_path)

    def embed_documents(self, docs, verbose=False):
        output = []

        if verbose:
            docs = tqdm(docs)

        for d in docs:
            output.append(self.model.encode(d).tolist())
        return output

    def embed_query(self, query):
        return self.model.encode(query) #self.model.encode(query).tolist()

    def get_id(self):
        return "minilm"

    def __call__(self, query):
        return self.embed_query(query)

In [4]:
ds_qa = load_dataset("enelpol/rag-mini-bioasq", "question-answer-passages")["test"]
ds_corpus = load_dataset("enelpol/rag-mini-bioasq", "text-corpus")["test"]

# only keep ds_qa rows with 3 items or more because we use 3 by default in our case
ds_qa = ds_qa.filter(lambda x: len(x["relevant_passage_ids"]) >= 3)

ds_qa.to_csv("../test_datasets/rag-mini-bioasq/qa/qa.csv")
ds_corpus.to_csv("../test_datasets/rag-mini-bioasq/corpus/corpus.csv")

README.md:   0%|          | 0.00/1.76k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/1.12M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/187k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/4012 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/707 [00:00<?, ? examples/s]

test-00000-of-00001.parquet:   0%|          | 0.00/35.3M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/40181 [00:00<?, ? examples/s]

Filter:   0%|          | 0/707 [00:00<?, ? examples/s]

Creating CSV from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/41 [00:00<?, ?ba/s]

60169807

In [5]:
print("Datasets sizes:")
print(f"ds_qa: {len(ds_qa)}")
print(f"ds_corpus: {len(ds_corpus)}")

Datasets sizes:
ds_qa: 497
ds_corpus: 40181


In [6]:
print(ds_qa.to_pandas().head(1))
print(ds_corpus.to_pandas().head(1))

                                          question  \
0  Describe the mechanism of action of ibalizumab.   

                                              answer    id  \
0  Ibalizumab is a humanized monoclonal antibody ...  2835   

                                relevant_passage_ids  
0  [29675744, 24853313, 29689540, 21289125, 20698...  
                                             passage    id
0  New data on viruses isolated from patients wit...  9797


In [7]:
doc_list, metadata = [], []
for index, row in ds_corpus.to_pandas().iterrows():
    doc_list.append(f"passage : {row['passage']}")
    metadata.append({"id": row["id"]})

In [8]:
# intialize the bm25 retriever and faiss retriever
bm25_retriever = BM25Retriever.from_texts(texts=doc_list,metadatas=metadata)
bm25_retriever.k = 3

In [9]:
bm25_retriever.invoke("Titanic")

[Document(metadata={'id': 9987477}, page_content='passage : The Titanic has become a metaphor for the disastrous consequences of an \nunqualified belief in the safety and invincibility of new technology. Similarly, \nthe thalidomide tragedy stands for all of the "monsters" that can be \ninadvertently or negligently created by modern medicine. Thalidomide, once \nbanned, has returned to the center of controversy with the Food and Drug \nAdministration\'s (FDA\'s) announcement that thalidomide will be placed on the \nmarket for the treatment of erythema nodosum leprosum, a severe dermatological \ncomplication of Hansen\'s disease. Although this indication is very restricted, \nthalidomide will be available for off-label uses once it is on the market. New \nlaws regarding abortion and a new technology, ultrasound, make reasonable the \napproval of thalidomide for patients who suffer from serious conditions it can \nalleviate. In addition, the FDA and the manufacturer have proposed the mos

In [10]:
bm25_retriever.dict

### Embeddings - Dense retrievers FAISS

In [11]:
faiss_vectorstore = FAISS.from_texts(doc_list, MiniLM("."), metadata)
faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": 3})

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]



In [12]:
faiss_retriever.invoke("Dragon")

[Document(id='adf76445-46cd-49f9-9af0-b2051a469535', metadata={'id': 25336862}, page_content='passage : Idelalisib (Zydelig) for certain types of leukemia and lymphoma, peginterferon \nbeta-1a (Plegridy) for relapsing forms of multiple sclerosis, and suvorexant \n(Belsomra) for insomnia.'),
 Document(id='a58abbc5-21bc-4ec7-8b3c-94c0c30f06a6', metadata={'id': 23056472}, page_content="passage : The C. elegans nervous system is particularly well suited for optogenetic \nanalyses of circuit function: Essentially all connections have been mapped, and \nlight can be directed at the neuron of interest in the freely moving, \ntransparent animals, while behavior is observed. Thus, different nodes of a \nneuronal network can be probed for their role in controlling a particular \nbehavior, using different optogenetic tools for photo-activation or -inhibition, \nwhich respond to different colors of light. As neurons may act in concert or in \nopposing ways to affect a behavior, one would further l

In [13]:
# initialize the ensemble retriever
ensemble_retriever = EnsembleRetriever(
    retrievers=[bm25_retriever, faiss_retriever],
    weights=[0.5, 0.5])

In [14]:
docs = ensemble_retriever.invoke("Titanic")
docs

[Document(metadata={'id': 9987477}, page_content='passage : The Titanic has become a metaphor for the disastrous consequences of an \nunqualified belief in the safety and invincibility of new technology. Similarly, \nthe thalidomide tragedy stands for all of the "monsters" that can be \ninadvertently or negligently created by modern medicine. Thalidomide, once \nbanned, has returned to the center of controversy with the Food and Drug \nAdministration\'s (FDA\'s) announcement that thalidomide will be placed on the \nmarket for the treatment of erythema nodosum leprosum, a severe dermatological \ncomplication of Hansen\'s disease. Although this indication is very restricted, \nthalidomide will be available for off-label uses once it is on the market. New \nlaws regarding abortion and a new technology, ultrasound, make reasonable the \napproval of thalidomide for patients who suffer from serious conditions it can \nalleviate. In addition, the FDA and the manufacturer have proposed the mos

In [15]:
def test_retrieval(ds_qa, retrieve_function):
    total_num_documents_considered = 0
    num_valid_docs = 0
    for test_item in tqdm(ds_qa):
        question = test_item["question"]
        expected_documents_ids = test_item["relevant_passage_ids"]

        response = retrieve_function(state = {"question": question, "context": []})
        docs_retrieved = response["context"]

        # print("Question:",question)

        # if it's a dict of docs (e.g. with QueryTranslationRAGDecomposition)
        if isinstance(docs_retrieved, dict):
            num_docs_retrieved = 0
            for question, docs in docs_retrieved.items():
                num_docs_retrieved += len(docs)
                for doc in docs:
                    if doc.metadata["id"] in expected_documents_ids:
                        num_valid_docs += 1
            total_num_documents_considered += min(len(expected_documents_ids), num_docs_retrieved)
        else:
            num_documents_considered = min(len(expected_documents_ids), len(docs_retrieved))
            total_num_documents_considered += num_documents_considered
            # print("Expected:",expected_documents_ids,"Got:",[doc.metadata["id"] for doc in docs_retrieved])
            # print("Expected:",expected_documents_ids)
            for doc in docs_retrieved:
                # print("Got:",doc.metadata["id"])
                if int(doc.metadata["id"]) in expected_documents_ids:
                    num_valid_docs += 1


    return num_valid_docs / total_num_documents_considered

In [23]:
def retrieve(state):
    state["context"] = ensemble_retriever.invoke(state["question"])
    return state

In [24]:
score = test_retrieval(ds_qa, retrieve)

100%|██████████| 497/497 [01:26<00:00,  5.73it/s]


In [25]:
print(f"RAG score: {score}")

RAG score: 0.553484602917342
