In [3]:
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_text_splitters import RecursiveCharacterTextSplitter

from nlp_chat_bot.doc_loader.test_data_csv_loader import TestDataCSVLoader
from nlp_chat_bot.model.minilm import MiniLM
from nlp_chat_bot.rag.classic_rag import ClassicRAG

In [5]:
from dotenv import load_dotenv
load_dotenv()

True

In [1]:
from datasets import load_dataset

ds_qa = load_dataset("enelpol/rag-mini-bioasq", "question-answer-passages")["test"]
ds_corpus = load_dataset("enelpol/rag-mini-bioasq", "text-corpus")["test"]

# only keep ds_qa rows with 3 items or more because we use 3 by default in our case
ds_qa = ds_qa.filter(lambda x: len(x["relevant_passage_ids"]) >= 3)

ds_qa.to_csv("../test_datasets/rag-mini-bioasq/qa/qa.csv")
ds_corpus.to_csv("../test_datasets/rag-mini-bioasq/corpus/corpus.csv")

  from .autonotebook import tqdm as notebook_tqdm
Creating CSV from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 52.31ba/s]
Creating CSV from Arrow format: 100%|██████████| 41/41 [00:00<00:00, 49.62ba/s]


60209989

In [4]:
len(ds_qa)

497

In [5]:
print(ds_qa.to_pandas().head(1))
print(ds_corpus.to_pandas().head(1))

                                          question  \
0  Describe the mechanism of action of ibalizumab.   

                                              answer    id  \
0  Ibalizumab is a humanized monoclonal antibody ...  2835   

                                relevant_passage_ids  
0  [29675744, 24853313, 29689540, 21289125, 20698...  
                                             passage    id
0  New data on viruses isolated from patients wit...  9797


In [10]:
from nlp_chat_bot.model.late_chunking_embedding import LateChunkingEmbedding

corpus_path = "../test_datasets/rag-mini-bioasq/corpus"
vector_store_path = "../test_chromadb"
model_download_path = "../models"

test_params = {
    "splitter": {
        "class": RecursiveCharacterTextSplitter,
        "params": {
            "chunk_size": 1000,
            "chunk_overlap": 50,
            "add_start_index": True,
        }
    },
    "embedding_function": {
        "class": LateChunkingEmbedding
    },
    "llm": {
        "class": ChatGoogleGenerativeAI,
        "params": {
            "model": "gemini-1.5-flash"
        }
    },
    "rag": {
        "class": ClassicRAG
    }
}
splitter = None
# splitter = test_params["splitter"]["class"](
#     chunk_size=test_params["splitter"]["params"]["chunk_size"],  # chunk size (characters)
#     chunk_overlap=test_params["splitter"]["params"]["chunk_overlap"],  # chunk overlap (characters)
#     add_start_index=test_params["splitter"]["params"]["add_start_index"],  # track index in original document
# )

embedding_function = test_params["embedding_function"]["class"](model_download_path=model_download_path)
llm_gemini = model_name=test_params["llm"]["class"](model=test_params["llm"]["params"]["model"])
rag = test_params["rag"]["class"](corpus_path, embedding_function, vector_store_path, splitter, llm=llm_gemini, late_chunking=True, document_loader=TestDataCSVLoader())

docs_retrieved = rag.retrieve(state = {"question": "What is my conclusion in my project report on image inpainting?", "context": []})

for i in range(len(docs_retrieved["context"])):
    doc = docs_retrieved["context"][i]
    print("\n\n", "#"*30,"\n")
    print(f"doc {i}: (score: {doc.metadata['score']})")
    print(doc.page_content)

100%|██████████| 1/1 [00:01<00:00,  1.05s/it]

KeyboardInterrupt

