In [None]:
! pip install langchain_community tiktoken langchain-openai langchainhub chromadb langchain langgraph tavily-python



# Corrective RAG (CRAG)

Self-reflection can enhance RAG, enabling correction of poor quality retrieval or generations.

Several recent papers focus on this theme, but implementing the ideas can be tricky.

Here we show how to implement ideas from the `Corrective RAG (CRAG)` paper [here](https://arxiv.org/pdf/2401.15884.pdf) using LangGraph.

## Dependencies

Set `GOOGLE API KEY`

Set `TAVILY_API_KEY` to enable web search [here](https://app.tavily.com/sign-in)

## CRAG Detail

Corrective-RAG (CRAG) is a recent paper that introduces an interesting approach for self-reflective RAG.

The framework grades retrieved documents relative to the question:

1. Correct documents -

* If at least one document exceeds the threshold for relevance, then it proceeds to generation
* Before generation, it performns knowledge refinement
* This paritions the document into "knowledge strips"
* It grades each strip, and filters our irrelevant ones                                           

2. Ambiguous or incorrect documents -

* If all documents fall below the relevance threshold or if the grader is unsure, then the framework seeks an additional datasource
* It will use web search to supplement retrieval
* The diagrams in the paper also suggest that query re-writing is used here

![Screenshot 2024-02-04 at 2.50.32 PM.png](attachment:5bfa38a2-78a1-4e99-80a2-d98c8a440ea2.png)

---

Let's implement some of these ideas from scratch using [LangGraph](https://python.langchain.com/docs/langgraph).

## Retriever

Let's index 3 blog posts.

In [None]:
!pip install google-generativeai.

[31mERROR: Invalid requirement: 'google-generativeai.'[0m[31m
[0m

In [None]:
!pip install langchain_google_genai

Collecting langchain_google_genai
  Downloading langchain_google_genai-0.0.9-py3-none-any.whl (17 kB)
Installing collected packages: langchain_google_genai
Successfully installed langchain_google_genai-0.0.9


In [None]:
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
import google.generativeai as genai
import chromadb.utils.embedding_functions as embedding_functions
from getpass import getpass

from langchain_google_genai import GoogleGenerativeAIEmbeddings




if "GOOGLE_API_KEY" not in os.environ:
    os.environ["GOOGLE_API_KEY"] = getpass("APi_KEY")


urls = [
    "https://lilianweng.github.io/posts/2023-06-23-agent/",
    "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
    "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]

docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=250, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)


embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")

# Add to vectorDB
vectorstore = Chroma.from_documents(
    documents=doc_splits,
    collection_name="rag-chroma",
    embedding= embeddings)
retriever = vectorstore.as_retriever()

APi_KEY··········


## State

We will define a graph.

Our state will be a `dict`.

We can access this from any graph node as `state['keys']`.

In [None]:
from typing import Dict, TypedDict

from langchain_core.messages import BaseMessage


class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        keys: A dictionary where each key is a string.
    """

    keys: Dict[str, any]

## Nodes and Edges

Each `node` will simply modify the `state`.

Each `edge` will choose which `node` to call next.

We can make some simplifications from the paper:

* Let's skip the knowledge refinement phase as a first pass. This can be added back as a node, if desired.
* If *any* document is irrelevant, let's opt to supplement retrieval with web search.
* We'll use [Tavily Search](https://python.langchain.com/docs/integrations/tools/tavily_search) for web search.
* Let's use query re-writing to optimize the query for web search.

Here is our graph flow:

![Screenshot 2024-02-04 at 1.32.52 PM.png](attachment:3b65f495-5fc4-497b-83e2-73844a97f6cc.png)

In [None]:
!pip install langchain

Collecting langchain
  Downloading langchain-0.1.7-py3-none-any.whl (815 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m815.9/815.9 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
Collecting dataclasses-json<0.7,>=0.5.7 (from langchain)
  Downloading dataclasses_json-0.6.4-py3-none-any.whl (28 kB)
Collecting jsonpatch<2.0,>=1.33 (from langchain)
  Downloading jsonpatch-1.33-py2.py3-none-any.whl (12 kB)
Collecting langchain-community<0.1,>=0.0.20 (from langchain)
  Downloading langchain_community-0.0.20-py3-none-any.whl (1.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m23.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting langchain-core<0.2,>=0.1.22 (from langchain)
  Downloading langchain_core-0.1.23-py3-none-any.whl (241 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m241.2/241.2 kB[0m [31m20.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting langsmith<0.1,>=0.0.83 (from langchain)
  Downloading langsmith-

In [None]:
!pip install langchain_openai


Collecting langchain_openai
  Downloading langchain_openai-0.0.6-py3-none-any.whl (29 kB)
Collecting openai<2.0.0,>=1.10.0 (from langchain_openai)
  Downloading openai-1.12.0-py3-none-any.whl (226 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m226.7/226.7 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting tiktoken<1,>=0.5.2 (from langchain_openai)
  Downloading tiktoken-0.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m18.1 MB/s[0m eta [36m0:00:00[0m
Collecting httpx<1,>=0.23.0 (from openai<2.0.0,>=1.10.0->langchain_openai)
  Downloading httpx-0.26.0-py3-none-any.whl (75 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.9/75.9 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
Collecting httpcore==1.* (from httpx<1,>=0.23.0->openai<2.0.0,>=1.10.0->langchain_openai)
  Downloading httpcore-1.0.3-py3-none-any.whl (77 kB)
[2K     

In [None]:
import json
import operator
from typing import Annotated, Sequence, TypedDict
import google.generativeai as genai

genai.configure(api_key=ENTERYOURAPIKEYHERE)

from langchain import hub
from langchain.output_parsers.openai_tools import PydanticToolsParser
from langchain.prompts import PromptTemplate
from langchain.schema import Document
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.vectorstores import Chroma
from langchain_core.messages import BaseMessage, FunctionMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import RunnablePassthrough
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_openai import ChatOpenAI, OpenAIEmbeddings

### Nodes ###


def retrieve(state):
    """
    Retrieve documents

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, documents, that contains retrieved documents
    """
    print("---RETRIEVE---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = retriever.get_relevant_documents(question)
    return {"keys": {"documents": documents, "question": question}}


def generate(state):
    """
    Generate answer

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, generation, that contains LLM generation
    """
    print("---GENERATE---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]

    # Prompt
    prompt = hub.pull("rlm/rag-prompt")

    # LLM
    llm = genai.GenerativeModel('gemini-pro')

    # Post-processing
    def format_docs(docs):
        return "\n\n".join(doc.page_content for doc in docs)

    # Chain
    rag_chain = prompt | llm | StrOutputParser()

    # Run
    generation = rag_chain.invoke({"context": documents, "question": question})
    return {
        "keys": {"documents": documents, "question": question, "generation": generation}
    }


def grade_documents(state):
    """
    Determines whether the retrieved documents are relevant to the question.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates documents key with relevant documents
    """

    print("---CHECK RELEVANCE---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]

    # Data model
    class grade(BaseModel):
        """Binary score for relevance check."""

        binary_score: str = Field(description="Relevance score 'yes' or 'no'")

    # LLM
    model = genai.GenerativeModel('gemini-pro')

    # Tool
    grade_tool_oai = convert_to_openai_tool(grade)

    # LLM with tool and enforce invocation
    #llm_with_tool = model.bind(
        #tools=[convert_to_openai_tool(grade_tool_oai)],
        #tool_choice={"type": "function", "function": {"name": "grade"}},
    #)

    # Parser
    parser_tool = PydanticToolsParser(tools=[grade])

    # Prompt
    prompt = PromptTemplate(
        template="""You are a grader assessing relevance of a retrieved document to a user question. \n
        Here is the retrieved document: \n\n {context} \n\n
        Here is the user question: {question} \n
        If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
        Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""",
        input_variables=["context", "question"],
    )

    # Chain
    chain = prompt | model | parser_tool

    # Score
    filtered_docs = []
    search = "No"  # Default do not opt for web search to supplement retrieval
    for d in documents:
        score = chain.invoke({"question": question, "context": d.page_content})
        grade = score[0].binary_score
        if grade == "yes":
            print("---GRADE: DOCUMENT RELEVANT---")
            filtered_docs.append(d)
        else:
            print("---GRADE: DOCUMENT NOT RELEVANT---")
            search = "Yes"  # Perform web search
            continue

    return {
        "keys": {
            "documents": filtered_docs,
            "question": question,
            "run_web_search": search,
        }
    }


def transform_query(state):
    """
    Transform the query to produce a better question.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates question key with a re-phrased question
    """

    print("---TRANSFORM QUERY---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]

    # Create a prompt template with format instructions and the query
    prompt = PromptTemplate(
        template="""You are generating questions that is well optimized for retrieval. \n
        Look at the input and try to reason about the underlying sematic intent / meaning. \n
        Here is the initial question:
        \n ------- \n
        {question}
        \n ------- \n
        Formulate an improved question: """,
        input_variables=["question"],
    )

    # Grader

    model = genai.GenerativeModel('gemini-pro')

    # Prompt
    chain = prompt | model | StrOutputParser()
    better_question = chain.invoke({"question": question})

    return {"keys": {"documents": documents, "question": better_question}}


def web_search(state):
    """
    Web search based on the re-phrased question using Tavily API.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates documents key with appended web results
    """

    print("---WEB SEARCH---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]

    tool = TavilySearchResults()
    docs = tool.invoke({"query": question})
    web_results = "\n".join([d["content"] for d in docs])
    web_results = Document(page_content=web_results)
    documents.append(web_results)

    return {"keys": {"documents": documents, "question": question}}


### Edges


def decide_to_generate(state):
    """
    Determines whether to generate an answer or re-generate a question for web search.

    Args:
        state (dict): The current state of the agent, including all keys.

    Returns:
        str: Next node to call
    """

    print("---DECIDE TO GENERATE---")
    state_dict = state["keys"]
    question = state_dict["question"]
    filtered_documents = state_dict["documents"]
    search = state_dict["run_web_search"]

    if search == "Yes":
        # All documents have been filtered check_relevance
        # We will re-generate a new query
        print("---DECISION: TRANSFORM QUERY and RUN WEB SEARCH---")
        return "transform_query"
    else:
        # We have relevant documents, so generate answer
        print("---DECISION: GENERATE---")
        return "generate"

## Build Graph

The just follows the flow we outlined in the figure above.

In [None]:
!pip install langgraph



In [None]:
!pip install GraphState

Collecting GraphState
  Downloading GraphState-1.0.6.tar.gz (19 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: GraphState
  Building wheel for GraphState (setup.py) ... [?25l[?25hdone
  Created wheel for GraphState: filename=GraphState-1.0.6-py3-none-any.whl size=23481 sha256=ef5bc3d1f551e8672372031020687033e98ef5acab1cb37a05249545806c2a77
  Stored in directory: /root/.cache/pip/wheels/33/af/35/f6c117b4b40b30e6a599a99d71eb1e6662c3068bfe906f96ad
Successfully built GraphState
Installing collected packages: GraphState
Successfully installed GraphState-1.0.6


In [None]:
import pprint

from langgraph.graph import END, StateGraph


workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("retrieve", retrieve)  # retrieve
workflow.add_node("grade_documents", grade_documents)  # grade documents
workflow.add_node("generate", generate)  # generatae
workflow.add_node("transform_query", transform_query)  # transform_query
workflow.add_node("web_search", web_search)  # web search

# Build graph
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "transform_query": "transform_query",
        "generate": "generate",
    },
)
workflow.add_edge("transform_query", "web_search")
workflow.add_edge("web_search", "generate")
workflow.add_edge("generate", END)

# Compile
app = workflow.compile()

In [None]:
# Run
inputs = {"keys": {"question": "Explain how the different types of agent memory work?"}}
for output in app.stream(inputs):
    for key, value in output.items():
        # Node
        pprint.pprint(f"Node '{key}':")
        # Optional: print full state at each node
        # pprint.pprint(value["keys"], indent=2, width=80, depth=None)
    pprint.pprint("\n---\n")

# Final generation
pprint.pprint(value['keys']['generation'])

---RETRIEVE---
"Node 'retrieve':"
'\n---\n'
---CHECK RELEVANCE---


TypeError: Expected a Runnable, callable or dict.Instead got an unsupported type: <class 'google.generativeai.generative_models.GenerativeModel'>

In [None]:
# Correction for question not present in context
inputs = {"keys": {"question": "What is the approach for code generation taken in the AlphaCodium paper?"}}
for output in app.stream(inputs):
    for key, value in output.items():
        # Node
        pprint.pprint(f"Node '{key}':")
        # Optional: print full state
        # pprint.pprint(value["keys"], indent=2, width=80, depth=None)
    pprint.pprint("\n---\n")

# Final generation
pprint.pprint(value['keys']['generation'])

---RETRIEVE---
"Node 'retrieve':"
'\n---\n'
---CHECK RELEVANCE---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
"Node 'grade_documents':"
'\n---\n'
---DECIDE TO GENERATE---
---DECISION: TRANSFORM QUERY and RUN WEB SEARCH---
---TRANSFORM QUERY---
"Node 'transform_query':"
'\n---\n'
---WEB SEARCH---
"Node 'web_search':"
'\n---\n'
---GENERATE---
"Node 'generate':"
'\n---\n'
"Node '__end__':"
'\n---\n'
('The AlphaCodium paper uses a test-based, iterative approach for code '
 'generation. It employs a multi-stage, code-oriented flow that addresses the '
 'specific challenges of coding problems. Unlike traditional models, '
 'AlphaCodium actively engages in problem self-reflection, reasoning, and '
 'iterative code solution generation.')
