<a href="https://colab.research.google.com/github/Nischal2015/ncit-workshop/blob/main/3_rag/2_agentic_rag.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Lab 2
### Building an "Agentic" Rag System

### Run this if you are using Google Colab

In [None]:
# !pip install langchain langchain-openai langchain-qdrant

In [None]:
# import os
# from google.colab import userdata

# os.environ["OPENAI_API_KEY"] = userdata.get("OPENAI_API_KEY")
# os.environ["LANGSMITH_API_KEY"] = userdata.get("LANGSMITH_API_KEY")
# os.environ["LANGSMITH_PROJECT"] = "ncit-workshop"
# os.environ["LANGSMITH_TRACING"] = "true"
# os.environ["LANGSMITH_ENDPOINT"] = "https://api.smith.langchain.com"
# os.environ["QDRANT_API_KEY"] = userdata.get("QDRANT_API_KEY")
# os.environ["QDRANT_URL"] = "qdrant-host"

### Run this if you are running VSCode

In [None]:
import sys
from pathlib import Path

sys.path.append(str(Path().resolve().parent))
from core import load_vault_env

load_vault_env()

### Imports

In [1]:
import os
from typing import Literal, TypedDict

from pydantic import BaseModel, Field

from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_core.messages import SystemMessage, ToolMessage
from langchain_qdrant import QdrantVectorStore
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain.agents import create_agent
from langchain.tools import tool

from langgraph.types import Command
from langgraph.graph import StateGraph, START

from tavily import TavilyClient

from qdrant_client.models import models

## RAG Agent

### Initialization

#### Credentials

In [None]:
QDRANT_URL = os.getenv("QDRANT_URL")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")

#### Initialize clients

In [None]:
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
vector_store = QdrantVectorStore.from_existing_collection(
    collection_name="ncit-workshop-simple-rag",
    embedding=embeddings,
    url=QDRANT_URL,
    api_key=QDRANT_API_KEY,
    prefer_grpc=True,
)
llm = ChatOpenAI(model="gpt-4.1-mini", temperature=0.2)
tavily_client = TavilyClient(api_key=os.environ["TAVILY_API_KEY"])

### Retrieval Tool

In [None]:
# Defining tool using LangChain's tool decorator
@tool(
    name_or_callable="retrieve_relevant_docs",
    description="Retrieve relevant policy documents based on a question and optional filter.",
)
def retrieve_relevant_docs(
    question: str,
    filter: Literal[
        "finance",
        "it_policy",
        "hr_policy",
        "legal_policy",
        "operations_policy",
        "engineering_policy",
    ]
    | None = None,
    k: int = 3,
):
    """
    Retrieve relevant documents from the vector store based on the question and optional filter.
    """

    print(f"\n[CHAIN LOG] Searching for: '{question} in '{filter or 'ALL'}'")

    q_filter = None
    if filter:
        q_filter = models.Filter(
            must=[
                models.FieldCondition(
                    key="metadata.category", match=models.MatchValue(value=filter)
                )
            ]
        )

    # Perform search with scores
    results = vector_store.similarity_search_with_score(
        query=question, k=k, filter=q_filter
    )

    # Filter by Threshold & Format
    valid_context = []
    for doc, score in results:
        if score >= 0.5:
            valid_context.append(
                f"Policy ID: {doc.metadata['policy_id']}\n"
                f"Topic: {doc.metadata['topic']}\n"
                f"Rule: {doc.page_content}"
            )

    if not valid_context:
        return "NO RELEVANT DOCUMENT FOUND."

    return "\n\n".join(valid_context)

### RAG

In [None]:
rag_agent = create_agent(
    model="gpt-4.1-mini",
    tools=[retrieve_relevant_docs],
    system_prompt=SystemMessage(
        content=[
            {
                "type": "text",
                "text": (
                    "You are a strictly factual HR Policy Bot."
                    "Answer the question based ONLY on the context provided below."
                    "Cite the Policy ID and topic for every fact you state."
                ),
            }
        ]
    ),
)

In [None]:
# Also provider the filter parameter to narrow down the search
question = "What is the hotel spending limit for major metro areas like NYC? (filter: 'finance')"
for step in rag_agent.stream(
    {"messages": [{"role": "user", "content": question}]},
    stream_mode="values",
):
    step["messages"][-1].pretty_print()

## SQL Agent

In [None]:
db = SQLDatabase.from_uri("sqlite:///data.db")

print(f"Dialect: {db.dialect}")
print(f"Available Tables: {db.get_usable_table_names()}")
print(f"Sample output: {db.run('SELECT * from trips')}")

#### Initialize LLM model

In [None]:
llm = ChatOpenAI(name="gpt-4.1-mini", temperature=0)

#### SqlDatabaseToolkit available tools

In [None]:
toolkit = SQLDatabaseToolkit(db=db, llm=llm)

tools = toolkit.get_tools()

for toolkit_tool in tools:
    print(f"{toolkit_tool.name}: {toolkit_tool.description}\n")

### Create Agent

#### Prompt

In [None]:
system_prompt = """
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run,
then look at the results of the query and return the answer. Unless the user
specifies a specific number of example they wish to obtain, always limit your
query to at most {top_k} results.

You can order the results by a relevent column to return the most interesting
examples in the database. Never query for all the columns from a specific table,
only ask for the columns that are relevant to the question.

You MUST double check your query before executing it. If you get an error while
executing a query, try to fix the query and execute it again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the
database.

To start you should ALWAYS look at the tables in the databsae to see what you
can query, Do NOT skip this step.

Then you should query the schema for the most relevant tables.
""".format(
    dialect=db.dialect,
    top_k=5,
)

#### Langchain Agent

In [None]:
from langchain.agents import create_agent

agent = create_agent(
    model=llm,
    tools=tools,
    system_prompt=system_prompt,
)

#### Run the Agent

In [None]:
question = "Which trip had the most expense?"

for step in agent.stream(
    {"messages": [{"role": "user", "content": question}]},
    stream_mode="values",
):
    step["messages"][-1].pretty_print()

## LangGraph RAG

#### Create a retriver tool

In [None]:
vector_store = QdrantVectorStore.from_existing_collection(
    collection_name="apex_policies",
    embedding=embeddings,
    url=QDRANT_URL,
    api_key=QDRANT_API_KEY,
    prefer_grpc=True,
)

In [None]:
# Defining tool using LangChain's tool decorator
@tool(
    name_or_callable="retrieve_relevant_docs",
    description="Retrieve relevant policy documents based on a question and optional filter.",
)
def retriever_tool(
    question: str,
    filter: Literal["compliance_rule", "it_policy", "per_diem"] | None = None,
    k: int = 3,
):
    """
    Retrieve relevant documents from the vector store based on the question and optional filter.
    """

    print(f"\n[CHAIN LOG] Searching for: '{question} in '{filter or 'ALL'}'")

    q_filter = None
    if filter:
        q_filter = models.Filter(
            must=[
                models.FieldCondition(
                    key="metadata.category", match=models.MatchValue(value=filter)
                )
            ]
        )

    # Perform search with scores
    results = vector_store.similarity_search_with_score(
        query=question, k=k, filter=q_filter
    )

    valid_context = [
        f"Topic: {doc.metadata.get('topic', 'N/A')}\nRule: {doc.page_content}"
        for doc, _ in results
    ]

    return "\n\n".join(valid_context)

In [None]:
class AuditState(TypedDict):
    # Inputs
    query: str

    # Internal Reasoning State
    required_info_types: list[str]  # ["sql", "policy", "web"]

    # Data Buckets
    sql_results: str  # Raw CSV/Text from SQL
    policy_context: str  # Retrieved chunks
    web_regulations: str  # External laws

    # Analysis
    final_report: str

    # Error Handling for SQL Self-Correction
    sql_retry_count: int
    sql_error: str

#### Orchestrator Node

In [None]:
class Plan(BaseModel):
    needs_database: bool = Field(
        description="Set True if we need internal transaction/employee data."
    )
    needs_policy: bool = Field(
        description="Set True if we need internal company rules."
    )
    needs_external_law: bool = Field(
        description="Set True if we need public laws/tax info."
    )
    reasoning: str


planner_agent = create_agent(
    model="gpt-4.1-mini",
    system_prompt="You are an expert auditor assistant. Given an audit request, determine what information is needed to complete the audit.",
    response_format=Plan,
)


def orchestrator_node(
    state: AuditState,
) -> Command[Literal["sql_node", "policy_node", "web_node", "compliance_analyst"]]:
    print(f"\n--- [ORCHESTRATOR] Analyzing Audit Request: '{state['query']}' ---")

    plan = planner_agent.invoke(
        {
            "messages": [
                {
                    "role": "user",
                    "content": f"Analyze this audit request: {state['query']}",
                }
            ]
        }
    )
    response: Plan = plan["structured_response"]

    print(
        f"   -> Plan: DB={response.needs_database}, Policy={response.needs_policy}, Web={response.needs_external_law}"
    )

    # We use a list to track active branches.
    # In a real async system, we would fire these in parallel using langgraph's Send() API.
    # For this linear workshop demo, we will chain them sequentially based on flags.

    # Logic: If we need data, go to SQL first.
    if response.needs_database:
        return Command(
            update={"required_info_types": ["sql"] if response.needs_database else []},
            goto="sql_node",
        )
    elif response.needs_policy:
        return Command(goto="policy_node")
    else:
        return Command(goto="compliance_analyst")

#### SQL Node

In [None]:
class ExpenseData(BaseModel):
    city: str
    amount: float
    currency: str
    category: str


class SQLAgentResponse(BaseModel):
    expense_data: list[ExpenseData] = Field(
        description="List of expense records with city, amount, currency, and category."
    )


sql_agent_prompt = (
    system_prompt
    + """
    ### KEY RULES FOR SCHEMA ADHERENCE:
    1. **ONLY use the tables listed in `sql_db_list_tables`.**
       Do NOT assume common business tables exist (e.g., do not assume a 'per_diem', 'users', or 'exchange_rates' table exists unless you see it).
    2. **If the user asks for information not present in the tables:**
       - Do NOT try to invent a table to find it.
       - Do NOT try to use external knowledge to fill in the gaps.
       - Simply state: "I cannot answer that part of the question because the table [X] does not exist."
    3. **Check for existence before querying:**
       - You have `expenses` and `trips`. You do NOT have `city_per_diem`.
       - Never query `city_per_diem` or any other unlisted table.
    4. **Error Handling:**
       - If you receive a "no such table" error, STOP immediately. Do not try to rewrite the query using a similar name.
       - Report to the user that the data is missing.
    """
)

sql_agent = create_agent(
    model="gpt-5-mini",
    tools=tools,
    system_prompt=sql_agent_prompt,
    # response_format=ToolStrategy(SQLAgentResponse),
)


def sql_node(
    state: AuditState,
) -> Command[Literal["sql_node", "policy_node", "web_node", "compliance_analyst"]]:
    print("  [SQL AGENT] Querying database...")

    response = sql_agent.invoke(
        {"messages": [{"role": "user", "content": state["query"]}]}
    )
    expense_data = response["messages"][-1].content
    print(f"  -> Rows Found: {len(expense_data) if expense_data else 0} chars of data")

    return Command(
        update={"sql_result": expense_data, "sql_error": None},
        goto="policy_node",
    )

#### Policy Node

In [None]:
class Filter(BaseModel):
    topic: str


policy_agent = create_agent(
    model="gpt-4.1-mini",
    tools=[retriever_tool],
    system_prompt="""You are an expert policy classification agent. Given an audit query, determine the most relevant policy topic to search for.

    You are provided with the following tools:
    1. retrieve_relevant_docs: A tool to retrieve relevant policy documents based on a question and optional filter.
    """,
    response_format=Filter,
)


def policy_node(state: AuditState) -> Command[Literal["web_node"]]:
    print("--- [POLICY AGENT] Retrieving Corporate Rules ---")

    result = policy_agent.invoke(
        {"messages": [{"role": "user", "content": state["query"]}]}
    )
    topic_filter: Filter = result["structured_response"]

    print(f"  -> Filter Topic: {topic_filter.topic}")

    tool_content = None
    for message in result.get("messages", []):
        if isinstance(message, ToolMessage):
            tool_content = message.content

    return Command(
        update={"policy_context": tool_content},
        goto="web_node",
    )

#### Web Node

In [None]:
def web_node(state: AuditState) -> Command[Literal["compliance_analyst"]]:
    print("--- [WEB AGENT] Checking IRS/External Regulations ---")

    # Simulating a check for tax limits
    query = f"IRS expense deduction limits 2024 for {state['query']}"

    try:
        # In a real workshop, enable this:
        results = tavily_client.search(query=query, max_results=2)
        context = "\n".join([r["content"] for r in results["results"]])

        print("  -> Found external data.")
    except Exception:
        context = "Search failed."

    return Command(update={"web_regulations": context}, goto="compliance_analyst")

#### Compliance Analyst

In [None]:
def compliance_analyst(state: AuditState) -> Command[Literal["__end__"]]:
    print("--- [ANALYST] Synthesizing Audit Report ---")

    prompt = f"""
    You are the Chief Compliance Officer.
    
    AUDIT QUERY: {state["query"]}
    
    EVIDENCE GATHERED:
    1. ERP TRANSACTIONS (SQL):
    {state.get("sql_results", "No data")}
    
    2. INTERNAL POLICY (VECTOR DB):
    {state.get("policy_context", "No data")}
    
    3. EXTERNAL REGULATION (WEB):
    {state.get("web_regulations", "No data")}
    
    INSTRUCTIONS:
    - Review every transaction listed in the SQL results.
    - Cross-reference with Policy and External Regulation.
    - Flag VIOLATIONS (Red) and WARNINGS (Yellow).
    - If a transaction is valid, mark it PASS (Green).
    - Provide a short executive summary.
    """

    report = llm.invoke(prompt).content

    return Command(
        update={"final_report": report},
        goto="__end__",
    )

#### LangGraph Workflow

In [None]:
workflow = StateGraph(AuditState)

# Nodes
workflow.add_node("orchestrator", orchestrator_node)
workflow.add_node("sql_node", sql_node)
workflow.add_node("policy_node", policy_node)
workflow.add_node("web_node", web_node)
workflow.add_node("compliance_analyst", compliance_analyst)

# Edge (Start)
workflow.add_edge(START, "orchestrator")

# Compile
app = workflow.compile()

# --- RUN THE AUDIT ---
audit_query = "Audit Alice's trip expenses. Flag any meal violations based on the city's per diem limits."

print(f"ðŸš€ STARTING ENTERPRISE AUDIT: {audit_query}\n")

final_state = app.invoke({"query": audit_query, "sql_retry_count": 0})

In [None]:
print("#" * 60)
print(final_state["final_report"])

In [None]:
from IPython.display import Image, display

display(Image(app.get_graph().draw_mermaid_png()))

In [None]:
final_state