# Corrective Retrieval Augmented Generation (CRAG)

Corrective-RAG (CRAG) is a recent paper that introduces an interesting approach for self-reflective RAG. You can read the paper [here](https://arxiv.org/pdf/2401.15884.pdf)

retrieval augmented generation (RAG) has introduced a retrieval technique to incorporate relevant knowledge to the model's input therefore improving output generation. Within this framework, models receive augmented input by adding relevant documents retrieved from an external knowledge collections. While RAG serves as a practicable complement to LLMs, its effectiveness is contingent upon the relevance and accuracy of the retrieved documents. The heavy reliance of generation on the retrieved knowledge raises significant concerns about the model’s behavior and performance in scenarios where retrieval may fail or return inaccurate results.

While RAG acts as a viable supplement to LLMs, its efficiency relies heavily on the relevance and accuracy of the retrieved documents. The substantial dependence of generation on the retrieved knowledge raises notable concerns regarding the model's behavior and performance in scenarios where retrieval is not successful or the retrieved documents are inaccurate.

A low-quality retriever can bring in a lot of irrelevant information. This can make it hard for models to acquire accurate knowledge and might even mislead them, causing problems like hallucinations.

Figure 1 shows how CRAG works at inference, in order to make generation more resilient. Given an input query and the retrieved documents from a retriever, CRAG uses a lightweight evaluator to estimate the relevance score of retrieved documents to the input query. This evaluation results in three confidence degrees and then triggered the corresponding actions: {Correct, Incorrect, Ambiguous}. If it's Correct, the retrieved documents are improved to be more accurate through knowledge refienment processes. This refinement operation involves knowledge decomposition, filter, and recomposition. If it's Incorrect, the retrieved documents are ignored, and web searches are used instead as complementary knowledge sources for corrections. If it's not clear whether the documents are correct or not, an action called Ambiguous is taken, combining both. Once the retrieval is refined, any generative model can be used.

<center><figure><img src="imgs/CRAG.jpg" alt="drawing" width="700"/><figcaption>Fig. 1: An overview of CRAG at inference.</figcaption></figure></center>    

## Knowledge Refinement 
A retrieval is considered Correct if the confidence score of at least one retrieved document exceeds the upper threshold. This means the presence of relevant documents in the retrieval results. However, even when a relevant document is found, it may contain some irrelevant information. To extract the most important information within this document, a method called knowledge refinement is applied. This method involves decomposing and then recomposing the content of each retrieved relevant document to extract the most crucial information. Initially, each document is divided into smaller knowledge segments through heuristic rules. Then, a fine-tuned retrieval evaluator assesses the relevance score of each segment. Based on these scores, irrelevant segments are filtered out, and relevant ones are recomposed via concatenation in order.


# CRAG implementation in LangChain

We can use LangGraph of LangChain to implement CRAG. So, first let's see what [LangGraph](https://python.langchain.com/docs/langgraph) is .

## LangGraph
LLMs can be used for reasoning tasks. This can essentially be thought of as running an LLM in a for-loop. These types of systems are often called agents. Here comes LangGraph! You may want to always force an agent to call a particular tool first. You may want to have more control over agents and how tools are called. These more controlled flows are referred to as "state machines" in LangGraph terminology. LangGraph is a way to create these state machines by specifying them as graphs.

The primary function of LangGraph is to add cycles into LLM applications. Cycles play a vital role in scenarios with agents. For example, you might repeatedly invoke an LLM within a loop to determine the next course of action.
LangGraph is a tool designed for creating complex, stateful applications that involve multiple actors using LLMs. It is built upon LangChain and expands its capabilities by enabling the coordination of multiple chains or actors through various steps of computation in a cyclic fashion. 
It's important to note that LangGraph is not a **Directed Acyclic Graph (DAG)**.

### How to build a graph using LangGraph

#### Creating SateGraph
Graphs in langgraph are the `StatefulGraph`. This graph is parameterized by a state object that it passes around to each node. This state definition represents a central state object that is updated over time by each node. These operations can either `set` specific attributes on the state (e.g. overwrite the existing values) or `add` to the existing attribute in the form of a key-value store. 
#### Adding Nodes
We will add nodes to the graph with `(name, value)` pair, where name is node's name that will be used to refer to the node when adding edges. The value is a function or LCEL runnable that will be called.
#### Adding Edges
After adding nodes, we can then add edges to the graph. We have two types of edges.
* ##### Normal Edges
These are edges where one node should ALWAYS be called after another.
* ##### Conditional Edges
These are edges where based on the output of a node, one of several paths is taken. After the agent is called: if the agent said to take an action, then the function to invoke tools should be called, or if the agent said that it was finished, then it should finish.

# CRAG example 
Lets code an example of CRAG in LangGraph! 
In the following example, I load a pdf file about the United States. First, I split the document into chunks and store their embeddings in a vector database. Then, I retrieve the relevant splits to the input question. Using a prompt, I ask the underlying LLM to assess each chunk more carefully and evaluate it with YES, or NO based on the relevance to the input question. If there is a chuck that LLM marked as NO, we perform a web search to yield a better output. Otherwise (i.e., all chucks are marked as YES), we only use the retrieved documents to render the final output. Finally, with or without web knowledge, we go through a classic RAG process to ask LLM generate the final output. Our graph consists of the following nodes and edges (see figure 2).

### Graph nodes
Our graph consists of following nodes:
* ### retrive:
This node is responsible for retrieving relevant documents to the input question
* ### evaluate document:
This node evaluates each retrived document by `retrieve` node. If not all documents are evaluated as `relevant`, the node will turn `web_search` flag to `True` to run a web search on the topic by the next node (i.e., `web_search` node). It also updates documents attribute to only include relevant documents.
* ### generate knowledge keywords:
If search flag is on, this node will generate knowledge keywords from the question as web search queries.
* ### generate: 
This node will generate the answer. Whether this is an internal answer retrieved from the provided documents or is an external answer collected from web search. This node perform a classic RAG algorithm on the question with the provided context. 
* ### web search
After generating the keywords from the question by `generate_knowledge_keywords` node, this node will perform a web search to find relevant documents to the input question. 

#### Each node receives graph sate as an argument and adds or updates its attributes upon returning it. 

### Graph edges
#### Normal edges:
* (retrieve,evaluate_documents)
* (evaluate_documents, should_generate)
* (generate_knowledge_keywords, web_search)
* (web_search, generate)

#### Conditional edges:
##### should_generate: Based on the search flag, it decides to generate the final answer (i.e., goes to `generate` node if flag is on) or geneate keywords for the web search (i.e., goes to `generate_knowledge_keywords` node if flag in off).
* (should_generate, web_search)
* (should_generate, generate)

<center><figure><img src="imgs/crag_langGraph.jpg" alt="drawing" width="900"/><figcaption>Fig. 2: example's graph.</figcaption></figure></center>  

In [None]:
! pip install langchain_community faiss-cpu tiktoken langchain-openai langchainhub langchain langgraph tavily-python pypdf

In [3]:
import os
from getpass import getpass
os.environ['OPENAI_API_KEY'] = getpass('Enter your OpenAI API Key: ')
os.environ['TAVILY_API_KEY'] = getpass('Enter your Tavily API Key: ')

Enter your OpenAI API Key:  ········
Enter your Tavily API Key:  ········


In [4]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=250, chunk_overlap=10
)
chunks = text_splitter.split_documents(PyPDFLoader("docs/the-usa.pdf").load())

vectorstore = FAISS.from_documents(
    documents=chunks,
    embedding=OpenAIEmbeddings(),
)
retriever = vectorstore.as_retriever()

In [5]:
from langchain import hub
from langchain.output_parsers.openai_tools import PydanticToolsParser,JsonOutputToolsParser
from langchain.prompts import PromptTemplate
from langchain.schema import Document
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import BaseMessage, FunctionMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_openai import ChatOpenAI


def retrieve(state):

    print("Retrieving documents---")
    question = state["attribs"]["question"]
    documents = retriever.get_relevant_documents(question)
    return {"attribs": {"documents": documents, "question": question}}


def evaluate_documents(state):
    
    print("Evaluating documents------")
    question = state["attribs"]["question"]
    documents = state["attribs"]["documents"]   

    tools = [
        {
            "type": "function",
            "function": {
                "name": "evaluate",
                "description": "Predict the relevance score for each question-document pair",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "score": {
                            "type": "string",
                            "description": "Score could be yes or no",
                        },                       
                    },
                    "required": ["score"],
                },
            },
        }
    ]
    
    prompt = PromptTemplate(
        template="""Given a question, does the following document have exact information to answer the question?
                    Question: {question}
                    Document: {context}
                    Think Step by step, and answer with yes or no only""",
        input_variables=["context", "question"],
    )

    llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo", streaming=True)
    llm_with_tool = llm.bind(
        tools=tools,
        tool_choice="auto"
    )
    # Chain
    chain = prompt | llm_with_tool | JsonOutputToolsParser()


    relevant_docs = []
    for d in documents:
        eval_result = chain.invoke({"question": question, "context": d.page_content})        
        score = eval_result[0]['args']['score']
        print("score",score)
        if score == "yes":
            print("---Evaluation Result: DOCUMENT RELEVANT---")
            relevant_docs.append(d)
        else:
            print("---Evaluation Result: DOCUMENT IRRELEVANT---")
    search = "No"  if len(relevant_docs) else "Yes" # Perform web search


    return {"attribs": { "documents": relevant_docs, "question": question, "run_web_search": search}}


def generate_knowledge_keywords(state):
    
    print("Generating knowledge keywords from the question for web search-----")
    question = state["attribs"]["question"]
    documents = state["attribs"]["documents"]
    
    prompt = PromptTemplate(
        template="""Extract at most three keywords separated by comma from the following dialogues and questions as queries for the
                    web search, including topic background within dialogues and main intent within questions.
                    question: What is Henry Feilden’s occupation?
                    query: Henry Feilden, occupation
                    question: In what city was Billy Carlson born?
                    query: city, Billy Carlson, born
                    question: What is the religion of John Gwynn?
                    query: religion of John Gwynn
                    question: What sport does Kiribati men’s national basketball team play?
                    query: sport, Kiribati men’s national basketball team play
                    question: {question}
                    query:""",
        input_variables=["question"],
    )

    model = ChatOpenAI(temperature=0, model="gpt-3.5-turbo", streaming=True)

    # Prompt
    chain = prompt | model | StrOutputParser()
    knowledge_keywords = chain.invoke({"question": question})
    
    print("--- knowledge keywords-----")
    print(knowledge_keywords)

    return {"attribs": {"documents": documents, "question": knowledge_keywords}}


def generate(state):
    
    print("Generating output---")
    question = state["attribs"]["question"]
    documents = state["attribs"]["documents"]

    # Prompt
    prompt = hub.pull("rlm/rag-prompt")

    # LLM
    llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, streaming=True)

    # Chain
    rag_chain = prompt | llm | StrOutputParser()

    # Run
    generation = rag_chain.invoke({"context": documents, "question": question})
    
    return {
        "attribs": {"documents": documents, "question": question, "generation": generation}
    }

def should_generate(state):
    
    print("Deciding on next node---")
    search = state["attribs"]["run_web_search"]

    if search == "Yes":
        print("DECISION: Generate knowledge keywords and run web search---")
        return "generate_knowledge_keywords"
    else:
        # We have relevant documents, so generate answer
        print("DECISION: Generate final output")
        return "generate"


def web_search(state):
    
    print("Running web search---")
    question = state["attribs"]["question"]
    documents = state["attribs"]["documents"]

    tool = TavilySearchResults()
    docs = tool.invoke({"query": question})
    web_results = "\n".join([d["content"] for d in docs])
    web_results = Document(page_content=web_results)
    documents.append(web_results)

    return {"attribs": {"documents": documents, "question": question}}


## Defining a graph

In [6]:
from typing import Dict, TypedDict
from langgraph.graph import END, StateGraph

class State(TypedDict):
    attribs: Dict[str, any]
    
graph = StateGraph(State)

## Adding nodes, edges and building the graph

In [7]:
graph.add_node("retrieve", retrieve)  # retrieve
graph.add_node("evaluate_documents", evaluate_documents)  # evaluate documents
graph.add_node("generate", generate)  # generatae
graph.add_node("generate_knowledge_keywords", generate_knowledge_keywords)  # generate_knowledge_keywords
graph.add_node("web_search", web_search)  # web search

graph.set_entry_point("retrieve")
graph.add_edge("retrieve", "evaluate_documents")
graph.add_conditional_edges(
    "evaluate_documents",
    should_generate,
    {
        "generate_knowledge_keywords": "generate_knowledge_keywords",
        "generate": "generate",
    },
)
graph.add_edge("generate_knowledge_keywords", "web_search")
graph.add_edge("web_search", "generate")
graph.add_edge("generate", END)

# Compile
app = graph.compile()

In [8]:
inputs = {"attribs": {"question": "Tell me more about Seattle"}}
gen = []
for output in app.stream(inputs):  
    gen.append(output)

print("Final Output--------------------------------------------------------------------------------")
print(f"{gen[0]['retrieve']['attribs']['question']} : {gen[-1]['__end__']['attribs']['generation']}")
# Final generation
# pprint.pprint(state["attribs"]["generation"])

Retrieving documents---
Evaluating documents------
score no
---Evaluation Result: DOCUMENT IRRELEVANT---
score no
---Evaluation Result: DOCUMENT IRRELEVANT---
score no
---Evaluation Result: DOCUMENT IRRELEVANT---
score no
---Evaluation Result: DOCUMENT IRRELEVANT---
Deciding on next node---
DECISION: Generate knowledge keywords and run web search---


KeyError: 'improve_question'