**Installtion of Libraries**

In [2]:
!pip install -q langgraph langchain_google_genai sqlalchemy pandas matplotlib
!pip install -q mysql-connector-python psycopg2


**LLM API & Database Url**

In [3]:
import os

# Replace with your actual database URL and API key
os.environ["DB_CONNECTION"] = "Enter Yours Database URL"
os.environ["LLM_API_KEY"] = "Enter Yours API Key"


**Setting Up Database Connection**

In [4]:
from sqlalchemy import create_engine, text

db_engine = create_engine(os.getenv("DB_CONNECTION"))
try:
    with db_engine.connect() as connection:
        result = connection.execute(text("SELECT 1"))
        print("Database connection successful!")
except Exception as e:
    print(f"Database connection failed: {e}")


Database connection successful!


**SQL AGENT Working**

In [None]:
import os
from typing import List, Dict, Any, Optional
import pandas as pd
from sqlalchemy import create_engine, text
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
import matplotlib.pyplot as plt
from decimal import Decimal

class State:
    """Class to hold the workflow state."""
    def __init__(self):
        self.query: str = ""
        self.sql: str = ""
        self.data: List[Dict[str, Any]] = []
        self.chart: Optional[Any] = None
        self.insight: str = ""
        self.response: str = ""
        self.error: Optional[str] = None
        self.path_decision: str = ""
        self.schema: Optional[Dict[str, List[Dict[str, Any]]]] = None  # Added schema attribute


class SQLAgent:
    """Encapsulates the workflow for processing queries."""
    def __init__(self, db_url: str, llm_api_key: str):
        self.db_engine = create_engine(db_url)
        self.llm = ChatGoogleGenerativeAI(
            model="gemini-2.0-flash-exp", api_key=llm_api_key, temperature=0
        )
        # Store schema at initialization
        self._cached_schema = self.get_optimized_schema()

    def test_queries(self):
        """Runs a set of test queries and prints the results."""
        test_queries = [
            "Who is the most popular artist?",
            "What is the average danceability of songs?",
            # Add more test queries here...
        ]
        for query in test_queries:
            state = State()
            state.query = query
            state.schema = self.get_database_schema()  # Assuming you have a method to get schema
            state = self.question_classifier(state)
            if state.path_decision == "database":
                state = self.sql_generator(state)
                state = self.sql_executor(state)
                print(f"Query: {query}\nSQL: {state.sql}\nResult: {state.data}\n")
            else:
                state = self.insight_generator(state)
                print(f"Query: {query}\nInsight: {state.insight}\n")

    def question_classifier(self, state: State) -> State:
        """Classify the user's query."""
        state.query = input("Question: ")
        prompt = ChatPromptTemplate.from_messages([
            ("human", f"Question: '{state.query}'. Does it require data from a database? "
                      f"Answer strictly with 'database' or 'general'.")
        ])
        try:
            state.path_decision = (
                prompt | self.llm | StrOutputParser()
            ).invoke({"query": state.query}).strip().lower()
        except Exception as e:
            state.error = f"Error classifying question: {e}"
        return state

    def sql_generator(self, state: State) -> State:
        """
        Generate an accurate SQL query based on a natural language query
        and database schema.
        """
        def build_schema_description(schema: dict) -> str:
            """Format schema data into a human-readable string for the prompt."""
            if not schema:
                return ""

            schema_description = []
            for table_name, table_info in schema.items():
                if table_name == "metadata":
                    continue

                columns_info = []
                for column in table_info:
                    # Build comprehensive column description
                    col_desc = [
                        f"{column['name']} ({column['type']})",
                        f"Description: {column['description']}",
                    ]

                    # Add constraints if they exist
                    if 'constraints' in column:
                        col_desc.append(f"Constraints: {column['constraints']}")

                    # Add index information if it exists
                    if 'index' in column:
                        col_desc.append(f"Indexed: {column['index']}")

                    columns_info.append(" - ".join(col_desc))

                # Join all column information
                table_desc = f"Table {table_name}:\n" + "\n".join(columns_info)
                schema_description.append(table_desc)

                # Add metadata information if available
                if "metadata" in schema and "performance_hints" in schema["metadata"]:
                    hints = schema["metadata"]["performance_hints"]
                    schema_description.append(f"Performance Hints: Use indexes on {', '.join(hints['use_indexes'])}")

            return "\n".join(schema_description)

        if not state.schema:
            state.error = "Schema information is required but not provided."
            return state

        schema_description = build_schema_description(state.schema)

        # Enhanced prompt with more context
        prompt = ChatPromptTemplate.from_messages([
            ("human", f"""Given the following schema and query, generate an optimized SQL query.

    Query: {state.query}

    Schema Details:
    {schema_description}

    Requirements:
    1. Use appropriate indexes when available
    2. Consider performance optimization
    3. Handle NULL values appropriately
    4. Follow the constraints defined in the schema
    5. Generate only the SQL query, no explanation needed

    SQL Query:""")
        ])

        try:
            # Generate the SQL query
            raw_sql = (prompt | self.llm | StrOutputParser()).invoke({"query": state.query}).strip()

            # Sanitize SQL query to remove backticks or any unwanted formatting
            state.sql = raw_sql.strip('```sql').strip('```').strip()
        except Exception as e:
            state.error = f"Error generating SQL query: {e}"

        return state


    def sql_executor(self, state: State) -> State:
        """Execute the SQL query and fetch results."""
        if not state.sql:
            state.error = "No SQL query to execute. Generate a query first."
            return state

        try:
            with self.db_engine.connect() as connection:
                result = connection.execute(text(state.sql))
                # Fetch column names
                keys = result.keys()
                # Convert rows into a list of dictionaries
                state.data = [dict(zip(keys, row)) for row in result]
                if not state.data:  # Handle empty result sets
                    state.error = "Query executed successfully, but no data was returned."
        except Exception as e:
            state.error = f"Error executing SQL query: {e}"
            state.data = []  # Ensure state.data is an empty list on error
        return state



    def chart_generator(self, state: State, chart_type: str = "bar") -> State:
      """Generate a chart from the query result with support for multiple chart types."""
      try:
          if not state.data:
              state.error = "No data available to generate a chart."
              return state

          # Convert the query results to a Pandas DataFrame
          df_data = pd.DataFrame(state.data)

          # Convert Decimal columns to float for compatibility with plotting
          for col in df_data.select_dtypes(include=["object"]).columns:
              if df_data[col].apply(lambda x: isinstance(x, Decimal)).any():
                  df_data[col] = df_data[col].astype(float)

          # Check if there is at least one numeric column for plotting
          numeric_columns = df_data.select_dtypes(include=["number"]).columns
          if numeric_columns.empty:
              state.error = "No numeric data available in the result to generate a chart."
              return state

          # Select the first numeric column for the y-axis
          y_column = numeric_columns[0]

          # Use the first non-numeric column for the x-axis
          x_columns = df_data.select_dtypes(exclude=["number"]).columns
          if x_columns.empty:
              state.error = "No categorical data available in the result to use as x-axis."
              return state

          # Select the first categorical column for the x-axis
          x_column = x_columns[0]

          # Generate the chart
          ax = df_data.plot(
              x=x_column,
              y=y_column,
              kind=chart_type,
              title=f"{chart_type.capitalize()} Chart of {y_column} vs {x_column}"
          )

          # Customize and show the chart
          plt.xlabel(x_column)
          plt.ylabel(y_column)
          plt.tight_layout()
          plt.show()  # Display the chart in the notebook

          state.chart = ax  # Store the Axes object for further manipulation if needed
      except Exception as e:
          state.error = f"Error generating chart: {e}"
      return state




    def insight_generator(self, state: State) -> State:
        """Generate an insight for the query."""
        prompt = ChatPromptTemplate.from_messages([
            ("human", f"Provide a detailed answer for: '{state.query}'")
        ])
        try:
            state.insight = (
                prompt | self.llm | StrOutputParser()
            ).invoke({"query": state.query}).strip()
        except Exception as e:
            state.error = f"Error generating insight: {e}"
        return state

    def reset_state(self) -> State:
        """Reset the state object while preserving schema information."""
        new_state = State()
        new_state.schema = self._cached_schema  # Preserve schema
        return new_state


    def get_optimized_schema(self) -> Dict:
      """
      Returns an optimized schema with enhanced metadata, constraints, and relationships
      for the music database.
      """
      return {
          #provide your optimized schema here along with metadata and constraints
      }

    def run_workflow(self):
        """Run the complete query processing workflow."""
        state = self.reset_state()  # Initial state with schema

        while True:
            try:
                state = self.question_classifier(state)
                if state.error:
                    print(f"Error: {state.error}")
                    state = self.reset_state()  # Reset state but keep schema
                    continue

                if state.path_decision == "database":
                    state = self.sql_generator(state)
                    if state.error:
                        print(f"Error: {state.error}")
                        state = self.reset_state()
                        continue

                    print(f"Generated SQL Query:\n{state.sql}")

                    state = self.sql_executor(state)
                    if state.error:
                        print(f"Error: {state.error}")
                        state = self.reset_state()
                        continue

                    print("Query Results:")
                    for row in state.data:
                        print(row)

                    state = self.chart_generator(state, chart_type="bar")
                    if state.error:
                        print(f"Error: {state.error}")
                    else:
                        print("Chart displayed successfully in the notebook.")

                    print(f"Data: {state.data}")

                elif state.path_decision == "general":
                    state = self.insight_generator(state)
                    if state.error:
                        print(f"Error: {state.error}")
                        state = self.reset_state()
                        continue

                    print(f"Insight: {state.insight}")

                # Reset state after successful query completion
                state = self.reset_state()

            except KeyboardInterrupt:
                print("Workflow terminated by user.")
                break
            except Exception as e:
                print(f"Unexpected error: {e}")
                state = self.reset_state()




if __name__ == "__main__":
    DB_URL = os.getenv("DB_CONNECTION")
    LLM_API_KEY = os.getenv("LLM_API_KEY")

    if not DB_URL or not LLM_API_KEY:
        raise EnvironmentError(
            "Database URL or LLM API key not set in environment variables."
        )

    agent = SQLAgent(DB_URL, LLM_API_KEY)
    agent.run_workflow()
