#### Corrective RAG 
Self correction can enhance RAG, enabling correction of poor quality retrieval or generations.
Several recent papers focus on this theme, but implementing the ideas can be tricky.
Here we show how to implement self-relfective RAG using `Mistral` and `LangGraph`.


##### Indexing
First, let'sindex a popular blog post on agents

In [1]:
import os
from dotenv import load_dotenv
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores.chroma import Chroma
from langchain_community.embeddings import GPT4AllEmbeddings
from langchain_openai.embeddings import OpenAIEmbeddings

load_dotenv()

True

In [2]:
url = "https://lilianweng.github.io/posts/2023-06-23-agent/"
loader = WebBaseLoader(url)
docs = loader.load()



text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size = 500,
    chunk_overlap = 100
)

all_splits = text_splitter.split_documents(docs)

embeddings = GPT4AllEmbeddings()


vectorstore = Chroma.from_documents(
    documents=all_splits,
    embedding=embeddings,
    collection_name="rag-chroma"
)

retriever = vectorstore.as_retriever()

bert_load_from_file: gguf version     = 2
bert_load_from_file: gguf alignment   = 32
bert_load_from_file: gguf data offset = 695552
bert_load_from_file: model name           = BERT
bert_load_from_file: model architecture   = bert
bert_load_from_file: model file type      = 1
bert_load_from_file: bert tokenizer vocab = 30522


##### Corrective RAG
Let's implement self-reflective RAG with some ideas from CRAG (Corrective RAG)
- Grade documents for relevance relative to the question
- If any are irrelevant, then we will supplement the context used for generation with web search
- For web search, we will re-phrase the question and use Tavily API
- We will then pass retrieved documents and web results to an LLM for final answer generation

In [3]:
from typing import Annotated, Dict, TypedDict, Any
from langchain_core.messages import BaseMessage


class GraphState(TypedDict):
    keys: Dict[str, Any]
    
    

#### Nodes and edges
Each node in the graph we laid out above has a function. 
Each node will modify the state in some way.
Each edge will choose which node to call next.

In [4]:
import json
import operator
from typing import Annotated, Sequence, TypedDict

from langchain import hub
from langchain_core.output_parsers import JsonOutputParser
from langchain.prompts import PromptTemplate
from langchain.schema import Document
from langchain_openai.chat_models import ChatOpenAI
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.vectorstores.chroma import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough


In [5]:
def retrieve(state: GraphState):
    """ 
    Retrieve documents
    """
    print("--RETRIEVE--")
    state_dict = state["keys"]
    question = state_dict["question"]
    local = state_dict["local"]
    
    documents = retriever.get_relevant_documents(question)
    return {"keys": {
        "documents": documents,
        "local": local,
        "question": question
    }}

In [7]:
def generate(state: GraphState):
    """ 
    Generate answer
    """
    print("--GENERATE--")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    local = state_dict["local"]
    
    prompt = hub.pull("rlm/rag-prompt")
    
    model = ChatOpenAI(model="mistralai/Mixtral-8x7B-Instruct-v0.1")
    
    def format_docs(docs):
        return "\n\n".join(doc.page_content for doc in docs)
    
    rag_chain = prompt | model | StrOutputParser()
    
    generation = rag_chain.invoke({"context": documents, "question": question})
    return {
        "keys": {
            "documents": documents,
            "question": question,
            "generation": generation
        }
    }


In [8]:
from textwrap import dedent

In [9]:
def grade_documents(state: GraphState):
    """ 
    Determines whether the retrieved documents are relevant to the question. 
    """
    print("---CHECK RELEVANCE---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    local = state_dict["local"]
    
    llm = ChatOpenAI(model="mistralai/Mixtral-8x7B-Instruct-v0.1")
    
    prompt=PromptTemplate(
        template="""You are a grader assessing relevance of a retrieved document to a user question. \n 
        Here is the retrieved document: \n\n {context} \n\n
        Here is the user question: {question} \n
        If the document contains keywords related to the user question, grade it as relevant. \n
        It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
        Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. \n
        Provide the binary score as a JSON with a single key 'score' and no premable or explaination.""",
        input_variables=["question", "context"])
    
    chain = prompt | llm | JsonOutputParser()
    
    filtered_docs = []
    search = "No"
    
    for d in documents:
        score = chain.invoke({
            "question": question,
            "context": d.page_content,
        })
        
        grade = score["score"]
        
        if grade == "yes":
            print("-- document relevant --")
            filtered_docs.append(d)
        else:
            print("-- document not relevant --")
            search = "Yes"
            continue
    return {
        "keys": {
            "documents": filtered_docs,
            "question": question,
            "local": local,
            "run_web_search": search
        }
    }

In [10]:
def transform_query(state: GraphState):
    """
    Transform the query to produce better question
    """
    print("-- transform query --")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    local = state_dict["local"]
    
    
    prompt = PromptTemplate(
        template=dedent("""
        You are generating questions that is well optimized for retrieval. \n 
        Look at the input and try to reason about the underlying sematic intent / meaning. \n 
        Here is the initial question:
        \n ------- \n
        {question} 
        \n ------- \n
        Provide an improved question without any premable, only respond with the updated question: 
        """),
        input_variables = ["question"]
    )
    
    llm = ChatOpenAI(model="mistralai/Mixtral-8x7B-Instruct-v0.1")
    
    chain = prompt | llm | StrOutputParser()
    better_question = chain.invoke({"question": question})
    
    return {
        "keys": {
            "documents": documents,
            "question": better_question,
            "local": local
        }
    }

In [11]:
def web_search(state: GraphState):
    """
    Web search based on the re-phrased qeustion using Tavily API
    """
    print("-- web search --")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    local = state_dict["local"]
    
    tool = TavilySearchResults(max_results=2)
    
    docs = tool.invoke({"query": question})
    
    web_results = "\n".join([d["content"] for d in docs])
    web_results = Document(page_content=web_results)
    documents.append(web_results)
    
    return {
        "keys": {
            "documents": documents,
            "local": local,
            "question": question
        }
    }

In [12]:
def decide_to_generate(state: GraphState):
    """ 
    Determin whether to generat an answer or re-generate a question for web search
    """
    print("-- decide to generate --")
    state_dict = state["keys"]
    search = state_dict["run_web_search"]
    
    if search == "Yes":
        print("--Transform qeury and run web search --")
        return "transform_query"
    
    else:
        print("-- generate response --")
        return "generate"

#### Build Graph
This just follows the flow we outlined in the figure above

In [13]:
from pprint import pprint
from langgraph.graph import StateGraph, END

# define the nodes
workflow = StateGraph(GraphState)

workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("transform_query", transform_query)
workflow.add_node("web_search", web_search)


# build graph
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges("grade_documents", decide_to_generate, {
    "transform_query": "transform_query",
    "generate": "generate",
})

workflow.add_edge("transform_query", "web_search")
workflow.add_edge("web_search","generate")
workflow.add_edge("generate", END)

app = workflow.compile()

### Run the app
We will now run our implemented graph and see how the instruction flows between the nodes

In [20]:
inputs = {
    "keys": {
        "question": "Explain how the different types of agent memory work?",
        "local": False,
    }
}

for output in app.stream(inputs):
    for key, value in output.items():
        pprint(f"Node {key} :")
        
    pprint("\n --- \n")
    
pprint(value["keys"]["generation"])

--RETRIEVE--
'Node retrieve :'
'\n --- \n'
---CHECK RELEVANCE---
-- document relevant --
-- document relevant --
-- document relevant --
-- document relevant --
'Node grade_documents :'
'\n --- \n'
-- decide to generate --
-- generate response --
--GENERATE--
'Node generate :'
'\n --- \n'
'Node __end__ :'
'\n --- \n'
(' In the context of LLM-powered autonomous agents, memory is categorized into '
 'short-term and long-term memory. Short-term memory, also known as working '
 'memory, stores information currently needed for complex tasks. It has a '
 'capacity of around 7 items and lasts for 20-30 seconds. Long-term memory, on '
 'the other hand, can store information for extended periods, even decades, '
 'with an essentially unlimited storage capacity. It includes explicit or '
 'declarative memory for consciously recalling facts and events, and implicit '
 'or procedural memory for unconsciously storing skills and routines. In terms '
 'of operation, short-term memory is utilized for 

In [14]:
# Run
inputs = {
    "keys": {
        "question": "Explain how attention works in the transformer archticture?",
        "local": False,
    }
}
for output in app.stream(inputs):
    for key, value in output.items():
        # Node
        pprint(f"Node '{key}':")
        # Optional: print full state at each node
        # pprint.pprint(value["keys"], indent=2, width=80, depth=None)
        pprint("\n---\n")

# Final generation
pprint(value["keys"]["generation"])

--RETRIEVE--
"Node 'retrieve':"
'\n---\n'
---CHECK RELEVANCE---
-- document not relevant --
-- document not relevant --
-- document not relevant --
-- document not relevant --
"Node 'grade_documents':"
'\n---\n'
-- decide to generate --
--Transform qeury and run web search --
-- transform query --
"Node 'transform_query':"
'\n---\n'
-- web search --
"Node 'web_search':"
'\n---\n'
--GENERATE--
"Node 'generate':"
'\n---\n'
"Node '__end__':"
'\n---\n'
(' The Transformer architecture is built on the concept of attention, which '
 'assigns weights to various parts of an input sequence. This allows the model '
 'to focus on different parts of the sequence and understand it better. The '
 'role of attention in the Transformer architecture is crucial, as it enables '
 'the model to selectively concentrate on relevant features within the input '
 'data.')
