# Build a Simple Retrieval-Augmented Generation (RAG) Pipeline

In this use case, we show how to build and evaluate a simple RAG pipeline with LightRAG. RAG (Retrieval-Augmented Generation) pipelines leverage a retriever to fetch relevant context from a knowledge base (e.g., a document database) which is then fed to an LLM generator with the query to produce the answer. This allows the model to generate more contextually relevant answers.

In [29]:
# Import needed modules, including modules for loading datasets, constructing a RAG pipeline, and evaluating the performance of the RAG pipeline.
import yaml
from typing import Any, List, Optional, Union

from datasets import load_dataset

from lightrag.core.types import Document
from lightrag.core.component import Component, Sequential
from lightrag.core.embedder import Embedder
from lightrag.core.document_splitter import DocumentSplitter
from lightrag.core.data_components import (
    RetrieverOutputToContextStr,
    ToEmbeddings,
)
from lightrag.components.retriever import FAISSRetriever
from lightrag.core.generator import Generator
from lightrag.core.db import LocalDocumentDB
from lightrag.core.string_parser import JsonParser

from lightrag.eval import (
    AnswerMatchAcc,
    RetrieverRecall,
    RetrieverRelevance,
    LLMasJudge,
)

In [30]:
# Here, we use the OpenAIClient in the Generator as an example, but you can use any other clients (with the corresponding API Key as needed)
from lightrag.components.model_client import OpenAIClient
# OPENAI_API_KEY="YOUR_API_KEY" # Replace with your OpenAI API Key, or you can put it in a .env file
OPENAI_API_KEY="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" # Example API Key
import dotenv
# load evironment
dotenv.load_dotenv(dotenv_path=".env", override=True)

False

**Define the configuration for the RAG pipeline**. We load the configuration from a YAML file. This configuration specifies the components of the RAG pipeline, including the text_splitter, vectorizer, retriever, and generator.

In [31]:
# Define the configuration settings for the RAG pipeline.
with open("./simple_rag.yaml", "r") as file:
    settings = yaml.safe_load(file)
print(settings)

{'vectorizer': {'batch_size': 100, 'model_kwargs': {'model': 'text-embedding-3-small', 'dimensions': 256, 'encoding_format': 'float'}}, 'retriever': {'top_k': 2}, 'generator': {'model': 'gpt-3.5-turbo', 'temperature': 0.3, 'stream': False}, 'text_splitter': {'split_by': 'sentence', 'chunk_size': 1, 'chunk_overlap': 0}}


**Load a dataset**. Here, We use the [HotpotQA](https://huggingface.co/datasets/hotpotqa/hotpot_qa) dataset as an example. Each data sample in HotpotQA has *question*, *answer*, *context* and *supporting_facts* selected from the whole context.

In [32]:
# Load the HotpotQA dataset. We select a subset of the dataset for demonstration purposes.
dataset = load_dataset(path="hotpot_qa", name="fullwiki")
selected_dataset = dataset["train"].select(range(5))
print(f"example: {selected_dataset[0]}")
print(f"ground truth context: {selected_dataset[0]['supporting_facts']}")

example: {'id': '5a7a06935542990198eaf050', 'question': "Which magazine was started first Arthur's Magazine or First for Women?", 'answer': "Arthur's Magazine", 'type': 'comparison', 'level': 'medium', 'supporting_facts': {'title': ["Arthur's Magazine", 'First for Women'], 'sent_id': [0, 0]}, 'context': {'title': ['Radio City (Indian radio station)', 'History of Albanian football', 'Echosmith', "Women's colleges in the Southern United States", 'First Arthur County Courthouse and Jail', "Arthur's Magazine", '2014–15 Ukrainian Hockey Championship', 'First for Women', 'Freeway Complex Fire', 'William Rast'], 'sentences': [["Radio City is India's first private FM radio station and was started on 3 July 2001.", ' It broadcasts on 91.1 (earlier 91.0 in most cities) megahertz from Mumbai (where it was started in 2004), Bengaluru (started first in 2001), Lucknow and New Delhi (since 2003).', ' It plays Hindi, English and regional songs.', ' It was launched in Hyderabad in March 2006, in Chenna

**Define a simple RAG pipeline**. Define a RAG pipeline by specifying the key components, such as *vectorizer*, *retriever*, and *generator*. For more information on these components, refer to the developer notes.

In [33]:
# The defined RAG pipeline.
class RAG(Component):

    def __init__(self, settings: dict):
        super().__init__()
        self.vectorizer_settings = settings["vectorizer"]
        self.retriever_settings = settings["retriever"]
        self.generator_model_kwargs = settings["generator"]
        self.text_splitter_settings = settings["text_splitter"]

        vectorizer = Embedder(
            model_client=OpenAIClient(),
            model_kwargs=self.vectorizer_settings["model_kwargs"],
        )

        text_splitter = DocumentSplitter(
            split_by=self.text_splitter_settings["split_by"],
            split_length=self.text_splitter_settings["chunk_size"],
            split_overlap=self.text_splitter_settings["chunk_overlap"],
        )
        self.data_transformer = Sequential(
            text_splitter,
            ToEmbeddings(
                vectorizer=vectorizer,
                batch_size=self.vectorizer_settings["batch_size"],
            ),
        )
        self.data_transformer_key = self.data_transformer._get_name()
        # initialize retriever, which depends on the vectorizer too
        self.retriever = FAISSRetriever(
            top_k=self.retriever_settings["top_k"],
            dimensions=self.vectorizer_settings["model_kwargs"]["dimensions"],
            vectorizer=vectorizer,
        )
        self.retriever_output_processors = RetrieverOutputToContextStr(deduplicate=True)

        self.db = LocalDocumentDB()

        # initialize generator
        self.generator = Generator(
            preset_prompt_kwargs={
                "task_desc_str": r"""
                    You are a helpful assistant.

                    Your task is to answer the query that may or may not come with context information.
                    When context is provided, you should stick to the context and less on your prior knowledge to answer the query.

                    Output JSON format:
                    {
                        "answer": "The answer to the query",
                    }"""
            },
            model_client=OpenAIClient(),
            model_kwargs=self.generator_model_kwargs,
            output_processors=JsonParser(),
        )
        self.tracking = {"vectorizer": {"num_calls": 0, "num_tokens": 0}}

    def build_index(self, documents: List[Document]):
        self.db.load_documents(documents)
        self.map_key = self.db.map_data()
        print(f"map_key: {self.map_key}")
        self.data_key = self.db.transform_data(self.data_transformer)
        print(f"data_key: {self.data_key}")
        self.transformed_documents = self.db.get_transformed_data(self.data_key)
        self.retriever.build_index_from_documents(self.transformed_documents)

    def generate(self, query: str, context: Optional[str] = None) -> Any:
        if not self.generator:
            raise ValueError("Generator is not set")

        prompt_kwargs = {
            "context_str": context,
            "input_str": query,
        }
        response = self.generator(prompt_kwargs=prompt_kwargs)
        if response.error:
            raise ValueError(f"Error in generator: {response.error}")
        return response.data

    def call(self, query: str) -> Any:
        retrieved_documents = self.retriever(query)
        # fill in the document
        for i, retriever_output in enumerate(retrieved_documents):
            retrieved_documents[i].documents = [
                self.transformed_documents[doc_index]
                for doc_index in retriever_output.doc_indexes
            ]
        # convert all the documents to context string
        context_str = self.retriever_output_processors(retrieved_documents)

        return self.generate(query, context=context_str), context_str

To run the RAG piepline for each example in the dataset, we need to first **build the index** and then **call the pipeline**. For each sample in the dataset, we create a list of documents to retrieve from, according to its corresponding *context* in the dataset. Each document has a title and a list of sentences. We use the `Document` class from `lightrag.core.types` to represent each document.

In [None]:
# To get the ground truth context string from the supporting_facts filed in HotpotQA. This function is specific to the HotpotQA dataset.
def get_supporting_sentences(
    supporting_facts: dict[str, list[Union[str, int]]], context: dict[str, list[str]]
) -> List[str]:
    """
    Extract the supporting sentences from the context based on the supporting facts.
    """
    extracted_sentences = []
    for title, sent_id in zip(supporting_facts["title"], supporting_facts["sent_id"]):
        if title in context["title"]:
            index = context["title"].index(title)
            sentence = context["sentences"][index][sent_id]
            extracted_sentences.append(sentence)
    return extracted_sentences


questions = []
retrieved_contexts = []
gt_contexts = []
pred_answers = []
gt_answers = []
for data in selected_dataset:
    # build the document list
    num_docs = len(data["context"]["title"])
    doc_list = [
        Document(
            meta_data={"title": data["context"]["title"][i]},
            text=" ".join(data["context"]["sentences"][i]),
        )
        for i in range(num_docs)
    ]
    rag = RAG(settings)
    # build the index
    rag.build_index(doc_list)
    # call the pipeline
    query = data["question"]
    response, context_str = rag.call(query)
    gt_context_sentence_list = get_supporting_sentences(
        data["supporting_facts"], data["context"]
    )
    questions.append(query)
    retrieved_contexts.append(context_str)
    gt_contexts.append(gt_context_sentence_list)
    pred_answers.append(response["answer"])
    gt_answers.append(data["answer"])
    print(f"query: {query}")
    print(f"response: {response['answer']}")
    print(f"ground truth response: {data['answer']}")
    print(f"context_str: {context_str}")
    print(f"ground truth context_str: {gt_context_sentence_list}")
    break


**Evaluate the performance of the RAG pipeline**. We first evaluate the performance of the retriever component by calculating the *recall* of the retrieved context and the *relevance* score of the retrieved context.

In [34]:
# Compute the recall.
retriever_recall = RetrieverRecall()
avg_recall, recall_list = retriever_recall.compute(retrieved_contexts, gt_contexts)
print(f"average recall: {avg_recall}")
print(f"recall list: {recall_list}")

NameError: name 'retrieved_contexts' is not defined

In [None]:
# Compute the relevance.
retriever_relevance = RetrieverRelevance()
avg_relevance, relevance_list = retriever_relevance.compute(
    retrieved_contexts, gt_contexts
)
print(f"average relevance: {avg_relevance}")
print(f"relevance list: {relevance_list}")

Next, we evaluate the generated answers using the AnswerMatchAcc metric, which compares the predicted answer with the ground truth answer.

In [None]:
# Compute the answer match accuracy.
answer_match_acc = AnswerMatchAcc(type="exact_match")
avg_acc, acc_list = answer_match_acc.compute(pred_answers, gt_answers)
print(f"average accuracy: {avg_acc}")
print(f"accuracy list: {acc_list}")
answer_match_acc = AnswerMatchAcc(type="fuzzy_match")
avg_acc, acc_list = answer_match_acc.compute(pred_answers, gt_answers)
print(f"average accuracy: {avg_acc}")
print(f"accuracy list: {acc_list}")

We finally use an LLM as the judge for evaluating the performance. The task description in the `DEFAULT_LLM_EVALUATOR_PROMPT` is "You are a helpful assistant. Given the question, ground truth answer, and predicted answer, you need to answer the judgement query. Output True or False according to the judgement query." You can customize the task description as needed. See the `lightrag.eval.LLMasJudge` class for more details.

In [None]:
llm_judge = LLMasJudge()
judgement_query = (
        "For the question, does the predicted answer contain the ground truth answer?"
    )
avg_judgement, judgement_list = llm_judge.compute(
    questions, gt_answers, pred_answers, judgement_query
)
print(f"average judgement: {avg_judgement}")
print(f"judgement list: {judgement_list}")