# Graph 6: Self Reflective RAG

In [None]:
import boto3
from getpass import getpass

from langchain_aws import ChatBedrockConverse, BedrockEmbeddings
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import FAISS
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from langchain_core.output_parsers import PydanticOutputParser
from langchain.prompts import ChatPromptTemplate
from langgraph.graph import StateGraph, START, END


from typing import TypedDict, List, Literal
from pydantic import BaseModel, Field
import json
import os
from dotenv import load_dotenv
load_dotenv()

In [None]:
os.environ["LANGCHAIN_API_KEY"] = os.getenv("LANGCHAIN_API_KEY") or \
    getpass("Enter LangSmith API Key: ")

os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGCHAIN_PROJECT"] = os.getenv("LANGCHAIN_PROJECT") or \
    "default-project"
    
print("LangSmith configured with project:", os.environ["LANGCHAIN_PROJECT"])

In [None]:
# Initialize the Bedrock client and the ChatBedrockConverse LLM

def _build_bedrock_client():
    return boto3.client(
        "bedrock-runtime",
        region_name="ap-south-1",
        aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
        aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
    )

llm = ChatBedrockConverse(
        model=os.getenv("MODEL_NAME", "mistral.magistral-small-2509"),
        temperature=0.5,
        max_tokens=1024,
        region_name="ap-south-1",
        client=_build_bedrock_client(),
    )

In [None]:
llm.invoke("What is the capital of India?")

In [None]:
# As it's a RAG, so we need to store the embeddings first. 
docs = (
    PyPDFLoader(os.getenv("RAG_DOCS_PATH")).load()
)

chunks = RecursiveCharacterTextSplitter(
    chunk_size=2000,
    chunk_overlap=500,
).split_documents(docs)

In [None]:
len(chunks)

In [None]:
# Initializing the BedrockEmbeddings model
embeddings = BedrockEmbeddings(
    model_id=os.getenv("EMBEDDING_MODEL", "amazon.titan-embed-text-v2:0"),
    region_name="ap-south-1",
    client=_build_bedrock_client(),
)

In [None]:
# Initializing the vector store and adding the documents to it
embeddings_dir = os.getenv("EMBEDDINGS_STORE_PATH", "./embeddings_store")

if not os.path.exists(embeddings_dir):
    vectorstore = FAISS.from_documents(chunks, embeddings)
    vectorstore.save_local(embeddings_dir)
else:
    vectorstore = FAISS.load_local(embeddings_dir, embeddings, allow_dangerous_deserialization=True)


In [None]:
# Initializing the retriever
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})

In [None]:
vectorstore.similarity_search("Cryptographic keys & Algorithms", k=3)

In [None]:
# Agent State Schema
class AgentState(TypedDict):
    question: str
    need_retrieval: bool
    retrieved_docs: List[Document]
    relevant_docs: List[Document]
    answer: str

In [None]:
# Structured Output for a Retrieve Decision Node
class RetrieveDecisionSchema(BaseModel):
    need_retrieval: bool = Field(
        ...,
        description="True if the agent needs to perform retrieval to answer the question reliably, False otherwise."
    )
    reasoning: str = Field(
        ...,
        description="A brief explanation of why the agent decided to retrieve or not."
    )
    
    
# Structured Output for Relevancy Check for filtering retrieved docs
class RelevancyCheckSchema(BaseModel):
    is_relevant: bool = Field(
        ...,
        description="True if the document is relevant to the question, False otherwise."
    )
    reasoning: str = Field(
        ...,
        description="A brief explanation of why these documents are relevant and others are not."
    )
    

In [None]:
# Prompt for the Retrieve Decision Node
retrieve_decision_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", """You are an intelligent agent designed to answer questions based on a given set of documents. Your task is to determine whether you need to perform retrieval from the document store to answer the question reliably. Guidelines to decide if retrieval is needed:
        - True: If the question is specific and likely requires information that is not commonly known or is detailed in the documents.
        - False: If the question is general and can be answered based on common knowledge or does not require specific information from the documents.
        Respond with Valid JSON Only:
        {json_schema}
        """),
        ("human", "<Question>{question}</Question>")
    ]
)

# Prompt for Direct Answer Generation Node
direct_answer_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", """You are an intelligent agent designed to answer questions based on your knowledge. Your task is to generate a direct answer to the question based on your existing knowledge. Guidelines for generating the answer:
        - Provide a concise and accurate answer to the question.
        - Do not include any information that is not relevant to the question.
        - If you are unsure about the answer, you can state that you do not know.
        """),
        ("human", "<Question>{question}</Question>")
    ]
)

# Prompt for Relevancy Check Node
relevancy_check_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", """You are an intelligent agent designed to filter retrieved documents based on their relevance to a given question. Your task is to determine whether the retrieved documents are relevant to the question. Guidelines for determining relevance:
        - True: If the document contains information that is directly related to the question and can help in answering it.
        - False: If the document does not contain relevant information or is not related to the question.
        Respond with Valid JSON Only:
        {json_schema}
        """),
        ("human", "<Question>{question}</Question><Retrieved_Docs>{retrieved_docs}</Retrieved_Docs>")
    ]
)

In [None]:
# Node: Retrieve Decision Node
def retrieve_decision_node(agent_state: AgentState) -> AgentState:
    ## If model is good and support structured output
    # retrieve_decision_response = llm.with_structured_output(RetrieveDecisionSchema).invoke(
    #     retrieve_decision_prompt.format_messages(question=agent_state["question"], json_schema=parser.get_format_instructions())
    # )
    
    ## If model is not good and does not support structured output
    parser = PydanticOutputParser(pydantic_object=RetrieveDecisionSchema)
    try:
        retrieve_decision_response = llm.invoke(
            retrieve_decision_prompt.format_messages(question=agent_state["question"], json_schema=parser.get_format_instructions())
        )
        retrieve_decision_response = parser.parse(retrieve_decision_response.content)
    except Exception as e:
        print("Error parsing the response, defaulting to need_retrieval=True. Error:", e)
        retrieve_decision_response = RetrieveDecisionSchema(need_retrieval=True)
    
    return {"need_retrieval": retrieve_decision_response.need_retrieval}

In [None]:
### Scratchpad for testing

## here, i'm testing if the structured output parsing is working correctly or not. And it'll also help in understanding if model is capable of Tool Call.
# check_retrieval_response = llm.with_structured_output(RetrieveDecisionSchema).invoke(
#         retrieve_decision_prompt.format_messages(question="What are the CO's and POs of case 56")
#     )
# print(check_retrieval_response)


## Here, I'm testing smaller models who might not do tool call, but we can check if they are able to follow the instructions for structured output or not.
# from langchain_core.output_parsers import PydanticOutputParser
# parser = PydanticOutputParser(pydantic_object=RetrieveDecisionSchema)
# response = llm.invoke(
#     retrieve_decision_prompt.format_messages(question="Who is the team player in last month") +
#     [("system", f"Respond with valid JSON only:\n{parser.get_format_instructions()}")]
# )
# check_retrieval_response = parser.parse(response.content)
# print(check_retrieval_response)
# # print(parser.get_format_instructions())


In [None]:
# Node : Generate Answer from LLM Data Node 
def direct_answer_node(agent_state: AgentState) -> AgentState:
    direct_answer_response = llm.invoke(
        direct_answer_prompt.format_messages(question=agent_state["question"])
    )
    
    return {"answer": direct_answer_response.content}

In [None]:
# Node : Retrive Docs Node
def retrieve_docs_node(agent_state: AgentState) -> AgentState:
    retrieved_docs = retriever.invoke(agent_state["question"])
    return {"retrieved_docs": retrieved_docs}

In [None]:
# Router : Router Function for deciding the flow based on the retrieval decision
def router_node(agent_state: AgentState) -> Literal["direct_answer_node", "retrieve_docs_node"]:
    if agent_state["need_retrieval"]:
        return "retrieve_docs_node"
    else:
        return "direct_answer_node"

In [None]:
# Node: Filter out non relevant docs out of all Retrieved Docs 
def filter_relevant_docs_node(agent_state: AgentState) -> AgentState:
    relevant_docs: List[Document] = []
    parser = PydanticOutputParser(pydantic_object=RelevancyCheckSchema)
    
    for doc in agent_state["retrieved_docs"]:
        relevancy_check_response = llm.invoke(
            relevancy_check_prompt.format_messages(
                question=agent_state["question"],
                retrieved_docs=doc.page_content,
                json_schema=parser.get_format_instructions()
            )
        )
        
        try:
            relevancy_check_response = parser.parse(relevancy_check_response.content)
            if relevancy_check_response.is_relevant:
                relevant_docs.append(doc)
        except Exception as e:
            print(f"Error parsing relevancy check response for doc {doc.metadata['source']}, defaulting to relevant. Error: {e}")
            relevant_docs.append(doc)

    return {"relevant_docs": relevant_docs}

In [None]:
# # Note: The above implementation of filter_relevant_docs_node is sequential and can be slow if there are many retrieved documents. We can optimize it by parallelizing the relevancy checks using ThreadPoolExecutor.
# from concurrent.futures import ThreadPoolExecutor
# def check_relevancy(doc, question, parser):
#     try:
#         response = llm.invoke(
#             relevancy_check_prompt.format_messages(
#                 question=question,
#                 retrieved_docs=doc.page_content,
#                 json_schema=parser.get_format_instructions()
#             )
#         )
#         result = parser.parse(response.content)
#         return doc if result.is_relevant else None
#     except:
#         return doc  # Default to relevant on error


# def filter_relevant_docs_node(agent_state: AgentState) -> AgentState:
#     parser = PydanticOutputParser(pydantic_object=RelevancyCheckSchema)

#     with ThreadPoolExecutor(max_workers=5) as executor:
#         results = executor.map(
#             lambda doc: check_relevancy(doc, agent_state["question"], parser),
#             agent_state["retrieved_docs"]
#         )

#     relevant_docs = [doc for doc in results if doc is not None]
#     return {"relevant_docs": relevant_docs}

In [None]:
# # Note: If the number of retrieved documents is small (like 3-5), the sequential version might be simpler and sufficient. But if we expect a larger number of retrieved documents, the parallelized version can significantly reduce the time taken for relevancy checks.
# def filter_relevant_docs_node(agent_state: AgentState) -> AgentState:
#     if not agent_state["retrieved_docs"]:
#         return {"relevant_docs": []}

#     # Combine all docs into one prompt
#     docs_text = "\n\n---\n\n".join([
#         f"Document {i+1}:\n{doc.page_content}"
#         for i, doc in enumerate(agent_state["retrieved_docs"])
#     ])

#     batch_prompt = f"""Question: {agent_state["question"]}

# Documents:
# {docs_text}

# For each document, return JSON array with relevancy:
# {{"results": [{{"doc_index": 0, "is_relevant": true}}, ...]}}"""

#     response = llm.invoke(batch_prompt)

#     try:
#         result = json.loads(response.content)
#         relevant_docs = [
#             agent_state["retrieved_docs"][r["doc_index"]]
#             for r in result["results"]
#             if r["is_relevant"]
#         ]
#     except Exception as e:
#         print("Error parsing batch relevancy response, defaulting to all docs relevant. Error:", e)
#         relevant_docs = agent_state["retrieved_docs"]  # Fallback

#     return {"relevant_docs": relevant_docs}

In [None]:
graph = StateGraph(AgentState)

# adding nodes
graph.add_node("retrieve_decision_node", retrieve_decision_node)
graph.add_node("direct_answer_node", direct_answer_node)
graph.add_node("retrieve_docs_node", retrieve_docs_node)
graph.add_node("relevancy_check_node", filter_relevant_docs_node)

# making edges
graph.add_edge(START, "retrieve_decision_node")
graph.add_conditional_edges(
    "retrieve_decision_node",
    router_node,
    {
        "direct_answer_node": "direct_answer_node",
        "retrieve_docs_node": "retrieve_docs_node"
    }
)
graph.add_edge("direct_answer_node", END)
graph.add_edge("retrieve_docs_node", "relevancy_check_node")
graph.add_edge("relevancy_check_node", END)

app = graph.compile()

# Run the visualization
from IPython.display import display, Image
display(Image(app.get_graph().draw_mermaid_png()))


In [None]:
test_case_1 = app.invoke(
    {
        "question": "What are the cryptographic algorithms ?",
        "need_retrieval": False,
        "retrieved_docs": [],
        "answer": ""
    }
)

In [None]:
print("Test Case 1 Output:\n", json.dumps(test_case_1, indent=2) )

In [None]:
test_case_2 = app.invoke(
    {
        "question": "Who is the managing the infosec team",
        "need_retrieval": False,
        "retrieved_docs": [],
        "answer": ""
    }
)

print("Test Case 2 Output:\n", test_case_2)

In [None]:
print(len(test_case_2["relevant_docs"]))