## 1. Installation des bibliothèques

In [3]:
!pip install -qU langchain sentence_transformers langchain_community langchain-huggingface faiss-cpu kagglehub tiktoken transformers sentencepiece langchain-google-genai datasets

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/480.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/179.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.3/179.3 kB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/134.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

## 2. Téléchargement et chargement du dataset

In [4]:
from datasets import load_dataset, tqdm

ds_qa = load_dataset("enelpol/rag-mini-bioasq", "question-answer-passages")["test"]
ds_corpus = load_dataset("enelpol/rag-mini-bioasq", "text-corpus")["test"]

# only keep ds_qa rows with 3 items or more because we use 3 by default in our case
ds_qa = ds_qa.filter(lambda x: len(x["relevant_passage_ids"]) >= 3)

ds_qa.to_csv("../test_datasets/rag-mini-bioasq/qa/qa.csv")
ds_corpus.to_csv("../test_datasets/rag-mini-bioasq/corpus/corpus.csv")

README.md:   0%|          | 0.00/1.76k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/1.12M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/187k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/4012 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/707 [00:00<?, ? examples/s]

test-00000-of-00001.parquet:   0%|          | 0.00/35.3M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/40181 [00:00<?, ? examples/s]

Filter:   0%|          | 0/707 [00:00<?, ? examples/s]

Creating CSV from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/41 [00:00<?, ?ba/s]

60169807

In [5]:
print("Datasets sizes:")
print(f"ds_qa: {len(ds_qa)}")
print(f"ds_corpus: {len(ds_corpus)}")

Datasets sizes:
ds_qa: 497
ds_corpus: 40181


In [6]:
print(ds_qa.to_pandas().head(1))
print(ds_corpus.to_pandas().head(1))

                                          question  \
0  Describe the mechanism of action of ibalizumab.   

                                              answer    id  \
0  Ibalizumab is a humanized monoclonal antibody ...  2835   

                                relevant_passage_ids  
0  [29675744, 24853313, 29689540, 21289125, 20698...  
                                             passage    id
0  New data on viruses isolated from patients wit...  9797


In [10]:
doc_list, metadata = [], []
for index, row in ds_corpus.to_pandas().iterrows():
    doc_list.append(f"passage : {row['passage']}")
    metadata.append({"id": row["id"]})

In [12]:
from langchain.docstore.document import Document

documents = []
for i, doc_text in enumerate(doc_list):
    doc = Document(
        page_content=doc_text,
        metadata=metadata[i]
    )
    documents.append(doc)

# 3. Embeddings

In [7]:
from sentence_transformers import SentenceTransformer
from tqdm import tqdm


class MiniLM:
    def __init__(self, model_download_path):
        self.model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=model_download_path)

    def embed_documents(self, docs, verbose=False):
        output = []

        if verbose:
            docs = tqdm(docs)

        for d in docs:
            output.append(self.model.encode(d).tolist())
        return output

    def embed_query(self, query):
        return self.model.encode(query) #self.model.encode(query).tolist()

    def get_id(self):
        return "minilm"

    def __call__(self, query):
        return self.embed_query(query)

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

# 4. Création d'embeddings et Vector Store

In [11]:
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np


# 1) Instanciation du modèle MiniLM (avec un chemin local de cache)
model_download_path = "./models"  # par exemple
embedding_model = MiniLM(model_download_path=model_download_path)

# 2) Récupération du corpus
corpus = doc_list

# 3) Génération des embeddings via votre classe
#    On peut passer `verbose=True` pour avoir la barre de progression.
embeddings = embedding_model.embed_documents(corpus, verbose=True)

# 4) Indexation dans Faiss
if len(embeddings) > 0:
    embedding_dim = len(embeddings[0])  # taille du premier vecteur
else:
    embedding_dim = 0  # si corpus vide

index = faiss.IndexFlatL2(embedding_dim)
index.add(np.array(embeddings, dtype=np.float32))

print("Nombre de vecteurs dans l'index:", index.ntotal)

100%|██████████| 40181/40181 [05:45<00:00, 116.44it/s]


Nombre de vecteurs dans l'index: 40181


In [17]:
from langchain.embeddings.base import Embeddings

class MiniLMEmbeddings(Embeddings):
    def __init__(self, mini_lm_model):
        self.mini_lm_model = mini_lm_model

    def embed_documents(self, texts: list[str]) -> list[list[float]]:
        return self.mini_lm_model.embed_documents(texts)

    def embed_query(self, text: str) -> list[float]:
        return self.mini_lm_model.embed_query(text)

mini_lm_embeddings = MiniLMEmbeddings(mini_lm_model=embedding_model)

In [28]:
from langchain.docstore.in_memory import InMemoryDocstore

# Map doc_id -> Document
docstore_dict = {str(i): doc for i, doc in enumerate(documents)}
docstore = InMemoryDocstore(docstore_dict)

# Index to docstore ID
index_to_docstore_id = {i: str(i) for i in range(len(documents))}

In [29]:
from langchain.vectorstores import FAISS

vectorstore = FAISS(
    embedding_function=mini_lm_embeddings,
    index=index,
    docstore=docstore,
    index_to_docstore_id=index_to_docstore_id
)

In [30]:
retriever = vectorstore.as_retriever(
    search_type="similarity",
    search_kwargs={"k": 3}
)

In [31]:
def retrieve(state: dict):
    question = state["question"]
    docs = retriever.invoke(question)
    return {"context": docs}

# 5 évualuation du RAG reranking

In [32]:
def test_retrieval(ds_qa, retrieve_function):
    total_num_documents_considered = 0
    num_valid_docs = 0
    for test_item in tqdm(ds_qa):
        question = test_item["question"]
        expected_documents_ids = test_item["relevant_passage_ids"]

        response = retrieve_function(state = {"question": question, "context": []})
        docs_retrieved = response["context"]

        # print("Question:",question)

        # if it's a dict of docs (e.g. with QueryTranslationRAGDecomposition)
        if isinstance(docs_retrieved, dict):
            num_docs_retrieved = 0
            for question, docs in docs_retrieved.items():
                num_docs_retrieved += len(docs)
                for doc in docs:
                    if doc.metadata["id"] in expected_documents_ids:
                        num_valid_docs += 1
            total_num_documents_considered += min(len(expected_documents_ids), num_docs_retrieved)
        else:
            num_documents_considered = min(len(expected_documents_ids), len(docs_retrieved))
            total_num_documents_considered += num_documents_considered
            # print("Expected:",expected_documents_ids,"Got:",[doc.metadata["id"] for doc in docs_retrieved])
            # print("Expected:",expected_documents_ids)
            for doc in docs_retrieved:
                # print("Got:",doc.metadata["id"])
                if int(doc.metadata["id"]) in expected_documents_ids:
                    num_valid_docs += 1


    return num_valid_docs / total_num_documents_considered

In [33]:
score = test_retrieval(ds_qa, retrieve)

100%|██████████| 497/497 [00:09<00:00, 53.98it/s]


In [34]:
print("Score RAG:", score)

Score RAG: 0.6183769282360831


## 5. test du reranking avec la génération (gemini-1.5-flash)

In [35]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

cross_encoder_model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
tokenizer = AutoTokenizer.from_pretrained(cross_encoder_model_name)
cross_encoder = AutoModelForSequenceClassification.from_pretrained(cross_encoder_model_name)

def rerank(query, documents):
    pairs = [(query, doc.page_content) for doc in documents]
    inputs = tokenizer(
        pairs,
        padding=True, truncation=True,
        return_tensors='pt'
    )
    with torch.no_grad():
        scores = cross_encoder(**inputs).logits.squeeze()
    doc_scores = list(zip(documents, scores.tolist()))
    doc_scores = sorted(doc_scores, key=lambda x: x[1], reverse=True)
    return doc_scores

tokenizer_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/794 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

### 5.1 Test du Reranking

In [36]:
query_test = "A story about a group of friends who go on a road trip."
candidate_docs = retriever.get_relevant_documents(query_test)

print("---- Documents avant Reranking ----")
for i, doc in enumerate(candidate_docs):
    print(f"{i+1}. {doc.metadata} -> {doc.page_content[:100]}...")

reranked = rerank(query_test, candidate_docs)
print("\n---- Documents après Reranking ----")
for i, (doc, score) in enumerate(reranked):
    print(f"{i+1}. Score={score:.4f} | {doc.metadata} -> {doc.page_content[:100]}...")

  candidate_docs = retriever.get_relevant_documents(query_test)


---- Documents avant Reranking ----
1. {'id': 2279154} -> passage : OBJECTIVE: To compare the long term survival of a group of athletes taking 
prolonged vigo...
2. {'id': 21618162} -> passage : It is widely held among the general population and even among health 
professionals that m...
3. {'id': 21199140} -> passage : BACKGROUND: Health-care professionals can help travelers by providing accurate 
pre-travel...

---- Documents après Reranking ----
1. Score=-10.3014 | {'id': 21199140} -> passage : BACKGROUND: Health-care professionals can help travelers by providing accurate 
pre-travel...
2. Score=-10.8338 | {'id': 2279154} -> passage : OBJECTIVE: To compare the long term survival of a group of athletes taking 
prolonged vigo...
3. Score=-11.0469 | {'id': 21618162} -> passage : It is widely held among the general population and even among health 
professionals that m...


In [None]:
import os
from langchain.chains import RetrievalQA
from langchain_google_genai import ChatGoogleGenerativeAI
from typing import List, Callable
from langchain.schema import Document as LCDocument
from langchain.schema import BaseRetriever

# ---- Récupération de la clé API Google (si besoin) ----
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY") or input("Veuillez saisir votre clé Google : ")

In [41]:
# ---- 1) Instanciation du modèle Gemini ----
model = ChatGoogleGenerativeAI(
    model="gemini-1.5-flash",
    google_api_key=GOOGLE_API_KEY,
    temperature=0.0
)

# ---- 2) Votre RerankingRetriever ----
class RerankingRetriever(BaseRetriever):
    """
    Un retriever qui encapsule un retriever de base + un rerank_fn
    """
    def __init__(
        self,
        base_retriever: BaseRetriever,
        rerank_fn: Callable,
        k: int = 5
    ):
        """
        On stocke les paramètres dans des attributs "privés" pour éviter
        le conflit avec Pydantic (champ inconnu).
        """
        super().__init__()
        self._base_retriever = base_retriever
        self._rerank_fn = rerank_fn
        self._k = k

    def _get_relevant_documents(self, query: str) -> List[LCDocument]:
        docs = self._base_retriever.get_relevant_documents(query)
        reranked = self._rerank_fn(query, docs)
        top_k_docs = [doc_score[0] for doc_score in reranked[: self._k]]
        return top_k_docs

    async def _aget_relevant_documents(self, query: str) -> List[LCDocument]:
        raise NotImplementedError

# ---- 3) On encapsule votre retriever précédent (celui basé sur doc_list + Faiss)
#     "retriever" est celui créé par
#     retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})

reranking_retriever = RerankingRetriever(
    base_retriever=retriever,  # votre retriever actuel
    rerank_fn=rerank,          # la fonction Cross-Encoder
    k=5
)

# ---- 4) Construction de la chaîne RAG avec le modèle Gemini ----
qa_chain_rerank = RetrievalQA.from_chain_type(
    llm=model,
    chain_type="stuff",  # "stuff" = on inclut les docs in extenso dans le prompt
    retriever=reranking_retriever
)

# ---- 5) Test de la chaîne ----
query_rerank = "Which passage references something about immunology?"
result_rerank = qa_chain_rerank.invoke(query_rerank)
print("Réponse du LLM (avec Reranking) :", result_rerank)

Réponse du LLM (avec Reranking) : {'query': 'Which passage references something about immunology?', 'result': 'The first passage references immunology, stating that Edward Jenner, who discovered vaccination against smallpox, started the science of immunology.'}
