In [1]:
import os
from dotenv import load_dotenv
from pyprojroot import here
from typing import List
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
from langchain_core.runnables import RunnablePassthrough
from langchain.chains import create_sql_query_chain
from operator import itemgetter
from pprint import pprint

load_dotenv()

True

In [2]:
os.environ['OPENAI_API_KEY'] = os.getenv("OPEN_AI_API_KEY")

sql_agent_llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)

In [5]:
# Connect to database
sqldb_directory = here("data/csv_sql.db")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
print("Database dialect:", db.dialect)
print("Available tables:", db.get_usable_table_names())

# Get table schema information for better context
table_info = db.get_table_info()
print("\nTable schema information:")

Database dialect: sqlite
Available tables: ['cluster_data']

Table schema information:


In [6]:
class SingleTableSQLAgent:
    """
    A specialized SQL agent for querying a single large table in a SQL database.
    
    This agent is optimized for scenarios where you have one large table and need
    to generate efficient SQL queries based on natural language questions.
    
    Attributes:
        sql_agent_llm (ChatOpenAI): The language model for SQL generation
        db (SQLDatabase): The SQL database connection
        query_chain (Runnable): Chain for generating SQL queries
    """
    
    def __init__(self, sqldb_directory: str, llm_model: str = "gpt-4o-mini", llm_temperature: float = 0):
        """
        Initialize the SingleTableSQLAgent.
        
        Args:
            sqldb_directory (str): Path to the SQLite database file
            llm_model (str): The LLM model to use
            llm_temperature (float): Temperature for LLM responses
        """
        self.sql_agent_llm = ChatOpenAI(model=llm_model, temperature=llm_temperature)
        self.db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
        
        # Since we have only one table, we can directly create the query chain
        # with enhanced context about the table structure
        self.query_chain = create_sql_query_chain(self.sql_agent_llm, self.db)
        
        # Get table information for better context
        self.table_info = self.db.get_table_info()
        print(f"Connected to database with table: {self.db.get_usable_table_names()}")
    
    def query(self, question: str) -> str:
        """
        Execute a natural language query against the database.
        
        Args:
            question (str): Natural language question about the data
            
        Returns:
            str: SQL query result
        """
        try:
            # Generate SQL query
            raw_sql_query = self.query_chain.invoke({"question": question})
            
            # Clean the SQL query by removing any prefixes like "SQLQuery: "
            sql_query = self._clean_sql_query(raw_sql_query)
            print(f"Generated SQL Query:\n{sql_query}")
            
            # Execute query
            result = self.db.run(sql_query)
            return result
            
        except Exception as e:
            return f"Error executing query: {str(e)}"
    
    def _clean_sql_query(self, raw_query: str) -> str:
        """
        Clean the SQL query by removing any prefixes or formatting artifacts.
        
        Args:
            raw_query (str): Raw SQL query from the LLM
            
        Returns:
            str: Cleaned SQL query
        """
        # Remove common prefixes that LLMs might add
        prefixes_to_remove = [
            "SQLQuery: ",
            "SQL Query: ",
            "Query: ",
            "SQL: ",
            "```sql\n",
            "```\n",
            "sql\n"
        ]
        
        cleaned_query = raw_query.strip()
        
        for prefix in prefixes_to_remove:
            if cleaned_query.startswith(prefix):
                cleaned_query = cleaned_query[len(prefix):].strip()
        
        # Remove trailing backticks if present
        if cleaned_query.endswith("```"):
            cleaned_query = cleaned_query[:-3].strip()
        
        return cleaned_query
    
    def get_sample_data(self, limit: int = 5) -> str:
        """
        Get sample data from the table to understand its structure.
        
        Args:
            limit (int): Number of sample rows to return
            
        Returns:
            str: Sample data from the table
        """
        table_name = self.db.get_usable_table_names()[0]
        query = f"SELECT * FROM {table_name} LIMIT {limit};"
        return self.db.run(query)
    
    def get_column_info(self) -> str:
        """
        Get detailed information about the table columns.
        
        Returns:
            str: Column information
        """
        table_name = self.db.get_usable_table_names()[0]
        # Get column info using PRAGMA (SQLite specific)
        query = f"PRAGMA table_info({table_name});"
        return self.db.run(query)
    
    def get_table_stats(self) -> dict:
        """
        Get basic statistics about the table.
        
        Returns:
            dict: Table statistics including row count
        """
        table_name = self.db.get_usable_table_names()[0]
        
        # Get row count
        count_query = f"SELECT COUNT(*) as total_rows FROM {table_name};"
        row_count = self.db.run(count_query)
        
        return {
            "table_name": table_name,
            "total_rows": row_count,
            "column_info": self.get_column_info()
        }

In [7]:
# Initialize the agent
agent = SingleTableSQLAgent(
    sqldb_directory=sqldb_directory,
    llm_model="gpt-4o-mini",
    llm_temperature=0
)

Connected to database with table: ['cluster_data']


In [8]:
# Example usage
if __name__ == "__main__":
    # Get basic table information
    print("\n" + "="*50)
    print("TABLE STATISTICS")
    print("="*50)
    stats = agent.get_table_stats()
    pprint(stats)
    
    # Get sample data
    print("\n" + "="*50)
    print("SAMPLE DATA")
    print("="*50)
    sample_data = agent.get_sample_data(limit=3)
    print(sample_data)
    
    # Example queries
    print("\n" + "="*50)
    print("EXAMPLE QUERIES")
    print("="*50)
    
    # Query 1: Count records
    print("\nQuery 1: How many records are in the database?")
    result1 = agent.query("How many records are in the database?")
    print(f"Result: {result1}")
    
    # Query 2: Unique values in a specific column (adjust based on your data)
    print("\nQuery 2: How many unique clusters are there?")
    result2 = agent.query("How many unique clusters are there in the data?")
    print(f"Result: {result2}")
    
    # Query 3: Filter and count (adjust based on your data structure)
    print("\nQuery 3: Count records by status")
    result3 = agent.query("Show me the count of records grouped by status")
    print(f"Result: {result3}")
    


TABLE STATISTICS
{'column_info': "[(0, 'ALERT_ID', 'TEXT', 0, None, 0), (1, 'CUSTOMER_ID', "
                "'REAL', 0, None, 0), (2, 'PROJECT_ID', 'REAL', 0, None, 0), "
                "(3, 'SUBSCRIPTION_ID', 'REAL', 0, None, 0), (4, "
                "'RESOURCE_ID', 'REAL', 0, None, 0), (5, 'RESOURCE_TYPE', "
                "'REAL', 0, None, 0), (6, 'AGENT_ID', 'TEXT', 0, None, 0), (7, "
                "'PLUGIN_NAME', 'TEXT', 0, None, 0), (8, 'ALERT_TYPE', 'TEXT', "
                "0, None, 0), (9, 'PLUGIN_TYPE', 'TEXT', 0, None, 0), (10, "
                "'PLUGIN_SUBTYPE', 'TEXT', 0, None, 0), (11, 'TITLE', 'TEXT', "
                "0, None, 0), (12, 'DESCRIPTION', 'TEXT', 0, None, 0), (13, "
                "'SEVERITY', 'TEXT', 0, None, 0), (14, 'VALUE_LABELS', 'TEXT', "
                "0, None, 0), (15, 'THRESHOLD_EXPRESSION', 'TEXT', 0, None, "
                "0), (16, 'RULE_ID', 'REAL', 0, None, 0), (17, 'PLUGIN_ID', "
                "'REAL', 0, None, 0), (18, 'IPADDR

In [9]:
# Tool function for use in agent frameworks
def query_single_table_db(query: str) -> str:
    """
    Tool function to query the single table database.
    
    Args:
        query (str): Natural language query
        
    Returns:
        str: Query result
    """
    agent = SingleTableSQLAgent(
        sqldb_directory=sqldb_directory,
        llm_model="gpt-4o-mini",
        llm_temperature=0
    )
    return agent.query(query)