# Corrective RAG (CRAG)
---

### What is Corrective RAG?

Corrective RAG (CRAG) is a methodology that adds a step to the RAG (Retrieval Augmented Generation) strategy to evaluate the documents found during the search process and refine the knowledge. This includes a series of processes to check the search results before generation and, if necessary, perform auxiliary searches to generate high-quality answers.

- Retrieval Grader: Evaluates the relevance of retrieved documents and assigns a score to each document.
- Web Search Integration: If quality of retrieved documents is low, CRAG uses web searches to augment retrieval results. It optimizes search results through query rewriting.

**Reference**

- [Corrective RAG paper](https://arxiv.org/pdf/2401.15884)  

In [None]:
import os
from dotenv import load_dotenv
from azure_genai_utils.tracer import get_langchain_api_key, set_langsmith

load_dotenv(override=True)

# If you want to trace your RAG API calls, please set the tracing=True. You need to have a valid Langchain API key.
langchain_key, has_langchain_key = get_langchain_api_key()
set_langsmith("[RAG Innv Lab] 1_Agentic-Design-Pattern", tracing=False)

azure_openai_chat_deployment_name = os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME")
azure_openai_embedding_deployment_name = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME")

<br>

## 🧪 Step 1. Test and Construct each module
---

Before building the entire the graph pipeline, we will test and construct each module separately.

- **Retrieval Grader**
- **Answer Generator**
- **Question Re-writer**
- **Web Search Tool**

### Construct Retrieval Chain based on PDF

In [None]:
from azure_genai_utils.rag.pdf import PDFRetrievalChain

pdf_path = "../../../sample-docs/AutoGen-paper.pdf"

pdf = PDFRetrievalChain(
    source_uri=[pdf_path],
    loader_type="PDFPlumber",
    model_name=azure_openai_chat_deployment_name,
    embedding_name=azure_openai_embedding_deployment_name,
    chunk_size=500,
    chunk_overlap=50,
).create_chain()

pdf_retriever = pdf.retriever
pdf_chain = pdf.chain

question = "What is AutoGen's main features?"
docs = pdf_retriever.invoke(question)

# Non-streaming
# results = pdf_chain.invoke({"chat_history": "", "question": question, "context": docs})

# Streaming
for text in pdf_chain.stream(
    {"chat_history": "", "question": question, "context": docs}
):
    print(text, end="", flush=True)

### Define your LLM

This hands-on only uses the `gpt-4o-mini`, but you can utilize multiple models in the pipeline.

In [None]:
from langchain_openai import AzureChatOpenAI

llm = AzureChatOpenAI(model=azure_openai_chat_deployment_name, temperature=0)

### Question-Retrieval Grader

Construct a retrieval grader that evaluates the relevance of the retrieved documents to the input question. The retrieval grader should take the input question and the retrieved documents as input and output a relevance score for each document.<br>
Note that the retrieval grader should be able to handle **multiple documents** as input.

In [None]:
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate


class GradeDocuments(BaseModel):
    """A binary score to determine the relevance of the retrieved documents."""

    binary_score: str = Field(
        description="Documents are relevant to the question, 'yes' or 'no'"
    )


structured_llm_grader = llm.with_structured_output(GradeDocuments)

system = """You are a grader assessing relevance of a retrieved document to a user question. \n 
    If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
    It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
    Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""

grade_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
    ]
)

retrieval_grader = grade_prompt | structured_llm_grader

Test the retrieval grader. For testing, we only show the result of the a single document, not the entire document set. 

In [None]:
question = "What is AutoGen's main features?"
docs = pdf_retriever.invoke(question)

# Extract the page content of the second document retrieved
doc_txt = docs[1].page_content
print(retrieval_grader.invoke({"question": question, "document": doc_txt}))

### Answer Generator

Construct a LLM Generation node. This is a Naive RAG chain that generates an answer based on the retrieved documents. 

We recommend you to use more advanced RAG chain for production

In [None]:
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import load_prompt

if has_langchain_key:
    print(f"Load prompt from LangChain Hub.")
    prompt = hub.pull("daekeun-ml/rag-baseline")
else:
    print("LANGCHAIN_API_KEY is not set. Load prompt from YAML file.")
    prompt = load_prompt("prompts/rag-baseline.yaml")


def format_docs(docs):
    return "\n\n".join(
        [
            f'<document><content>{doc.page_content}</content><source>{doc.metadata["source"]}</source><page>{doc.metadata["page"]+1}</page></document>'
            for doc in docs
        ]
    )


rag_chain = prompt | llm | StrOutputParser()
generation = rag_chain.invoke({"context": format_docs(docs), "question": question})
print(generation)

### Question Re-writer

Construct a `question_rewriter` node to rewrite the question for web search optimization.

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

system = """You a question re-writer that converts an input question to a better version that is optimized 
for web search. Look at the input and try to reason about the underlying semantic intent / meaning."""

re_write_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        (
            "human",
            "Here is the initial question: \n\n {question} \n Formulate an improved question.",
        ),
    ]
)

question_rewriter = re_write_prompt | llm | StrOutputParser()

In [None]:
print(f"[Original question] {question}")
question_rewriter.invoke({"question": question})

### Web Search Tool

Web search tool is used to enhance the context. <br>

It is used when all the documents do not meet the relevance threshold or the evaluator is not confident.

In [None]:
from azure_genai_utils.tools import BingSearch

WEB_SEARCH_FORMAT_OUTPUT = True

web_search_tool = BingSearch(
    max_results=3,
    locale="en-US",
    include_news=True,
    include_entity=False,
    format_output=WEB_SEARCH_FORMAT_OUTPUT,
)

In [None]:
results = web_search_tool.invoke({"query": question})
print(results)

<br>

## 🧪 Step 2. Define the Graph
---

### State Definition

- `question`: Question from the user
- `generation`: Generated answer
- `web_search`: Whether to use web search or not
- `documents`: Retrieved documents

In [None]:
from typing import Annotated, List
from typing_extensions import TypedDict


class GraphState(TypedDict):
    question: Annotated[str, "Question"]
    generation: Annotated[str, "LLM Generation"]
    web_search: Annotated[str, "Whether to add search"]
    documents: Annotated[List[str], "Retrieved Documents"]

### Define Nodes

We will define the following nodes in the graph:

- `retrieve`: Retrieve documents based on the user question.
- `grade_documents`: Generate an answer based on the retrieved documents and user question.
- `generate`: Grade documents based on their relevance to the user question.
- `rewrite_query`: Rewrite the user question to improve retrieval performance.
- `web_search`: Search the web for additional information.

In [None]:
from langchain.schema import Document


def retrieve(state: GraphState):
    """
    Retrieve documents based on the user question.
    """
    print("==== [RETRIEVE] ====")
    question = state["question"]

    documents = pdf_retriever.invoke(question)
    return {"documents": documents}


def generate(state: GraphState):
    """Generate an answer based on the retrieved documents and user question."""
    print("==== [GENERATE] ====")
    question = state["question"]
    documents = state["documents"]

    generation = rag_chain.invoke({"context": documents, "question": question})
    return {"generation": generation}


def grade_documents(state: GraphState):
    """Grade documents based on their relevance to the user question."""
    print("\n==== [CHECK DOCUMENT RELEVANCE TO QUESTION] ====\n")
    question = state["question"]
    documents = state["documents"]

    filtered_docs = []
    relevant_doc_count = 0

    for d in documents:
        score = retrieval_grader.invoke(
            {"question": question, "document": d.page_content}
        )
        grade = score.binary_score

        if grade == "yes":
            print("==== [GRADE: DOCUMENT RELEVANT] ====")
            # Add related documents to filtered_docs
            filtered_docs.append(d)
            relevant_doc_count += 1
        else:
            print("==== [GRADE: DOCUMENT NOT RELEVANT] ====")
            continue

    # Web search if no relevant documents
    web_search = "Yes" if relevant_doc_count == 0 else "No"
    return {"documents": filtered_docs, "web_search": web_search}


def rewrite_query(state: GraphState):
    """Rewrite the user question to improve web search results"""
    print("\n==== [REWRITE QUERY] ====\n")
    question = state["question"]

    better_question = question_rewriter.invoke({"question": question})
    return {"question": better_question}


def web_search(state: GraphState):
    """Search the web for additional information."""
    print("\n==== [WEB SEARCH] ====\n")
    question = state["question"]
    documents = state["documents"]

    docs = web_search_tool.invoke({"query": question})
    # Convert search results to document format
    if WEB_SEARCH_FORMAT_OUTPUT:
        web_results = "\n".join(docs)
    else:
        web_results = "\n".join([d["content"] for d in docs])

    web_results = Document(page_content=web_results)
    documents.append(web_results)

    return {"documents": documents}

### Define Conditional Nodes

- `decide_to_generate`: Decide whether to generate an answer based on the retrieved documents.

In [None]:
def decide_to_generate(state: GraphState):
    """
    Access the graded documents and decide whether to generate an answer or rewrite the query.
    """
    print("==== [ASSESS GRADED DOCUMENTS] ====")
    web_search = state["web_search"]

    if web_search == "Yes":
        # If all documents are not relevant to the question, rewrite the query and search the web
        print(
            "==== [DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, REWRITE QUERY] ===="
        )
        # Route to the query rewrite node
        return "rewrite_query"
    else:
        # Generate an answer based on the retrieved documents
        print("==== [DECISION: GENERATE] ====")
        return "generate"

### Construct the Graph

In [None]:
from langgraph.graph import END, StateGraph, START

workflow = StateGraph(GraphState)

# Node definition
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("rewrite_query", rewrite_query)
workflow.add_node("web_search_node", web_search)

# Edge connections
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "rewrite_query": "rewrite_query",
        "generate": "generate",
    },
)

workflow.add_edge("rewrite_query", "web_search_node")
workflow.add_edge("web_search_node", "generate")
workflow.add_edge("generate", END)

# Compile the workflow
app = workflow.compile()

### Visualize the graph

In [None]:
from azure_genai_utils.graphs import visualize_langgraph

visualize_langgraph(app, xray=True)

<br>

## 🧪 Step 3. Execute the Graph
---

### Execute the graph

In [None]:
from langchain_core.runnables import RunnableConfig
from azure_genai_utils.messages import stream_graph, invoke_graph, random_uuid

config = RunnableConfig(recursion_limit=10, configurable={"thread_id": random_uuid()})

inputs = {
    "question": "What is AutoGen's main features?",
}

stream_graph(
    app,
    inputs,
    config,
    ["retrieve", "grade_documents", "rewrite_query", "web_search_node", "generate"],
)

In [None]:
config = RunnableConfig(recursion_limit=10, configurable={"thread_id": random_uuid()})


inputs = {
    "question": "Please tell me the Microsoft brief history.",
}
stream_graph(
    app,
    inputs,
    config,
    ["retrieve", "grade_documents", "rewrite_query", "web_search_node", "generate"],
)