# 04. Agentic RAG Workflow (LangGraph)

This is the professional, industry-grade Agentic RAG system.
We use **LangGraph** to create a state machine that orchestrates the RAG process.

**Agent Capabilities:**
1.  **Retrieve**: Fetch documents.
2.  **Grade**: Evaluate if retrieved documents are relevant to the question.
3.  **Rewriter**: If documents are irrelevant, rewrite the query to be better.
4.  **Generate**: Synthesize the answer.

**Graph Flow:**
`Start` -> `Retrieve` -> `Grade` -> `(Decide)` -> 
   - If Relevant -> `Generate` -> `End`
   - If Irrelevant -> `Rewrite Query` -> `Retrieve` (Loop)

In [None]:
from typing import List
from typing_extensions import TypedDict

from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
from langchain_core.documents import Document

from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_ollama import ChatOllama

from langgraph.graph import END, StateGraph

# --- CONFIGURATION ---
DB_DIR = "data/chroma_db"
EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
LLM_MODEL = "llama3"

# Setup Components
embedding_model = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL, model_kwargs={'device': 'cpu'})
vector_store = Chroma(persist_directory=DB_DIR, embedding_function=embedding_model)
retriever = vector_store.as_retriever()
llm = ChatOllama(model=LLM_MODEL, temperature=0, format="json") # JSON mode for grading
llm_gen = ChatOllama(model=LLM_MODEL, temperature=0) # Normal mode for generation

## 1. Define State

In [None]:
class GraphState(TypedDict):
    """
    Represents the state of our graph.
    """
    question: str
    generation: str
    documents: List[Document]
    reformulated_count: int

## 2. Define Nodes

In [None]:
def retrieve(state):
    """
    Retrieve documents from vectorstore
    """
    print("---RETRIEVE---")
    question = state["question"]
    documents = retriever.invoke(question)
    return {"documents": documents, "question": question}

def grade_documents(state):
    """
    Determines whether the retrieved documents are relevant to the question.
    """
    print("---CHECK DOCUMENT RELEVANCE---")
    question = state["question"]
    documents = state["documents"]
    
    # LLM with JSON output to score relevancy
    prompt = ChatPromptTemplate.from_template(
        """You are a grader assessing relevance of a retrieved document to a user question. \n 
        Here is the retrieved document: \n\n {context} \n\n
        Here is the user question: {question} \n
        If the document contains keyword(s) or semantic meaning useful to the question, grade it as relevant. \n
        Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.\n
        Provide the binary score as a JSON with a single key 'score' and no premable or explaination."""
    )
    chain = prompt | llm | JsonOutputParser()
    
    filtered_docs = []
    for d in documents:
        score = chain.invoke({"question": question, "context": d.page_content})
        grade = score["score"]
        if grade == "yes":
            print("   - Grade: RELEVANT")
            filtered_docs.append(d)
        else:
            print("   - Grade: NOT RELEVANT")
            continue
            
    return {"documents": filtered_docs, "question": question}

def generate(state):
    """
    Generate answer
    """
    print("---GENERATE---")
    question = state["question"]
    documents = state["documents"]
    
    prompt = ChatPromptTemplate.from_template(
        """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.
        Question: {question} 
        Context: {context} 
        Answer:"""
    )
    chain = prompt | llm_gen | StrOutputParser()
    generation = chain.invoke({"context": documents, "question": question})
    return {"generation": generation}

def transform_query(state):
    """
    Transform the query to produce a better question.
    """
    print("---TRANSFORM QUERY---")
    question = state["question"]
    count = state.get("reformulated_count", 0) + 1
    
    if count > 3: # Limit loops
        print("   - Max retries reached, stopping.")
        return {"question": question, "reformulated_count": count}

    prompt = ChatPromptTemplate.from_messages([
        ("system", "You are a helper that re-writes questions to improve retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."),
        ("human", "Here is the initial question: \n\n {question} \n Formulate an improved question. Output only the improved question string.")
    ])
    chain = prompt | llm_gen | StrOutputParser()
    better_question = chain.invoke({"question": question})
    print(f"   - Modified: {better_question}")
    return {"question": better_question, "reformulated_count": count}

## 3. Define Conditional Edges

In [None]:
def decide_to_generate(state):
    """
    Determines whether to generate an answer, or re-generate a question.
    """
    filtered_documents = state["documents"]
    count = state.get("reformulated_count", 0)
    
    if not filtered_documents and count <= 3:
        # No relevant documents found, regenerate question
        return "transform_query"
    else:
        # We have relevant documents, so generate answer
        return "generate"

## 4. Build Graph

In [None]:
workflow = StateGraph(GraphState)

# Add Nodes
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("transform_query", transform_query)

# Add Edges
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")

# Conditional Edge
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "transform_query": "transform_query",
        "generate": "generate",
    },
)
workflow.add_edge("transform_query", "retrieve")
workflow.add_edge("generate", END)

# Compile
app = workflow.compile()

## 5. Visualise Graph (Optional)
Requires `graphviz` usually, but we can print the structure.

In [None]:
from IPython.display import Image, display

try:
    display(Image(app.get_graph().draw_mermaid_png()))
except Exception:
    print("Graph visualization requires 'grandalf' or 'graphviz' extra dependencies.")

## 6. Run the Agent

In [None]:
inputs = {"question": "What is Multi-Head Attention?"}
for output in app.stream(inputs):
    for key, value in output.items():
        print(f"Finished Node: {key}")

# Final Answer
print("\n--- Final Result ---")
print(value.get("generation"))