In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append('../../')

In [3]:
from src.indexing import get_multivector_retriever, get_parent_child_splits
from src.generation import QA_SYSTEM_PROMPT, QA_PROMPT, LLAMA_PROMPT_TEMPLATE, MIXTRAL_PROMPT_TEMPLATE
from src.generation import get_model, format_docs, get_rag_chain
from langchain_core.documents import Document

from src.ingestion import load_pdf

import os
import chromadb
import uuid
import pickle

In [7]:
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

from tqdm import tqdm

import pandas as pd

In [5]:
DATA_PATH = 'D:\Ahmed\saudi-rag-project\data'
RAW_DOCS_PATH = os.path.join(DATA_PATH, "raw")
CHROMA_PATH = os.path.join(DATA_PATH, "chroma")
INTERIM_DATA_PATH = os.path.join(DATA_PATH, "interim")

EMBEDDING_MODEL_NAMES = [
    "intfloat/multilingual-e5-small", 
    "intfloat/multilingual-e5-base", 
    "text-embedding-3-small", 
    "text-embedding-3-large",
    "text-embedding-ada-002"
 ]
MODEL_NAMES = ["meta-llama/Llama-3-8b-chat-hf", "meta-llama/Llama-3-70b-chat-hf", "mistralai/Mixtral-8x22B-Instruct-v0.1"]

  DATA_PATH = 'D:\Ahmed\saudi-rag-project\data'


In [25]:
persistent_client = chromadb.PersistentClient(path=CHROMA_PATH)

In [8]:
retrievers_results = pd.read_csv('retrieval_results.csv')

#### Get the top collections

In [16]:
retrievers_results['configuration'] = retrievers_results.apply(lambda x: x['collection_name'] + '-' + 'K_' + str(x['k']), axis=1)

In [17]:
top_collection_names = retrievers_results.sort_values('recall', ascending=False).head(5).configuration.tolist()
top_collection_names += retrievers_results.sort_values('precision', ascending=False).head(5).configuration.tolist()
top_collection_names += retrievers_results.sort_values('average_precision', ascending=False).head(5).configuration.tolist()
top_collection_names = set(top_collection_names)

In [18]:
top_collection_names

{'PQS_ALL_text_embedding_3_small-K_11',
 'PQS_ALL_text_embedding_3_small-K_13',
 'PQS_ALL_text_embedding_3_small-K_3',
 'PQS_ALL_text_embedding_3_small-K_5',
 'PQS_ALL_text_embedding_3_small-K_9',
 'PQ_COMB_Llama_3_70b_chat_hf_intfloat_multilingual_e5_base-K_17',
 'PQ_COMB_Llama_3_70b_chat_hf_intfloat_multilingual_e5_base-K_19',
 'PQ_COMB_Llama_3_70b_chat_hf_intfloat_multilingual_e5_base-K_3',
 'PQ_COMB_Llama_3_70b_chat_hf_text_embedding_3_large-K_13',
 'PQ_COMB_Llama_3_70b_chat_hf_text_embedding_3_large-K_15',
 'PQ_SPLIT_ALL_text_embedding_3_small-K_11',
 'PQ_SPLIT_ALL_text_embedding_3_small-K_3',
 'PQ_SPLIT_ALL_text_embedding_3_small-K_5'}

#### Create a list of configurations

In [57]:
configs = []

for collection in top_collection_names:

    config_dict = dict()
    config_dict['config_name'] = collection
    config_dict['collection_name'] = collection.split('-')[0]
    config_dict['k'] = collection.split('-')[-1].split('_')[-1]

    if 'text_embedding_3_large' in collection:
        config_dict['embedding_model_name'] = 'text-embedding-3-large'

    elif 'text_embedding_3_small' in collection:
        config_dict['embedding_model_name'] = 'text-embedding-3-small'

    elif 'text_embedding_ada_002' in collection:
        config_dict['embedding_model_name'] = 'text-embedding-ada-002'

    elif 'multilingual_e5_small' in collection:
        config_dict['embedding_model_name'] = 'intfloat/multilingual-e5-small'

    elif 'multilingual_e5_base' in collection:
        config_dict['embedding_model_name'] = 'intfloat/multilingual-e5-base'

    configs.append(config_dict)

#### Load benchmark

In [22]:
benchmark = pd.read_csv("../../data/benchmark.csv")

In [42]:
import re

def extract_numbers(text):
    text = re.sub(',', '', text)

    # This pattern matches both integers and decimal numbers
    pattern = r'\b\d+\.?\d*\b'

    # Find all matches in the text and return them as a list of floats or integers
    numbers = re.findall(pattern, text)

    # Convert the extracted number strings to appropriate float or int types
    return [float(num) if '.' in num else int(num) for num in numbers]

#### Evaluate model

In [55]:
def evaluate_model_config(qa_chain, benchmark, config_name, model_name):

    benchmark = benchmark.copy()
    questions = benchmark.question.tolist()
    answers = benchmark.answer.tolist()

    generated_answers = qa_chain.batch(questions)

    hits = []

    for answer, generated_answer in zip(answers, generated_answers):
        if set(extract_numbers(answer)).intersection(extract_numbers(generated_answer)):
            hits.append(1)
        else:
            hits.append(0)

    benchmark['generated_answer'] = generated_answers
    benchmark['correct'] = hits
    benchmark['config'] = config_name
    benchmark['model'] = model_name

    return benchmark

In [24]:
llm = get_model(MODEL_NAMES[0])

In [26]:
config_dict = configs[0]

In [28]:
retriever = get_multivector_retriever(persistent_client, config_dict['embedding_model_name'], config_dict['collection_name'], DATA_PATH, k=config_dict['k'])

In [32]:
from langchain_core.prompts import PromptTemplate

qa_prompt_template = PromptTemplate.from_template(LLAMA_PROMPT_TEMPLATE.format(system_prompt=QA_SYSTEM_PROMPT, user_message=QA_PROMPT))

In [33]:
qa_chain = get_rag_chain(llm, retriever, format_docs, qa_prompt_template)

In [43]:
# %%time

# model_benchmark = evaluate_model_config(qa_chain, benchmark)

CPU times: total: 33.8 s
Wall time: 1min 22s


In [45]:
# model_benchmark.correct.sum()

35

#### Evaluate all models

In [52]:
from tqdm import tqdm
from collections import defaultdict

In [53]:
model_benchmarks = defaultdict(list)

In [58]:
model_name = MODEL_NAMES[0]
llm = get_model(model_name)
if 'mistral' in model_name:
    qa_prompt_template = PromptTemplate.from_template(MIXTRAL_PROMPT_TEMPLATE.format(system_prompt=QA_SYSTEM_PROMPT, user_message=QA_PROMPT))
else:
    qa_prompt_template = PromptTemplate.from_template(LLAMA_PROMPT_TEMPLATE.format(system_prompt=QA_SYSTEM_PROMPT, user_message=QA_PROMPT))


for config_dict in tqdm(configs):
    retriever = get_multivector_retriever(persistent_client, config_dict['embedding_model_name'], config_dict['collection_name'], DATA_PATH, k=config_dict['k'])
    qa_chain = get_rag_chain(llm, retriever, format_docs, qa_prompt_template)
    model_benchmark = evaluate_model_config(qa_chain, benchmark, config_dict['config_name'], model_name)

    model_benchmarks[model_name].append(model_benchmark)

 23%|██▎       | 3/13 [04:14<13:59, 83.94s/it]

In [None]:
model_name = MODEL_NAMES[1]
llm = get_model(model_name)
if 'mistral' in model_name:
    qa_prompt_template = PromptTemplate.from_template(MIXTRAL_PROMPT_TEMPLATE.format(system_prompt=QA_SYSTEM_PROMPT, user_message=QA_PROMPT))
else:
    qa_prompt_template = PromptTemplate.from_template(LLAMA_PROMPT_TEMPLATE.format(system_prompt=QA_SYSTEM_PROMPT, user_message=QA_PROMPT))


for config_dict in tqdm(configs):
    retriever = get_multivector_retriever(persistent_client, config_dict['embedding_model_name'], config_dict['collection_name'], DATA_PATH, k=config_dict['k'])
    qa_chain = get_rag_chain(llm, retriever, format_docs, qa_prompt_template)
    model_benchmark = evaluate_model_config(qa_chain, benchmark, config_dict['config_name'], model_name)

    model_benchmarks[model_name].append(model_benchmark)

In [None]:
model_name = MODEL_NAMES[2]
llm = get_model(model_name)
if 'mistral' in model_name:
    qa_prompt_template = PromptTemplate.from_template(MIXTRAL_PROMPT_TEMPLATE.format(system_prompt=QA_SYSTEM_PROMPT, user_message=QA_PROMPT))
else:
    qa_prompt_template = PromptTemplate.from_template(LLAMA_PROMPT_TEMPLATE.format(system_prompt=QA_SYSTEM_PROMPT, user_message=QA_PROMPT))


for config_dict in tqdm(configs):
    retriever = get_multivector_retriever(persistent_client, config_dict['embedding_model_name'], config_dict['collection_name'], DATA_PATH, k=config_dict['k'])
    qa_chain = get_rag_chain(llm, retriever, format_docs, qa_prompt_template)
    model_benchmark = evaluate_model_config(qa_chain, benchmark, config_dict['config_name'], model_name)

    model_benchmarks[model_name].append(model_benchmark)