In [75]:
import os
import openai
import pyodbc
from dotenv import load_dotenv
import pandas as pd
from tabulate import tabulate

def load_api_key():
    """Load the OpenAI API key from an environment variable."""
    load_dotenv()  # Load environment variables from a .env file (optional)
    return os.getenv("OPENAI_API_KEY")  # Make sure the key is set in your environment

def get_user_query():
    """Prompt the user or read from input."""
    return input("Give me a list of brands: ")

def generate_sql_query(natural_language_query, schema_info=""):
    """
    Send the user query (and optional schema info) to the OpenAI API
    and return the generated SQL query.
    """
    # Craft a system prompt or instructions:
    system_prompt = f"""
    You are an expert SQL assistant. Only write the Select queries for Microsoft SQL server.  Given the database schema:
    {schema_info}

    Convert the following natural language request into a SQL query:
    """

    response = openai.ChatCompletion.create(
        model="gpt-3.5-turbo",
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": natural_language_query}
        ],
        max_tokens=150,
        temperature=0.0,  # Keep it deterministic for predictable SQL
    )

    # Extract the generated text (SQL query) from the response
    sql_query = response.choices[0].message["content"]
    sql_query = sql_query[sql_query.find("SELECT"):sql_query.rfind("```")].strip() if "SELECT" in sql_query else None   
    return sql_query.strip()

def execute_query(sql_query):
    """Execute the generated SQL query on a SQL Server instance running in Docker."""
    try:
        # Update server=localhost,1433 -- no space between host and port
        conn = pyodbc.connect(
            "DRIVER={ODBC Driver 17 for SQL Server};"
            "SERVER=localhost,1433;"
            "DATABASE=SampleDB;"
            "UID=sa;"
            "PWD=dockerStrongPwd123;"
        )
        cursor = conn.cursor()

        cursor.execute(sql_query)
        if cursor.description:  # If it's a SELECT query, fetch results
            records = cursor.fetchall()
            # Print or process each row in the result
            
            df = pd.DataFrame.from_records(records, columns=[column[0] for column in cursor.description])
            return df
        else:
            # For INSERT, UPDATE, etc., commit changes
            conn.commit()
            print("Query executed successfully!")
    except Exception as e:
        print("Error executing query:", e)
    finally:
        cursor.close()
        conn.close()
def display_results(df):
    """Display the DataFrame results in a nicely formatted table."""
    if df is not None and not df.empty:
        # Display using tabulate for a nice ASCII table
        try:
            # Try to import tabulate - if not available, fall back to pandas display
            print("\nQuery Results:")
            print(tabulate(df, headers=df.columns, tablefmt="grid", showindex=False))
            print(f"\nTotal rows: {len(df)}")
        except ImportError:
            # Fallback if tabulate is not installed
            print("\nQuery Results:")
            pd.set_option('display.max_rows', None)  # Show all rows
            pd.set_option('display.max_columns', None)  # Show all columns
            pd.set_option('display.width', None)  # Auto-detect width
            pd.set_option('display.expand_frame_repr', False)  # Don't wrap to multiple lines
            print(df)
            print(f"\nTotal rows: {len(df)}")
    else:
        print("No results to display.")
def main():
    # 1. Load API key
    openai_api_key = load_api_key()
    if not openai_api_key:
        print("Error: OpenAI API key not found.")
        return
    openai.api_key = openai_api_key

    # 2. Get user input
    user_query = get_user_query()

    # 3. (Optional) Provide schema info (hardcoded example)
    schema_info = """
    Table:
      - BrandData(brand VARCHAR(50), count VARCHAR(255), percentage FLOAT, state VARCHAR(50), year VARCHAR(255), quater INT)
    """

    # 4. Generate the SQL query from the user’s request
    print("Generating SQL from natural language...")
    sql_query = generate_sql_query(user_query, schema_info)
    print(f"Generated SQL query:\n{sql_query}\n")

    # 5. Execute the SQL query
    print("Executing query...")
    res = execute_query(sql_query)
    display_results(res)

if __name__ == "__main__":
    main()


Generating SQL from natural language...
Generated SQL query:
SELECT *
FROM BrandData
WHERE count = 'Apple'
AND percentage = (SELECT MAX(percentage) FROM BrandData WHERE count = 'Apple');

Executing query...

Query Results:
+---------+---------+--------------+-----------+-------------+-----------+
|   brand | count   |   percentage |     state | year        | quarter   |
|       5 | Apple   |   1.5108e+06 | 0.0313265 | maharashtra | 2022,1    |
+---------+---------+--------------+-----------+-------------+-----------+

Total rows: 1


In [12]:
import pyodbc

conn = pyodbc.connect(
    "DRIVER={ODBC Driver 17 for SQL Server};"
    "SERVER=localhost,1433;"
    "DATABASE=master;"        # or any DB you want to connect to
    "UID=sa;"
    "PWD=dockerStrongPwd123;"
)
print("Connection successful!")
conn.close()

Connection successful!


In [88]:
import os
import openai  
import pyodbc
from dotenv import load_dotenv
import pandas as pd
import re
from tabulate import tabulate
import time
import sys

def load_api_key():
    """Load the OpenAI API key from an environment variable."""
    load_dotenv()
    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        print("Error: OpenAI API key not found in environment variables or .env file.")
        print("Please set the OPENAI_API_KEY environment variable and try again.")
        sys.exit(1)
    return api_key

def get_user_query():
    """Prompt the user or read from input."""
    while True:
        query = input("\nWhat do you want to know about the brands? (type 'exit' to quit): ")
        if query.strip().lower() == 'exit':
            print("Exiting program. Goodbye!")
            sys.exit(0)
        if not query.strip():
            print("Please enter a valid query.")
            continue
        return query

def generate_sql_query(natural_language_query, schema_info=""):
    """
    Send the user query to the OpenAI API and return the generated SQL query.
    """
    # Craft a system prompt with SQL Server specific instructions
    system_prompt = f"""
    You are an expert Microsoft SQL Server assistant. Given the database schema:
    {schema_info}

    IMPORTANT: Generate SQL specifically for Microsoft SQL Server (T-SQL).
    Use T-SQL syntax conventions:
    - Use TOP instead of LIMIT
    - Use [] for table/column names with spaces or special characters
    - No trailing semicolons
    - Always return valid SQL, even if the request is ambiguous
    - For strange or unclear requests, default to a simple query that shows sample data

    Convert the following natural language request into a SQL query and ONLY return the executable query with no markdown or explanation:
    """

    try:
        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": natural_language_query}
            ],
            max_tokens=150,
            temperature=0.0,
        )
        
        # Extract the SQL query and clean it up
        sql_query = response.choices[0].message.content.strip()
        
        # Process and clean the query
        sql_query = clean_sql_query(sql_query)
        
        # Validate that it's a reasonable SQL query
        valid, message = validate_sql_with_ai(sql_query, schema_info)
        if not valid:
            print(f"Warning: {message}")
            print("Using a safe fallback query instead.")
            return "SELECT TOP 5 * FROM BrandData"
            
        return sql_query
        
    except Exception as e:
        print(f"Error: Failed to generate SQL: {e}")
        # Fallback to a safe query
        return "SELECT TOP 5 * FROM BrandData"

def clean_sql_query(sql_query):
    """Clean up the SQL query by removing markdown formatting and fixing syntax."""
    # Remove markdown code blocks if present
    if "```" in sql_query:
        sql_query = re.sub(r'```sql', '', sql_query)
        sql_query = re.sub(r'```', '', sql_query)
    
    # Remove any trailing semicolons
    sql_query = sql_query.strip()
    if sql_query.endswith(';'):
        sql_query = sql_query[:-1]
        

    return sql_query.strip()

def validate_sql_with_ai(query, schema_info):
    """
    Use AI to check if the SQL is valid for SQL Server.
    Returns (is_valid, message) tuple.
    """
    # Basic validation first before spending API calls
    if not query:
        return False, "Empty query."
    
    # Check if it contains basic SQL keywords
    basic_sql_pattern = re.compile(r'\bSELECT\b', re.IGNORECASE)
    if not basic_sql_pattern.search(query):
        return False, "Query doesn't contain SELECT keyword."
    
    # Check for balanced parentheses
    open_count = query.count('(')
    close_count = query.count(')')
    if open_count != close_count:
        return False, f"Unbalanced parentheses: {open_count} opening vs {close_count} closing."
    
    try:
        
        system_prompt = f"""
        You are an expert in SQL Server (T-SQL) validation. Given the database schema:
        {schema_info}
        
        Your task is to validate if the following SQL query is valid T-SQL syntax that would work in SQL Server.
        Do not execute the query, just analyze its syntax.
        JUST CHECK THE SYNTAX ONLY AND ENSURE IT ONLY RETURNS THE DATA.
        IF THE QUERY IS ABOUT ANY UPDATE, DELETE, OR ANY DATA MODIFICATION, MARK IT AS INVALID.
        IGNORE any semantic issues, such as filtering based on data type mismatches. Focus solely on the syntax.
        """
        
        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": query}
            ],
            max_tokens=100,
            temperature=0.0
        )
        
        validation_result = response.choices[0].message.content.strip()
        
        if validation_result.startswith("VALID"):
            return True, "AI validation passed."
        elif validation_result.startswith("INVALID"):
            reason = validation_result[8:].strip()  # Remove "INVALID: " prefix
            return False, f"AI validation failed: {reason}"
        else:
            # If response doesn't follow expected format, assume it's ok
            # This avoids false negatives on unusual but valid SQL
            return True, "AI validation inconclusive, proceeding with query."
            
    except Exception as e:
        # If AI validation fails, fall back to basic validation
        print(f"AI validation error: {e}")
        return True, "AI validation failed, proceeding with caution."

def test_connection(connection_string, timeout=5):
    """Test database connection before executing queries."""
    try:
        conn = pyodbc.connect(connection_string, timeout=timeout)
        conn.close()
        return True, "Connection successful"
    except Exception as e:
        return False, f"Connection failed: {str(e)}"

def execute_query(sql_query, connection_string, timeout=30):
    """
    Execute the generated SQL query with robust error handling.
    Returns a tuple: (success_flag, result_or_error_message)
    """
    conn = None
    cursor = None
    
    try:
        # First test the connection
        connection_ok, message = test_connection(connection_string)
        if not connection_ok:
            return False, f"Database connection error: {message}"
        
        # Connect to the database
        conn = pyodbc.connect(connection_string, timeout=timeout)
        cursor = conn.cursor()
        
        # Time how long the query takes
        start_time = time.time()
        
        # Execute the query with a timeout
        cursor.execute(sql_query)
        
        execution_time = time.time() - start_time
        
        # For SELECT queries, fetch results
        if cursor.description:
            # Get column names
            columns = [column[0] for column in cursor.description]
            
            # Fetch all rows
            rows = cursor.fetchall()
            
            # Convert to DataFrame
            df = pd.DataFrame.from_records(rows, columns=columns)
            
            # Add query statistics
            query_stats = {
                'success': True,
                'rows_returned': len(df),
                'execution_time': f"{execution_time:.2f}s",
                'query_type': 'SELECT'
            }
            
            return True, (df, query_stats)
        else:
            # For non-SELECT queries (INSERT, UPDATE, DELETE, etc.)
            affected_rows = cursor.rowcount
            conn.commit()
            
            query_stats = {
                'success': True,
                'rows_affected': affected_rows,
                'execution_time': f"{execution_time:.2f}s", 
                'query_type': 'UPDATE/INSERT/DELETE'
            }
            
            return True, (None, query_stats)
            
    except pyodbc.Error as e:
        # Handle specific database errors
        error_code = e.args[0] if e.args else "Unknown"
        error_message = e.args[1] if len(e.args) > 1 else str(e)
        
        if isinstance(error_code, str) and 'timeout' in error_code.lower():
            return False, f"Query timed out after {timeout} seconds. Try a more specific query."
        elif '42S02' in str(error_code):  # Invalid object name
            return False, f"Table not found: {error_message}"
        elif '42000' in str(error_code):  # Syntax error
            return False, f"SQL syntax error: {error_message}"
        else:
            return False, f"Database error ({error_code}): {error_message}"
    except Exception as e:
        # Handle other exceptions
        return False, f"Error executing query: {str(e)}"
    finally:
        # Clean up resources
        if cursor:
            cursor.close()
        if conn:
            conn.close()

def display_results(success, result, sql_query):
    """
    Display query results or error messages in a user-friendly format.
    """
    if success:
        df, stats = result
        
        # For SELECT queries with data
        if df is not None:
            if df.empty:
                print("\n⚠️ Query executed successfully but returned no results.")
                print(f"Query type: {stats['query_type']}")
                print(f"Execution time: {stats['execution_time']}")
                suggest_alternatives(sql_query)
            else:
                print("\n✅ Query Results:")
                # Display using tabulate for a nice ASCII table
                try:
                    # Limit display to avoid overwhelming output for large results
                    display_df = df.head(20) if len(df) > 20 else df
                    print(tabulate(display_df, headers=display_df.columns, tablefmt="grid", showindex=False))
                    if len(df) > 20:
                        print(f"\n(Showing 20 of {len(df)} rows)")
                except ImportError:
                    print(df.head(20) if len(df) > 20 else df)
                    if len(df) > 20:
                        print(f"\n(Showing 20 of {len(df)} rows)")
                
                print(f"\nTotal rows: {stats['rows_returned']}")
                print(f"Execution time: {stats['execution_time']}")
                
                # Provide summary statistics for numeric columns
                numeric_cols = df.select_dtypes(include=['number']).columns
                if len(numeric_cols) > 0:
                    print("\nNumeric column statistics:")
                    stats_df = df[numeric_cols].describe().transpose()[['count', 'mean', 'min', 'max']]
                    print(tabulate(stats_df, headers=stats_df.columns, tablefmt="simple", floatfmt=".2f"))
        else:
            # For non-SELECT queries
            print(f"\n✅ Command executed successfully.")
            print(f"Rows affected: {stats['rows_affected']}")
            print(f"Execution time: {stats['execution_time']}")
    else:
        # Handle errors
        print(f"\n❌ Error: {result}")
        suggest_alternatives(sql_query)

def suggest_alternatives(failed_query):
    """Suggest alternative queries when the original fails."""
    print("\nSuggestions:")
    print("1. Try a simpler query to check if the table exists:")
    print("   SELECT TOP 5 * FROM BrandData")
    print("2. Check available brands with:")
    print("   SELECT DISTINCT brand FROM BrandData")
    
    # If it's a complex query, suggest breaking it down
    if len(failed_query) > 50 or failed_query.lower().count('where') > 1:
        print("3. Break down your complex query into simpler parts")

def get_connection_string():
    """Return the database connection string, with option to override from env vars."""
    # Check for environment variables first
    server = os.getenv("DB_SERVER", "localhost,1433")
    database = os.getenv("DB_NAME", "SampleDB")
    username = os.getenv("DB_USER", "sa")
    password = os.getenv("DB_PASSWORD", "dockerStrongPwd123")
    
    return (
        f"DRIVER={{ODBC Driver 17 for SQL Server}};"
        f"SERVER={server};"
        f"DATABASE={database};"
        f"UID={username};"
        f"PWD={password};"
    )

def main():
    print("\n=== SQL Chatbot with AI-Based SQL Validation ===")
    print("Ask questions about the brand data in natural language.")
    
    try:
        # 1. Load API key
        load_api_key()  # Will exit if API key not found
        
        # Database schema info
        schema_info = """
        Table:
          - BrandData(brand VARCHAR(50), count VARCHAR(255), percentage FLOAT, state VARCHAR(50), year VARCHAR(255), quater INT)
          
        Sample data:
          - Contains brand information across different states and quarters
          - Brands include various companies
          - Percentage represents market share
        """
        
        # Get connection string
        connection_string = get_connection_string()
        
        # Test connection at startup
        connection_ok, message = test_connection(connection_string)
        if not connection_ok:
            print(f"Warning: {message}")
            print("Continuing anyway, but queries may fail.")
        
        # Main interaction loop
       
        # 2. Get user input
        user_query = get_user_query()
        
        # 3. Generate the SQL query from the user's request
        print("\nGenerating SQL from natural language...")
        sql_query = generate_sql_query(user_query, schema_info)
        print(f"Generated SQL query:\n{sql_query}\n")
        
        # 4. Validate SQL using AI
        print("Validating SQL syntax...")
        valid, validation_message = validate_sql_with_ai(sql_query, schema_info)
        if valid:
            print(f"✅ Validation passed: {validation_message}")
        else:
            print(f"⚠️ Validation warning: {validation_message}")
            correction = input("Would you like to proceed anyway? (y/n): ").strip().lower()
            if correction != 'y':
                print("Skipping execution. Try rephrasing your question.")
                print("\n" + "-" * 50)
                
        
        # 5. Execute the SQL query
        print("Executing query...")
        success, result = execute_query(sql_query, connection_string)
        
        # 6. Display results
        display_results(success, result, sql_query)
        
        print("\n" + "-" * 50)  # Separator between queries
        
    except KeyboardInterrupt:
        print("\n\nProgram interrupted. Exiting gracefully...")
    except Exception as e:
        print(f"\nUnexpected error: {str(e)}")
        print("Exiting program.")
    finally:
        print("\nThank you for using the SQL Chatbot!")

if __name__ == "__main__":
    main()


=== SQL Chatbot with AI-Based SQL Validation ===
Ask questions about the brand data in natural language.

Generating SQL from natural language...
Generated SQL query:
SELECT count, SUM(percentage) AS total_percentage
FROM BrandData
GROUP BY count

Validating SQL syntax...
✅ Validation passed: AI validation inconclusive, proceeding with query.
Executing query...

✅ Query Results:
+------------+--------------------+
| count      |   total_percentage |
| Infinix    |        5.14287e+06 |
+------------+--------------------+
| Samsung    |        6.71604e+08 |
+------------+--------------------+
| Oppo       |        4.2025e+08  |
+------------+--------------------+
| Xiaomi     |        8.69563e+08 |
+------------+--------------------+
| COOLPAD    |       10           |
+------------+--------------------+
| Lava       |        1.53011e+06 |
+------------+--------------------+
| Motorola   |        7.33407e+07 |
+------------+--------------------+
| Others     |        2.8295e+08  |
+----