In [20]:
%load_ext autoreload
%autoreload 2

In [21]:
import sys

sys.path.append('../../')

In [22]:
from src.indexing import get_multivector_retriever, get_parent_child_splits
from src.generation import QA_SYSTEM_PROMPT, QA_PROMPT, LLAMA_PROMPT_TEMPLATE, MIXTRAL_PROMPT_TEMPLATE
from src.generation import get_model, format_docs, get_rag_chain
from langchain_core.documents import Document

from src.ingestion import load_pdf

import os
import chromadb
import uuid
import pickle

In [23]:
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

from tqdm import tqdm

In [24]:
DATA_PATH = 'D:\Ahmed\saudi-rag-project\data'
RAW_DOCS_PATH = os.path.join(DATA_PATH, "raw")
CHROMA_PATH = os.path.join(DATA_PATH, "chroma")
INTERIM_DATA_PATH = os.path.join(DATA_PATH, "interim")

EMBEDDING_MODEL_NAMES = [
    "intfloat/multilingual-e5-small", 
    "intfloat/multilingual-e5-base", 
    "text-embedding-3-small", 
    "text-embedding-3-large",
    "text-embedding-ada-002"
 ]
MODEL_NAMES = ["meta-llama/Llama-3-8b-chat-hf", "meta-llama/Llama-3-70b-chat-hf", "mistralai/Mixtral-8x22B-Instruct-v0.1"]

  DATA_PATH = 'D:\Ahmed\saudi-rag-project\data'


In [25]:
docs = [load_pdf(os.path.join(RAW_DOCS_PATH, f)) for f in os.listdir(RAW_DOCS_PATH) if ".pdf" in f]

In [26]:
# print(docs[1].page_content)

In [27]:
persistent_client = chromadb.PersistentClient(path=CHROMA_PATH)

In [28]:
parent_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1200,
    chunk_overlap=400,
    separators=['\n\n\n', '\n\n', '\n', r'\.\s+', ' ', '']
)

parent_docs = parent_splitter.split_documents(docs)
# parent_docs_ids = [str(uuid.uuid4()) for _ in parent_docs]
# pickle.dump(parent_docs_ids, open(os.path.join(INTERIM_DATA_PATH, "parent_docs_ids"), 'wb'))
parent_docs_ids = pickle.load(open(os.path.join(INTERIM_DATA_PATH, "parent_docs_ids"), 'rb'))
id_key = "parent_doc_id"

In [12]:
# parent_docs_ids

#### Create regular child splitters with each embedding model with different sizes

In [18]:
len(parent_docs)

48

In [41]:
for embedding_model_name in tqdm(EMBEDDING_MODEL_NAMES):
    for child_chunk_size in [500, 300, 100]:
        collection_name = f"PC_{child_chunk_size}_{embedding_model_name.split('/')[-1].replace("-", "_").replace("/", "_")}"

        child_splitter = RecursiveCharacterTextSplitter(
            chunk_size=child_chunk_size,
            chunk_overlap=0,
        )

        child_docs = []
        for i, doc in enumerate(parent_docs):
            _id = parent_docs_ids[i]
            _child_docs = child_splitter.split_documents([doc])
            for _doc in _child_docs:
                _doc.metadata[id_key] = _id
            child_docs.extend(_child_docs)

        _ = get_multivector_retriever(persistent_client, embedding_model_name, collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=child_docs, id_key="parent_doc_id")

100%|██████████| 5/5 [02:24<00:00, 28.81s/it]


### Create retrievers with different child chunks sizes

In [46]:
for embedding_model_name in tqdm(EMBEDDING_MODEL_NAMES):
    for child_chunk_sizes in [(500, 100), (500, 300), (300, 100), (500, 300, 100)]:
        all_child_docs = []
        collection_name = f"PC_{'_'.join([str(x) for x in child_chunk_sizes])}_{embedding_model_name.split('/')[-1].replace("-", "_").replace("/", "_")}"

        for child_chunk_size in child_chunk_sizes:
            child_splitter = RecursiveCharacterTextSplitter(
                chunk_size=child_chunk_size,
                chunk_overlap=0,
            )

            child_docs = []
            for i, doc in enumerate(parent_docs):
                _id = parent_docs_ids[i]
                _child_docs = child_splitter.split_documents([doc])
                for _doc in _child_docs:
                    _doc.metadata[id_key] = _id
                child_docs.extend(_child_docs)

            all_child_docs.extend(child_docs)

        _ = get_multivector_retriever(persistent_client, embedding_model_name, collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=all_child_docs, id_key="parent_doc_id")

100%|██████████| 5/5 [13:24<00:00, 160.99s/it]


#### Generate questions for each document.

In [30]:
QUESTIONS_SYSTEM_PROMPT = """Write a list of 30 fact based simple questions in Arabic that can be answered using the document. \
The questions should be require specific numbers or information mentioned in the document for an answer. \
But the questions should include information present in the document itself. \
Each question should be understandable indepdent from any context. \
You SHOULD NOT use any numbers in the question (except for years). \
Write the questions in Arabic only and don't write the answers. \
Output the questions directly without an introduction."""

SUMMARY_SYSTEM_PROMPT = """You are an expert summary writer. \
Write a title and summarize the provided context in an entity dense way. \
The summary and title should be in Arabic only."""

USER_MESSAGE = """Context: {context}"""

In [31]:
from langchain_core.prompts import PromptTemplate

### Generate questions from all models

In [32]:
import langid

def is_english(text):
    lang, _ = langid.classify(text)
    return lang == 'en'

In [13]:
MODEL_NAMES[2:]

['mistralai/Mixtral-8x22B-Instruct-v0.1']

In [38]:
for model_name in tqdm(MODEL_NAMES[1:2]):

    if "llama" in model_name:
        prompt_template = PromptTemplate.from_template(LLAMA_PROMPT_TEMPLATE.format(system_prompt=QUESTIONS_SYSTEM_PROMPT, user_message=USER_MESSAGE))
        llm = get_model(model_name)

    if "mistral" in model_name:
        prompt_template = PromptTemplate.from_template(MIXTRAL_PROMPT_TEMPLATE.format(system_prompt=QUESTIONS_SYSTEM_PROMPT, user_message=USER_MESSAGE))
        llm = get_model(model_name, max_tokens=2048)

    question_chain = (
        prompt_template
        | llm
        | StrOutputParser()
    )

    _question_docs = question_chain.batch(parent_docs)

    question_docs = []
    split_question_docs = []
    for i, doc in enumerate(_question_docs):
        _id = parent_docs_ids[i]
        _doc = Document(page_content=doc, metadata={id_key: _id, 'source': parent_docs[i].metadata['source']})
        question_docs.append(_doc)
        for l in doc.split('\n'):
            if not is_english(l): 
                _question = Document(page_content=l, metadata={id_key: _id, 'source': parent_docs[i].metadata['source']})
                split_question_docs.append(_question)
        

    question_file = f"PQ_COMB_{model_name.split('/')[-1].replace('-', '_').replace('/', '_')}"
    question_file_path = os.path.join(INTERIM_DATA_PATH, question_file)

    split_question_file = f"PQ_SPLIT_{model_name.split('/')[-1].replace('-', '_').replace('/', '_')}"
    split_question_file_path = os.path.join(INTERIM_DATA_PATH, split_question_file)

    pickle.dump(question_docs, open(question_file_path, 'wb'))
    pickle.dump(split_question_docs, open(split_question_file_path, 'wb'))

100%|██████████| 1/1 [08:03<00:00, 483.14s/it]


#### Generate summaries from all models

In [42]:
for model_name in tqdm(MODEL_NAMES[2:]):

    if "llama" in model_name:
        prompt_template = PromptTemplate.from_template(LLAMA_PROMPT_TEMPLATE.format(system_prompt=SUMMARY_SYSTEM_PROMPT, user_message=USER_MESSAGE))
    if "mistral" in model_name:
        prompt_template = PromptTemplate.from_template(MIXTRAL_PROMPT_TEMPLATE.format(system_prompt=SUMMARY_SYSTEM_PROMPT, user_message=USER_MESSAGE))

    llm = get_model(model_name)

    summary_chain = (
        prompt_template
        | llm
        | StrOutputParser()
    )

    _summary_docs = summary_chain.batch(parent_docs)

    summary_docs = []
    for i, doc in enumerate(_summary_docs):
        _id = parent_docs_ids[i]
        _doc = Document(page_content=doc, metadata={id_key: _id, 'source': parent_docs[i].metadata['source']})
        summary_docs.append(_doc)
        

    summary_file = f"PS_{model_name.split('/')[-1].replace('-', '_').replace('/', '_')}"
    summary_file_path = os.path.join(INTERIM_DATA_PATH, summary_file)

    pickle.dump(summary_docs, open(summary_file_path, 'wb'))

100%|██████████| 1/1 [11:28<00:00, 688.96s/it]


### Create vector stores

In [44]:
all_questions_docs = []
all_split_questions_docs = []
all_summaries_docs = []

for model_name in tqdm(MODEL_NAMES):

    question_file = f"PQ_COMB_{model_name.split('/')[-1].replace('-', '_').replace('/', '_')}"
    question_file_path = os.path.join(INTERIM_DATA_PATH, question_file)

    split_question_file = f"PQ_SPLIT_{model_name.split('/')[-1].replace('-', '_').replace('/', '_')}"
    split_question_file_path = os.path.join(INTERIM_DATA_PATH, split_question_file)

    summary_file = f"PS_{model_name.split('/')[-1].replace('-', '_').replace('/', '_')}"
    summary_file_path = os.path.join(INTERIM_DATA_PATH, summary_file)

    question_docs = pickle.load(open(question_file_path, 'rb'))
    split_question_docs = pickle.load(open(split_question_file_path, 'rb'))
    summary_docs = pickle.load(open(summary_file_path, 'rb'))

    all_questions_docs.extend(question_docs)
    all_split_questions_docs.extend(split_question_docs)
    all_summaries_docs.extend(summary_docs)

    for embedding_model_name in EMBEDDING_MODEL_NAMES:

        if "mistral" in model_name:
            model_name = model_name.split('/')[-1].replace('-', '_').replace('/', '_')[:13]
        else:
            model_name = model_name.split('/')[-1].replace('-', '_').replace('/', '_')

        questions_collection_name = f"PQ_COMB_{model_name}_{embedding_model_name.replace("-", "_").replace("/", "_")}"
        split_questions_collection_name = f"PQ_SPLIT_{model_name}_{embedding_model_name.replace("-", "_").replace("/", "_")}"
        summary_collection_name = f"PS_{model_name}_{embedding_model_name.replace("-", "_").replace("/", "_")}"

        if questions_collection_name not in persistent_client.list_collections():
            print("Creating", questions_collection_name)
            _ = get_multivector_retriever(persistent_client, embedding_model_name, questions_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=question_docs, id_key="parent_doc_id")
        
        if split_questions_collection_name not in persistent_client.list_collections():
            print("Creating", split_questions_collection_name)
            _ = get_multivector_retriever(persistent_client, embedding_model_name, split_questions_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=split_question_docs, id_key="parent_doc_id")
        
        if summary_collection_name not in persistent_client.list_collections():
            print("Creating", summary_collection_name)
            _ = get_multivector_retriever(persistent_client, embedding_model_name, summary_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=summary_docs, id_key="parent_doc_id")

  0%|          | 0/3 [00:00<?, ?it/s]

Creating PQ_COMB_Llama_3_8b_chat_hf_intfloat_multilingual_e5_small
Creating PQ_SPLIT_Llama_3_8b_chat_hf_intfloat_multilingual_e5_small
Creating PS_Llama_3_8b_chat_hf_intfloat_multilingual_e5_small
Creating PQ_COMB_Llama_3_8b_chat_hf_intfloat_multilingual_e5_base
Creating PQ_SPLIT_Llama_3_8b_chat_hf_intfloat_multilingual_e5_base
Creating PS_Llama_3_8b_chat_hf_intfloat_multilingual_e5_base
Creating PQ_COMB_Llama_3_8b_chat_hf_text_embedding_3_small


  warn_deprecated(


Creating PQ_SPLIT_Llama_3_8b_chat_hf_text_embedding_3_small
Creating PS_Llama_3_8b_chat_hf_text_embedding_3_small
Creating PQ_COMB_Llama_3_8b_chat_hf_text_embedding_3_large
Creating PQ_SPLIT_Llama_3_8b_chat_hf_text_embedding_3_large
Creating PS_Llama_3_8b_chat_hf_text_embedding_3_large
Creating PQ_COMB_Llama_3_8b_chat_hf_text_embedding_ada_002
Creating PQ_SPLIT_Llama_3_8b_chat_hf_text_embedding_ada_002
Creating PS_Llama_3_8b_chat_hf_text_embedding_ada_002


 33%|███▎      | 1/3 [16:22<32:44, 982.13s/it]

Creating PQ_COMB_Llama_3_70b_chat_hf_intfloat_multilingual_e5_small
Creating PQ_SPLIT_Llama_3_70b_chat_hf_intfloat_multilingual_e5_small
Creating PS_Llama_3_70b_chat_hf_intfloat_multilingual_e5_small
Creating PQ_COMB_Llama_3_70b_chat_hf_intfloat_multilingual_e5_base
Creating PQ_SPLIT_Llama_3_70b_chat_hf_intfloat_multilingual_e5_base
Creating PS_Llama_3_70b_chat_hf_intfloat_multilingual_e5_base
Creating PQ_COMB_Llama_3_70b_chat_hf_text_embedding_3_small
Creating PQ_SPLIT_Llama_3_70b_chat_hf_text_embedding_3_small
Creating PS_Llama_3_70b_chat_hf_text_embedding_3_small
Creating PQ_COMB_Llama_3_70b_chat_hf_text_embedding_3_large
Creating PQ_SPLIT_Llama_3_70b_chat_hf_text_embedding_3_large
Creating PS_Llama_3_70b_chat_hf_text_embedding_3_large
Creating PQ_COMB_Llama_3_70b_chat_hf_text_embedding_ada_002
Creating PQ_SPLIT_Llama_3_70b_chat_hf_text_embedding_ada_002
Creating PS_Llama_3_70b_chat_hf_text_embedding_ada_002


 67%|██████▋   | 2/3 [24:08<11:18, 678.65s/it]

Creating PQ_COMB_Mixtral_8x22B_intfloat_multilingual_e5_small
Creating PQ_SPLIT_Mixtral_8x22B_intfloat_multilingual_e5_small
Creating PS_Mixtral_8x22B_intfloat_multilingual_e5_small
Creating PQ_COMB_Mixtral_8x22B_intfloat_multilingual_e5_base
Creating PQ_SPLIT_Mixtral_8x22B_intfloat_multilingual_e5_base
Creating PS_Mixtral_8x22B_intfloat_multilingual_e5_base
Creating PQ_COMB_Mixtral_8x22B_text_embedding_3_small
Creating PQ_SPLIT_Mixtral_8x22B_text_embedding_3_small
Creating PS_Mixtral_8x22B_text_embedding_3_small
Creating PQ_COMB_Mixtral_8x22B_text_embedding_3_large
Creating PQ_SPLIT_Mixtral_8x22B_text_embedding_3_large
Creating PS_Mixtral_8x22B_text_embedding_3_large
Creating PQ_COMB_Mixtral_8x22B_text_embedding_ada_002
Creating PQ_SPLIT_Mixtral_8x22B_text_embedding_ada_002
Creating PS_Mixtral_8x22B_text_embedding_ada_002


100%|██████████| 3/3 [24:11<00:00, 483.75s/it]


#### Create one vectorstore with all questions

In [55]:
for embedding_model_name in tqdm(EMBEDDING_MODEL_NAMES):

    all_questions_collection_name = f"PQ_COMB_ALL_{embedding_model_name.replace("-", "_").replace("/", "_")}"
    all_split_questions_collection_name = f"PQ_SPLIT_ALL_{embedding_model_name.replace("-", "_").replace("/", "_")}"
    all_summary_collection_name = f"PS_ALL_{embedding_model_name.replace("-", "_").replace("/", "_")}"
    all_collection_name = f"PQS_ALL_{embedding_model_name.replace("-", "_").replace("/", "_")}"

    if all_questions_collection_name in persistent_client.list_collections():
        persistent_client.delete_collection(all_questions_collection_name)

    if all_split_questions_collection_name in persistent_client.list_collections():
        persistent_client.delete_collection(all_split_questions_collection_name)

    if all_summary_collection_name in persistent_client.list_collections():
        persistent_client.delete_collection(all_summary_collection_name)

    if all_collection_name in persistent_client.list_collections():
        persistent_client.delete_collection(all_collection_name)

    all_generated_docs = all_questions_docs + all_split_questions_docs + all_summaries_docs

    _ = get_multivector_retriever(persistent_client, embedding_model_name, all_questions_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=all_questions_docs, id_key="parent_doc_id")
    _ = get_multivector_retriever(persistent_client, embedding_model_name, all_split_questions_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=all_split_questions_docs, id_key="parent_doc_id")
    _ = get_multivector_retriever(persistent_client, embedding_model_name, all_summary_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=all_summaries_docs, id_key="parent_doc_id")
    _ = get_multivector_retriever(persistent_client, embedding_model_name, all_collection_name, DATA_PATH, parent_docs=parent_docs, parent_docs_ids=parent_docs_ids, child_docs=all_generated_docs, id_key="parent_doc_id")

 20%|██        | 1/5 [17:03<1:08:15, 1023.85s/it]

In [43]:
for model_name in MODEL_NAMES:
    print(model_name)
    question_file = f"PQ_COMB_{model_name.split('/')[-1].replace('-', '_').replace('/', '_')}"
    question_file_path = os.path.join(INTERIM_DATA_PATH, question_file)

    split_question_file = f"PQ_SPLIT_{model_name.split('/')[-1].replace('-', '_').replace('/', '_')}"
    split_question_file_path = os.path.join(INTERIM_DATA_PATH, split_question_file)

    summary_file = f"PS_{model_name.split('/')[-1].replace('-', '_').replace('/', '_')}"
    summary_file_path = os.path.join(INTERIM_DATA_PATH, summary_file)

    question_docs = pickle.load(open(question_file_path, 'rb'))
    split_question_docs = pickle.load(open(split_question_file_path, 'rb'))
    summary_docs = pickle.load(open(summary_file_path, 'rb'))

    _ids = [d.metadata['parent_doc_id'] for d in question_docs]
    print(len(set(parent_docs_ids).intersection(_ids)))

    _ids = [d.metadata['parent_doc_id'] for d in split_question_docs]
    print(len(set(parent_docs_ids).intersection(_ids)))

    _ids = [d.metadata['parent_doc_id'] for d in summary_docs]
    print(len(set(parent_docs_ids).intersection(_ids)))

meta-llama/Llama-3-8b-chat-hf
48
48
48
meta-llama/Llama-3-70b-chat-hf
48
48
48
mistralai/Mixtral-8x22B-Instruct-v0.1
48
48
48


In [47]:
 _ids = [d.metadata['parent_doc_id'] for d in all_questions_docs]
print(len(set(parent_docs_ids).intersection(_ids)))


48


In [48]:
 _ids = [d.metadata['parent_doc_id'] for d in all_split_questions_docs]
print(len(set(parent_docs_ids).intersection(_ids)))


48


In [49]:
 _ids = [d.metadata['parent_doc_id'] for d in all_summaries_docs]
print(len(set(parent_docs_ids).intersection(_ids)))


48
