In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append('../../')

In [9]:
from src.indexing import get_multivector_retriever, get_parent_child_splits
from src.generation import QA_SYSTEM_PROMPT, QA_PROMPT, LLAMA_PROMPT_TEMPLATE, MIXTRAL_PROMPT_TEMPLATE
from src.generation import get_model, format_docs, get_rag_chain
from langchain_core.documents import Document

from src.ingestion import load_pdf

import os
import chromadb
import uuid
import pickle

from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

from tqdm import tqdm

DATA_PATH = 'D:\Ahmed\saudi-rag-project\storage'
RAW_DOCS_PATH = os.path.join('../../data', "raw")
CHROMA_PATH = os.path.join(DATA_PATH, "chroma")
INTERIM_DATA_PATH = os.path.join('../../data', "interim")

EMBEDDING_MODEL_NAMES = [
    # "intfloat/multilingual-e5-small", 
    # "intfloat/multilingual-e5-base", 
    # "text-embedding-3-small", 
    "text-embedding-3-large",
    # "text-embedding-ada-002"
 ]
 
MODEL_NAMES = ["meta-llama/Llama-3-8b-chat-hf", "meta-llama/Llama-3-70b-chat-hf", "mistralai/Mixtral-8x22B-Instruct-v0.1"]

COLLECTIONS = ["PQ_COMB_Llama_3_70b_chat_hf_text_embedding_3_large", "PQ_COMB_S_Llama_3_70b_chat_hf_text_embedding_3_large"]

docs = [load_pdf(os.path.join(RAW_DOCS_PATH, f)) for f in os.listdir(RAW_DOCS_PATH) if ".pdf" in f]
persistent_client = chromadb.PersistentClient(path=CHROMA_PATH)

parent_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1200,
    chunk_overlap=400,
    separators=['\n\n\n', '\n\n', '\n', r'\.\s+', ' ', '']
)

parent_docs = parent_splitter.split_documents(docs)
# parent_docs_ids = [str(uuid.uuid4()) for _ in parent_docs]
# pickle.dump(parent_docs_ids, open(os.path.join(INTERIM_DATA_PATH, "parent_docs_ids"), 'wb'))
parent_docs_ids = pickle.load(open(os.path.join(INTERIM_DATA_PATH, "parent_docs_ids"), 'rb'))
id_key = "parent_doc_id"

all_questions_docs = []
all_split_questions_docs = []
all_summaries_docs = []

for model_name in tqdm(MODEL_NAMES):

    question_file = f"PQ_COMB_{model_name.split('/')[-1].replace('-', '_').replace('/', '_')}"
    question_file_path = os.path.join(INTERIM_DATA_PATH, question_file)

    split_question_file = f"PQ_SPLIT_{model_name.split('/')[-1].replace('-', '_').replace('/', '_')}"
    split_question_file_path = os.path.join(INTERIM_DATA_PATH, split_question_file)

    summary_file = f"PS_{model_name.split('/')[-1].replace('-', '_').replace('/', '_')}"
    summary_file_path = os.path.join(INTERIM_DATA_PATH, summary_file)

    question_docs = pickle.load(open(question_file_path, 'rb'))
    split_question_docs = pickle.load(open(split_question_file_path, 'rb'))
    summary_docs = pickle.load(open(summary_file_path, 'rb'))

    all_questions_docs.extend(question_docs)
    all_split_questions_docs.extend(split_question_docs)
    all_summaries_docs.extend(summary_docs)

    all_questions_summaries_docs = all_questions_docs + all_summaries_docs

    for embedding_model_name in EMBEDDING_MODEL_NAMES:

        if "mistral" in model_name:
            model_name = model_name.split('/')[-1].replace('-', '_').replace('/', '_')[:13]
        else:
            model_name = model_name.split('/')[-1].replace('-', '_').replace('/', '_')

        questions_collection_name = f"PQ_COMB_{model_name}_{embedding_model_name.replace("-", "_").replace("/", "_")}"
        split_questions_collection_name = f"PQ_SPLIT_{model_name}_{embedding_model_name.replace("-", "_").replace("/", "_")}"
        summary_collection_name = f"PS_{model_name}_{embedding_model_name.replace("-", "_").replace("/", "_")}"
        questions_summary_collection_name = f"PQ_COMB_S_{model_name}_{embedding_model_name.replace("-", "_").replace("/", "_")}"

        if questions_collection_name in COLLECTIONS:
            print("Creating", questions_collection_name)
            _ = get_multivector_retriever(persistent_client, embedding_model_name, questions_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=question_docs, id_key="parent_doc_id")
        
        if split_questions_collection_name in COLLECTIONS:
            print("Creating", split_questions_collection_name)
            _ = get_multivector_retriever(persistent_client, embedding_model_name, split_questions_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=split_question_docs, id_key="parent_doc_id")
        
        if summary_collection_name in COLLECTIONS:
            print("Creating", summary_collection_name)
            _ = get_multivector_retriever(persistent_client, embedding_model_name, summary_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=all_questions_summaries_docs, id_key="parent_doc_id")

        if questions_summary_collection_name in COLLECTIONS:
            print("Creating", summary_collection_name)
            _ = get_multivector_retriever(persistent_client, embedding_model_name, questions_summary_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=all_questions_summaries_docs, id_key="parent_doc_id")

    x = []

for embedding_model_name in tqdm(EMBEDDING_MODEL_NAMES):

    print(embedding_model_name)

    all_questions_collection_name = f"PQ_COMB_ALL_{embedding_model_name.replace("-", "_").replace("/", "_")}"
    all_split_questions_collection_name = f"PQ_SPLIT_ALL_{embedding_model_name.replace("-", "_").replace("/", "_")}"
    all_summary_collection_name = f"PS_ALL_{embedding_model_name.replace("-", "_").replace("/", "_")}"
    all_collection_name = f"PQS_ALL_{embedding_model_name.replace("-", "_").replace("/", "_")}"

    # if all_questions_collection_name in persistent_client.list_collections():
    #     persistent_client.delete_collection(all_questions_collection_name)

    # if all_split_questions_collection_name in persistent_client.list_collections():
    #     persistent_client.delete_collection(all_split_questions_collection_name)

    # if all_summary_collection_name in persistent_client.list_collections():
    #     persistent_client.delete_collection(all_summary_collection_name)

    # if all_collection_name in persistent_client.list_collections():
    #     persistent_client.delete_collection(all_collection_name)

    # all_generated_docs = all_questions_docs + all_split_questions_docs + all_summaries_docs

    # print(all_questions_collection_name)
    # r = get_multivector_retriever(persistent_client, embedding_model_name, all_questions_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=all_questions_docs, id_key="parent_doc_id")
    # x.append(r)

    # if all_split_questions_collection_name in COLLECTIONS: 
    # print(all_split_questions_collection_name)
    # r = get_multivector_retriever(persistent_client, embedding_model_name, all_split_questions_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=all_split_questions_docs, id_key="parent_doc_id")
    # x.append(r)
    # print(all_summary_collection_name)
    # r = get_multivector_retriever(persistent_client, embedding_model_name, all_summary_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=all_summaries_docs, id_key="parent_doc_id")
    # x.append(r)

    # if all_collection_name in COLLECTIONS:
    #     _ = get_multivector_retriever(persistent_client, embedding_model_name, all_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=all_generated_docs, id_key="parent_doc_id")

  DATA_PATH = 'D:\Ahmed\saudi-rag-project\storage'


  DATA_PATH = 'D:\Ahmed\saudi-rag-project\data\storage'
  RAW_DOCS_PATH = os.path.join('..\..\data', "raw")
  INTERIM_DATA_PATH = os.path.join('..\..\data', "interim")


In [24]:
# print(docs[1].page_content)

In [7]:
persistent_client = chromadb.PersistentClient(path=CHROMA_PATH)

In [8]:
persistent_client.list_collections()

[]

In [27]:
persistent_client.delete_collection('PS_ALL_text_embedding_3_small')
persistent_client.delete_collection('PQ_SPLIT_ALL_text_embedding_3_small')

In [10]:
parent_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1200,
    chunk_overlap=400,
    separators=['\n\n\n', '\n\n', '\n', r'\.\s+', ' ', '']
)

parent_docs = parent_splitter.split_documents(docs)
# parent_docs_ids = [str(uuid.uuid4()) for _ in parent_docs]
# pickle.dump(parent_docs_ids, open(os.path.join(INTERIM_DATA_PATH, "parent_docs_ids"), 'wb'))
parent_docs_ids = pickle.load(open(os.path.join(INTERIM_DATA_PATH, "parent_docs_ids"), 'rb'))
id_key = "parent_doc_id"

In [11]:
COLLECTIONS = [
    "PQ_SPLIT_ALL_text_embedding_3_small",
    "PS_ALL_text_embedding_3_small"
]

In [12]:
all_questions_docs = []
all_split_questions_docs = []
all_summaries_docs = []

for model_name in tqdm(MODEL_NAMES):

    question_file = f"PQ_COMB_{model_name.split('/')[-1].replace('-', '_').replace('/', '_')}"
    question_file_path = os.path.join(INTERIM_DATA_PATH, question_file)

    split_question_file = f"PQ_SPLIT_{model_name.split('/')[-1].replace('-', '_').replace('/', '_')}"
    split_question_file_path = os.path.join(INTERIM_DATA_PATH, split_question_file)

    summary_file = f"PS_{model_name.split('/')[-1].replace('-', '_').replace('/', '_')}"
    summary_file_path = os.path.join(INTERIM_DATA_PATH, summary_file)

    question_docs = pickle.load(open(question_file_path, 'rb'))
    split_question_docs = pickle.load(open(split_question_file_path, 'rb'))
    summary_docs = pickle.load(open(summary_file_path, 'rb'))

    all_questions_docs.extend(question_docs)
    all_split_questions_docs.extend(split_question_docs)
    all_summaries_docs.extend(summary_docs)

    # for embedding_model_name in EMBEDDING_MODEL_NAMES:

    #     if "mistral" in model_name:
    #         model_name = model_name.split('/')[-1].replace('-', '_').replace('/', '_')[:13]
    #     else:
    #         model_name = model_name.split('/')[-1].replace('-', '_').replace('/', '_')

    #     questions_collection_name = f"PQ_COMB_{model_name}_{embedding_model_name.replace("-", "_").replace("/", "_")}"
    #     split_questions_collection_name = f"PQ_SPLIT_{model_name}_{embedding_model_name.replace("-", "_").replace("/", "_")}"
    #     summary_collection_name = f"PS_{model_name}_{embedding_model_name.replace("-", "_").replace("/", "_")}"

    #     if questions_collection_name not in persistent_client.list_collections():
    #         print("Creating", questions_collection_name)
    #         _ = get_multivector_retriever(persistent_client, embedding_model_name, questions_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=question_docs, id_key="parent_doc_id")
        
    #     if split_questions_collection_name not in persistent_client.list_collections():
    #         print("Creating", split_questions_collection_name)
    #         _ = get_multivector_retriever(persistent_client, embedding_model_name, split_questions_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=split_question_docs, id_key="parent_doc_id")
        
    #     if summary_collection_name not in persistent_client.list_collections():
    #         print("Creating", summary_collection_name)
    #         _ = get_multivector_retriever(persistent_client, embedding_model_name, summary_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=summary_docs, id_key="parent_doc_id")

100%|██████████| 3/3 [00:00<00:00, 85.95it/s]


In [13]:
import os

os.listdir('../../storage')

[]

In [18]:
x = []

for embedding_model_name in tqdm(EMBEDDING_MODEL_NAMES):

    print(embedding_model_name)

    all_questions_collection_name = f"PQ_COMB_ALL_{embedding_model_name.replace("-", "_").replace("/", "_")}"
    all_split_questions_collection_name = f"PQ_SPLIT_ALL_{embedding_model_name.replace("-", "_").replace("/", "_")}"
    all_summary_collection_name = f"PS_ALL_{embedding_model_name.replace("-", "_").replace("/", "_")}"
    all_collection_name = f"PQS_ALL_{embedding_model_name.replace("-", "_").replace("/", "_")}"

    # if all_questions_collection_name in persistent_client.list_collections():
    #     persistent_client.delete_collection(all_questions_collection_name)

    # if all_split_questions_collection_name in persistent_client.list_collections():
    #     persistent_client.delete_collection(all_split_questions_collection_name)

    # if all_summary_collection_name in persistent_client.list_collections():
    #     persistent_client.delete_collection(all_summary_collection_name)

    # if all_collection_name in persistent_client.list_collections():
    #     persistent_client.delete_collection(all_collection_name)

    # all_generated_docs = all_questions_docs + all_split_questions_docs + all_summaries_docs

    # print(all_questions_collection_name)
    # r = get_multivector_retriever(persistent_client, embedding_model_name, all_questions_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=all_questions_docs, id_key="parent_doc_id")
    # x.append(r)

    # if all_split_questions_collection_name in COLLECTIONS: 
    print(all_split_questions_collection_name)
    r = get_multivector_retriever(persistent_client, embedding_model_name, all_split_questions_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=all_split_questions_docs, id_key="parent_doc_id")
    x.append(r)
    print(all_summary_collection_name)
    r = get_multivector_retriever(persistent_client, embedding_model_name, all_summary_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=all_summaries_docs, id_key="parent_doc_id")
    x.append(r)

    # if all_collection_name in COLLECTIONS:
    #     _ = get_multivector_retriever(persistent_client, embedding_model_name, all_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=all_generated_docs, id_key="parent_doc_id")

  0%|          | 0/1 [00:00<?, ?it/s]

text-embedding-3-small
PQ_SPLIT_ALL_text_embedding_3_small
../storage\docstore\PQ_SPLIT_ALL_text_embedding_3_small
False





ValidationError: 1 validation error for OpenAIEmbeddings
__root__
  Did not find openai_api_key, please add an environment variable `OPENAI_API_KEY` which contains it, or pass `openai_api_key` as a named parameter. (type=value_error)

In [61]:
x[0].invoke('sdfds')

[Document(page_content='ﺑﻴﺎن أرﺑﺎح اﻟﻨﺘﺎﺋﺞ اﻟﻤﺎﻟﻴﺔ اﻟﺴﻨﻮﻳﺔ ﻟﻠﻌﺎم م2022\n26 ﻓﺒﺮاﻳﺮ م2023\nو ﻋ ﻠﻖ اﻟﻤﻬﻨﺪس ﺧﺎﻟﺪ ﺑﻦ ﻋﺒﺪﷲ اﻟﺤﺼﺎن اﻟﺮﺋﻴﺲ اﻟﺘﻨﻔﻴﺬي ﻟﻤﺠﻤﻮﻋﺔ ﺗﺪاول اﻟﺴﻌﻮدﻳﺔ: "ﺧﻼل اﻟﻌﺎم م2022 ﻋﻤﻠﺖ\nاﻟﻤﺠﻤﻮﻋﺔ ﻋﻠﻰ ﺗﻘﺪﻳﻢ ﻋﺪد ﻣﻦ اﻟﺘﺤﺴﻴﻨﺎت ﻋﻠﻰ اﻟﺒﻨﻴﺔ اﻟﺘﺤﺘﻴﺔ ﻟﻠﺴﻮق اﻟﻤﺎﻟﻴﺔ اﻟﺴﻌﻮدﻳﺔ ﺑﻬﺪف اﺳﺘﻀﺎﻓﺔ ﻣﺠﻤﻮﻋﺔ\nﻣﺘﻨﻮﻋﺔ ﻣﻦ اﻟﻤﺼﺪرﻳﻦ واﻟﻤﺴﺘﺜﻤﺮﻳﻦ . وﻳﻌ ﺪ ذﻟﻚ دﻟﻴ ﻼ ﻋﻠﻰ ﻧﺸﺎط اﻹدراﺟﺎت اﻟﻘﻮي اﻟﺬي ﺳﺎﻫﻢ ﻓﻲ ﺗﺮﺳﻴﺦ ﻣﻜﺎﻧﺔ اﻟﺴﻮق\nاﻟﻤﺎﻟﻴﺔ اﻟﺴﻌﻮدﻳﺔ ﺿﻤﻦ أﺳﻮاق اﻹدراﺟﺎت اﻷﻓﻀﻞ أداء ﻋﻠﻰ ﻣﺴﺘﻮى اﻟﻌﺎﻟﻢ . واﺳﺘﺜﻤﺮﻧﺎ أﻳﻀﺎ ﻓﻲ ﺗﻄﻮﻳﺮ ﺑﻨﻴﺘﻨﺎ اﻟﺘﺤﺘﻴﺔ\nوﺧﺪﻣﺎﺗﻨﺎ ﻓﻲ إﻃﺎر ﺳﻌﻴﻨﺎ اﻟﻤﺴﺘﻤﺮ ﻟﺘﺤﻘﻴﻖ أﻫﺪاﻓﻨﺎ اﻻﺳﺘﺮاﺗﻴﺠﻴﺔ".\nوأﺿﺎف اﻟﺤﺼﺎن: "ﺗﻤﺎﺷﻴﺎ ﻣﻊ ﺟﻬﻮدﻧﺎ اﻟﻤﺒﺬوﻟﺔ ﻟﺘﻌﺰﻳﺰ اﻟﺒﻨﻴﺔ اﻟﺘﺤﺘﻴﺔ ﻟﺨﺪﻣﺎت اﻟﺘﺪاول وﻣﺎ ﺑﻌﺪ اﻟﺘﺪاول وﺗﺸﺠﻴﻊ ﻋﻤﻠﻴﺎت\nاﻹدرا ج ﺷﻬﺪﻧﺎ أول ﻋﻤﻠﻴﺔ إدراج ﻣﺰدوج وﻣﺘﺰاﻣﻦ ﺑﻴﻦ ﺗﺪاول اﻟﺴﻌﻮدﻳﺔ وﺳﻮق أﺑﻮﻇﺒﻲ ﻟﻸوراق اﻟﻤﺎﻟﻴﺔ واﻟﺬي ﻳﻤﺜﻞ ﻣﺮﺣﻠﺔ\nﺟﺪﻳﺪة ﻣﻦ اﻟﺘﻌﺎون ﺑﻴﻦ اﻟﺴﻮق اﻟﻤﺎﻟﻴﺔ اﻟﺴﻌﻮدﻳﺔ واﻷﺳﻮاق اﻟﻤﺎﻟﻴﺔ اﻟﺨﻠﻴﺠﻴﺔ واﻟﺪوﻟﻴﺔ . وﻋﻠﻰ ﻣﺪار اﻟﻌﺎم واﺻﻠﻨﺎ اﻟﺘﻌﺎون\nﻣﻊ اﻷﺳﻮاق اﻟﻤﺎﻟﻴﺔ اﻹﻗﻠﻴﻤﻴﺔ واﻟﺪوﻟﻴﺔ ﻟﺘﻤﻬﻴﺪ اﻟﻄﺮﻳﻖ ﻟﻠﻤﺰﻳﺪ ﻣﻦ ﻋﻤﻠﻴﺎت اﻹدراج اﻟﻤﺰدوج ﻓﻲ اﻟﻤﺴﺘﻘﺒﻞ اﻟﻘﺮﻳ ﺐ".\nواﺧﺘﺘﻢ ﻗ

In [50]:
COLLECTIONS

['PQ_SPLIT_ALL_text_embedding_3_small', 'PS_ALL_text_embedding_3_small']

In [54]:
x

[MultiVectorRetriever(vectorstore=<langchain_community.vectorstores.chroma.Chroma object at 0x000002A7288030E0>, byte_store=<langchain.storage.file_system.LocalFileStore object at 0x000002A722566690>, docstore=<langchain.storage.encoder_backed.EncoderBackedStore object at 0x000002A7282DD610>, id_key='parent_doc_id')]

In [41]:
x[1].invoke("ما هو")

IndexError: list index out of range

In [36]:
 _ids = [d.metadata['parent_doc_id'] for d in all_questions_docs]
print(len(set(parent_docs_ids).intersection(_ids)))


48


In [37]:
 _ids = [d.metadata['parent_doc_id'] for d in all_summaries_docs]
print(len(set(parent_docs_ids).intersection(_ids)))


48


In [None]:
# parent_docs_ids