# Introduction  

In my last post, I built a RAG using LangChain. After testing it, I quickly realized the results weren’t great. That pushed me to look for ways to improve its performance. One of the first things that came to mind was a startup called **ZeroEntropy**, which works on something known as **reranking** .
So I decided to dig into this concept: **Rerankers**.

# The problem 

Imagine searching for **'football shoes'** in Google . The search engine will return to you hundreds of options , some of them will be highly relevant and others not so much . You will find sneakers or even tennis sandals.  
And the solution to face this issue is to use a **Reranker** that will reorder results to put the most relevant items at the top of the page (for example CR7 -the goat- shoes) and at the bottom , the less useful ones (golf shoes for example).

We have the same issue with RAGs . In fact in a naive RAG pipeline, the user’s query is embedded and sent to a retriever. The retriever then returns the documents whose embeddings are closest to the query’s embedding according to a similarity measure. But the thing is that the retriever can and do provides irrelevant documents which leads to poor-quality answers .  
The solution ? **RERANKER** , after that we for example retrieve 10 relevant documents to the query we reorder them to only use the most relevant . Here is a quick a recap to make it clear:  
- **Step1 : Broad Retrieval**  
The retriever pulls a large of potential document , it quick but not accurate

- **Step2 : Reranking**  
The reranker examines each of the retrieved documents alongsie the user's query and assigns precise relevance score to esach quer-document pair and then the documents are reordered based on this score : it's prioritizes accuracy.


# What is a Reranker ?  

A reranker is neither more nor less an **Encoder** , to be more precise a Cross-encoder . In the broad retrieval step , we embedd the query and the document separtly with a classical encoder : the embedding model , and then we retrieve the k most relevant documents . After this we took every document and we calculate it's embbeding **alongside** the query and then we pass this embedding in a **scoring head** that will returns a **relevance score**. 

Here is a quick example to show you what happens:

**Step1 : create the pair query/document**
$$
texttoscore = [cls]\;query\;[sep]\;document_{i}\;[sep]
$$
**Step2: generate its embedding**
$$
embedding_{i} = embedding(texttoscore)
$$
**Step3 : compute the relevance score**

$$
crossEncoderScore =  relevanceScore(embedding_{i})
$$

**Step 4/5: Reorder**  

The reranker reorder the documents based on relevance scores .

**Okay now that we know how rerankers work let code this and see if the answers are better !!**

In [None]:
from langchain_community.document_loaders import DirectoryLoader
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from typing import List
from transformers import AutoTokenizer
from langchain_chroma import Chroma
import torch
from sentence_transformers import CrossEncoder
from langchain_huggingface import HuggingFaceEmbeddings
import hashlib
from dotenv import load_dotenv
from langchain_core.output_parsers import StrOutputParser
from langchain_classic.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda,RunnablePassthrough
from langchain_google_genai import ChatGoogleGenerativeAI

In [None]:
class Rag:
    def __init__(self, name:str):
        self.name = name
    
    def load_documents(self,path:str) -> List[Document]:
        self.path = path 
        loader = DirectoryLoader(
            path=path,
            glob="**/*.pdf",
            show_progress=True,
            recursive=True
        )
        documents = loader.load()
        return documents
    
    def split_documents(self):
        doc_list = self.load_documents(self.path)
        tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-small')
        textSplitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
            tokenizer=tokenizer,
            separators=["\n\n", "\n", " ", "","\n\n"],
            chunk_size = 512,
            chunk_overlap = int(512/10),
            add_start_index = True,               
            strip_whitespace = True ,
        )
        chunks = textSplitter.split_documents(doc_list)
        self.len_chunks = len(chunks)
        return chunks
     
    def createVectorStore(self ,collection_name:str, persist_dir:str, chunks: List[Document]):
        self.collection_name = collection_name
        self.persist_dir = persist_dir
        if torch.cuda.is_available():
            model_kwargs = {"device": "cuda"}
            print("Using Cuda to generate embeddings")
        else:
            model_kwargs=  {"device": "cpu"}
            print("Using CPU to generate embeddings")
            
        embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-small", model_kwargs=model_kwargs)
        
        vector_store = Chroma(
            collection_name=collection_name,
            embedding_function=embeddings,
            persist_directory=self.persist_dir
        )
        
        current_content = vector_store.get()
        existing_ids = set(current_content["ids"]) if current_content["ids"] else set()
        
        docs_to_add = []
        ids_to_add = []
        
        for doc in chunks:
            doc_id = hashlib.md5(doc.page_content.encode()).hexdigest()
            
            #Verify if the id is alread in the db
            if doc_id not in existing_ids:
                docs_to_add.append(doc)
                ids_to_add.append(doc_id)
        if len(docs_to_add) == 0:
            print("Documents already in the DB nothing will be added to the vector store")
            return vector_store
        try:
            vector_store.add_documents(documents=docs_to_add,ids=ids_to_add)
        except Exception as e:
            print(" Error while adding to ChromaDB : ",e)
        print(f"Successfully added {len(docs_to_add)} vectors in the vector database ")
        return vector_store
    
    
    def late_interaction(self, query:str, docs :List[Document]):
        pass 
    
    def answer_query(self,query:str, vector_store: Chroma, k_retrieval: int,k_reranker:int,reranker: bool = False ):
        
        if k_reranker > self.len_chunks:
            raise ValueError(f"k_reranker {k_reranker} can not be greater than the number of chuns {self.len_chunks}")
        
        if k_retrieval > self.len_chunks:
            raise ValueError(f"k {k_retrieval} can not be greater than the number of chuns {self.len_chunks}")
        
        retrieved_doc = vector_store.similarity_search(query,k_retrieval)
        
        load_dotenv()
        llm = ChatGoogleGenerativeAI(
            model="gemini-2.5-flash",
            temperature = 0.5,
            max_retries = 2
        )
        prompt = ChatPromptTemplate.from_template(
            """Use the following context to answer the question at the end. 
           You must be respectful and helpful, and answer in the language of the question.
           If you don't know the answer, say that you don't know.

           Context: {context}

           Question: {question}
           """
        )   
        runnable_query = RunnablePassthrough()                         
        prompt_runnable = RunnableLambda(lambda args: prompt.format_messages(context = args["context"] , question=args["question"]))
        if reranker:
            model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
            chunks_content = [chunk.page_content for chunk in retrieved_doc]
            ranks = model.rank(query,chunks_content,k_reranker)
            context = RunnableLambda(lambda _ :"\n\n".join(chunks_content[rank["corpus_id"]] for rank in (ranks)))
            pipeline = (
            {
                "context":context,
                "question" : runnable_query
            }
            | prompt_runnable
            |llm
            |StrOutputParser()
        )
        
        else:
            context = RunnableLambda(lambda doc :"\n\n".join(doc.page_content for doc in retrieved_doc ))

            pipeline = (
                {
                    "context":context,
                    "question" : runnable_query
                }
                | prompt_runnable
                |llm
                |StrOutputParser()
            )
            
        answer = pipeline.invoke(query)
        return answer


In [51]:
rag_1 = Rag("Saad_rag")
docs = rag_1.load_documents(path="/home/pepito/Documents/Python/ML/GenAI/RAG/pdf_documents")
chunks = rag_1.split_documents()
vector_store = rag_1.createVectorStore(
    collection_name="test_1",
    persist_dir="/home/pepito/Documents/Python/ML/GenAI/RAG/persist_dir_oop",
    chunks=chunks
)

answer = rag_1.answer_query(
    query="De quoi parle le document",
    vector_store=vector_store,
    k_retrieval=6,
    k_reranker=3,
    reranker=True
)


answer_no_reranking = rag_1.answer_query(
    query="De quoi parle le document",
    vector_store=vector_store,
    k_retrieval=6,
    k_reranker = 0,
    reranker=False
)




100%|██████████| 1/1 [00:00<00:00,  5.45it/s]




100%|██████████| 1/1 [00:00<00:00,  5.70it/s]






Using Cuda to generate embeddings
Documents already in the DB nothing will be added to the vector store


In [55]:
import pprint
pprint.pp(answer)
print("*"*120)
pprint.pp(answer_no_reranking)

('Le document parle des adresses IP (IPv4 et IPv6), de leur structure, du '
 'masque de sous-réseau, du protocole TCP/IP et du routage des données sur '
 'Internet. Il explique comment les données sont acheminées sous forme de '
 'paquets entre les machines via des routeurs.')
************************************************************************************************************************
('Le document parle principalement des protocoles **TCP (Transmission Control '
 "Protocol)** et **IP (Internet Protocol)**, réunis sous l'appellation "
 "**TCP/IP**, qui sont fondamentaux pour le fonctionnement d'Internet et "
 "l'échange de données entre ordinateurs.\n"
 '\n'
 'Il détaille le rôle de chaque protocole :\n'
 '*   Le **TCP** assure la fiabilité de la transmission des paquets de données '
 "(segments), leur numérotation, l'envoi d'accusés de réception et la "
 "reconstitution des fichiers à l'arrivée.\n"
 "*   L'**IP** gère l'adressage unique des machines sur le réseau (avec les 