# Corrective RAG

Corrective RAG, or CRAG is a RAG framework extended with extra corrective steps to ensure the pipeline sticks to the information it gets. Every CRAG pipeline consists of four main parts: 

1. **Generative model** --> generation of initial sequence (invisible to the user)
2. **Retrieval model** --> retrieval of the context from the storage based on the initial sequence (invisible to the user)
3. **Evaluation-correction model** --> back-and-forth between step 1. and step 2. to correct the results and decide on the final sequence that'll be presented to the user (correction step)
    - Keep track of generated responses
    - Ask more extended questions
    - Add web search (in case the retrieved answers are irrelevant)
    - Re-ranking results

Here's a piece of pseudo-code from the official [CRAG paper](https://arxiv.org/abs/2401.15884):

![image.png](../images/crag_inference.png)


4. **Final Generation model** --> presentation of the results/responses (visible to the user)

In [1]:
# import modules
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings

In [2]:
# OpenAI

AZURE_OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY")
AZURE_OPENAI_ENDPOINT = os.environ.get('AZURE_OPENAI_ENDPOINT')
AZURE_OPENAI_VERSION = os.environ.get('AZURE_OPENAI_VERSION')
AZURE_OPENAI_DEPLOYMENT_NAME = os.environ.get('AZURE_OPENAI_DEPLOYMENT_NAME')
AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME = os.environ.get('AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME')

In [3]:
# generation model
oai = AzureChatOpenAI(
    openai_api_version=AZURE_OPENAI_VERSION,
    azure_deployment=AZURE_OPENAI_DEPLOYMENT_NAME,
    temperature=0
)

# embedding model
azure_oai_emb_model: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings(
    azure_deployment=AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME,
    openai_api_version=AZURE_OPENAI_VERSION,
    azure_endpoint=AZURE_OPENAI_ENDPOINT,
    api_key=AZURE_OPENAI_API_KEY,
)

## Load Data & Setup Basic Retriever

In [4]:
# articles about sabre fecing 

urls = [
    "https://www.ncbi.nlm.nih.gov/pmc/articles/PMC10234521/",
    "https://www.usafencing.org/news_article/show/1265759",
    "https://commons.nmu.edu/cgi/viewcontent.cgi?article=2108&context=isbs"
]

docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

In [5]:
# document chunking

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=350, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)
print(f"TOTAL NUMBER OF CHUNKS: {len(doc_splits)}")

TOTAL NUMBER OF CHUNKS: 516


In [6]:
vector_store = Chroma.from_documents(documents=doc_splits, embedding=azure_oai_emb_model)
retriever = vector_store.as_retriever(search_type="mmr")

## CRAG attributes

#### Graph State

In [9]:
from typing_extensions import TypedDict
from typing import List


class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        generation: LLM generation
        web_search: whether to add search
        documents: list of documents
    """

    question: str
    generation: str
    web_search: str
    documents: List[str]

#### Query Re-writer