#### Corrective RAG
Self reflection can enhance RAG, enabling correction of poor quatlity retieval or genrations.
Some remlevant papers focus on the theme, but implementing the dieas can be tricky. 

Here we show to implement ideas from `Corrective RAG (CRAG)`


##### CRAG details
Corrective-RAG (CRAG) is a recent paper that introduces an interesting approach for self-reflective RAG.

The framework grades retrieved documents relative to the question
1. Correct documents
   1. If at least one document exceeds the threshold of relevance, then it proceeds to generation
   2. Before generation, it performs knowledge refinement
   3. This partitions the document into `knowledge strips`
   4. It grades each strip, and filters out irrelevant ones
2. Ambigous or incorrect documents
   1. If all documents fall below the relevance threshold or if the grader is unsure, then the framework seeks an additional datasource
   2. It will use web search to supplement retrieval
   3. This is done using query re-writing

In [3]:
from langchain_core.embeddings import Embeddings
from langchain.pydantic_v1 import BaseModel
from dotenv import load_dotenv

In [5]:
load_dotenv()

True

In [6]:
from openai import OpenAI

client = OpenAI()

def get_embeddings(texts, model="togethercomputer/m2-bert-80M-32k-retrieval"):
   texts = [text.replace("\n", " ") for text in texts]
   outputs = client.embeddings.create(input = texts, model=model)
   return [outputs.data[i].embedding for i in range(len(texts))]

texts=["hello"]
len(get_embeddings(texts))

1

In [8]:
from typing import Coroutine, List, Any


class NewOpenAIEmbeddings(BaseModel, Embeddings):
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        return get_embeddings(texts)
    def aembed_documents(self, texts: List[str]) -> Coroutine[Any, Any, List[List[float]]]:
        return get_embeddings(texts)
    def aembed_query(self, text: str) -> Coroutine[Any, Any, List[float]]:
        return get_embeddings([text])[0]
    def embed_query(self, text: str) -> List[float]:
        return get_embeddings([text])[0]

In [10]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores.chroma import Chroma
from langchain_openai import OpenAIEmbeddings


urls = [
    "https://lilianweng.github.io/posts/2023-06-23-agent/"
]


docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=250, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)
doc_splits = doc_splits[:10]
#add to vectordb
vectorstore = Chroma.from_documents(
    documents=doc_splits,
    collection_name="rag-chroma",
    embedding=NewOpenAIEmbeddings(),
)

#### State
We can now define the state of our graph. Our state will be a dict and will define it using TypedDict type. 
Every node in the graph will have access to the state

In [11]:
from typing import Dict, TypedDict
from langchain_core.messages import BaseMessage


class GraphState(TypedDict):
    """ 
    Represent the state of our graph
    """
    keys: Dict[str, Any]

In [18]:
import json
import operator
from typing import Annotated, Sequence, TypedDict


from langchain import hub
from langchain.output_parsers.openai_tools import PydanticToolsParser
from langchain.prompts import PromptTemplate
from langchain.schema import Document
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.vectorstores.chroma import Chroma
from langchain_core.messages import BaseMessage, FunctionMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import RunnablePassthrough
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_openai import ChatOpenAI, OpenAIEmbeddings

In [19]:
retriever = vectorstore.as_retriever()

In [20]:
def retrieve(state: GraphState):
    """
    Retrieve documents
    """
    print("---RETRIEVE---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = retriever.get_relevant_documents(question)
    return {"keys": {"documents": documents, "question": question}}

In [42]:
def generate(state: GraphState):
    print("--GENERATE--")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    
    prompt = hub.pull("rlm/rag-prompt")
    
    llm = ChatOpenAI(model="mistralai/Mixtral-8x7B-Instruct-v0.1")
    
    def format_docs(docs):
        return "\n\n".join(doc.page_content for doc in docs)
    
    # chain 
    rag_chain = prompt | llm | StrOutputParser()
    
    # run
    generation = rag_chain.invoke({"context": documents, "question": question})
    
    return {"keys": {"documents": documents, "question": question, "generation": generation}}



def grade_documents(state: GraphState):
    """ 
    Determine whether the retrieved documents are relevant to the question
    """
    print("--CHECK RELEVANCE--")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    
    class Grade(BaseModel):
        """ 
        Binary score for relevance check
        """
        binary_score: str = Field(description="Binary score for relevenance 'yes' or 'no'")
        
    model = ChatOpenAI(model="mistralai/Mixtral-8x7B-Instruct-v0.1")
    
    grade_tool_oai = convert_to_openai_tool(Grade)
    llm_with_tool= model.bind(
        tools=[convert_to_openai_tool(grade_tool_oai)],
        tool_choice={"type": "function", "function": {"name": "grade"}}
    )
    parser_tool = PydanticToolsParser(tools=[Grade])
    
    prompt = PromptTemplate(
        template="""You are a grader assessing relevance of a retrieved document to a user question . \n
                    Here is the retrieved document: \n\n {context} \n\n
                    Here is the question: {question} \n
                    If the document contains keyword(s) or sementic meaning related to the user question, grade it as relevant. 
                    Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question""",
        input_variables=["context", "question"]
    )
    
    chain = prompt | llm_with_tool 
    
    filtered_docs = []
    search = "No"
    
    for d in documents:
        score = chain.invoke({"question": question, "context": d.page_content})
        try:
            grade = score.additional_kwargs["tool_calls"][0]["function"]["binary_score"]
        except Exception as e:
            try:
                grade = score.additional_kwargs["tool_calls"][0]["function"]["function"]["parameters"]["binary_score"]
            except Exception as e:
                print("Failed using default")
                grade = "yes"
        if grade == "yes":
            print("--GRADE: DOCUMENT RELEVANT--")
            filtered_docs.append(d)
        else:
            print("--GRADE: DOCUMENT NOT RELEVANT --")
            search = "Yes"
            continue
        
    return {
        "keys": {
            "documents": filtered_docs,
            "question": question,
            "run_web_search": search
        }
    }
    
def transform_query(state: GraphState):
    """ 
    Transform the query to produce a better question
    """
    print("---TRANSFORM QUERY---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    
    prompt = PromptTemplate(
        template="""You are generating questions that is well optimized for retrieval. \n
        Look at the input and try to reason about the underlying semantic intent /meaning. \n
        Here is the initial question:
        \n-----\n
        {question}
        \n --- \n
        Formulate an improved question:
        """,
        input_variables=["question"]
    )
    
    model = ChatOpenAI(model="mistralai/Mixtral-8x7B-Instruct-v0.1")
    
    chain = prompt | model | StrOutputParser()
    better_question = chain.invoke({"question": question})
    
    return {"keys": {"documents": documents, "question": better_question}}


def web_search(state:GraphState):
    """ 
    Web search based on the re-phased question using Tavily API
    """
    
    print("---WEB SEARCH ---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    
    tool = TavilySearchResults(max_results=1)
    docs = tool.invoke({"query": question})
    web_results = "\n".join([d["content"] for d in docs])
    web_results = Document(page_content=web_results)
    documents.append(web_results)
    
    return {"keys": {"documents": documents, "question": question}}


def decide_to_generate(state):
    """ 
    Determine whether to generate an answer or re-generate a question for web search
    """
    print("---DECIDE TO GENERATE ---")
    state_dict = state["keys"]
    # question = state["question"]
    # documents = state_dict["documents"]
    search = state_dict["run_web_search"]
    
    
    if search == "Yes":
        print("transform query and run web search")
        return "transform_query"
    else:
        print("generate response")
        return "generate"

#### Build Graph
The just follows the flow we outlined in the figure above.

In [43]:
import pprint
from langgraph.graph import END, StateGraph

workflow = StateGraph(GraphState)

workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("transform_query", transform_query)
workflow.add_node("web_search", web_search)


# build graph
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "transform_query": "transform_query",
        "generate": "generate"
    }
)

workflow.add_edge("transform_query", "web_search")
workflow.add_edge("web_search", "generate")
workflow.add_edge("generate", END)

app = workflow.compile()

In [44]:
inputs = {"keys": {
    "question": "Explain how the different types of agent memory work?"
}}

for output in app.stream(inputs):
    for key, value in output.items():
        pprint.pprint(f"Node '{key}'")
    pprint.pprint("\n --- \n")
    
pprint.pprint(value["keys"]["generation"])

---RETRIEVE---
"Node 'retrieve'"
'\n --- \n'
--CHECK RELEVANCE--
--GRADE: DOCUMENT RELEVANT--
Failed using default
--GRADE: DOCUMENT RELEVANT--
--GRADE: DOCUMENT RELEVANT--
--GRADE: DOCUMENT RELEVANT--
"Node 'grade_documents'"
'\n --- \n'
---DECIDE TO GENERATE ---
generate response
--GENERATE--
"Node 'generate'"
'\n --- \n'
"Node '__end__'"
'\n --- \n'
(' In a LLM-powered autonomous agent system, there are three types of memory: '
 '1) Short-term memory, used for in-context learning. 2) Long-term memory, '
 'which retains and recalls information over extended periods, facilitated by '
 'an external vector store. 3) Additionally, the agent learns to use tools, '
 'such as external APIs, for extra information or capabilities. '
 'Self-reflection is another key aspect, allowing the agent to improve '
 'iteratively by refining past action decisions and correcting previous '
 'mistakes.')


In [45]:
### Correction for question not in present in context
inputs = {
    "keys": {
        "question": "What is the approach for code generation taken in the AlphaCodium paper?"
    }
}

for output in app.stream(inputs):
    for key, value in output.items():
        pprint.pprint(f"Node '{key}' :")
        pprint.pprint(value["keys"], indent=2, width=80, depth=None)
    pprint.pprint("\n --- \n")

print("Final output")
pprint.pprint(value["keys"]["generation"])

---RETRIEVE---
"Node 'retrieve' :"
{ 'documents': [ Document(page_content='Chain of Hindsight (CoH; Liu et al. 2023) encourages the model to improve on its own outputs by explicitly presenting it with a sequence of past outputs, each annotated with feedback. Human feedback data is a collection of $D_h = \\{(x, y_i , r_i , z_i)\\}_{i=1}^n$, where $x$ is the prompt, each $y_i$ is a model completion, $r_i$ is the human rating of $y_i$, and $z_i$ is the corresponding human-provided hindsight feedback. Assume the feedback tuples are ranked by reward, $r_n \\geq r_{n-1} \\geq \\dots \\geq r_1$ The process is supervised fine-tuning where the data is a sequence in the form of $\\tau_h = (x, z_i, y_i, z_j, y_j, \\dots, z_n, y_n)$, where $\\leq i \\leq j \\leq n$. The model is finetuned to only predict $y_n$', metadata={'description': 'Building agents with LLM (large language model) as its core controller is a cool concept. Several proof-of-concepts demos, such as AutoGPT, GPT-Engineer and BabyA