### SQL Agent: Natural Language to SQL Query Generator

This notebook will teach you step-by-step how to build a robust SQL agent that converts natural language queries into PostgreSQL queries.

What We'll Build:

Database Connection Setup - Connect to PostgreSQL database

Schema Discovery - Automatically discover and understand database structure

Query Context Builder - Create context for better query generation

AI-Powered SQL Generation - Use LLM to convert natural language to SQL

Query Validation - Validate and optimize generated queries
Interactive Interface - Create a user-friendly interface

In [1]:
# Install required packages
%pip install -q psycopg2-binary pandas sqlalchemy openai anthropic python-dotenv langchain langchain-openai langchain-anthropic
%pip install -q ipywidgets

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
2+2

4

In [3]:
# Core libraries
import os
import re
import json
from typing import List, Dict, Any, Optional
import warnings
warnings.filterwarnings('ignore')

# Database libraries
import psycopg2
from psycopg2.extras import RealDictCursor
import pandas as pd
from sqlalchemy import create_engine, text, inspect

# AI libraries
import openai
from langchain.schema import HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic

# Utility libraries
from datetime import datetime
import time

print("✅ All libraries imported successfully!")

✅ All libraries imported successfully!


# ESTABLISHING SQL CONNECTION

In [4]:
class PostgreSQLConnection:
    """
    A robust PostgreSQL connection handler with error handling and connection management.
    """
    
    def __init__(self, host, port, database, user, password):
        self.host = host
        self.port = port
        self.database = database
        self.user = user
        self.password = password
        self.connection = None
        self.engine = None
        
    def connect(self):
        """Establish connection to PostgreSQL database"""
        try:
            # Create connection string
            connection_string = f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
            
            # Create SQLAlchemy engine
            self.engine = create_engine(connection_string)
            
            # Test connection
            with self.engine.connect() as conn:
                result = conn.execute(text("SELECT version()"))
                version = result.fetchone()[0]
                print(f"✅ Connected to PostgreSQL!")
                print(f"📊 Database: {self.database}")
                print(f"🔧 Version: {version[:50]}...")
                
            return True
            
        except Exception as e:
            print(f"❌ Connection failed: {str(e)}")
            return False
    
    def execute_query(self, query, return_df=True):
        """Execute a SQL query and return results"""
        try:
            if return_df:
                df = pd.read_sql_query(query, self.engine)
                return df
            else:
                with self.engine.connect() as conn:
                    result = conn.execute(text(query))
                    return result.fetchall()
                    
        except Exception as e:
            print(f"❌ Query execution failed: {str(e)}")
            return None
    
    def get_table_info(self):
        """Get information about all tables in the database"""
        try:
            inspector = inspect(self.engine)
            tables_info = {}
            
            for table_name in inspector.get_table_names():
                columns = inspector.get_columns(table_name)
                tables_info[table_name] = {
                    'columns': [col['name'] for col in columns],
                    'column_details': columns
                }
                
            return tables_info
            
        except Exception as e:
            print(f"❌ Failed to get table info: {str(e)}")
            return None

# Initialize database connection
DB_CONFIG = {
    'host': '54.251.218.166',
    'port': 5432,
    'database': 'dummy',
    'user': 'rajesh',
    'password': 'rajesh123'
}

# Create database connection
db = PostgreSQLConnection(**DB_CONFIG)
connection_success = db.connect()







✅ Connected to PostgreSQL!
📊 Database: dummy
🔧 Version: PostgreSQL 15.13 (Debian 15.13-1.pgdg120+1) on x86...


In [5]:
db.execute_query("SELECT * FROM actor a")

Unnamed: 0,actor_id,first_name,last_name,last_update
0,1,Penelope,Guiness,2013-05-26 14:47:57.620
1,2,Nick,Wahlberg,2013-05-26 14:47:57.620
2,3,Ed,Chase,2013-05-26 14:47:57.620
3,4,Jennifer,Davis,2013-05-26 14:47:57.620
4,5,Johnny,Lollobrigida,2013-05-26 14:47:57.620
...,...,...,...,...
195,196,Bela,Walken,2013-05-26 14:47:57.620
196,197,Reese,West,2013-05-26 14:47:57.620
197,198,Mary,Keitel,2013-05-26 14:47:57.620
198,199,Julia,Fawcett,2013-05-26 14:47:57.620


# CHECKING AVAILBLE TABLES

In [8]:
if connection_success:
     tables_query = """
    SELECT 
        table_name,
        table_schema,
        table_type
    FROM information_schema.tables 
    WHERE table_schema = 'public'
    ORDER BY table_name;
    """
     tables_df = db.execute_query(tables_query)
     print("Available Tables")
     print("="*50)
     for idx, row in tables_df.iterrows():
        print(f"  {idx+1}. {row['table_name']} ({row['table_type']})")
     print(f"\n🔢 Total tables found: {len(tables_df)}")


else:
     print("trouble to connect with database")
     

Available Tables
  1. actor (BASE TABLE)
  2. actor_info (VIEW)
  3. address (BASE TABLE)
  4. category (BASE TABLE)
  5. city (BASE TABLE)
  6. country (BASE TABLE)
  7. customer (BASE TABLE)
  8. customer_list (VIEW)
  9. film (BASE TABLE)
  10. film_actor (BASE TABLE)
  11. film_category (BASE TABLE)
  12. film_list (VIEW)
  13. inventory (BASE TABLE)
  14. language (BASE TABLE)
  15. nicer_but_slower_film_list (VIEW)
  16. payment (BASE TABLE)
  17. rental (BASE TABLE)
  18. sales_by_film_category (VIEW)
  19. sales_by_store (VIEW)
  20. staff (BASE TABLE)
  21. staff_list (VIEW)
  22. store (BASE TABLE)

🔢 Total tables found: 22


# EXPLORING TABLE STRUCTURES

In [7]:
def explore_table_structure(table_name, limit=5):
    columns_query = f"""
    SELECT 
        column_name,
        data_type,
        is_nullable,
        column_default
    FROM information_schema.columns 
    WHERE table_name = '{table_name}'
    ORDER BY ordinal_position;
    """
    columns_df = db.execute_query(columns_query)
    print(f"🔍 Table: {table_name}")
    print("=" * 60)
    print("📊 Column Structure:")
    for idx, row in columns_df.iterrows():
        nullable = "NULL" if row['is_nullable'] == 'YES' else "NOT NULL"
        default = f" DEFAULT {row['column_default']}" if row['column_default'] else ""
        print(f"  • {row['column_name']}: {row['data_type']} ({nullable}){default}")
    # Get sample data
    sample_query = f"SELECT * FROM {table_name} LIMIT {limit};"
    sample_df = db.execute_query(sample_query)
    
    print(f"\n📝 Sample Data (first {limit} rows):")
    if sample_df is not None and not sample_df.empty:
        print(sample_df.to_string())
    else:
        print("  No data found or query failed")
    
    print("\n" + "=" * 60)
    return columns_df, sample_df


if connection_success and not tables_df.empty:
    # Take first few tables to explore
    tables_to_explore = tables_df['table_name'].head(3).tolist()
    
    for table in tables_to_explore:
        try:
            explore_table_structure(table)
            print()
        except Exception as e:
            print(f"❌ Error exploring {table}: {str(e)}")
            print()


🔍 Table: actor
📊 Column Structure:
  • actor_id: integer (NOT NULL) DEFAULT nextval('actor_actor_id_seq'::regclass)
  • first_name: character varying (NOT NULL)
  • last_name: character varying (NOT NULL)
  • last_update: timestamp without time zone (NOT NULL) DEFAULT now()

📝 Sample Data (first 5 rows):
   actor_id first_name     last_name             last_update
0         1   Penelope       Guiness 2013-05-26 14:47:57.620
1         2       Nick      Wahlberg 2013-05-26 14:47:57.620
2         3         Ed         Chase 2013-05-26 14:47:57.620
3         4   Jennifer         Davis 2013-05-26 14:47:57.620
4         5     Johnny  Lollobrigida 2013-05-26 14:47:57.620


🔍 Table: actor_info
📊 Column Structure:
  • actor_id: integer (NULL)
  • first_name: character varying (NULL)
  • last_name: character varying (NULL)
  • film_info: text (NULL)

📝 Sample Data (first 5 rows):
   actor_id first_name     last_name                                                                                  

# SCHEMA CONTEXT BUILDER

In [9]:
class SchemaContextBuilder:
    """
    Builds context about database schema for AI models to generate accurate SQL queries
    """
    def __init__(self, db_connection):
        self.db = db_connection
        self.schema_cache = {}
        self.build_full_schema_context()

    def build_full_schema_context(self):
        """Build complete schema context for all tables"""
        
        # Get all tables
        tables_query = """
        SELECT table_name, table_schema 
        FROM information_schema.tables 
        WHERE table_schema = 'public'
        ORDER BY table_name;
        """
        tables_df = self.db.execute_query(tables_query)
        if tables_df is None:
            return
        for _, row in tables_df.iterrows():
            table_name = row['table_name']
            self.schema_cache[table_name] = self.get_table_schema(table_name)

    def get_table_schema(self, table_name):
        """Get detailed schema for a specific table"""

        columns_query = f"""
        SELECT 
            column_name,
            data_type,
            is_nullable,
            column_default,
            character_maximum_length
        FROM information_schema.columns 
        WHERE table_name = '{table_name}'
        ORDER BY ordinal_position;
        """
        columns_df = self.db.execute_query(columns_query)

        if columns_df is None:
            return None
        
        # Get foreign key relationships
        fk_query = f"""
        SELECT
            kcu.column_name,
            ccu.table_name AS foreign_table_name,
            ccu.column_name AS foreign_column_name
        FROM information_schema.table_constraints AS tc
        JOIN information_schema.key_column_usage AS kcu
            ON tc.constraint_name = kcu.constraint_name
        JOIN information_schema.constraint_column_usage AS ccu
            ON ccu.constraint_name = tc.constraint_name
        WHERE tc.constraint_type = 'FOREIGN KEY'
            AND tc.table_name = '{table_name}';
        """

        fk_df = self.db.execute_query(fk_query)

        # Build schema info
        schema_info = {
            'table_name': table_name,
            'columns': [],
            'foreign_keys': []
        }

        for _, col in columns_df.iterrows():
            col_info = {
                'name': col['column_name'],
                'type': col['data_type'],
                'nullable': col['is_nullable'] == 'YES',
                'default': col['column_default'],
                'max_length': col['character_maximum_length']
            }
            schema_info['columns'].append(col_info)

        if fk_df is not None and not fk_df.empty:
            for _, fk in fk_df.iterrows():
                fk_info = {
                    'column': fk['column_name'],
                    'references_table': fk['foreign_table_name'],
                    'references_column': fk['foreign_column_name']
                }
                schema_info['foreign_keys'].append(fk_info)
        
        return schema_info
    

    def get_relevant_tables(self, query_text):
        """Identify tables that might be relevant to the query"""
        query_lower = query_text.lower()
        relevant_tables = []
        
        for table_name in self.schema_cache.keys():
            # Check if table name appears in query
            if table_name.lower() in query_lower:
                relevant_tables.append(table_name)
                continue
                
            # Check if any column names appear in query
            schema = self.schema_cache[table_name]
            if schema:
                for col in schema['columns']:
                    if col['name'].lower() in query_lower:
                        relevant_tables.append(table_name)
                        break
        
        # If no specific tables found, return first few tables
        if not relevant_tables:
            relevant_tables = list(self.schema_cache.keys())[:5]
            
        return relevant_tables
    
    def build_context_for_query(self, query_text):
        """Build focused context for a specific query"""
        relevant_tables = self.get_relevant_tables(query_text)
        
        context = f"""
DATABASE SCHEMA INFORMATION:
Database: {self.db.database}
Relevant Tables for Query: "{query_text}"

"""
        
        for table_name in relevant_tables:
            schema = self.schema_cache.get(table_name)
            if not schema:
                continue
                
            context += f"TABLE: {table_name}\n"
            context += "Columns:\n"
            
            for col in schema['columns']:
                nullable = "NULL" if col['nullable'] else "NOT NULL"
                context += f"  - {col['name']}: {col['type']} ({nullable})\n"
            
            if schema['foreign_keys']:
                context += "Foreign Keys:\n"
                for fk in schema['foreign_keys']:
                    context += f"  - {fk['column']} -> {fk['references_table']}.{fk['references_column']}\n"
            
            context += "\n"
        
        return context
    

# Initialize schema builder
if connection_success:
    schema_builder = SchemaContextBuilder(db)
    print("✅ Schema context builder initialized!")
    print(f"📊 Cached schema for {len(schema_builder.schema_cache)} tables")
else:
    print("❌ Cannot initialize schema builder - no database connection")

✅ Schema context builder initialized!
📊 Cached schema for 22 tables


# AGENT BUILDING LOGIC

In [10]:
# AI Configuration
# You'll need to set your API keys here
# Option 1: Set as environment variables
# export OPENAI_API_KEY="your-openai-key"
# export ANTHROPIC_API_KEY="your-anthropic-key"

# Option 2: Set directly in code (less secure)
# os.environ["OPENAI_API_KEY"] = "your-openai-key"
# os.environ["ANTHROPIC_API_KEY"] = "your-anthropic-key"

def get_available_models():
    """Check which AI models are available based on API keys"""
    models = {}
    
    # Check OpenAI
    if os.getenv("OPENAI_API_KEY"):
        try:
            models["openai"] = ChatOpenAI(
                model="gpt-4o-mini",  # Cost-effective but powerful
                temperature=0.1,      # Low temperature for consistent SQL generation
                max_tokens=1000
            )
            print("✅ OpenAI GPT-4o-mini available")
        except Exception as e:
            print(f"❌ OpenAI setup failed: {str(e)}")
    
    # Check Anthropic
    if os.getenv("ANTHROPIC_API_KEY"):
        try:
            models["anthropic"] = ChatAnthropic(
                model="claude-3-haiku-20240307",  # Fast and cost-effective
                temperature=0.1,
                max_tokens=1000
            )
            print("✅ Anthropic Claude available")
        except Exception as e:
            print(f"❌ Anthropic setup failed: {str(e)}")
    
    if not models:
        print("⚠️  No AI models available. Please set your API keys.")
        print("   You can use OpenAI, Anthropic, or other compatible models.")
        print("   For this tutorial, we'll create a mock model for demonstration.")
        
        # Create a mock model for demonstration
        class MockModel:
            def invoke(self, messages):
                # Simple pattern matching for demo
                user_msg = messages[-1].content.lower()
                
                if "count" in user_msg and "table" in user_msg:
                    return type('Response', (), {'content': 'SELECT COUNT(*) FROM your_table_name;'})()
                elif "select" in user_msg or "show" in user_msg:
                    return type('Response', (), {'content': 'SELECT * FROM your_table_name LIMIT 10;'})()
                else:
                    return type('Response', (), {'content': 'SELECT * FROM your_table_name WHERE condition = value;'})()
        
        models["mock"] = MockModel()
        print("✅ Mock model created for demonstration")
    
    return models

# Initialize available models
available_models = get_available_models()
print(f"\n📊 Available models: {list(available_models.keys())}")

⚠️  No AI models available. Please set your API keys.
   You can use OpenAI, Anthropic, or other compatible models.
   For this tutorial, we'll create a mock model for demonstration.
✅ Mock model created for demonstration

📊 Available models: ['mock']
