# Corrective Retrieval Augmented Generation (CRAG)

Corrective-RAG (CRAG) is a recent paper that introduces an interesting approach for self-reflective RAG. You can read the paper [here](https://arxiv.org/pdf/2401.15884.pdf)

retrieval augmented generation (RAG) has introduced a retrieval technique to incorporate relevant knowledge to the model's input therefore improving output generation. Within this framework, models receive augmented input by adding relevant documents retrieved from an external knowledge collections. While RAG serves as a practicable complement to LLMs, its effectiveness is contingent upon the relevance and accuracy of the retrieved documents. The heavy reliance of generation on the retrieved knowledge raises significant concerns about the model’s behavior and performance in scenarios where retrieval may fail or return inaccurate results.

While RAG acts as a viable supplement to LLMs, its efficiency relies heavily on the relevance and accuracy of the retrieved documents. The substantial dependence of generation on the retrieved knowledge raises notable concerns regarding the model's behavior and performance in scenarios where retrieval is not successful or the retrieved documents are inaccurate.

A low-quality retriever can bring in a lot of irrelevant information. This can make it hard for models to acquire accurate knowledge and might even mislead them, causing problems like hallucinations.

Figure 1 shows how CRAG works at inference, in order to make generation more resilient. Given an input query and the retrieved documents from a retriever, CRAG uses a lightweight evaluator to estimate the relevance score of retrieved documents to the input query. This evaluation results in three confidence degrees and then triggered the corresponding actions: {Correct, Incorrect, Ambiguous}. If it's Correct, the retrieved documents are improved to be more accurate through knowledge refienment processes. This refinement operation involves knowledge decomposition, filter, and recomposition. If it's Incorrect, the retrieved documents are ignored, and web searches are used instead as complementary knowledge sources for corrections. If it's not clear whether the documents are correct or not, an action called Ambiguous is taken, combining both (Section 4.3). Once the retrieval is refined, any generative model can be used.

<center><figure><img src="imgs/CRAG.jpg" alt="drawing" width="700"/><figcaption>Fig. 1: An overview of CRAG at inference.</figcaption></figure></center>    

## Knowledge Refinement 
A retrieval is considered Correct if the confidence score of at least one retrieved document exceeds the upper threshold. This means the presence of relevant documents in the retrieval results. However, even when a relevant document is found, it may contain some irrelevant information. To extract the most important information within this document, a method called knowledge refinement is applied. This method involves decomposing and then recomposing the content of each retrieved relevant document to extract the most crucial information. Initially, each document is divided into smaller knowledge segments through heuristic rules. Then, a fine-tuned retrieval evaluator assesses the relevance score of each segment. Based on these scores, irrelevant segments are filtered out, and relevant ones are recomposed via concatenation in order.


# CRAG implementation in LangChain

We can use LangGraph of LangChain to implement CRAG. So, first let's see what [LangGraph](https://python.langchain.com/docs/langgraph) is .

## LangGraph
LLMs can be used for reasoning tasks. This can essentially be thought of as running an LLM in a for-loop. These types of systems are often called agents. Here comes LangGraph! You may want to always force an agent to call a particular tool first. You may want to have more control over agents and how tools are called. These more controlled flows are referred to as "state machines" in LangGraph terminology. LangGraph is a way to create these state machines by specifying them as graphs.

The primary function of LangGraph is to add cycles into LLM applications. Cycles play a vital role in scenarios with agents. For example, you might repeatedly invoke an LLM within a loop to determine the next course of action.
LangGraph is a tool designed for creating complex, stateful applications that involve multiple actors using LLMs. It is built upon LangChain and expands its capabilities by enabling the coordination of multiple chains or actors through various steps of computation in a cyclic fashion. 
It's important to note that LangGraph is not a **Directed Acyclic Graph (DAG)**.

### How to build a graph using LangGraph

#### Creating SateGraph
Graphs in langgraph are the `StatefulGraph`. This graph is parameterized by a state object that it passes around to each node. This state definition represents a central state object that is updated over time by each node. These operations can either `set` specific attributes on the state (e.g. overwrite the existing values) or `add` to the existing attribute in the form of a key-value store. 
#### Adding Nodes
We will add nodes to the graph with `(name, value)` pair, where name is node's name that will be used to refer to the node when adding edges. The value is a function or LCEL runnable that will be called.
#### Adding Edges
After adding nodes, we can then add edges to the graph. We have two types of edges.
* ##### Normal Edges
These are edges where one node should ALWAYS be called after another.
* ##### Conditional Edges
These are edges where based on the output of a node, one of several paths is taken. After the agent is called: if the agent said to take an action, then the function to invoke tools should be called, or if the agent said that it was finished, then it should finish.

### Graph nodes
Our graph consists of following nodes:
* ### retrive:
This node is responsible for retrieving relevent documents to the input question
* ### grade document:
This node grades each retrived document by `retrieve` node. If not all documents are graded as `relevent`, the node will turn `web_search` flag to `True` to run a web search on the topic by the next node (i.e., `web_search` node). It also updates documents attribute to only include relevent documents.
* ### improve question:
If search flag is on, this node will improve the quality of the question inorder to perform a better search by web_search node.
* ### generate: 
This node will generate the answer. Whether this is an internal answer retrieved from the provided documents or is an external answer collected from web search. This node perform a classic RAG algorithm on the question with the provided context. 
* ### web search
After improving the question by `improve_question` node, this node will perform a web search to find relevent documents to the input question. 
* ### should_generate
This is a decision node. Based on the search flag, it decides to generate the final answer (i.e., goes to `generate` node if flag is on) or to improve the question for the web search (i.e., goes to `improve_question` node if flag in off).

#### Each node receives graph sate as an argument and adds or updates its attributes upon returning it. 

In [2]:
! pip install langchain_community faiss-cpu tiktoken langchain-openai langchainhub langchain langgraph tavily-python pypdf

Collecting langchain_community
  Using cached langchain_community-0.0.24-py3-none-any.whl.metadata (8.1 kB)
Collecting faiss-cpu
  Using cached faiss_cpu-1.7.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.3 kB)
Collecting tiktoken
  Using cached tiktoken-0.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)
Collecting langchain-openai
  Using cached langchain_openai-0.0.8-py3-none-any.whl.metadata (2.5 kB)
Collecting langchainhub
  Using cached langchainhub-0.1.14-py3-none-any.whl.metadata (478 bytes)
Collecting langchain
  Using cached langchain-0.1.9-py3-none-any.whl.metadata (13 kB)
Collecting langgraph
  Using cached langgraph-0.0.26-py3-none-any.whl.metadata (34 kB)
Collecting tavily-python
  Using cached tavily_python-0.3.1-py3-none-any.whl.metadata (4.4 kB)
Collecting pypdf
  Using cached pypdf-4.0.2-py3-none-any.whl.metadata (7.4 kB)
Collecting aiohttp<4.0.0,>=3.8.3 (from langchain_community)
  Using cached aiohttp-3.9.3-cp

In [3]:
import os
from getpass import getpass
os.environ['OPENAI_API_KEY'] = getpass('Enter your OpenAI API Key: ')
os.environ['TAVILY_API_KEY'] = getpass('Enter your Tavily API Key: ')

Enter your OpenAI API Key:  ········
Enter your Tavily API Key:  ········


In [4]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings

doc = PyPDFLoader("docs/the-usa.pdf").load()

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=250, chunk_overlap=10
)
doc_splits = text_splitter.split_documents(doc)

# Add to vectorDB
vectorstore = FAISS.from_documents(
    documents=doc_splits,
    embedding=OpenAIEmbeddings(),
)
retriever = vectorstore.as_retriever()

In [15]:
from typing import Dict, TypedDict
from langgraph.graph import END, StateGraph

class State(TypedDict):
    attribs: Dict[str, any]
    
graph = StateGraph(State)

In [16]:
from langchain import hub
from langchain.output_parsers.openai_tools import PydanticToolsParser
from langchain.prompts import PromptTemplate
from langchain.schema import Document
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import BaseMessage, FunctionMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_openai import ChatOpenAI


def retrieve(state):

    print("Retrieving documents---")
    state_dict = state["attribs"]
    question = state_dict["question"]
    documents = retriever.get_relevant_documents(question)
    return {"attribs": {"documents": documents, "question": question}}


def grade_documents(state):
    
    print("Grading documents------")
    state_dict = state["attribs"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    
    # Data model
    class grade(BaseModel):
        """Binary score for relevance check."""

        score: str = Field(description="Relevance score 'yes' or 'no'")

    # LLM
    model = ChatOpenAI(temperature=0, model="gpt-3.5-turbo", streaming=True)

    # Tool
    grade_tool_oai = convert_to_openai_tool(grade)
    print("grade_tool_oai",grade_tool_oai)

    # LLM with tool and enforce invocation
    llm_with_tool = model.bind(
        tools=[grade_tool_oai],
        tool_choice={"type": "function", "function": {"name": "grade"}},
    )

    # Parser
    parser_tool = PydanticToolsParser(tools=[grade])
    
    prompt = PromptTemplate(
        template="""Given a question, does the following document have exact information to answer the question?
                    Question: {question}
                    Document: {context}
                    Think Step by step, and answer with yes or no only""",
        input_variables=["context", "question"],
    )

    # Chain
    chain = prompt | llm_with_tool | parser_tool

    # If any document has 'No' grade, search the web; Otherwise, 
    # generate the answer from the internal source (e.g., indexed documents)
    filtered_docs = []
    search = "No"  # Default do not opt for web search to supplement retrieval
    for d in documents:
        score = chain.invoke({"question": question, "context": d.page_content})
        print("score",score)
        grade = score[0].score
        if grade == "yes":
            print("---GRADE: DOCUMENT RELEVANT---")
            filtered_docs.append(d)
        else:
            print("---GRADE: DOCUMENT NOT RELEVANT---")
            search = "Yes"  # Perform web search
            continue

    return {
        "attribs": {
            "documents": filtered_docs,
            "question": question,
            "run_web_search": search,
        }
    }


def improve_question(state):
    
    print("Improving the question-----")
    state_dict = state["attribs"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    
    prompt = PromptTemplate(
        template="""Extract at most three keywords separated by comma from the following dialogues and questions as queries for the
                    web search, including topic background within dialogues and main intent within questions.
                    question: What is Henry Feilden’s occupation?
                    query: Henry Feilden, occupation
                    question: In what city was Billy Carlson born?
                    query: city, Billy Carlson, born
                    question: What is the religion of John Gwynn?
                    query: religion of John Gwynn
                    question: What sport does Kiribati men’s national basketball team play?
                    query: sport, Kiribati men’s national basketball team play
                    question: {question}
                    query:""",
        input_variables=["question"],
    )

    model = ChatOpenAI(temperature=0, model="gpt-3.5-turbo", streaming=True)

    # Prompt
    chain = prompt | model | StrOutputParser()
    better_question = chain.invoke({"question": question})
    
    print("--- better question-----")
    print(better_question)

    return {"attribs": {"documents": documents, "question": better_question}}


def generate(state):
    
    print("Generating output---")
    state_dict = state["attribs"]
    question = state_dict["question"]
    documents = state_dict["documents"]

    # Prompt
    prompt = hub.pull("rlm/rag-prompt")

    # LLM
    llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, streaming=True)

    # Post-processing
    def format_docs(docs):
        return "\n\n".join(doc.page_content for doc in docs)

    # Chain
    rag_chain = prompt | llm | StrOutputParser()

    # Run
    generation = rag_chain.invoke({"context": documents, "question": question})
    
    return {
        "attribs": {"documents": documents, "question": question, "generation": generation}
    }

def should_generate(state):
    
    print("Deciding on next node---")
    state_dict = state["attribs"]
    question = state_dict["question"]
    filtered_documents = state_dict["documents"]
    search = state_dict["run_web_search"]

    if search == "Yes":
        # All documents have been filtered check_relevance
        # We will re-generate a new query
        print("DECISION: Improve question and run web search---")
        return "improve_question"
    else:
        # We have relevant documents, so generate answer
        print("DECISION: Generate final output")
        return "generate"


def web_search(state):
    
    print("Running web search---")
    state_dict = state["attribs"]
    question = state_dict["question"]
    documents = state_dict["documents"]

    tool = TavilySearchResults()
    docs = tool.invoke({"query": question})
    web_results = "\n".join([d["content"] for d in docs])
    web_results = Document(page_content=web_results)
    documents.append(web_results)

    return {"attribs": {"documents": documents, "question": question}}


## Adding nodes, edges and building graph

In [17]:
graph.add_node("retrieve", retrieve)  # retrieve
graph.add_node("grade_documents", grade_documents)  # grade documents
graph.add_node("generate", generate)  # generatae
graph.add_node("improve_question", improve_question)  # improve_question
graph.add_node("web_search", web_search)  # web search

graph.set_entry_point("retrieve")
graph.add_edge("retrieve", "grade_documents")
graph.add_conditional_edges(
    "grade_documents",
    should_generate,
    {
        "improve_question": "improve_question",
        "generate": "generate",
    },
)
graph.add_edge("improve_question", "web_search")
graph.add_edge("web_search", "generate")
graph.add_edge("generate", END)

# Compile
app = graph.compile()

In [18]:
import pprint
inputs = {"attribs": {"question": "Tell me more about Seattle"}}
for output in app.stream(inputs):
    for key, value in output.items():
        pprint.pprint(f"Current node : '{key}'")

# Final generation
pprint.pprint(value["attribs"]["generation"])

Retrieving documents---
"Current node : 'retrieve'"
Grading documents------
grade_tool_oai {'type': 'function', 'function': {'name': 'grade', 'description': 'Binary score for relevance check.', 'parameters': {'type': 'object', 'properties': {'score': {'description': "Relevance score 'yes' or 'no'", 'type': 'string'}}, 'required': ['score']}}}
score [grade(score='no')]
---GRADE: DOCUMENT NOT RELEVANT---
score [grade(score='no')]
---GRADE: DOCUMENT NOT RELEVANT---
score [grade(score='no')]
---GRADE: DOCUMENT NOT RELEVANT---
score [grade(score='no')]
---GRADE: DOCUMENT NOT RELEVANT---
"Current node : 'grade_documents'"
Deciding on next node---
DECISION: Improve question and run web search---
Improving the question-----
--- better question-----
Seattle, more information
"Current node : 'improve_question'"
Running web search---
"Current node : 'web_search'"
Generating output---
"Current node : 'generate'"
"Current node : '__end__'"
('Seattle is a major port of entry and an air and sea gatew