In [1]:
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma  import Chroma
from nlp_chat_bot.model.minilm import MiniLM
from nlp_chat_bot.rag import RAG

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset_path = "../data"
model_download_path = "../models"

splitter = RecursiveCharacterTextSplitter(
    chunk_size=100,  # chunk size (characters)
    chunk_overlap=10,  # chunk overlap (characters)
    add_start_index=True,  # track index in original document
)

minilm = MiniLM(model_download_path=model_download_path)
vector_store = Chroma(embedding_function=minilm)
rag = RAG(dataset_path, vector_store, splitter)
print("LENGTH", len(vector_store.get()['documents']))
docs_retrieved = rag.retrieve(state = {"question": "What is the acronym AIA?", "context": []})

print("Num docs:", len(docs_retrieved["context"]))

for i in range(len(docs_retrieved["context"])):
    doc = docs_retrieved["context"][i]
    print("\n\n", "#"*30,"\n")
    print(f"doc {i}: (score: {doc.metadata['score']})")
    print(doc.page_content)
print(docs_retrieved["context"][0].page_content)    



0it [00:00, ?it/s]
0it [00:00, ?it/s]
100%|██████████| 2/2 [00:03<00:00,  1.72s/it]
0it [00:00, ?it/s]


LENGTH 847
Num docs: 3


 ############################## 

doc 0: (score: 0.9017001390457153)
AAAI – (i) Association for the Advancement of Artificial Intelligence, formerly American


 ############################## 

doc 1: (score: 0.9144654273986816)
AP

AQ

AR

AS

AT

AU

AV

AW

AX

AY

AZ

List of acronyms: A

1 language


 ############################## 

doc 2: (score: 0.9341943264007568)
AIAA – (i) American Institute of Aeronautics and Astronautics

AIC

(i) African Independent Church
AAAI – (i) Association for the Advancement of Artificial Intelligence, formerly American
