# Tutorial 10: CRAG (Corrective RAG)

In this tutorial, you'll build a **Corrective RAG** system that uses web search as a fallback when local document retrieval is insufficient.

**What you'll learn:**
- **Knowledge Assessment**: Determine if retrieved documents are sufficient
- **Web Search Fallback**: Use external search when needed
- **Knowledge Fusion**: Combine local and web sources
- **Corrective Flow**: Fix retrieval failures dynamically

By the end, you'll have a robust RAG system that never says "I don't know" when the answer exists online.

## Why CRAG?

Self-RAG grades documents but can't fix poor retrieval. **CRAG** adds correction:

```
                    ┌─── Local docs sufficient ───▶ Generate
Retrieve → Grade ──┤
                    └─── Insufficient ───▶ Web Search ───▶ Generate
```

This is particularly useful when:
- Your document corpus doesn't cover the topic
- The user asks about recent events
- You need to supplement local knowledge

## Prerequisites

For web search, we'll use Tavily (free tier available) or DuckDuckGo:
```bash
pip install tavily-python duckduckgo-search
```

Set your API key (optional, we have a mock fallback):
```bash
export TAVILY_API_KEY="your-key-here"
```

In [None]:
# Setup
import os
from langgraph_ollama_local import LocalAgentConfig
from langchain_ollama import ChatOllama

config = LocalAgentConfig()
llm = ChatOllama(
    model=config.ollama.model,
    base_url=config.ollama.base_url,
    temperature=0,
)

# Check for Tavily API key
has_tavily = bool(os.environ.get("TAVILY_API_KEY"))
print(f"Using model: {config.ollama.model}")
print(f"Tavily API available: {has_tavily}")

## Step 1: Define the State

In [None]:
from typing import List, Literal
from typing_extensions import TypedDict
from langchain_core.documents import Document

class CRAGState(TypedDict):
    """State for Corrective RAG pipeline."""
    question: str                      # User's question
    documents: List[Document]          # Retrieved local documents
    web_results: List[Document]        # Web search results
    combined_documents: List[Document] # Merged documents for generation
    knowledge_source: str              # "local", "web", or "combined"
    generation: str                    # Final answer

print("State defined!")

## Step 2: Create Web Search Tool

In [None]:
def web_search(query: str, max_results: int = 3) -> List[Document]:
    """
    Search the web for information.
    Uses Tavily if available, otherwise falls back to mock results.
    """
    # Try Tavily first
    if os.environ.get("TAVILY_API_KEY"):
        try:
            from tavily import TavilyClient
            client = TavilyClient(api_key=os.environ["TAVILY_API_KEY"])
            response = client.search(query, max_results=max_results)
            
            return [
                Document(
                    page_content=r.get("content", ""),
                    metadata={
                        "source": r.get("url", ""),
                        "title": r.get("title", ""),
                        "type": "web",
                    }
                )
                for r in response.get("results", [])
            ]
        except Exception as e:
            print(f"Tavily search failed: {e}")
    
    # Try DuckDuckGo as fallback
    try:
        from duckduckgo_search import DDGS
        with DDGS() as ddgs:
            results = list(ddgs.text(query, max_results=max_results))
            return [
                Document(
                    page_content=r.get("body", ""),
                    metadata={
                        "source": r.get("href", ""),
                        "title": r.get("title", ""),
                        "type": "web",
                    }
                )
                for r in results
            ]
    except Exception as e:
        print(f"DuckDuckGo search failed: {e}")
    
    # Mock fallback for demonstration
    print("Using mock web search results")
    return [
        Document(
            page_content=f"Mock web result for: {query}. This would contain actual search results in production.",
            metadata={"source": "https://example.com", "type": "web_mock"}
        )
    ]

# Test web search
test_results = web_search("What is RAG in AI?", max_results=2)
print(f"Web search returned {len(test_results)} results")
for r in test_results:
    print(f"  - {r.metadata.get('title', 'No title')[:50]}...")

## Step 3: Setup Components

In [None]:
from langgraph_ollama_local.rag import LocalRetriever, DocumentGrader

# Local retriever
retriever = LocalRetriever()

# Document grader
doc_grader = DocumentGrader(llm)

print("Components initialized!")

## Step 4: Define Node Functions

In [None]:
def retrieve_local(state: CRAGState) -> dict:
    """Retrieve from local document store."""
    print("--- RETRIEVE LOCAL ---")
    question = state["question"]
    docs = retriever.retrieve_documents(question, k=4)
    print(f"Retrieved {len(docs)} local documents")
    return {"documents": docs}

print("Local retrieve node defined!")

In [None]:
def grade_documents(state: CRAGState) -> dict:
    """Grade documents and decide if web search is needed."""
    print("--- GRADE DOCUMENTS ---")
    question = state["question"]
    documents = state["documents"]
    
    if not documents:
        print("No documents retrieved, need web search")
        return {
            "combined_documents": [],
            "knowledge_source": "web",
        }
    
    # Grade each document
    relevant, irrelevant = doc_grader.grade_documents(documents, question)
    print(f"Relevant: {len(relevant)}, Irrelevant: {len(irrelevant)}")
    
    # Decide based on relevance
    if len(relevant) >= 2:
        # Enough relevant documents
        return {
            "combined_documents": relevant,
            "knowledge_source": "local",
        }
    elif len(relevant) == 1:
        # Some relevant, supplement with web
        return {
            "combined_documents": relevant,
            "knowledge_source": "combined",
        }
    else:
        # No relevant documents, use web
        return {
            "combined_documents": [],
            "knowledge_source": "web",
        }

print("Grade documents node defined!")

In [None]:
def search_web(state: CRAGState) -> dict:
    """Search the web for additional information."""
    print("--- WEB SEARCH ---")
    question = state["question"]
    
    web_docs = web_search(question, max_results=3)
    print(f"Found {len(web_docs)} web results")
    
    # Combine with any existing relevant docs
    existing = state.get("combined_documents", [])
    combined = existing + web_docs
    
    return {
        "web_results": web_docs,
        "combined_documents": combined,
    }

print("Web search node defined!")

In [None]:
from langchain_core.prompts import ChatPromptTemplate

CRAG_PROMPT = ChatPromptTemplate.from_template(
    """You are an assistant answering questions using provided context.
The context may come from local documents, web search, or both.
Use the information to provide an accurate, well-sourced answer.

Context:
{context}

Question: {question}

Provide a comprehensive answer based on the context. If citing web sources, mention them.

Answer:"""
)

def generate(state: CRAGState) -> dict:
    """Generate answer from combined documents."""
    print("--- GENERATE ---")
    question = state["question"]
    documents = state["combined_documents"]
    source = state["knowledge_source"]
    
    print(f"Using {len(documents)} documents from: {source}")
    
    if not documents:
        return {
            "generation": "I could not find any relevant information to answer this question."
        }
    
    # Format context with source attribution
    context_parts = []
    for i, doc in enumerate(documents, 1):
        source_type = doc.metadata.get("type", "local")
        source_name = doc.metadata.get("filename", doc.metadata.get("title", "Unknown"))
        context_parts.append(f"[Source {i} ({source_type}): {source_name}]\n{doc.page_content}")
    
    context = "\n\n".join(context_parts)
    
    # Generate
    messages = CRAG_PROMPT.format_messages(
        context=context,
        question=question
    )
    response = llm.invoke(messages)
    
    return {"generation": response.content}

print("Generate node defined!")

## Step 5: Define Routing

In [None]:
def route_after_grading(state: CRAGState) -> Literal["generate", "web_search"]:
    """Route based on knowledge source decision."""
    source = state["knowledge_source"]
    
    if source == "local":
        print("→ Local documents sufficient, generating...")
        return "generate"
    else:
        print(f"→ Need web search (source: {source})")
        return "web_search"

print("Routing defined!")

## Step 6: Build the CRAG Graph

In [None]:
from langgraph.graph import StateGraph, START, END

# Build graph
graph_builder = StateGraph(CRAGState)

# Add nodes
graph_builder.add_node("retrieve_local", retrieve_local)
graph_builder.add_node("grade_documents", grade_documents)
graph_builder.add_node("web_search", search_web)
graph_builder.add_node("generate", generate)

# Add edges
graph_builder.add_edge(START, "retrieve_local")
graph_builder.add_edge("retrieve_local", "grade_documents")

# Conditional edge after grading
graph_builder.add_conditional_edges(
    "grade_documents",
    route_after_grading,
    {
        "generate": "generate",
        "web_search": "web_search",
    }
)

# Web search leads to generate
graph_builder.add_edge("web_search", "generate")

# Generate leads to end
graph_builder.add_edge("generate", END)

# Compile
crag_graph = graph_builder.compile()

print("CRAG graph compiled!")

In [None]:
# Visualize
from IPython.display import Image, display

try:
    display(Image(crag_graph.get_graph().draw_mermaid_png()))
except Exception as e:
    print(crag_graph.get_graph().draw_ascii())

## Step 7: Test CRAG

In [None]:
# Test with a question covered by local docs
question1 = "What is Self-RAG and how does it work?"

print(f"Question: {question1}\n")
print("=" * 60)

result1 = crag_graph.invoke({"question": question1})

print("=" * 60)
print(f"\nKnowledge source: {result1['knowledge_source']}")
print(f"\nAnswer:\n{result1['generation']}")

In [None]:
# Test with a question NOT in local docs (triggers web search)
question2 = "What are the latest developments in AI in 2024?"

print(f"Question: {question2}\n")
print("=" * 60)

result2 = crag_graph.invoke({"question": question2})

print("=" * 60)
print(f"\nKnowledge source: {result2['knowledge_source']}")
print(f"Web results used: {len(result2.get('web_results', []))}")
print(f"\nAnswer:\n{result2['generation']}")

## Complete CRAG Implementation

In [None]:
# Complete CRAG Implementation

from typing import List, Literal
from typing_extensions import TypedDict
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama import ChatOllama
from langgraph.graph import StateGraph, START, END
from langgraph_ollama_local import LocalAgentConfig
from langgraph_ollama_local.rag import LocalRetriever, DocumentGrader

# State
class CRAGState(TypedDict):
    question: str
    documents: List[Document]
    web_results: List[Document]
    combined_documents: List[Document]
    knowledge_source: str
    generation: str

# Components
config = LocalAgentConfig()
llm = ChatOllama(model=config.ollama.model, base_url=config.ollama.base_url, temperature=0)
retriever = LocalRetriever()
grader = DocumentGrader(llm)

# Nodes
def retrieve(state):
    return {"documents": retriever.retrieve_documents(state["question"], k=4)}

def grade(state):
    relevant, _ = grader.grade_documents(state["documents"], state["question"])
    source = "local" if len(relevant) >= 2 else ("combined" if relevant else "web")
    return {"combined_documents": relevant, "knowledge_source": source}

def web_search_node(state):
    results = web_search(state["question"], 3)
    return {"web_results": results, "combined_documents": state["combined_documents"] + results}

def generate_answer(state):
    context = "\n\n".join([d.page_content for d in state["combined_documents"]])
    response = llm.invoke(f"Context: {context}\n\nQuestion: {state['question']}\n\nAnswer:")
    return {"generation": response.content}

def route(state) -> str:
    return "generate" if state["knowledge_source"] == "local" else "web_search"

# Build
g = StateGraph(CRAGState)
g.add_node("retrieve", retrieve)
g.add_node("grade", grade)
g.add_node("web_search", web_search_node)
g.add_node("generate", generate_answer)
g.add_edge(START, "retrieve")
g.add_edge("retrieve", "grade")
g.add_conditional_edges("grade", route, {"generate": "generate", "web_search": "web_search"})
g.add_edge("web_search", "generate")
g.add_edge("generate", END)

crag = g.compile()

# Use
result = crag.invoke({"question": "What is CRAG?"})
print(f"Source: {result['knowledge_source']}")
print(f"Answer: {result['generation'][:200]}...")

## Key Concepts Recap

| Component | Purpose |
|-----------|--------|
| **Knowledge Assessment** | Determine if local docs suffice |
| **Web Search Fallback** | External search when needed |
| **Source Tracking** | Know where answers came from |
| **Conditional Routing** | Dynamic path selection |

## CRAG vs Self-RAG

| Aspect | Self-RAG | CRAG |
|--------|----------|------|
| Focus | Quality checking | Knowledge gaps |
| Fallback | Retry generation | Web search |
| Best for | Accuracy | Coverage |

## What's Next?

In [Tutorial 11: Adaptive RAG](11_adaptive_rag.ipynb), you'll learn:
- Query classification and routing
- Multiple retrieval strategies
- Intelligent source selection