# Self-RAG

Self-RAG is a strategy for RAG that incorporates self-reflection / self-grading on retrieved documents and generations.

In the [paper](https://arxiv.org/abs/2310.11511), a few decisions are made:

1. Should I retrieve from retriever, `R` -
2. Input: `x (question)` OR `x (question)`, `y (generation)`
3. Decides when to retrieve `D` chunks with `R`
4. Output: `yes`, `no`, `continue`
5. Are the retrieved passages `D` relevant to the question `x` -
6. Input: `(x (question)`, `d (chunk))` for `d` in `D`
7. `d` provides useful information to solve `x`
8. Output: `relevant`, `irrelevant`
9. Are the LLM generation from each chunk in `D` is relevant to the chunk (hallucinations, etc) -
10. Input: `x (question)`, `d (chunk)`, `y (generation)` for `d` in `D`
11. All of the verification-worthy statements in `y (generation)` are supported by `d`
12. Output: `{fully supported, partially supported, no support}`
13. The LLM generation from each chunk in `D` is a useful response to `x (question)` -
14. Input: `x (question)`, `y (generation)` for `d` in `D`
15. `y (generation)` is a useful response to `x (question)`.
16. Output: `{5, 4, 3, 2, 1}`

We will implement some of these ideas from scratch using LangGraph.

<img src='./images/self_rag.png'>



## Setup

First let's install our required packages and set our API keys

In [None]:
pip install langchain_community tiktoken langchain-openai langchainhub chromadb langchain langgraph


In [None]:
import os
import keyring
from google.colab import userdata

# API KEY
OPENAI_API_KEY = userdata.get('openai')
ANTHROPIC_API_KEY = userdata.get('anthropic')
TAVILY_API_KEY = userdata.get('tavily')

os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY
os.environ['ANTHROPIC_API_KEY'] = ANTHROPIC_API_KEY
os.environ['TAVILY_API_KEY'] = TAVILY_API_KEY

# Set up LangSmith observability
os.environ['LANGCHAIN_TRACING_V2'] = 'true'
os.environ['LANGCHAIN_ENDPOINT'] = "https://api.smith.langchain.com"
os.environ['LANGCHAIN_API_KEY'] = userdata.get('langsmith')
os.environ['LANGCHAIN_PROJECT'] = "pr-stupendous-hood-8"

## Retriever

Let's index 3 blog posts.

In [None]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import WebBaseLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma

urls = [
    "https://lilianweng.github.io/posts/2023-06-23-agent/",
    "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
    "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=250, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)

# Add to vectorDB
vectorstore = Chroma.from_documents(
    documents=doc_splits,
    collection_name='rag-chroma',
    embedding=OpenAIEmbeddings(),
)
retriever = vectorstore.as_retriever()

## LLMs

In [None]:
### Retriever Grader

from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI

from pydantic import BaseModel, Field

# Data model
class GradeDocuments(BaseModel):
  """Binary score for relevance check on retrieved documents."""
  binary_score: str = Field(
      description="Documents are relevant to the question, 'yes' or 'no'"
  )

# LLM with function call
llm = ChatOpenAI(model='gpt-4o-mini', temperature=0)
structured_llm_grader = llm.with_structured_output(GradeDocuments)

# Prompt
system = """You are a grader assessng relevance of a retrieved document to a user question. \n
It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
Give a binary core 'yes' or 'no' score to indicate whtether the document is relevant to the question."""
grade_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "Retrieved document: \n\n {document} \n\n User question: {question}."),
    ]
)

retrieval_grader = grade_prompt | structured_llm_grader
question = "agent memory"
docs = retriever.invoke(question)
doc_txt = docs[1].page_content
print(retrieval_grader.invoke({"document": doc_txt, "question": question}))

In [None]:
### Generate

from langchain import hub
from langchain_core.output_parsers import StrOutputParser

# Prompt
prompt = hub.pull("rlm/rag-prompt")

# LLM
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)

# Post-processing
def format_docs(docs):
  return "\n\n".join(doc.page_content for doc in docs)

# Chain
rag_chain = prompt | llm | StrOutputParser()

# Run
generation = rag_chain.invoke({"context": docs, "question": question})
print(generation)

In [None]:
### Hallucination grader

# Data model
class GradeHallucinations(BaseModel):
  """Binary score for hallucination present in generateion answer."""

  binary_score: str = Field(
      description="Answer is grouned in the facts, 'yes' or 'no'"
  )

# LLM with function call
llm = ChatOpenAI(model='gpt-4o-mini', temperature=0)
structured_llm_grader = llm.with_structured_output(GradeHallucinations)

# Prompt
system = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
Give a binary score 'yes' or 'no'. 'Yes' means thst the answer is grounded in / supported by the set of facts."""
hallucination_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}."),
    ]
)

hallucination_grader = hallucination_prompt | structured_llm_grader
hallucination_grader.invoke({"documents": docs, "generation": generation})

In [None]:
### Answer Grader

# Data model
class GradeAnswer(BaseModel):
  """Binary score to assess answer adresses question."""

  binary_score: str = Field(
      description="Answer adresses the question, 'yes' or 'no'"
  )

# LLM with function call
llm = ChatOpenAI(model='gpt-4o-mini', temperature=0)
structured_llm_grader = llm.with_structured_output(GradeAnswer)

# Prompt
system = """You are a grader assessing whether an answer addresses / resolves a question. \n
Give a binary score 'yes' or 'no'. 'Yes' means that the answer resolves the question."""
answer_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
    ]
)

answer_grader = answer_prompt | structured_llm_grader
answer_grader.invoke({"question": question, "generation": generation})

In [None]:
### Question Re-writer

# LLM
llm  = ChatOpenAI(model='gpt-4o-mini', temperature=0)

# Prompt
system = """You are a question re-writer that converts an input question to a better version that is optimized \n
for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meanin."""
re_write_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        (
            "human",
            "Here is the initial question: \n\n{question} \n Formualte an improved question."
        ),
    ]
)

question_rewriter = re_write_prompt | llm | StrOutputParser()
question_rewriter.invoke({"question": question})

## Graph

Capture the flow in as a graph

### Graph state

In [None]:
from typing import List

from typing_extensions import TypedDict

class GraphState(TypedDict):
  """
  Represents the state of our graph.

  Attributes:
    question: question
    generation: LLM generation
    documents: list of documents
  """
  question: str
  generation: str
  documents: List[str]


In [None]:
### Nodes

def retrieve(state):
  """
  Retrieve documents

  Args:
    state (dict): The current graph state

  Returns:
    state (dict): New key added to state, documents, that contains retrieved documents
  """
  print("---RETIRVE---")
  question = state["question"]

  # Retrieval
  documents = retriever.invoke(question)
  return {"documents": documents, "question": question}

def generate(state):
  """
  Generate answer

  Args:
    state (dict): The current graph state

  Returns:
    state (dict): New key added to state, generation, that contains LLM generation
  """
  print("---GENERATE---")
  question = state["question"]
  documents = state["documents"]

  # RAG generation
  generation = rag_chain.invoke({"context": documents, "question": question})
  return {"documents": documents, "generation": generation, "question": question}

def grade_documents(state):
  """
  Define whether the retrieved documents are relevant to the quesion

  Args:
    state (dict): The current graph state
  Returns:
    state (dict): Updates documents key with only filtered relecant documents
  """

  print("---CHECK DOCUMENT RELEVANT TO QUESTION---")
  question = state["question"]
  documents = state["documents"]

  # Score each doc
  filtered_docs = []
  for d in documents:
    score = retrieval_grader.invoke(
        {"question": question, "document": d.page_content}
    )
    grade = score.binary_score
    if grade == "yes":
      print("---GRADE: DOCUMENT RELEVANT---")
      filtered_docs.append(d)
    else:
      print("---GRADE: DOCUMENT NOT RELEVANT---")
      continue

  return {"documents": filtered_docs, "question": question}

def transform_query(state):
  """
  Transform the query to produce a better question.

  Args:
    state (dict): The current graph state
  Returns:
    state (dict): Updates question key with a re-phrased question
    """

  print("---TRANSFORM QUERY---")
  question = state["question"]
  documents = state["documents"]

  # Re-write question
  better_question = question_rewriter.invoke({"question": question})
  return {"question": better_question, "documents": documents}

### Edges

def decide_to_generate(state):
  """
  Determines whether to generate an asnwer, or re-generate a question.

  Args:
    state (dict): The current graph state

  Returns:
    str: Binary decisions for next node to call
  """

  print("---ASSESS GRADED DOCUMENTS---")
  state["question"]
  filtered_documents = state["documents"]

  if not filtered_documents:
    # All documents have been filtered check_relevance
    # We will re-generate a new query
    print(
        "---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESION.TRANSFORM QUERY---"
    )
    return "transform_query"

  else:
    # We have relevant documents, so generate answer
    print("---DEICSION: GENERATE---")
    return "generate"

def grade_gneration_v_documents_and_question(state):
  """
  Determins whether the generation is grounded in the document and answers question.

  Args:
    state (dict): The current graph state

  Returns:
    str: Binary decisions for next node to call
  """

  print("---CHECK HALLUCINATIONS---")
  question = state["question"]
  documents = state["documents"]
  generation = state["generation"]

  score = hallucination_grader.invoke(
      {"documents": documents, "generation": generation}
  )
  grade = score.binary_score

  # Check hallucination
  if grade == "yes":
    print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
    # Check question-answering
    print("---GRADE GENERATION vs QUESION---")
    score = answer_grader.invoke({"question": question, "generation": generation})
    grade = score.binary_score
    if grade == "yes":
      print("---DECISION: GENERATION ADDRESSES QUESTION---")
      return "useful"
    else:
      print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
      return "not useful"

  else:
    pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
    return "not supproted"

## Build Graph

The just follows the flow we outlined in the figure above.

In [None]:
from langgraph.graph import StateGraph, START, END

workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("retrieve", retrieve)   # retrieve
workflow.add_node("grade_documents", grade_documents) # grade documents
workflow.add_node("generate", generate) # generate
workflow.add_node("transform_query", transform_query) # transform query

# Define the edges
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "transform_query": "transform_query",
        "generate": "generate",
    }
)
workflow.add_edge("transform_query", "retrieve")
workflow.add_conditional_edges(
    "generate",
    grade_gneration_v_documents_and_question,
    {
        "not supported": "generate",
        "useful": END,
        "not useful": "transform_query",
    }
)

# Compile
app = workflow.compile()

In [None]:
from pprint import pprint

# Run
inputs = {"question": "Explain how the different types of agent memory work?"}
for output in app.stream(inputs):
  for key, value in output.items():
    # Nodes
    pprint(f"Node '{key}':")
    # Optional: print full state at each node
    # pprint(value["keys"], indent=2, width=80, depth=None)

# Final generation
pprint(value["generation"])