In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append('../../')

In [3]:
from src.indexing import get_multivector_retriever, get_parent_child_splits
from src.generation import QA_SYSTEM_PROMPT, QA_PROMPT, LLAMA_PROMPT_TEMPLATE, MIXTRAL_PROMPT_TEMPLATE
from src.generation import get_model, format_docs, get_rag_chain
from langchain_core.documents import Document

from src.ingestion import load_pdf

import os
import chromadb
import uuid
import pickle

In [4]:
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

from tqdm import tqdm

In [15]:
DATA_PATH = 'D:\Ahmed\saudi-rag-project\storage'
RAW_DOCS_PATH = os.path.join('..\..\data', "raw")
CHROMA_PATH = os.path.join(DATA_PATH, "chroma")
INTERIM_DATA_PATH = os.path.join('..\..\data', "interim")

EMBEDDING_MODEL_NAMES = [
    "intfloat/multilingual-e5-small", 
    "intfloat/multilingual-e5-base", 
    "text-embedding-3-small", 
    "text-embedding-3-large",
    "text-embedding-ada-002"
 ]
MODEL_NAMES = ["meta-llama/Llama-3-8b-chat-hf", "meta-llama/Llama-3-70b-chat-hf", "mistralai/Mixtral-8x22B-Instruct-v0.1"]

  DATA_PATH = 'D:\Ahmed\saudi-rag-project\storage'
  RAW_DOCS_PATH = os.path.join('..\..\data', "raw")
  INTERIM_DATA_PATH = os.path.join('..\..\data', "interim")


In [16]:
docs = [load_pdf(os.path.join(RAW_DOCS_PATH, f)) for f in os.listdir(RAW_DOCS_PATH) if ".pdf" in f]

In [17]:
# print(docs[1].page_content)

In [18]:
persistent_client = chromadb.PersistentClient(path=CHROMA_PATH)

In [19]:
parent_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1200,
    chunk_overlap=400,
    separators=['\n\n\n', '\n\n', '\n', r'\.\s+', ' ', '']
)

parent_docs = parent_splitter.split_documents(docs)
# parent_docs_ids = [str(uuid.uuid4()) for _ in parent_docs]
# pickle.dump(parent_docs_ids, open(os.path.join(INTERIM_DATA_PATH, "parent_docs_ids"), 'wb'))
parent_docs_ids = pickle.load(open(os.path.join(INTERIM_DATA_PATH, "parent_docs_ids"), 'rb'))
id_key = "parent_doc_id"

In [20]:
COLLECTIONS = [
    "PQ_SPLIT_ALL_text_embedding_3_small",
    "PS_ALL_text_embedding_3_small"
]

In [21]:
all_questions_docs = []
all_split_questions_docs = []
all_summaries_docs = []

for model_name in tqdm(MODEL_NAMES):

    question_file = f"PQ_COMB_{model_name.split('/')[-1].replace('-', '_').replace('/', '_')}"
    question_file_path = os.path.join(INTERIM_DATA_PATH, question_file)

    split_question_file = f"PQ_SPLIT_{model_name.split('/')[-1].replace('-', '_').replace('/', '_')}"
    split_question_file_path = os.path.join(INTERIM_DATA_PATH, split_question_file)

    summary_file = f"PS_{model_name.split('/')[-1].replace('-', '_').replace('/', '_')}"
    summary_file_path = os.path.join(INTERIM_DATA_PATH, summary_file)

    question_docs = pickle.load(open(question_file_path, 'rb'))
    split_question_docs = pickle.load(open(split_question_file_path, 'rb'))
    summary_docs = pickle.load(open(summary_file_path, 'rb'))

    all_questions_docs.extend(question_docs)
    all_split_questions_docs.extend(split_question_docs)
    all_summaries_docs.extend(summary_docs)

    # for embedding_model_name in EMBEDDING_MODEL_NAMES:

    #     if "mistral" in model_name:
    #         model_name = model_name.split('/')[-1].replace('-', '_').replace('/', '_')[:13]
    #     else:
    #         model_name = model_name.split('/')[-1].replace('-', '_').replace('/', '_')

    #     questions_collection_name = f"PQ_COMB_{model_name}_{embedding_model_name.replace("-", "_").replace("/", "_")}"
    #     split_questions_collection_name = f"PQ_SPLIT_{model_name}_{embedding_model_name.replace("-", "_").replace("/", "_")}"
    #     summary_collection_name = f"PS_{model_name}_{embedding_model_name.replace("-", "_").replace("/", "_")}"

    #     if questions_collection_name not in persistent_client.list_collections():
    #         print("Creating", questions_collection_name)
    #         _ = get_multivector_retriever(persistent_client, embedding_model_name, questions_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=question_docs, id_key="parent_doc_id")
        
    #     if split_questions_collection_name not in persistent_client.list_collections():
    #         print("Creating", split_questions_collection_name)
    #         _ = get_multivector_retriever(persistent_client, embedding_model_name, split_questions_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=split_question_docs, id_key="parent_doc_id")
        
    #     if summary_collection_name not in persistent_client.list_collections():
    #         print("Creating", summary_collection_name)
    #         _ = get_multivector_retriever(persistent_client, embedding_model_name, summary_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=summary_docs, id_key="parent_doc_id")

100%|██████████| 3/3 [00:00<00:00, 42.66it/s]


In [22]:
for embedding_model_name in tqdm(EMBEDDING_MODEL_NAMES):

    all_questions_collection_name = f"PQ_COMB_ALL_{embedding_model_name.replace("-", "_").replace("/", "_")}"
    all_split_questions_collection_name = f"PQ_SPLIT_ALL_{embedding_model_name.replace("-", "_").replace("/", "_")}"
    all_summary_collection_name = f"PS_ALL_{embedding_model_name.replace("-", "_").replace("/", "_")}"
    all_collection_name = f"PQS_ALL_{embedding_model_name.replace("-", "_").replace("/", "_")}"

    if all_questions_collection_name in persistent_client.list_collections():
        persistent_client.delete_collection(all_questions_collection_name)

    if all_split_questions_collection_name in persistent_client.list_collections():
        persistent_client.delete_collection(all_split_questions_collection_name)

    if all_summary_collection_name in persistent_client.list_collections():
        persistent_client.delete_collection(all_summary_collection_name)

    if all_collection_name in persistent_client.list_collections():
        persistent_client.delete_collection(all_collection_name)

    all_generated_docs = all_questions_docs + all_split_questions_docs + all_summaries_docs

    if all_questions_collection_name in COLLECTIONS:
        _ = get_multivector_retriever(persistent_client, embedding_model_name, all_questions_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=all_questions_docs, id_key="parent_doc_id")

    if all_split_questions_collection_name in COLLECTIONS: 
        _ = get_multivector_retriever(persistent_client, embedding_model_name, all_split_questions_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=all_split_questions_docs, id_key="parent_doc_id")
    
    if all_summary_collection_name in COLLECTIONS:
        _ = get_multivector_retriever(persistent_client, embedding_model_name, all_summary_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=all_summaries_docs, id_key="parent_doc_id")
    
    if all_collection_name in COLLECTIONS:
        _ = get_multivector_retriever(persistent_client, embedding_model_name, all_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=all_generated_docs, id_key="parent_doc_id")

  warn_deprecated(
100%|██████████| 5/5 [00:51<00:00, 10.32s/it]


In [None]:
# parent_docs_ids