In [1]:
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_text_splitters import RecursiveCharacterTextSplitter

from nlp_chat_bot.doc_loader.test_data_csv_loader import TestDataCSVLoader
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_text_splitters import RecursiveCharacterTextSplitter

from nlp_chat_bot.model.embedding.minilm import MiniLM
from nlp_chat_bot.model.llm.gemma import Gemma
from nlp_chat_bot.rag.classic_rag import ClassicRAG
from nlp_chat_bot.rag.query_translation_rag_decomposition import QueryTranslationRAGDecomposition
from nlp_chat_bot.rag.query_translation_rag_fusion import QueryTranslationRAGFusion
from nlp_chat_bot.vector_store.late_chunking_chroma_vector_store_builder import LateChunkingChromaVectorStoreBuilder
from nlp_chat_bot.vector_store.naive_chunking_chroma_vector_store_builder import NaiveChunkingChromaVectorStoreBuilder
from datasets import load_dataset, tqdm
from dotenv import load_dotenv
load_dotenv()

  from .autonotebook import tqdm as notebook_tqdm


True

In [2]:
ds_qa = load_dataset("enelpol/rag-mini-bioasq", "question-answer-passages")["test"]
ds_corpus = load_dataset("enelpol/rag-mini-bioasq", "text-corpus")["test"]

# only keep ds_qa rows with 3 items or more because we use 3 by default in our case
ds_qa = ds_qa.filter(lambda x: len(x["relevant_passage_ids"]) >= 3)

ds_qa.to_csv("../test_datasets/rag-mini-bioasq/qa/qa.csv")
ds_corpus.to_csv("../test_datasets/rag-mini-bioasq/corpus/corpus.csv")

Creating CSV from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 31.17ba/s]
Creating CSV from Arrow format: 100%|██████████| 41/41 [00:00<00:00, 52.05ba/s]


60209989

In [3]:
print("Datasets sizes:")
print(f"ds_qa: {len(ds_qa)}")
print(f"ds_corpus: {len(ds_corpus)}")

Datasets sizes:
ds_qa: 497
ds_corpus: 40181


In [4]:
print(ds_qa.to_pandas().head(1))
print(ds_corpus.to_pandas().head(1))

                                          question  \
0  Describe the mechanism of action of ibalizumab.   

                                              answer    id  \
0  Ibalizumab is a humanized monoclonal antibody ...  2835   

                                relevant_passage_ids  
0  [29675744, 24853313, 29689540, 21289125, 20698...  
                                             passage    id
0  New data on viruses isolated from patients wit...  9797


In [47]:
from nlp_chat_bot.model.embedding.late_chunking_embedding import LateChunkingEmbedding


corpus_path = "../test_datasets/rag-mini-bioasq/corpus"
vector_store_path = "../test_chromadb"
model_download_path = "../models"
reload_vector_store = True # Add non existing documents
reset_vector_store = False # Remove previous documents

test_params = {
    "splitter": {
        "class": RecursiveCharacterTextSplitter,
        "params": {
            "chunk_size": 1000,
            "chunk_overlap": 0,
            "add_start_index": True,
        }
    },
    "embedding_function": {
        "class": MiniLM
    },
    "llm": {
        "class": ChatGoogleGenerativeAI,
        "params": {
            "model": "gemini-1.5-flash"
        }
        # "class": Gemma,
        # "params": {
        #     "model_download_path": model_download_path
        # }
    },
    "rag": {
        "class": ClassicRAG
    },
    "vector_store_builder": {
        "class": NaiveChunkingChromaVectorStoreBuilder
    }
}
# splitter = None
splitter = test_params["splitter"]["class"](
    chunk_size=test_params["splitter"]["params"]["chunk_size"],  # chunk size (characters)
    chunk_overlap=test_params["splitter"]["params"]["chunk_overlap"],  # chunk overlap (characters)
    add_start_index=test_params["splitter"]["params"]["add_start_index"],  # track index in original document
)

embedding_function = test_params["embedding_function"]["class"](model_download_path=model_download_path)
# vector_store = test_params["vector_store_builder"]["class"](corpus_path, embedding_function, vector_store_path, splitter, document_loader=TestDataCSVLoader()).build(reload_vector_store, reset_vector_store)
vector_store = test_params["vector_store_builder"]["class"](corpus_path, embedding_function, vector_store_path, splitter, document_loader=TestDataCSVLoader()).build(False, False)
llm = model_name=test_params["llm"]["class"](**(test_params["llm"]["params"]))
rag = test_params["rag"]["class"](vector_store, llm=llm)



In [48]:
# test if our Llama installation supports GPU

# import os
# from llama_cpp.llama_cpp import load_shared_library
# import llama_cpp
# 
# llama_root_path_module = os.path.dirname(llama_cpp.__file__)
# import pathlib
# 
# def is_gpu_available_v3() -> bool:
# 
#     lib = load_shared_library('llama',pathlib.Path(llama_root_path_module+'/lib'))
#     return bool(lib.llama_supports_gpu_offload())
# 
# print(is_gpu_available_v3())

In [49]:
def test_retrieval(ds_qa, retrieve_function):
    total_num_documents_considered = 0
    num_valid_docs = 0
    for test_item in tqdm(ds_qa):
        question = test_item["question"]
        expected_documents_ids = test_item["relevant_passage_ids"]
        
        response = retrieve_function(state = {"question": question, "context": []})
        docs_retrieved = response["context"]
        num_documents_considered = min(len(expected_documents_ids), len(docs_retrieved))
        
        
        # print("Question:",question)
        
        # if it's a dict of docs (e.g. with QueryTranslationRAGDecomposition)
        if isinstance(docs_retrieved, dict):
            for question, docs in docs_retrieved.items():
                total_num_documents_considered += len(docs)
                for doc in docs:
                    if doc.metadata["id"] in expected_documents_ids:
                        num_valid_docs += 1
        else:
            total_num_documents_considered += num_documents_considered
            # print("Expected:",expected_documents_ids,"Got:",[doc.metadata["id"] for doc in docs_retrieved])
            # print("Expected:",expected_documents_ids)
            for doc in docs_retrieved:
                # print("Got:",doc.metadata["id"])
                if int(doc.metadata["id"]) in expected_documents_ids:
                    num_valid_docs += 1
        
        
    return num_valid_docs / total_num_documents_considered

In [50]:
score = test_retrieval(ds_qa, rag.retrieve)

100%|██████████| 497/497 [00:03<00:00, 161.19it/s]


In [51]:
print(f"RAG score: {score}")

RAG score: 0.5372233400402414
