# Conditional Retrieval and Generation (CRAG) Demo

This notebook demonstrates a graph-based orchestration that conditionally combines retrieval, document grading, query rewriting, web search, and RAG-style generation to produce high-quality answers. The pipeline uses an LLM both for generation and for structured decision-making (grading/re-writing), and a small state graph (`CARG`) to control flow based on document relevance.

## Key components
- Loading and chunking web documents, creating embeddings, and constructing a retriever for initial candidates.
- `retrieval_grader`: a structured-LLM grader that labels retrieved documents as relevant (`yes`) or not (`no`).
- `rag_chain`: a RAG prompt + LLM chain that generates answers from retrieved context.
- `question_rewriter`: an LLM chain that rewrites queries to improve web search effectiveness.
- `web_search_tool`: an external web search (Tavily) used when retrieved documents are insufficient.
- A state graph (`CARG`) that composes nodes for `retrieve`, `grade_documents`, `rewrite_query`, `web_search_node`, and `generate` and conditionally routes execution.

## What this notebook contains
- Building and indexing a document collection with OpenAI embeddings and Chroma.
- Implementing a binary document grader using `ChatOpenAI.with_structured_output`.
- Defining RAG and rewrite prompts and their runnable chains.
- Implementing graph nodes for retrieval, grading, rewriting, web search, and generation.
- Compiling and streaming execution of the state graph to observe node-by-node behavior and final generations.

## Workflow (high level)
1. Retrieve top candidate documents for the input question.
2. Grade each document for relevance; if all are irrelevant, rewrite the query and perform a web search.
3. If relevant documents exist, run the RAG chain to generate an answer using retrieved context.
4. Stream and inspect intermediate states and the final generation.

In [1]:
from langchain.text_splitter import RecursiveCharacterTextSplitter  
from langchain_community.document_loaders import WebBaseLoader  
from langchain_community.vectorstores import Chroma  
from langchain_core.output_parsers import StrOutputParser  
from langchain_core.runnables import RunnablePassthrough  
from langchain_openai import ChatOpenAI, OpenAIEmbeddings 
from langchain.prompts import ChatPromptTemplate
from langchain.load import dumps, loads
from langchain_community.llms import Cohere
from langchain.retrievers import  ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CohereRerank
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain import hub
from langchain_community.tools.tavily_search import TavilySearchResults
from typing import List
from typing_extensions import TypedDict
from langchain.schema import Document
from langgraph.graph import END, StateGraph, START
from pprint import pprint
import yaml
import bs4  
import os


USER_AGENT environment variable not set, consider setting it to identify your requests.

For example, replace imports like: `from langchain_core.pydantic_v1 import BaseModel`
with: `from pydantic import BaseModel`
or the v1 compatibility namespace if you are working in a code base that has not been fully upgraded to pydantic 2 yet. 	from pydantic.v1 import BaseModel

  exec(code_obj, self.user_global_ns, self.user_ns)


In [2]:
# Get the current working directory
cwd = os.getcwd()

# Build the path to config.yaml
config_path = os.path.join(cwd, '..', 'configs', 'config.yaml')

# Normalize the path
config_path = os.path.abspath(config_path)

# Load credential from config file
with open(config_path, 'r') as file:
    config = yaml.safe_load(file)

# Set environment variables
os.environ['LANGCHAIN_API_KEY'] = config['API']['LANGCHAIN']
os.environ['OPENAI_API_KEY'] = config['API']['OPENAI']
os.environ['TAVILY_API_KEY'] = config['API']['TAVILY']

# Configure chat LLM (deterministic)
llm = ChatOpenAI(temperature=0) 

In [3]:
# Load documents from web pages
urls = [
    "https://lilianweng.github.io/posts/2023-06-23-agent/",
    "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
    "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]

# Load documents from the URLs
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

# Split long documents into smaller overlapping chunks suitable for embeddings (use from_tiktoken_encoder for token-based splitting so that it works well with LLMs)
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=300, chunk_overlap=0)
splits = text_splitter.split_documents(docs_list)  # list of smaller document chunks

# Create embeddings and store them in a vector DB (Chroma)
vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings())  # uses OpenAI embeddings under the hood

# Create a retriever to fetch relevant docs (return the top 10 results)
retriever = vectorstore.as_retriever(search_kwargs={"k": 10})

In [4]:
# Define the grade documents output schema
class GradeDocuments(BaseModel):
    """
    Binary score for relevance check on retrieved documents.
    """

    binary_score: str = Field(description="Documents are relevant to the question, 'yes' or 'no'")

# LLM with structured output for grading
structured_llm_grader = llm.with_structured_output(GradeDocuments)



In [5]:
# Define system prompt for grading
system = """
You are a grader assessing relevance of a retrieved document to a user question. \n 
If the document contains keyword(s) or semantic meaning related to the question, grade it as relevant. \n
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
"""

# Create prompt template for grading
grade_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
    ]
)

In [6]:
# Combine prompt and LLM into a runnable grader
retrieval_grader = grade_prompt | structured_llm_grader

# Test the retrieval grader with a sample question and document
question = "agent memory"
docs = retriever.get_relevant_documents(question)
doc_txt = docs[1].page_content
print(retrieval_grader.invoke({"question": question, "document": doc_txt}))

  docs = retriever.get_relevant_documents(question)


binary_score='yes'


In [7]:
# RAG Prompt
prompt = hub.pull("rlm/rag-prompt")

# Post-processing
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

# Create RAG LLM chain
rag_chain = prompt | llm | StrOutputParser()

# Run the RAG chain with retrieved documents
generation = rag_chain.invoke({"context": docs, "question": question})

In [8]:
# Rewirte Prompt
system = """
You a question re-writer that converts an input question to a better version that is optimized \n 
for web search. Look at the input and try to reason about the underlying semantic intent / meaning.
"""

# Create prompt template for question rewriting
re_write_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        (
            "human",
            "Here is the initial question: \n\n {question} \n Formulate an improved question.",
        ),
    ]
)

# Create question re-writer chain
question_rewriter = re_write_prompt | llm | StrOutputParser()


In [9]:
# Create Tavily web search tool
web_search_tool = TavilySearchResults(k=3)

  web_search_tool = TavilySearchResults(k=3)


In [10]:
# Create graph state
class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        generation: LLM generation
        web_search: whether to add search
        documents: list of documents
    """

    question: str
    generation: str
    web_search: str
    documents: List[str]

In [11]:
# Create retriever graph node
def retrieve(state):
    """
    Retrieve documents

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, documents, that contains retrieved documents
    """
    print("---RETRIEVE---")
    question = state["question"]

    # Retrieval
    documents = retriever.get_relevant_documents(question)
    return {"documents": documents, "question": question}

# Create generator graph node
def generate(state):
    """
    Generate answer

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, generation, that contains LLM generation
    """
    print("---GENERATE---")
    question = state["question"]
    documents = state["documents"]

    # RAG generation
    generation = rag_chain.invoke({"context": documents, "question": question})
    return {"documents": documents, "question": question, "generation": generation}

# Create grade documents graph node
def grade_documents(state):
    """
    Determines whether the retrieved documents are relevant to the question.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates documents key with only filtered relevant documents
    """

    print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
    question = state["question"]
    documents = state["documents"]

    # Score each doc
    filtered_docs = []
    web_search = "No"
    for d in documents:
        score = retrieval_grader.invoke(
            {"question": question, "document": d.page_content}
        )
        grade = score.binary_score
        if grade == "yes":
            print("---GRADE: DOCUMENT RELEVANT---")
            filtered_docs.append(d)
        else:
            print("---GRADE: DOCUMENT NOT RELEVANT---")
            web_search = "Yes"
            continue
    return {"documents": filtered_docs, "question": question, "web_search": web_search}

# Create rewrite query graph node
def rewrite_query(state):
    """
    Rewrite the query to produce a better question.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates question key with a re-phrased question
    """

    print("---TRANSFORM QUERY---")
    question = state["question"]
    documents = state["documents"]

    # Re-write question
    better_question = question_rewriter.invoke({"question": question})
    return {"documents": documents, "question": better_question}

# Create web search graph node
def web_search(state):
    """
    Web search based on the re-phrased question.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates documents key with appended web results
    """

    print("---WEB SEARCH---")
    question = state["question"]
    documents = state["documents"]

    # Web search
    docs = web_search_tool.invoke({"query": question})
    web_results = "\n".join([d["content"] for d in docs])
    web_results = Document(page_content=web_results)
    documents.append(web_results)

    return {"documents": documents, "question": question}

# Create decision graph node
def decide_to_generate(state):
    """
    Determines whether to generate an answer, or re-generate a question.

    Args:
        state (dict): The current graph state

    Returns:
        str: Binary decision for next node to call
    """

    print("---ASSESS GRADED DOCUMENTS---")
    state["question"]
    web_search = state["web_search"]
    state["documents"]

    if web_search == "Yes":
        # All documents have been filtered check_relevance
        # We will re-generate a new query
        print(
            "---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---"
        )
        return "rewrite_query"
    else:
        # We have relevant documents, so generate answer
        print("---DECISION: GENERATE---")
        return "generate"

In [12]:
# Create the state graph
CARG = StateGraph(GraphState)

# Define the nodes
CARG.add_node("retrieve", retrieve)  # retrieve
CARG.add_node("grade_documents", grade_documents)  # grade documents
CARG.add_node("generate", generate)  # generate
CARG.add_node("rewrite_query", rewrite_query)  # rewrite_query
CARG.add_node("web_search_node", web_search)  # web search

# Build graph
CARG.add_edge(START, "retrieve")
CARG.add_edge("retrieve", "grade_documents")
CARG.add_conditional_edges("grade_documents", decide_to_generate,
    {
        "rewrite_query": "rewrite_query",
        "generate": "generate",
    },
)
CARG.add_edge("rewrite_query", "web_search_node")
CARG.add_edge("web_search_node", "generate")
CARG.add_edge("generate", END)

# Compile
app = CARG.compile()

In [13]:
# Stream execution
inputs = {"question": "What are the types of agent memory?"}
for output in app.stream(inputs):
    for key, value in output.items():
        # Node
        pprint(f"Node '{key}':")
        # Optional: print full state at each node
        # pprint.pprint(value["keys"], indent=2, width=80, depth=None)
    pprint("\n---\n")

# Final generation
pprint(value["generation"])

---RETRIEVE---
"Node 'retrieve':"
'\n---\n'
---CHECK DOCUMENT RELEVANCE TO QUESTION---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---ASSESS GRADED DOCUMENTS---
---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---
"Node 'grade_documents':"
'\n---\n'
---TRANSFORM QUERY---
"Node 'rewrite_query':"
'\n---\n'
---WEB SEARCH---
"Node 'web_search_node':"
'\n---\n'
---GENERATE---
"Node 'generate':"
'\n---\n'
('The different types of memory used by agents are short-term memory and '
 'long-term memory. Short-term memory is utilized for in-context learning, '
 'while long-term memory allows agents to retain and recall information over '
 'extended periods. Long-term memory can be store

In [14]:
# Another example
inputs = {"question": "How does the AlphaCodium paper work?"}
for output in app.stream(inputs):
    for key, value in output.items():
        # Node
        pprint(f"Node '{key}':")
        # Optional: print full state at each node
        # pprint.pprint(value["keys"], indent=2, width=80, depth=None)
    pprint("\n---\n")

# Final generation
pprint(value["generation"])


---RETRIEVE---
"Node 'retrieve':"
'\n---\n'
---CHECK DOCUMENT RELEVANCE TO QUESTION---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---ASSESS GRADED DOCUMENTS---
---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---
"Node 'grade_documents':"
'\n---\n'
---TRANSFORM QUERY---
"Node 'rewrite_query':"
'\n---\n'
---WEB SEARCH---
"Node 'web_search_node':"
'\n---\n'
---GENERATE---
"Node 'generate':"
'\n---\n'
('The AlphaCodium paper introduces a test-driven, multi-step code generation '
 'flow that enhances the performance of LLMs in generating code. It '
 'incorporates elements of the Generative Adversarial Network architecture and '
 'uses an adversarial model to ensure code