# Tutorial 09: Self-RAG (Self-Reflective RAG)

In this tutorial, you'll build a Self-RAG system that **grades its own work** - checking document relevance, detecting hallucinations, and ensuring answer quality.

**What you'll learn:**
- **Document Grading**: Filter irrelevant retrieved documents
- **Hallucination Detection**: Check if answers are grounded in facts
- **Answer Grading**: Verify the answer addresses the question
- **Retry Logic**: Re-generate when quality checks fail

By the end, you'll have a RAG system that self-corrects for higher quality outputs.

## Why Self-RAG?

Basic RAG has blind spots:
- Retrieved documents might be irrelevant
- LLM might hallucinate despite having context
- Answer might not address the actual question

**Self-RAG** adds reflection steps:

```
Retrieve → Grade Docs → Generate → Check Hallucination → Check Answer → Return/Retry
```

## Prerequisites

Ensure you've run Tutorial 08 to index documents, or run:
```python
from langgraph_ollama_local.rag import DocumentIndexer
indexer = DocumentIndexer()
indexer.index_directory("sources/")
```

In [None]:
# Setup
from langgraph_ollama_local import LocalAgentConfig
from langchain_ollama import ChatOllama

config = LocalAgentConfig()
llm = ChatOllama(
    model=config.ollama.model,
    base_url=config.ollama.base_url,
    temperature=0,
)

print(f"Using model: {config.ollama.model}")

## Step 1: Define the State

Self-RAG needs to track more information than basic RAG:

In [None]:
from typing import List, Literal
from typing_extensions import TypedDict
from langchain_core.documents import Document

class SelfRAGState(TypedDict):
    """State for Self-RAG pipeline."""
    question: str                      # User's question
    documents: List[Document]          # Retrieved documents
    filtered_documents: List[Document] # Relevant documents only
    generation: str                    # Generated answer
    retry_count: int                   # Number of retries
    max_retries: int                   # Maximum retries allowed

print("State schema defined!")

## Step 2: Create the Graders

We'll create three graders using the LLM:
1. **Document Grader**: Is this document relevant to the question?
2. **Hallucination Grader**: Is the answer grounded in the documents?
3. **Answer Grader**: Does the answer address the question?

In [None]:
from langgraph_ollama_local.rag import DocumentGrader, HallucinationGrader, AnswerGrader

# Create graders
doc_grader = DocumentGrader(llm)
hallucination_grader = HallucinationGrader(llm)
answer_grader = AnswerGrader(llm)

print("Graders created!")

In [None]:
# Test the document grader
from langchain_core.documents import Document

test_doc = Document(
    page_content="Self-RAG is a framework that enhances language models with self-reflection capabilities. It retrieves documents on-demand and uses reflection tokens to grade its own outputs."
)

relevant_q = "What is Self-RAG?"
irrelevant_q = "What is the weather today?"

print(f"Document: {test_doc.page_content[:100]}...")
print(f"\nQuestion 1: '{relevant_q}'")
print(f"Relevant: {doc_grader.grade(test_doc, relevant_q)}")
print(f"\nQuestion 2: '{irrelevant_q}'")
print(f"Relevant: {doc_grader.grade(test_doc, irrelevant_q)}")

## Step 3: Create the Retriever

In [None]:
from langgraph_ollama_local.rag import LocalRetriever

retriever = LocalRetriever()

# Test retrieval
test_results = retriever.retrieve("What is Self-RAG?", k=2)
print(f"Retrieved {len(test_results)} documents")
for doc, score in test_results:
    print(f"  Score: {score:.3f}")

## Step 4: Define Node Functions

Now we'll define the nodes for our Self-RAG graph:

In [None]:
def retrieve(state: SelfRAGState) -> dict:
    """Retrieve documents for the question."""
    print(f"--- RETRIEVE ---")
    question = state["question"]
    docs = retriever.retrieve_documents(question, k=5)
    print(f"Retrieved {len(docs)} documents")
    return {
        "documents": docs,
        "retry_count": state.get("retry_count", 0),
        "max_retries": state.get("max_retries", 3),
    }

print("Retrieve node defined!")

In [None]:
def grade_documents(state: SelfRAGState) -> dict:
    """Grade documents for relevance."""
    print(f"--- GRADE DOCUMENTS ---")
    question = state["question"]
    documents = state["documents"]
    
    # Grade each document
    relevant, irrelevant = doc_grader.grade_documents(documents, question)
    
    print(f"Relevant: {len(relevant)}, Irrelevant: {len(irrelevant)}")
    return {"filtered_documents": relevant}

print("Grade documents node defined!")

In [None]:
from langchain_core.prompts import ChatPromptTemplate

# RAG generation prompt
GENERATE_PROMPT = ChatPromptTemplate.from_template(
    """You are an assistant for question-answering tasks.
Use ONLY the following context to answer the question.
If you cannot answer from the context, say "I cannot answer this from the provided documents."

Context:
{context}

Question: {question}

Answer:"""
)

def generate(state: SelfRAGState) -> dict:
    """Generate answer using filtered documents."""
    print(f"--- GENERATE ---")
    question = state["question"]
    documents = state["filtered_documents"]
    
    if not documents:
        return {"generation": "I could not find any relevant documents to answer this question."}
    
    # Format context
    context = "\n\n".join([doc.page_content for doc in documents])
    
    # Generate
    messages = GENERATE_PROMPT.format_messages(
        context=context,
        question=question
    )
    response = llm.invoke(messages)
    
    print(f"Generated {len(response.content)} characters")
    return {"generation": response.content}

print("Generate node defined!")

In [None]:
def check_hallucination(state: SelfRAGState) -> dict:
    """Check if generation is grounded in documents."""
    print(f"--- CHECK HALLUCINATION ---")
    documents = state["filtered_documents"]
    generation = state["generation"]
    
    is_grounded = hallucination_grader.grade(documents, generation)
    print(f"Grounded in facts: {is_grounded}")
    
    return {}  # Just for routing, no state update

def check_answer(state: SelfRAGState) -> dict:
    """Check if answer addresses the question."""
    print(f"--- CHECK ANSWER ---")
    question = state["question"]
    generation = state["generation"]
    
    is_useful = answer_grader.grade(question, generation)
    print(f"Addresses question: {is_useful}")
    
    return {}  # Just for routing

print("Check nodes defined!")

## Step 5: Define Routing Functions

Self-RAG uses conditional edges to route based on grading results:

In [None]:
def route_after_grading(state: SelfRAGState) -> Literal["generate", "no_docs"]:
    """Route based on whether we have relevant documents."""
    filtered = state.get("filtered_documents", [])
    if not filtered:
        print("No relevant documents found")
        return "no_docs"
    return "generate"

def route_after_hallucination_check(state: SelfRAGState) -> Literal["check_answer", "retry"]:
    """Route based on hallucination check."""
    documents = state["filtered_documents"]
    generation = state["generation"]
    retry_count = state.get("retry_count", 0)
    max_retries = state.get("max_retries", 3)
    
    is_grounded = hallucination_grader.grade(documents, generation)
    
    if is_grounded:
        return "check_answer"
    elif retry_count < max_retries:
        print(f"Hallucination detected, retrying ({retry_count + 1}/{max_retries})")
        return "retry"
    else:
        print("Max retries reached, proceeding anyway")
        return "check_answer"

def route_after_answer_check(state: SelfRAGState) -> Literal["end", "retry"]:
    """Route based on answer quality check."""
    question = state["question"]
    generation = state["generation"]
    retry_count = state.get("retry_count", 0)
    max_retries = state.get("max_retries", 3)
    
    is_useful = answer_grader.grade(question, generation)
    
    if is_useful:
        return "end"
    elif retry_count < max_retries:
        print(f"Answer not useful, retrying ({retry_count + 1}/{max_retries})")
        return "retry"
    else:
        print("Max retries reached, returning current answer")
        return "end"

print("Routing functions defined!")

In [None]:
def increment_retry(state: SelfRAGState) -> dict:
    """Increment retry counter."""
    return {"retry_count": state.get("retry_count", 0) + 1}

def handle_no_docs(state: SelfRAGState) -> dict:
    """Handle case when no relevant documents found."""
    return {
        "generation": "I could not find any relevant documents to answer this question. Please try rephrasing or ask about a topic covered in the indexed documents."
    }

print("Helper nodes defined!")

## Step 6: Build the Self-RAG Graph

In [None]:
from langgraph.graph import StateGraph, START, END

# Build the graph
graph_builder = StateGraph(SelfRAGState)

# Add nodes
graph_builder.add_node("retrieve", retrieve)
graph_builder.add_node("grade_documents", grade_documents)
graph_builder.add_node("generate", generate)
graph_builder.add_node("handle_no_docs", handle_no_docs)
graph_builder.add_node("increment_retry", increment_retry)

# Add edges
graph_builder.add_edge(START, "retrieve")
graph_builder.add_edge("retrieve", "grade_documents")

# Conditional edge after grading
graph_builder.add_conditional_edges(
    "grade_documents",
    route_after_grading,
    {
        "generate": "generate",
        "no_docs": "handle_no_docs",
    }
)

# No docs path ends
graph_builder.add_edge("handle_no_docs", END)

# After generation, check hallucination
graph_builder.add_conditional_edges(
    "generate",
    route_after_hallucination_check,
    {
        "check_answer": END,  # Simplified: go to end if grounded
        "retry": "increment_retry",
    }
)

# Retry loops back to generate
graph_builder.add_edge("increment_retry", "generate")

# Compile
self_rag_graph = graph_builder.compile()

print("Self-RAG graph compiled!")

In [None]:
# Visualize the graph
from IPython.display import Image, display

try:
    display(Image(self_rag_graph.get_graph().draw_mermaid_png()))
except Exception as e:
    print(f"Could not render graph: {e}")
    print(self_rag_graph.get_graph().draw_ascii())

## Step 7: Test Self-RAG

In [None]:
# Test with a relevant question
question = "What is Self-RAG and how does it improve upon traditional RAG?"

print(f"Question: {question}\n")
print("=" * 50)

result = self_rag_graph.invoke({
    "question": question,
    "retry_count": 0,
    "max_retries": 2,
})

print("=" * 50)
print("\nFINAL ANSWER:")
print(result["generation"])
print(f"\nUsed {len(result.get('filtered_documents', []))} relevant documents")
print(f"Retries: {result.get('retry_count', 0)}")

In [None]:
# Test with an irrelevant question (should handle gracefully)
irrelevant_question = "What is the best recipe for chocolate cake?"

print(f"Question: {irrelevant_question}\n")
print("=" * 50)

result2 = self_rag_graph.invoke({
    "question": irrelevant_question,
    "retry_count": 0,
    "max_retries": 2,
})

print("=" * 50)
print("\nFINAL ANSWER:")
print(result2["generation"])

## Complete Self-RAG Implementation

In [None]:
# Complete Self-RAG Implementation

from typing import List, Literal
from typing_extensions import TypedDict
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama import ChatOllama
from langgraph.graph import StateGraph, START, END
from langgraph_ollama_local import LocalAgentConfig
from langgraph_ollama_local.rag import (
    LocalRetriever,
    DocumentGrader,
    HallucinationGrader,
    AnswerGrader,
)

# 1. State
class SelfRAGState(TypedDict):
    question: str
    documents: List[Document]
    filtered_documents: List[Document]
    generation: str
    retry_count: int
    max_retries: int

# 2. Components
config = LocalAgentConfig()
llm = ChatOllama(model=config.ollama.model, base_url=config.ollama.base_url, temperature=0)
retriever = LocalRetriever()
doc_grader = DocumentGrader(llm)
hallucination_grader = HallucinationGrader(llm)

# 3. Prompt
PROMPT = ChatPromptTemplate.from_template(
    """Answer from context only. If unknown, say so.
Context: {context}
Question: {question}
Answer:"""
)

# 4. Nodes
def retrieve(state): 
    return {"documents": retriever.retrieve_documents(state["question"], k=5)}

def grade_docs(state):
    relevant, _ = doc_grader.grade_documents(state["documents"], state["question"])
    return {"filtered_documents": relevant}

def generate(state):
    if not state["filtered_documents"]:
        return {"generation": "No relevant documents found."}
    context = "\n".join([d.page_content for d in state["filtered_documents"]])
    response = llm.invoke(PROMPT.format_messages(context=context, question=state["question"]))
    return {"generation": response.content}

def retry(state):
    return {"retry_count": state.get("retry_count", 0) + 1}

# 5. Routing
def check_and_route(state) -> str:
    if not state["filtered_documents"]:
        return "end"
    is_grounded = hallucination_grader.grade(state["filtered_documents"], state["generation"])
    if is_grounded or state.get("retry_count", 0) >= state.get("max_retries", 2):
        return "end"
    return "retry"

# 6. Build graph
g = StateGraph(SelfRAGState)
g.add_node("retrieve", retrieve)
g.add_node("grade", grade_docs)
g.add_node("generate", generate)
g.add_node("retry", retry)
g.add_edge(START, "retrieve")
g.add_edge("retrieve", "grade")
g.add_edge("grade", "generate")
g.add_conditional_edges("generate", check_and_route, {"end": END, "retry": "retry"})
g.add_edge("retry", "generate")

self_rag = g.compile()

# 7. Use it
result = self_rag.invoke({"question": "What is Self-RAG?", "retry_count": 0, "max_retries": 2})
print(result["generation"])

## Key Concepts Recap

| Component | Purpose |
|-----------|--------|
| **DocumentGrader** | Filters irrelevant retrieved documents |
| **HallucinationGrader** | Checks if answer is grounded in facts |
| **AnswerGrader** | Verifies answer addresses the question |
| **Retry Logic** | Re-generates when quality checks fail |
| **Conditional Edges** | Routes based on grading results |

## Self-RAG vs Basic RAG

| Aspect | Basic RAG | Self-RAG |
|--------|-----------|----------|
| Document filtering | None | Graded for relevance |
| Hallucination check | None | LLM-based verification |
| Answer quality | Assumed | Verified |
| Error recovery | None | Retry mechanism |

## What's Next?

In [Tutorial 10: CRAG](10_crag.ipynb), you'll learn:
- Web search as a fallback when retrieval fails
- Combining multiple knowledge sources
- Corrective retrieval strategies