In [None]:
from langchain_core.tools import Tool, tool
from langchain_groq import ChatGroq
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.graph import MessagesState
from dotenv import load_dotenv
from typing import Dict, Any
load_dotenv()
llm = ChatGroq(model= "openai/gpt-oss-120b")

from src.Tools.nlp_gen import nlp_chain
from src.Tools.fetch_db import fetch_db
from src.Tools.db_connector import DBConnector
from src.Tools.execute_sql import execute_sql
from src.Tools.summary import SummaryGenerator
import os

nlp_generator = nlp_chain()
db_connector = DBConnector(os.getenv("POSTGRES_HOST"), 
                           int(os.getenv("POSTGRES_PORT")), 
                           os.getenv("POSTGRES_DB_NAME"), 
                           os.getenv("POSTGRES_USERNAME"), 
                           os.getenv("POSTGRES_PASSWORD"))
fetch_db = fetch_db(db_connector)  
execute_sql = execute_sql(db_connector)
summary_generator = SummaryGenerator()


@tool
def fetch_db_schema() -> str:
    """Fetches the database schema and returns it as a string."""
    try:
        print("Fetching the database schema...")
        connector = db_connector.get_connection_string()
        if connector is None:
            raise ValueError("Database connection failed.")
        print("Database connected successfully.")
        db_schema = fetch_db.get_db_schema()
        print("Database schema fetched successfully.")
        return str(db_schema)
    except Exception as e: 
        raise ValueError(f"Error occurred with exception : {e}")


@tool
def generate_sql(question: str) -> str:
    """Generates an SQL query based on the user's question. This tool fetches the database schema internally."""
    try:
        print("Fetching schema for SQL generation...")
        # Get fresh schema for SQL generation
        db_schema = fetch_db.get_db_schema()
        print("Generating SQL query for Read operation...")
        sql_chain = nlp_generator.get_sql_chain()
        generated_sql = sql_chain.invoke({"question": question, "db_schema": db_schema})
        print(f"Generated SQL: {generated_sql}")
        return str(generated_sql)
    except Exception as e:
        raise ValueError(f"Error occurred with exception : {e}")
        
        
@tool        
def execute_sql_query(query: str) -> str:
    """Executes the generated SQL query and returns the query results."""
    try:
        result = execute_sql.execute_query(query)
        print(f"Query Result: {result}")
        return str(result)
    except Exception as e: 
        raise ValueError(f"Error occurred with exception : {e}")
        

@tool
def get_summary(question_and_result: str) -> str:
    """Generates a summary of SQL query results. Input should be formatted as 'QUESTION: <question> RESULT: <result>'"""
    try:
        print("Generating summary of the SQL query result...")
        # Parse the combined input
        parts = question_and_result.split("RESULT:")
        if len(parts) != 2:
            return "Error: Please format input as 'QUESTION: <question> RESULT: <result>'"
        
        question = parts[0].replace("QUESTION:", "").strip()
        result = parts[1].strip()
        
        summary = summary_generator.generate_summary(question, result)
        return str(summary)
    except Exception as e:
        raise ValueError(f"Error occurred with exception : {e}")

           
        
# Create toolbox with decorated functions
toolbox = [fetch_db_schema, generate_sql, execute_sql_query, get_summary]

In [None]:
# Toolbox is now defined in the cell above
# toolbox = [fetch_db_schema, generate_sql, execute_sql_query, get_summary]

In [None]:
llm_with_tools = llm.bind_tools(toolbox)

In [None]:
assistant_system_message = """
You are an expert SQL agent designed to assist users with database-related queries.
Your primary functions include:
1. Understanding user questions and generating appropriate SQL queries.
2. Executing SQL queries against a PostgreSQL database.
3. Summarizing the results of SQL queries in a user-friendly manner.

You will utilize a set of specialized tools to accomplish these tasks:
- fetch_db_schema: Fetches the database schema. No input required.
- generate_sql: Generates an SQL query based on the user's question. Requires only the 'question' as input.
- execute_sql_query: Executes the generated SQL query. Requires the 'query' as input.
- get_summary: Generates a summary of SQL query results. Requires input formatted as 'QUESTION: <question> RESULT: <result>'.

Your responses should be concise and focused on the task at hand. Always ensure that the SQL queries you generate are syntactically correct and optimized for performance.

Guidelines:
- You can fetch the database schema first if needed, but the generate_sql tool handles schema internally
- Ensure that the SQL queries you generate are syntactically correct and optimized for performance
- When summarizing results, focus on the key insights and avoid unnecessary technical jargon
- Follow this workflow: generate_sql -> execute_sql_query -> get_summary
- For get_summary, format the input as 'QUESTION: <original_question> RESULT: <query_result>'
"""

In [None]:
def assistant(state: MessagesState):
   return {"messages": [llm_with_tools.invoke([assistant_system_message] + state["messages"])]}


In [None]:
from langgraph.graph import START, StateGraph
from langgraph.prebuilt import tools_condition
from langgraph.prebuilt import ToolNode
from IPython.display import Image, display
from langgraph.checkpoint.memory import MemorySaver

# Graph
builder = StateGraph(MessagesState)

# Define nodes: these do the work
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(toolbox))

# Define edges: these determine how the control flow moves
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
    "assistant",
    # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
    # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
    tools_condition,
)
builder.add_edge("tools", "assistant")
memory = MemorySaver()
react_graph = builder.compile(checkpointer=memory)

# Show
display(Image(react_graph.get_graph(xray=True).draw_mermaid_png()))

In [None]:
def sql_advisor(user_request: str, thread_id = "1", verbose = False):
    config = {"configurable": {"thread_id": thread_id}}
    messages = react_graph.invoke({"messages": [HumanMessage(content=user_request)]}, config)
    if verbose:
        for message in messages['messages']:
            message.pretty_print()
    else:
        messages['messages'][-1].pretty_print()

In [None]:
sql_advisor("provide me with a breif about my database schema")

In [None]:
sql_advisor("show me the first 5 rows in the peopel table")

In [None]:
sql_advisor("update Alice age to 36")

In [None]:
sql_advisor("delete Bob from the people table")

In [None]:
sql_advisor(
    user_request="add bob back", 
    thread_id="2"
)

In [None]:
sql_advisor(
    user_request="bob in people age is 20 and email bob@exp.com", 
    thread_id="2"
)