In [None]:
"please have a look at the readme file on how to execute the below code and in this notebook run every cell"

In [16]:
import os
from typing import Annotated, TypedDict, List, Union
from langchain_groq import ChatGroq
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage, SystemMessage
from langgraph.graph import StateGraph, START, END, add_messages
from langchain_core.tools import tool
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams
from sqlalchemy import create_engine, text

In [17]:

db_engine = create_engine("sqlite:///employee_db.sqlite")

with db_engine.connect() as conn:
    conn.execute(text("DROP TABLE IF EXISTS employees"))
    conn.execute(text("""
        CREATE TABLE employees (
            id INTEGER PRIMARY KEY, 
            name TEXT UNIQUE, 
            designation TEXT, 
            dept TEXT, 
            salary REAL
        )
    """))
    
    conn.execute(text("INSERT INTO employees (name, designation, dept, salary) VALUES ('Hrithik', 'AI Engineer', 'AI', 120000)"))
    conn.commit()

In [18]:
client = QdrantClient(":memory:")
collection = "employee_schema"


client.create_collection(
    collection_name=collection,
    vectors_config=VectorParams(size=768, distance=Distance.COSINE),
)

embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
vectorstore = QdrantVectorStore(client=client, collection_name=collection, embedding=embeddings)

schema_descriptions = [
    "name: The full name of the employee which is unique",
    "designation: The job title or role",
    "dept: The department name (AI, HR, Sales)",
    "salary: The numeric value representing annual pay",
    "id: The unique primary key"
]
vectorstore.add_texts(schema_descriptions)

['337c8d01b505449abf4c3ca16bc7ce65',
 '57d84d394ad54791bed1a790093fb232',
 'bc184c17fdb742d1b9bf1bcb48eece97',
 '5976957e88684f47b9fa9a670a8034aa',
 'f35bb482d153482c88ff66e4076a2ebc']

In [20]:
class AgentState(TypedDict):
    messages: Annotated[List[BaseMessage], add_messages]
    is_valid: bool  

In [21]:
@tool
def execute_sql(sql_query: str) -> str:
    """
    Executes a SQLite query against the employee database.
    SELECT, INSERT, UPDATE, or DELETE.
    """
    forbidden = ["drop", "alter", "truncate"]
    if any(word in sql_query.lower() for word in forbidden):
        return "Security Error: Schema changes are not allowed."

    try:
        with db_engine.connect() as conn:
            result = conn.execute(text(sql_query))
            conn.commit()
            if sql_query.lower().strip().startswith("select"):
                return f"Data: {str(result.fetchall())}"
            return f"Success: {result.rowcount} row(s) affected."
    except Exception as e:
        return f"Database Error: {str(e)}"

In [None]:
llm = ChatGroq(temperature=0, model_name="openai/gpt-oss-20b")
llm_with_tools = llm.bind_tools([execute_sql])

def classifier_node(state: AgentState):
    """Guardrail: Ensures the query is actually about the employee database."""
    last_user_msg = state["messages"][-1].content
    prompt = f"Is the following question strictly related to managing or querying employee data? Question: '{last_user_msg}'. Answer ONLY 'Yes' if the Question is completely related to managing Database otherwise even there is slightly unrelated content then Answer 'No'."
    res = llm.invoke(prompt).content.strip().lower()
    
    if 'yes' in res:
        return {"is_valid": True}
    # print(state)
    return {"is_valid": False, "messages": [AIMessage(content="I'm sorry, I can only assist with employee database operations.")]}

def retriever_node(state: AgentState):
    """Retrieves schema context from Qdrant to guide the LLM."""
    if not state.get("is_valid"): return {"messages": AIMessage(content=f"I am SQL agent only answers about the Database related queries")}
    
    query = state["messages"][0].content 
    hits = vectorstore.similarity_search(query, k=4)
    context = " | ".join([h.page_content for h in hits])
    
    sys_msg = SystemMessage(content=f"Available Table: employees. Relevant Columns: {context}")
    # print(state)
    return {"messages": [sys_msg]}

def agent_node(state: AgentState):
    """Processes history and decides next tool call."""
    if not state.get("is_valid"): return END


    core_behavior = SystemMessage(content=(
        """You are a Senior SQL Analyst. 
        Rules: 
        1. Be extremely concise. 
        2. If a query fails, don't try any alternative just explain the error. 
        3. Never explain the SQL unless asked. 
        4. Always show the data results as Markdown tables.
        5. if the question has other content that is not related to managing DB then don't respond"""
    ))

   
    full_messages = [core_behavior] + state["messages"]
    
    response = llm_with_tools.invoke(full_messages)
    # print(state)
    return {"messages": [response]}

def tool_node(state: AgentState):
    """Runs the SQL generated by the agent."""
    last_msg = state["messages"][-1]
    results = []
    for call in last_msg.tool_calls:
        out = execute_sql.invoke(call["args"])
        results.append(ToolMessage(tool_call_id=call["id"], content=str(out)))
    # print(state)
    return {"messages": results}

In [30]:
def router(state: AgentState):
    """Decides if we stop, go to tools, or start the agent."""
    if not state.get("is_valid"):
        return END
    
    last_msg = state["messages"][-1]
    if last_msg.tool_calls:
        return "tools"
    return END

In [31]:
builder = StateGraph(AgentState)

builder.add_node("classify", classifier_node)
builder.add_node("retrieve", retriever_node)
builder.add_node("agent", agent_node)
builder.add_node("tools", tool_node)

builder.add_edge(START, "classify")


builder.add_conditional_edges(
    "classify",
    lambda state: "retrieve" if state["is_valid"] else END
)

builder.add_edge("retrieve", "agent")
builder.add_conditional_edges("agent", router)
builder.add_edge("tools", "agent")

graph = builder.compile()

In [34]:
def chat_with_agent(input):
    inputs = {"messages": [HumanMessage(content=input)]}

    for output in graph.stream(inputs):

        for node, msg in output.items():
            try :
                if msg['messages']:
                    print("SQL Agent :",msg['messages'][0].content)
            except :
                    continue



In [None]:
"""
run this cell to input box to enter the query and type exit in the input box to exit
"""
print("SQL Agent Terminal (Type 'exit' to quit)")
while True:
    user_query = input("\nYou: ")
    if user_query.lower() in ["exit", "quit"]:
        break
    
    chat_with_agent(user_query)