In [17]:
# %%
import os
import logging
from datetime import datetime
from typing import Annotated, Literal, TypedDict, List, Dict, Any, Optional
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_core.messages import AIMessage, ToolMessage, HumanMessage
from langchain_groq import ChatGroq
from langchain_core.tools import tool
from langgraph.graph import END, StateGraph, START
from langgraph.graph.message import AnyMessage, add_messages
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
from langgraph.prebuilt import ToolNode
from langgraph.errors import GraphRecursionError
from langchain_core.prompts import ChatPromptTemplate
import time  # For timing LLM calls
# %%
# === LOGGING SETUP (Console logging DISABLED - ALL LEVELS TO FILE) ===
def get_logger(name: str) -> logging.Logger:
    """
    Create and return a logger that saves ALL log levels to file only (no console output).
    Captures DEBUG, INFO, WARNING, ERROR, CRITICAL - everything under control.
    Avoids duplicate handlers.
    """
    logger = logging.getLogger(name)
    logger.setLevel(logging.DEBUG)  # Capture everything from DEBUG and above
    logger.propagate = False
    if logger.handlers:
        return logger  # Prevent adding handlers multiple times
    # Create logs directory
    log_dir = "logs"
    os.makedirs(log_dir, exist_ok=True)
    # Daily log file
    log_file = os.path.join(log_dir, f"{datetime.now().strftime('%Y-%m-%d')}.log")
    # File handler only (NO console handler) - Set to DEBUG to capture ALL levels
    fh = logging.FileHandler(log_file, encoding='utf-8')
    fh.setLevel(logging.DEBUG)  # Ensure all levels (DEBUG, INFO, WARNING, ERROR, CRITICAL) are written
    # Enhanced formatter to clearly show all log levels
    formatter = logging.Formatter(
        "%(asctime)s [%(levelname)-8s] %(name)s: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S"
    )
    fh.setFormatter(formatter)
    # Add handler
    logger.addHandler(fh)
    return logger
# %%
# Initialize logger (silent to console, ALL levels to file)
logger = get_logger("SQLAgent")
logger.info("=== SQL Agent Logger Initialized (Console Output Disabled, ALL Log Levels Captured) ===")
logger.debug("DEBUG level logging enabled - all detailed information will be captured")
logger.warning("WARNING level logging enabled - all warnings will be captured")
logger.error("ERROR level logging enabled - all errors will be captured")
logger.critical("CRITICAL level logging enabled - all critical issues will be captured")
# %%
class State(TypedDict):
    """Represents the state of our graph."""
    messages: Annotated[list[AnyMessage], add_messages]
    query_attempts: int  # Track query attempts to prevent infinite loops
    final_answer: Optional[str]  # Store the final answer
# %%
class SQLAgent:
    """SQL Agent that uses LangGraph to interact with a SQLite database."""
    def __init__(
        self,
        db_path: str,
        model_name: str = "llama-3.1-8b-instant",
        groq_api_key: Optional[str] = None,
    ):
        """Initialize the SQL Agent with a SQLite database connection and Groq LLM."""
        logger.info("Initializing SQLAgent...")
        logger.debug(f"Input - db_path: {db_path}, model_name: {model_name}")
        # Create SQLite connection string
        self.connection_string = f"sqlite:///{db_path}"
        logger.debug(f"Connection string created: {self.connection_string}")
        self.db = SQLDatabase.from_uri(self.connection_string)
        logger.info("Connected to SQLite database.")
        logger.debug(f"Database usable tables: {self.db.get_usable_table_names()}")
        # Initialize Groq LLM
        self.llm = ChatGroq(
            model=model_name,
            api_key=groq_api_key or os.getenv("GROQ_API_KEY"),
            temperature=0,
        )
        logger.info(f"Groq LLM initialized with model: {model_name}")
        # Setup components
        self._setup_tools()
        self._setup_prompts()
        self._build_graph()
        logger.info("SQLAgent initialization completed.")

    def _setup_tools(self) -> None:
        """Set up the required tools for database interaction."""
        logger.info("Starting _setup_tools...")
        toolkit = SQLDatabaseToolkit(db=self.db, llm=self.llm)
        tools = toolkit.get_tools()
        logger.debug(f"Fetched {len(tools)} tools from toolkit: {[t.name for t in tools]}")
        # Extract standard tools
        self.list_tables_tool = next(
            tool for tool in tools if tool.name == "sql_db_list_tables"
        )
        logger.debug("Tool 'sql_db_list_tables' loaded.")
        self.get_schema_tool = next(
            tool for tool in tools if tool.name == "sql_db_schema"
        )
        logger.debug("Tool 'sql_db_schema' loaded.")
        # Define the query execution tool
        @tool
        def db_query_tool(query: str) -> str:
            """Execute a SQL query against the SQLite database and get back the result."""
            logger.info(f"db_query_tool called with query: {query}")
            try:
                result = self.db.run_no_throw(query)
                if result is None or result == "" or result == []:
                    logger.info("Query executed but returned no results.")
                    return "Query executed successfully but returned no results."
                logger.info(f"Query result: {result}")
                return str(result)
            except Exception as e:
                logger.error(f"Error in db_query_tool: {repr(e)}")
                return f"Error: {str(e)}"
        self.db_query_tool = db_query_tool
        logger.info("All tools setup completed.")

    def _setup_prompts(self) -> None:
        """Set up the system prompts for query generation and checking."""
        logger.info("Setting up prompts...")
        self.query_gen_prompt = ChatPromptTemplate.from_messages([
            (
                "system",
                """You are an expert SQLite database assistant. Your task is to generate SQL queries to answer user questions.
                    IMPORTANT RULES:
                    1. Generate ONLY valid SQLite SELECT queries
                    2. Use proper SQLite syntax (no MySQL/PostgreSQL specific functions)
                    3. Be careful with table and column names - use exact names from the schema
                    4. When in doubt about column names, use SELECT * to see all columns first
                    5. Use LIMIT to prevent huge result sets
                    6. Return ONLY the SQL query, nothing else
                    Example good responses:
                    - SELECT * FROM employees LIMIT 10;
                    - SELECT name, salary FROM employees WHERE department = 'IT';
                    - SELECT COUNT(*) FROM orders;""",
            ),
            ("placeholder", "{messages}"),
        ])
        logger.debug("Query generation prompt created.")
        self.interpret_prompt = ChatPromptTemplate.from_messages([
            (
                "system",
                """You are a data analyst. Your job is to interpret SQL query results and provide clear answers.
                    Given the query results, provide a clear, human-readable answer that directly addresses the user's question.
                    Start your response with "Answer: " followed by your interpretation.
                    If the results are empty, explain that no matching records were found.
                    If there's an error, suggest what might be wrong and how to fix it.""",
            ),
            ("placeholder", "{messages}"),
        ])
        logger.debug("Interpretation prompt created.")
        logger.info("Prompts setup completed.")

        # === NEW: Intent Classification Prompt (Uses LLM to detect pure greetings) ===
        self.intent_prompt = ChatPromptTemplate.from_messages([
            ("system", """
                You are an intent classifier. Determine if the user's message is:
                - A greeting only (e.g., 'hi', 'hello', 'good morning', 'how are you')
                - OR a real question about data (e.g., asking for email, orders, users)

                Respond ONLY with:
                - "greeting" → if it's small talk with no request
                - "query" → if there's any data request, even after a greeting

                Examples:
                User: Hi
                AI: greeting

                User: Hello, how are you?
                AI: greeting

                User: Hi, can you show me Arun Pandey's email?
                AI: query

                User: Good morning, what is my balance?
                AI: query

                User: Hey!
                AI: greeting

                User: Find all users named Arun
                AI: query
                """),
            ("human", "{input}")
        ])
        logger.debug("Intent classification prompt added.")

    def _create_tool_node_with_fallback(self, tools: list) -> RunnableWithFallbacks:
        """Create a tool node with error handling."""
        logger.debug(f"Creating tool node with fallback for tools: {[t.name for t in tools]}")
        def handle_tool_error(state: Dict) -> Dict:
            error = state.get("error")
            tool_calls = state.get("messages", [])[-1].tool_calls if state.get("messages") else []
            logger.error(f"Tool error caught: {repr(error)}")
            return {
                "messages": [
                    ToolMessage(
                        content=f"Error: {repr(error)}",
                        tool_call_id=tc["id"],
                    )
                    for tc in tool_calls
                ]
            }
        result = ToolNode(tools).with_fallbacks(
            [RunnableLambda(handle_tool_error)], exception_key="error"
        )
        logger.debug("Tool node with fallback created.")
        return result

    def _build_graph(self) -> None:
        """Build the LangGraph workflow."""
        logger.info("Building LangGraph workflow...")
        workflow = StateGraph(State)

        def first_tool_call(state: State) -> Dict:
            """Initial node to list database tables."""
            logger.info("Executing first_tool_call: requesting list of tables.")
            response = {
                "messages": [
                    AIMessage(
                        content="",
                        tool_calls=[
                            {
                                "name": "sql_db_list_tables",
                                "args": {},
                                "id": "tool_abcd123",
                            }
                        ],
                    )
                ],
                "query_attempts": 0,
                "final_answer": None
            }
            logger.debug("first_tool_call response prepared.")
            return response

        def model_get_schema(state: State) -> Dict:
            """Get database schema information."""
            messages = state["messages"]
            logger.info("Calling model_get_schema to fetch schema.")
            logger.debug(f"Current message history: {[(type(m).__name__, m.content) for m in messages]}")
            chat_with_get_schema = self.llm.bind_tools([self.get_schema_tool])
            start_time = time.time()
            result = chat_with_get_schema.invoke(messages)
            end_time = time.time()
            llm_time = end_time - start_time
            logger.info(f"LLM schema request took {llm_time:.2f} seconds.")
            logger.debug(f"Schema response: {result}")
            return {"messages": [result]}

        def query_gen_node(state: State) -> Dict:
            """Generate SQL query based on user question and context."""
            messages = state["messages"]
            logger.info("Entering query_gen_node to generate SQL query.")
            logger.debug(f"Message history: {[(type(m).__name__, m.content) for m in messages]}")
            # Increment query attempts
            query_attempts = state.get("query_attempts", 0) + 1
            logger.debug(f"Query attempt #{query_attempts}")
            # If we've tried too many times, give up
            if query_attempts > 3:
                logger.critical("CRITICAL: Max query attempts (3) exceeded. Unable to generate valid SQL query.")
                logger.warning(f"WARNING: Query generation failed after {query_attempts} attempts for question: {messages[0].content if messages else 'Unknown'}")
                return {
                    "messages": [AIMessage(content="Unable to generate a working query after multiple attempts.")],
                    "query_attempts": query_attempts,
                    "final_answer": "Unable to generate a working query after multiple attempts."
                }
            # Generate query using the prompt
            logger.info("Invoking LLM for SQL generation...")
            start_time = time.time()
            query_response = (self.query_gen_prompt | self.llm).invoke({"messages": messages})
            end_time = time.time()
            llm_time = end_time - start_time
            logger.info(f"LLM generated SQL in {llm_time:.2f} seconds.")
            logger.debug(f"LLM response content: {query_response.content}")
            return {
                "messages": [query_response],
                "query_attempts": query_attempts
            }

        def execute_query_node(state: State) -> Dict:
            """Execute the SQL query."""
            messages = state["messages"]
            last_message = messages[-1]
            logger.info("Executing SQL query from last message.")
            logger.debug(f"Last message type: {type(last_message).__name__}, content: {last_message.content}")
            # Extract SQL query from the last message
            sql_query = last_message.content.strip()
            logger.debug(f"Raw SQL content: {sql_query}")
            # Clean up the query - look for SELECT statement
            if "SELECT" in sql_query.upper():
                lines = sql_query.split('\n')
                for line in lines:
                    if 'SELECT' in line.upper():
                        sql_query = line.strip()
                        if sql_query.endswith('.'):
                            sql_query = sql_query[:-1]
                        break
            logger.info(f"Extracted SQL query: {sql_query}")
            # Execute the query
            try:
                start_time = time.time()
                result = self.db.run_no_throw(sql_query)
                end_time = time.time()
                exec_time = end_time - start_time
                logger.info(f"Query executed in {exec_time:.2f} seconds.")
                if result is None or result == "" or result == []:
                    content = "Query executed successfully but returned no results."
                    logger.info("Query returned no results.")
                else:
                    content = str(result)
                    logger.info(f"Query result: {content[:200]}{'...' if len(str(result)) > 200 else ''}")
                return {
                    "messages": [
                        ToolMessage(
                            content=content,
                            tool_call_id="manual_query_execution"
                        )
                    ]
                }
            except Exception as e:
                error_msg = f"Error executing query: {str(e)}"
                logger.error(error_msg)
                return {
                    "messages": [
                        ToolMessage(
                            content=error_msg,
                            tool_call_id="manual_query_execution"
                        )
                    ]
                }

        def interpret_results_node(state: State) -> Dict:
            """Interpret the query results and provide final answer."""
            messages = state["messages"]
            logger.info("Interpreting query results into natural language answer.")
            logger.debug(f"Message history for interpretation: {[(type(m).__name__, m.content) for m in messages]}")
            logger.info("Invoking LLM for result interpretation...")
            start_time = time.time()
            interpretation = (self.interpret_prompt | self.llm).invoke({"messages": messages})
            end_time = time.time()
            llm_time = end_time - start_time
            logger.info(f"LLM interpretation completed in {llm_time:.2f} seconds.")
            logger.debug(f"Interpretation result: {interpretation.content}")
            return {
                "messages": [interpretation],
                "final_answer": interpretation.content
            }

        def should_continue_after_query_gen(state: State) -> Literal[END, "execute_query"]:
            """Determine if we should execute the query or end."""
            messages = state["messages"]
            if not messages:
                logger.debug("No messages in state. Ending.")
                return END
            last_message = messages[-1]
            query_attempts = state.get("query_attempts", 0)
            if query_attempts > 3:
                logger.debug("Max attempts reached. Ending.")
                return END
            if (hasattr(last_message, 'content') and
                    last_message.content and
                    'SELECT' in last_message.content.upper()):
                logger.info("Valid SELECT query detected. Proceeding to execute_query.")
                return "execute_query"
            logger.debug("No valid query found. Ending.")
            return END

        def should_continue_after_execution(state: State) -> Literal[END, "interpret_results", "query_gen"]:
            """Determine what to do after query execution."""
            messages = state["messages"]
            if not messages:
                logger.debug("No messages after execution. Ending.")
                return END
            last_message = messages[-1]
            if (isinstance(last_message, ToolMessage) and
                    last_message.content.startswith("Error")):
                query_attempts = state.get("query_attempts", 0)
                if query_attempts >= 3:
                    logger.warning("Error after max attempts. Ending.")
                    return END
                logger.warning("Query execution error. Retrying with new query.")
                return "query_gen"
            if isinstance(last_message, ToolMessage):
                logger.info("Query executed successfully. Proceeding to interpretation.")
                return "interpret_results"
            logger.debug("Unknown state after execution. Ending.")
            return END

        def should_continue_after_interpretation(state: State) -> Literal[END]:
            """Always end after interpretation."""
            logger.info("Final answer generated. Ending workflow.")
            return END

        # Add nodes to graph
        workflow.add_node("first_tool_call", first_tool_call)
        logger.debug("Node 'first_tool_call' added to graph.")
        workflow.add_node(
            "list_tables_tool",
            self._create_tool_node_with_fallback([self.list_tables_tool]),
        )
        logger.debug("Node 'list_tables_tool' added to graph.")
        workflow.add_node("model_get_schema", model_get_schema)
        logger.debug("Node 'model_get_schema' added to graph.")
        workflow.add_node(
            "get_schema_tool",
            self._create_tool_node_with_fallback([self.get_schema_tool]),
        )
        logger.debug("Node 'get_schema_tool' added to graph.")
        workflow.add_node("query_gen", query_gen_node)
        logger.debug("Node 'query_gen' added to graph.")
        workflow.add_node("execute_query", execute_query_node)
        logger.debug("Node 'execute_query' added to graph.")
        workflow.add_node("interpret_results", interpret_results_node)
        logger.debug("Node 'interpret_results' added to graph.")

        # Add edges
        workflow.add_edge(START, "first_tool_call")
        logger.debug("Edge: START → first_tool_call")
        workflow.add_edge("first_tool_call", "list_tables_tool")
        logger.debug("Edge: first_tool_call → list_tables_tool")
        workflow.add_edge("list_tables_tool", "model_get_schema")
        logger.debug("Edge: list_tables_tool → model_get_schema")
        workflow.add_edge("model_get_schema", "get_schema_tool")
        logger.debug("Edge: model_get_schema → get_schema_tool")
        workflow.add_edge("get_schema_tool", "query_gen")
        logger.debug("Edge: get_schema_tool → query_gen")
        workflow.add_conditional_edges(
            "query_gen",
            should_continue_after_query_gen,
        )
        logger.debug("Conditional edges added from 'query_gen'")
        workflow.add_conditional_edges(
            "execute_query",
            should_continue_after_execution,
        )
        logger.debug("Conditional edges added from 'execute_query'")
        workflow.add_conditional_edges(
            "interpret_results",
            should_continue_after_interpretation,
        )
        logger.debug("Conditional edges added from 'interpret_results'")

        # Compile the workflow with recursion limit
        self.app = workflow.compile()
        logger.info("LangGraph workflow compiled successfully.")

    def query(self, question: str, recursion_limit: int = 10) -> Dict[str, Any]:
        """Execute a query against the database using the agent."""
        logger.info(f"Received query: '{question}'")

        # === USE LLM TO CHECK IF IT'S A PURE GREETING ===
        try:
            intent_chain = self.intent_prompt | self.llm
            intent_response = intent_chain.invoke({"input": question})
            intent = intent_response.content.strip().lower()
            logger.info(f"Intent classification result: '{intent}'")
            if intent == "greeting":
                logger.info("Pure greeting detected via LLM. Responding directly.")
                return {
                    "sql_query": None,
                    "answer": "Hello! How can I assist you today?"
                }
        except Exception as e:
            logger.warning(f"Failed to classify intent using LLM: {e}. Proceeding as query.")
        # === END OF GREETING CHECK ===

        logger.info(f"Setting recursion limit: {recursion_limit}")
        try:
            # Invoke with recursion limit
            config = {"recursion_limit": recursion_limit}
            logger.info("Invoking agent workflow...")
            start_time = time.time()
            messages = self.app.invoke(
                {"messages": [HumanMessage(content=question)], "query_attempts": 0, "final_answer": None},
                config=config
            )
            end_time = time.time()
            total_time = end_time - start_time
            logger.info(f"Agent workflow completed in {total_time:.2f} seconds.")
            # Extract results
            final_sql_query = self._extract_final_sql_query(messages)
            final_answer = messages.get("final_answer")
            if not final_answer:
                last_message = messages["messages"][-1] if messages["messages"] else None
                if last_message and hasattr(last_message, 'content'):
                    final_answer = last_message.content
                    logger.debug("Final answer extracted from last message.")
            logger.info(f"Final SQL Query: {final_sql_query}")
            logger.info(f"Final Answer: {final_answer}")
            return {
                "sql_query": final_sql_query,
                "answer": final_answer
            }
        except GraphRecursionError:
            error_msg = "Unable to process the query due to recursion limit. The query may be too complex or the database structure unclear."
            logger.critical(f"CRITICAL: GraphRecursionError occurred - {error_msg}")
            # Provide a more user-friendly error message
            return {
                "sql_query": None,
                "answer": (
                    "Sorry, I couldn't process your request. "
                    "Please make sure your question is about retrieving data (e.g., SELECT queries). "
                    "If you asked to modify or drop tables, those actions are not allowed for safety reasons."
                ),
            }
        except Exception as e:
            error_msg = (
                "Sorry, something went wrong while processing your request. "
                "Please try again or rephrase your question."
            )
            logger.critical(f"CRITICAL: Unexpected exception in query execution - {str(e)}")
            return {
                "sql_query": None,
                "answer": error_msg,
            }

    def _extract_final_sql_query(self, messages: Dict) -> Optional[str]:
        """Extract the final SQL query from the message history."""
        logger.debug("Extracting final SQL query from message history.")
        for msg in reversed(messages.get("messages", [])):
            if hasattr(msg, "content") and msg.content:
                content = msg.content
                if 'SELECT' in content.upper():
                    lines = content.split('\n')
                    for line in lines:
                        if 'SELECT' in line.upper():
                            query = line.strip()
                            if query.endswith('.'):
                                query = query[:-1]
                            logger.debug(f"Extracted SQL query: {query}")
                            return query
        logger.debug("No SQL query found in message history.")
        return None

    def get_table_info(self) -> str:
        """Get information about all tables in the database."""
        logger.info("get_table_info called.")
        try:
            result = self.db.get_table_info()
            logger.debug(f"Table info retrieved: {result[:300]}...")
            return result
        except Exception as e:
            error_msg = f"Error getting table information: {str(e)}"
            logger.error(error_msg)
            return error_msg

    def list_tables(self) -> List[str]:
        """Get list of all table names in the database."""
        logger.info("list_tables called.")
        try:
            tables = self.db.get_usable_table_names()
            logger.debug(f"Tables found: {tables}")
            return tables
        except Exception as e:
            error_msg = f"Error: {str(e)}"
            logger.error(error_msg)
            return [error_msg]


In [19]:

# %%
# Example usage
if __name__ == "__main__":
    logger.info("=== Starting SQL Agent  ===")
    # Initialize the SQL Agent with SQLite database
    agent = SQLAgent(
        db_path=r"D:\ML(ExtraClass Project)\AGENT AI\SQL-Sage-Intelligent-DB-Agent-with-Gemini-LangGraph\database\final_ecommerce.db",
        model_name="llama-3.1-8b-instant",
    )
    logger.info("SQLAgent instance created.")
    # Execute example query (results already logged by the query method)
    logger.info("=== Query Execution ===")
    question = "hi whats upp show me the email of Arun Pandey user?"
    result = agent.query(question, recursion_limit=15)
    print("SQL Query:", result["sql_query"])
    print("Question:", question)
    print( result["answer"])
    logger.info("===  Query Execution Completed ===")
# %%

SQL Query: SELECT email FROM users WHERE name = 'Arun Pandey';
Question: hi whats upp show me the email of Arun Pandey user?
Answer: The email of Arun Pandey is arun@example.com.
