In [1]:
import re

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys

sys.path.append('../../')

In [4]:
from src.indexing import get_multivector_retriever
from src.generation import QA_SYSTEM_PROMPT, QA_PROMPT, LLAMA_PROMPT_TEMPLATE, MIXTRAL_PROMPT_TEMPLATE
from src.generation import get_model, format_docs, get_rag_chain
from langchain_core.documents import Document

from src.ingestion import load_pdf
from tqdm import tqdm

import os
import chromadb
import uuid
import math

import pandas as pd

In [5]:
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

In [6]:
DATA_PATH = 'D:\Ahmed\saudi-rag-project\data'
RAW_DOCS_PATH = os.path.join(DATA_PATH, "raw")
CHROMA_PATH = os.path.join(DATA_PATH, "chroma")

MODEL_NAME = "meta-llama/Llama-3-8b-chat-hf"

  DATA_PATH = 'D:\Ahmed\saudi-rag-project\data'


In [7]:
persistent_client = chromadb.PersistentClient(path=CHROMA_PATH)

In [8]:
collections = []

for collection in persistent_client.list_collections():

    collection_dict = dict()

    collection_dict['collection_name'] = collection.name

    if 'text_embedding_3_large' in collection.name:
        collection_dict['embedding_model_name'] = 'text-embedding-3-large'

    elif 'text_embedding_3_small' in collection.name:
        collection_dict['embedding_model_name'] = 'text-embedding-3-small'

    elif 'text_embedding_ada_002' in collection.name:
        collection_dict['embedding_model_name'] = 'text-embedding-ada-002'

    elif 'multilingual_e5_small' in collection.name:
        collection_dict['embedding_model_name'] = 'intfloat/multilingual-e5-small'

    elif 'multilingual_e5_base' in collection.name:
        collection_dict['embedding_model_name'] = 'intfloat/multilingual-e5-base'

    collections.append(collection_dict)

In [9]:
len(collections)

100

In [10]:
collections[:5]

[{'collection_name': 'PQ_COMB_Llama_3_70b_chat_hf_text_embedding_ada_002',
  'embedding_model_name': 'text-embedding-ada-002'},
 {'collection_name': 'PQ_COMB_ALL_text_embedding_3_large',
  'embedding_model_name': 'text-embedding-3-large'},
 {'collection_name': 'PC_100_text_embedding_ada_002',
  'embedding_model_name': 'text-embedding-ada-002'},
 {'collection_name': 'PS_Mixtral_8x22B_intfloat_multilingual_e5_small',
  'embedding_model_name': 'intfloat/multilingual-e5-small'},
 {'collection_name': 'PQ_SPLIT_Mixtral_8x22B_intfloat_multilingual_e5_small',
  'embedding_model_name': 'intfloat/multilingual-e5-small'}]

### Load benchmark

In [11]:
benchmark = pd.read_csv("../../data/benchmark.csv")

In [12]:
def extract_numbers(text):
    text = re.sub(',', '', text)

    # This pattern matches both integers and decimal numbers
    pattern = r'\b\d+\.?\d*\b'

    # Find all matches in the text and return them as a list of floats or integers
    numbers = re.findall(pattern, text)

    # Convert the extracted number strings to appropriate float or int types
    return [float(num) if '.' in num else int(num) for num in numbers]

### Define a function to test one retriver

What are the metrics that I'll assess by?
1. Recall
2. Precision@k
3. AP@k
4. Retrieval count

In [13]:
def calculate_average_precision(hits):
    if not hits:
        return 0.0  # Return 0 if hits list is empty
    
    cum_sum = 0
    total_hits = 0
    precision_at_i = 0
    
    for i, hit in enumerate(hits):
        if hit == 1:
            total_hits += 1
            precision_at_i += total_hits / (i + 1)
    
    if total_hits == 0:
        return 0.0  # Avoid division by zero if there are no positive hits
    
    return precision_at_i / total_hits

In [14]:
def evaluate_retriever(collection_dict, k, benchmark):

    benchmark = benchmark.copy()

    retriever = get_multivector_retriever(persistent_client, collection_dict['embedding_model_name'], collection_dict['collection_name'], DATA_PATH, k=k)

    questions_retrieved_docs = retriever.batch(benchmark.question.tolist())

    recall = []
    recall_at_1 = []
    recall_at_2 = []
    recall_at_3 = []
    precision = []
    precision_at_1 = []
    precision_at_2 = []
    precision_at_3 = []
    average_precision = []
    retrieval_count = []

    missing_collections = []

    if not questions_retrieved_docs[0]:
        print(collection_dict['collection_name'])
        return collection_dict

    for docs, answer in zip(questions_retrieved_docs, benchmark.answer.tolist()):

        hits = []
        
        for d in docs:
            if set(extract_numbers(answer)).intersection(extract_numbers(d.page_content)):
                hits.append(1)
            else:
                hits.append(0)
        
        recall.append(max(hits))
        recall_at_1.append(max(hits[:1]))
        recall_at_2.append(max(hits[:2]))
        recall_at_3.append(max(hits[:3]))

        precision.append(sum(hits) / len(hits))
        precision_at_1.append(sum(hits[:1]))
        precision_at_2.append(sum(hits[:2])/2)
        precision_at_3.append(sum(hits[:3])/3)

        average_precision.append(calculate_average_precision(hits))
        retrieval_count.append(len(docs))

    benchmark['collection_name'] = collection_dict['collection_name']
    benchmark['embedding_model'] = collection_dict['embedding_model_name']
    benchmark['k'] = k
    benchmark['recall'] = recall
    benchmark['recall@1'] = recall_at_1
    benchmark['recall@2'] = recall_at_2
    benchmark['recall@3'] = recall_at_3
    benchmark['precision'] = precision
    benchmark['precision@1'] = precision_at_1
    benchmark['precision@2'] = precision_at_1
    benchmark['precision@3'] = precision_at_1
    benchmark['average_precision'] = average_precision
    benchmark['retrieval_count'] = retrieval_count

    return benchmark

##### Test one retriver

In [72]:
# Let's load a test retriever
collection_dict = collections[12]
retriever_benchmark = evaluate_retriever(collection_dict, 10, benchmark)

PQ_COMB_Llama_3_8b_chat_hf_text_embedding_3_small


In [18]:
metrics = ["recall", "recall@1", "recall@2", "recall@3", "precision", "precision@1", "precision@2", "precision@3", "average_precision", "retrieval_count"]

In [77]:
# retriever_benchmark.groupby(["collection_name", "embedding_model"])[metrics].mean()

##### Do all retrievers with different k

In [15]:
retrievers_benchmarks = []
missing_collections = []

for collection_dict in tqdm(collections):
    for k in [3, 5, 9, 11, 13, 15, 17, 19]:
        retriever_benchmark = evaluate_retriever(collection_dict, k, benchmark)
        if isinstance(retriever_benchmark, dict):
            missing_collections.append(retriever_benchmark)
        else:
            retrievers_benchmarks.append(retriever_benchmark)

  warn_deprecated(
100%|██████████| 100/100 [48:03<00:00, 28.83s/it]


In [16]:
retrievers_benchmarks = pd.concat(retrievers_benchmarks, axis=0)

In [17]:
retrievers_benchmarks.to_csv("retrievers_benchmark.csv", index=False)

What is the best embedding model?

In [19]:
retrievers_benchmarks.groupby("embedding_model")[metrics].mean()

Unnamed: 0_level_0,recall,recall@1,recall@2,recall@3,precision,precision@1,precision@2,precision@3,average_precision,retrieval_count
embedding_model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
intfloat/multilingual-e5-base,0.802778,0.614815,0.731481,0.787037,0.469059,0.614815,0.614815,0.614815,0.688606,2.969444
intfloat/multilingual-e5-small,0.796296,0.599074,0.714815,0.77037,0.464275,0.599074,0.599074,0.599074,0.672351,3.013889
text-embedding-3-large,0.748032,0.531944,0.661343,0.719213,0.430527,0.531944,0.531944,0.531944,0.617191,3.12419
text-embedding-3-small,0.641898,0.484259,0.579398,0.614815,0.374421,0.484259,0.484259,0.484259,0.542612,3.146296
text-embedding-ada-002,0.575,0.409259,0.515625,0.550926,0.311699,0.409259,0.409259,0.409259,0.475948,3.287269


What is the best retriever?

In [24]:
retrievers_benchmarks.groupby(["collection_name", "k"])[metrics].mean().sort_values("recall@1", ascending=False).head(20)

Unnamed: 0_level_0,Unnamed: 1_level_0,recall,recall@1,recall@2,recall@3,precision,precision@1,precision@2,precision@3,average_precision,retrieval_count
collection_name,k,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
PQS_ALL_text_embedding_3_small,19,0.87037,0.814815,0.87037,0.87037,0.756173,0.814815,0.814815,0.814815,0.839506,1.611111
PQS_ALL_text_embedding_3_small,17,0.87037,0.814815,0.87037,0.87037,0.756173,0.814815,0.814815,0.814815,0.839506,1.62963
PQS_ALL_text_embedding_3_small,15,0.87037,0.814815,0.87037,0.87037,0.756173,0.814815,0.814815,0.814815,0.839506,1.62963
PQS_ALL_text_embedding_3_small,13,0.87037,0.814815,0.87037,0.87037,0.756173,0.814815,0.814815,0.814815,0.839506,1.62963
PQS_ALL_text_embedding_3_small,11,0.87037,0.814815,0.87037,0.87037,0.756173,0.814815,0.814815,0.814815,0.839506,1.62963
PQS_ALL_text_embedding_3_small,9,0.87037,0.814815,0.87037,0.87037,0.756173,0.814815,0.814815,0.814815,0.839506,1.62963
PQS_ALL_text_embedding_3_small,5,0.87037,0.814815,0.87037,0.87037,0.756173,0.814815,0.814815,0.814815,0.839506,1.62963
PQS_ALL_text_embedding_3_small,3,0.87037,0.814815,0.87037,0.87037,0.756173,0.814815,0.814815,0.814815,0.839506,1.62963
PQS_ALL_intfloat_multilingual_e5_small,3,0.796296,0.796296,0.796296,0.796296,0.709877,0.796296,0.796296,0.796296,0.79321,1.611111
PQ_SPLIT_ALL_text_embedding_3_small,3,0.888889,0.796296,0.888889,0.888889,0.756173,0.796296,0.796296,0.796296,0.839506,1.62963


In [26]:
retrievers_benchmarks.groupby(["collection_name", "k"])[metrics].mean().sort_values("recall@2", ascending=False).head(20)

Unnamed: 0_level_0,Unnamed: 1_level_0,recall,recall@1,recall@2,recall@3,precision,precision@1,precision@2,precision@3,average_precision,retrieval_count
collection_name,k,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
PQ_SPLIT_Llama_3_70b_chat_hf_text_embedding_3_small,19,0.907407,0.740741,0.907407,0.907407,0.665123,0.740741,0.740741,0.740741,0.825617,2.037037
PQ_SPLIT_Llama_3_70b_chat_hf_text_embedding_3_small,17,0.907407,0.740741,0.907407,0.907407,0.665123,0.740741,0.740741,0.740741,0.825617,2.037037
PQ_SPLIT_Llama_3_70b_chat_hf_text_embedding_3_small,13,0.907407,0.740741,0.907407,0.907407,0.665123,0.740741,0.740741,0.740741,0.825617,2.037037
PQ_SPLIT_Llama_3_70b_chat_hf_text_embedding_3_small,11,0.907407,0.740741,0.907407,0.907407,0.665123,0.740741,0.740741,0.740741,0.825617,2.037037
PQ_SPLIT_Llama_3_70b_chat_hf_text_embedding_3_small,9,0.907407,0.740741,0.907407,0.907407,0.665123,0.740741,0.740741,0.740741,0.825617,2.037037
PQ_SPLIT_Llama_3_70b_chat_hf_text_embedding_3_small,5,0.907407,0.740741,0.907407,0.907407,0.665123,0.740741,0.740741,0.740741,0.825617,2.037037
PQ_SPLIT_Llama_3_70b_chat_hf_text_embedding_3_small,3,0.907407,0.740741,0.907407,0.907407,0.665123,0.740741,0.740741,0.740741,0.825617,2.037037
PQ_SPLIT_ALL_text_embedding_3_small,9,0.888889,0.796296,0.888889,0.888889,0.753086,0.796296,0.796296,0.796296,0.839506,1.648148
PQ_SPLIT_ALL_text_embedding_3_small,11,0.888889,0.796296,0.888889,0.888889,0.756173,0.796296,0.796296,0.796296,0.839506,1.62963
PQ_SPLIT_Llama_3_70b_chat_hf_text_embedding_3_small,15,0.907407,0.740741,0.888889,0.907407,0.665123,0.740741,0.740741,0.740741,0.822531,2.037037


In [23]:
retrievers_benchmarks.groupby(["collection_name", "k"])[metrics].mean().sort_values("recall", ascending=False).head(20)

Unnamed: 0_level_0,Unnamed: 1_level_0,recall,recall@1,recall@2,recall@3,precision,precision@1,precision@2,precision@3,average_precision,retrieval_count
collection_name,k,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
PQ_COMB_Llama_3_70b_chat_hf_intfloat_multilingual_e5_base,17,0.944444,0.703704,0.814815,0.944444,0.425926,0.703704,0.703704,0.703704,0.776749,4.0
PQ_COMB_Llama_3_70b_chat_hf_text_embedding_3_large,13,0.944444,0.574074,0.87037,0.888889,0.393519,0.574074,0.574074,0.574074,0.712963,4.0
PQ_COMB_Llama_3_70b_chat_hf_text_embedding_3_large,15,0.944444,0.574074,0.87037,0.888889,0.393519,0.574074,0.574074,0.574074,0.712963,4.0
PQ_COMB_Llama_3_70b_chat_hf_intfloat_multilingual_e5_base,19,0.944444,0.703704,0.814815,0.944444,0.425926,0.703704,0.703704,0.703704,0.776749,4.0
PQ_COMB_Llama_3_70b_chat_hf_intfloat_multilingual_e5_base,3,0.944444,0.703704,0.814815,0.944444,0.425926,0.703704,0.703704,0.703704,0.776749,4.0
PQ_COMB_Llama_3_70b_chat_hf_intfloat_multilingual_e5_base,5,0.944444,0.703704,0.814815,0.944444,0.425926,0.703704,0.703704,0.703704,0.776749,4.0
PQ_COMB_Llama_3_70b_chat_hf_intfloat_multilingual_e5_base,9,0.944444,0.703704,0.814815,0.944444,0.425926,0.703704,0.703704,0.703704,0.776749,4.0
PQ_COMB_Llama_3_70b_chat_hf_intfloat_multilingual_e5_base,11,0.944444,0.703704,0.814815,0.944444,0.425926,0.703704,0.703704,0.703704,0.776749,4.0
PQ_COMB_Llama_3_70b_chat_hf_intfloat_multilingual_e5_base,13,0.944444,0.703704,0.814815,0.944444,0.425926,0.703704,0.703704,0.703704,0.776749,4.0
PQ_COMB_Llama_3_70b_chat_hf_intfloat_multilingual_e5_base,15,0.944444,0.703704,0.814815,0.944444,0.425926,0.703704,0.703704,0.703704,0.776749,4.0


In [25]:
retrievers_benchmarks.groupby(["collection_name", "k"])[metrics].mean().sort_values("average_precision", ascending=False).head(20)

Unnamed: 0_level_0,Unnamed: 1_level_0,recall,recall@1,recall@2,recall@3,precision,precision@1,precision@2,precision@3,average_precision,retrieval_count
collection_name,k,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
PQ_SPLIT_ALL_text_embedding_3_small,3,0.888889,0.796296,0.888889,0.888889,0.756173,0.796296,0.796296,0.796296,0.839506,1.62963
PQ_SPLIT_ALL_text_embedding_3_small,5,0.888889,0.796296,0.888889,0.888889,0.756173,0.796296,0.796296,0.796296,0.839506,1.62963
PQS_ALL_text_embedding_3_small,3,0.87037,0.814815,0.87037,0.87037,0.756173,0.814815,0.814815,0.814815,0.839506,1.62963
PQS_ALL_text_embedding_3_small,5,0.87037,0.814815,0.87037,0.87037,0.756173,0.814815,0.814815,0.814815,0.839506,1.62963
PQS_ALL_text_embedding_3_small,9,0.87037,0.814815,0.87037,0.87037,0.756173,0.814815,0.814815,0.814815,0.839506,1.62963
PQS_ALL_text_embedding_3_small,11,0.87037,0.814815,0.87037,0.87037,0.756173,0.814815,0.814815,0.814815,0.839506,1.62963
PQS_ALL_text_embedding_3_small,13,0.87037,0.814815,0.87037,0.87037,0.756173,0.814815,0.814815,0.814815,0.839506,1.62963
PQS_ALL_text_embedding_3_small,15,0.87037,0.814815,0.87037,0.87037,0.756173,0.814815,0.814815,0.814815,0.839506,1.62963
PQS_ALL_text_embedding_3_small,17,0.87037,0.814815,0.87037,0.87037,0.756173,0.814815,0.814815,0.814815,0.839506,1.62963
PQS_ALL_text_embedding_3_small,19,0.87037,0.814815,0.87037,0.87037,0.756173,0.814815,0.814815,0.814815,0.839506,1.611111
