# PrintManager

In [1]:
from colorama import init, Fore, Back, Style
from datetime import datetime

init(autoreset=True)  # Initialize Colorama

class PrintManager:
    @staticmethod
    def section_header(title):
        """Print section header"""
        print(f"\n{Fore.CYAN}{'='*80}")
        print(f"{Fore.CYAN}== {title}")
        print(f"{Fore.CYAN}{'='*80}{Style.RESET_ALL}\n")

    @staticmethod
    def subsection(title):
        """Print subsection title"""
        print(f"\n{Fore.BLUE}{'-'*40}")
        print(f"{Fore.BLUE}-- {title}")
        print(f"{Fore.BLUE}{'-'*40}{Style.RESET_ALL}\n")

    @staticmethod
    def success(message):
        """Print success message"""
        print(f"{Fore.GREEN}✓ {message}{Style.RESET_ALL}")

    @staticmethod
    def error(message):
        """Print error message"""
        print(f"{Fore.RED}✗ ERROR: {message}{Style.RESET_ALL}")

    @staticmethod
    def warning(message):
        """Print warning message"""
        print(f"{Fore.YELLOW}⚠ WARNING: {message}{Style.RESET_ALL}")

    @staticmethod
    def info(message):
        """Print info message"""
        print(f"{Fore.WHITE}ℹ {message}{Style.RESET_ALL}")

    @staticmethod
    def security(message, is_safe=True):
        """Print security message"""
        if is_safe:
            print(f"{Fore.GREEN}🔒 {message}{Style.RESET_ALL}")
        else:
            print(f"{Fore.RED}🔓 {message}{Style.RESET_ALL}")

    @staticmethod
    def query_result(result_text):
        """Print query result"""
        print(f"{Fore.CYAN}{result_text}{Style.RESET_ALL}")

    @staticmethod
    def performance(metrics):
        """Print performance metrics"""
        print(f"\n{Fore.MAGENTA}📊 Performance Metrics:")
        for key, value in metrics.items():
            print(f"{Fore.MAGENTA}   {key}: {value}{Style.RESET_ALL}")

    @staticmethod
    def timestamp():
        """Print timestamp"""
        return f"[{datetime.now().strftime('%H:%M:%S')}]"

# MAIN APP

In [2]:
import os
import time
import logging
from datetime import datetime
from pathlib import Path
import sqlite3
import re

from llama_index.core import Settings
from llama_index.llms.gemini import Gemini
from llama_index.core.workflow import (
    Event,
    StartEvent,
    StopEvent,
    Workflow,
    step,
)

pm = PrintManager()

GOOGLE_API_KEY = "YOUR API KEY" 
os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY

class IntentAnalysisEvent(Event):
    intent: str
    message: str
    
class SQLGenerationEvent(Event):
    sql_query: str

class SQLExecutionEvent(Event):
    execution_result: str
    execution_time: float
    row_count: int

class FeedbackEvent(Event):
    feedback: str
    success: bool
    

    
class IntentAnalyzer(Workflow):
    def __init__(self):
        super().__init__()
        self.llm = Gemini()
        Settings.llm = self.llm
        
        # Defined intent patterns
        self.sql_patterns = [
            r'(?i)(show|list|find|search|sort)',
            r'(?i)(products|stock|price)',
            r'(?i)(how many|total|average)',
            r'(?i)(sql|query|database)',
            r'(?i)(highest|lowest|maximum|minimum)',
        ]
        
        self.chat_patterns = [
            r'(?i)(hello|hi|how are you)',
            r'(?i)(chat|talk|conversation)',
            r'(?i)(what are you doing|who are you)',
            r'(?i)(thank you|thanks)',
        ]

    async def analyze_intent(self, prompt: str) -> tuple[str, str]:
        """Analyze the purpose of the user's prompt"""
        # Pattern-based preliminary check
        if any(re.search(pattern, prompt) for pattern in self.sql_patterns):
            return "sql", "SQL query detected"
        if any(re.search(pattern, prompt) for pattern in self.chat_patterns):
            return "chat", "Chat intent detected"
            
        # Detailed analysis with LLM
        analysis_prompt = f"""
        Please analyze the purpose of the following user message:
        "{prompt}"
        
        There are only two options:
        1. SQL: The user wants to perform a database query
        2. CHAT: The user wants to chat
        
        Only write "SQL" or "CHAT".
        """
        
        response = await self.llm.acomplete(analysis_prompt)
        intent = str(response).strip().upper()
        
        if intent == "SQL":
            return "sql", "LLM analysis: SQL query detected"
        return "chat", "LLM analysis: Chat intent detected"

    @step
    async def determine_intent(self, ev: StartEvent) -> StopEvent:  
        prompt = ev.topic
        intent, message = await self.analyze_intent(prompt)
        return StopEvent(result={"intent": intent, "message": message}) 


class SQLAnalysisAgent(Workflow):
    def __init__(self):
        super().__init__()
        self.llm = Gemini()
        Settings.llm = self.llm
        
        # Logging settings
        log_file = f'logs/sql_agent_{datetime.now().strftime("%Y%m%d")}.log'
        os.makedirs('logs', exist_ok=True)
        logging.basicConfig(
            filename=log_file,
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s'
        )
        
        # SQLite connection
        db_path = Path('data/database.db')
        if not db_path.parent.exists():
            db_path.parent.mkdir(parents=True)
        self.db_connection = sqlite3.connect(db_path, check_same_thread=False)
        self.cursor = self.db_connection.cursor()
        
        # Safe SQL patterns
        self.safe_patterns = {
            'SELECT': r'^SELECT\s+(?:(?:[\w\s,.()*]|\s)+)\s+FROM\s+[\w]+(?:\s+WHERE\s+[\w\s><=]+)?(?:\s+ORDER\s+BY\s+[\w\s,]+)?(?:\s+LIMIT\s+\d+)?$',
            'COUNT': r'^SELECT\s+COUNT\s*\(\s*\*\s*\)\s+FROM\s+[\w]+(?:\s+WHERE\s+[\w\s><=]+)?$',
            'AVG': r'^SELECT\s+AVG\s*\(\s*[\w]+\s*\)\s+FROM\s+[\w]+(?:\s+WHERE\s+[\w\s><=]+)?$'
        }
        
        # Dangerous characters and patterns
        self.dangerous_patterns = [
            r';.*',                    # Multiple queries
            r'--.*',                   # SQL comments
            r'/\*.*?\*/',             # Multiline comments
            r'xp_.*',                 # System stored procedures
            r'exec.*',                # Execute commands
            r'UNION.*',               # UNION attacks
            r'DROP.*',                # DROP commands
            r'DELETE.*',              # DELETE commands
            r'UPDATE.*',              # UPDATE commands
            r'ALTER.*',               # ALTER commands
            r'TRUNCATE.*',            # TRUNCATE commands
            r'INSERT.*',              # INSERT commands
            r'GRANT.*',               # GRANT commands
            r'REVOKE.*',              # REVOKE commands
            r'SYSTEM.*',              # System commands
            r'INTO\s+(?:OUTFILE|DUMPFILE).*', # File operations
        ]
        
        # Allowed tables and columns
        self.allowed_tables = {'products'}
        self.allowed_columns = {
            'products': {'id', 'name', 'price', 'stock'}
        }
        
        # Malicious prompt patterns
        self.malicious_prompt_patterns = [
            r'(?i)(drop|delete|truncate|alter)\s+table',  # Dropping/modifying tables
            r'(?i)system\s+command',                      # System commands
            r'(?i)(hack|exploit|attack)',                 # Malicious words
            r'(?i)(union\s+select|join\s+select)',        # SQL injection
            r'(?i)(--|;|/\*|\*/)',                        # SQL comments and separators
            r'(?i)(xp_cmdshell|exec\s+sp)',               # Stored procedures
            r'(?i)(insert\s+into|update\s+set)',          # Data modification
            r'(?i)password|username|credential',           # Sensitive data
            r'(?i)grant|revoke|permission',               # Permission changes
            r'(?i)backup|restore|dump',                   # Backup operations
        ]
        
        # Safe prompt patterns
        self.safe_prompt_patterns = [
            r'(?i)(show|list|find|search|sort)',
            r'(?i)(products|stock|price)',
            r'(?i)(how many|total|average)',
            r'(?i)(highest|lowest|maximum|minimum)',
        ]

    def analyze_prompt_safety(self, prompt: str) -> tuple[bool, str]:
        """Analyze the user's prompt and check its safety"""
        if not prompt or not prompt.strip():
            return False, "Empty query"
            
        for pattern in self.malicious_prompt_patterns:
            if re.search(pattern, prompt):
                return False, f"Malicious content detected: {pattern}"
        
        safe_pattern_found = any(re.search(pattern, prompt) for pattern in self.safe_prompt_patterns)
        if not safe_pattern_found:
            return False, "Query does not contain safe patterns"
            
        if len(prompt) > 500:
            return False, "Query is too long"
            
        return True, "Safe query"

    async def verify_prompt_with_llm(self, prompt: str) -> tuple[bool, str]:
        """Verify the safety of the prompt using LLM"""
        verification_prompt = f"""
        Please analyze the safety of the following user query:
        "{prompt}"
        
        Check the following:
        1. Is there an attempt of SQL injection?
        2. Does it contain malicious commands?
        3. Are there any statements that threaten system security?
        4. Is it a query solely for data reading purposes?
        
        Only write "SAFE" or "UNSAFE" and briefly state the reason.
        """
        
        response = await self.llm.acomplete(verification_prompt)
        result = str(response).strip().upper()
        
        is_safe = result.startswith("SAFE")
        message = result.replace("SAFE", "").replace("UNSAFE", "").strip()
        
        return is_safe, message

    def sanitize_input(self, value: str) -> str:
        """Sanitize input against SQL injection"""
        if value is None:
            return None
        return value.replace("'", "''").replace(";", "").replace("--", "")

    def validate_sql_safety(self, sql_query: str) -> tuple[bool, str]:
        """Check the safety of the SQL query"""
        if not sql_query:
            return False, "Empty query"

        sql_upper = sql_query.upper()
        
        if not sql_upper.strip().startswith('SELECT'):
            return False, "Only SELECT queries are allowed"

        for pattern in self.dangerous_patterns:
            if re.search(pattern, sql_upper, re.IGNORECASE):
                return False, f"Dangerous pattern detected: {pattern}"

        tables = re.findall(r'FROM\s+(\w+)', sql_query, re.IGNORECASE)
        for table in tables:
            if table.lower() not in self.allowed_tables:
                return False, f"Unauthorized table: {table}"

        for pattern in self.safe_patterns.values():
            if re.match(pattern, sql_query, re.IGNORECASE):
                return True, "Safe query"

        return False, "Query format is not safe"

    def format_results(self, results, description):
        """Format query results"""
        if not results:
            return "No results found"
            
        column_names = [desc[0] for desc in description]
        formatted_result = "\nQuery Results:\n"
        formatted_result += "-" * 80 + "\n"
        formatted_result += " | ".join(f"{col:15}" for col in column_names) + "\n"
        formatted_result += "-" * 80 + "\n"
        
        for row in results:
            formatted_result += " | ".join(f"{str(item):15}" for item in row) + "\n"
        formatted_result += "-" * 80 + "\n"
        
        return formatted_result

    def learn_from_history(self, natural_query: str) -> str:
        """Learn from past queries"""
        self.cursor.execute("""
            SELECT natural_query, generated_sql, execution_result 
            FROM query_history 
            WHERE natural_query LIKE ? 
            AND execution_result NOT LIKE '%error%'
            ORDER BY created_at DESC 
            LIMIT 1
        """, (f"%{natural_query}%",))
        
        similar_query = self.cursor.fetchone()
        if similar_query:
            pm.info(f"\nSimilar successful query found:\n{similar_query}")
            return similar_query[1]
        return None

    def log_error(self, error_msg: str, query: str):
        """Log security breaches and errors"""
        logging.error(f"Security Breach - Query: {query}\nError: {error_msg}")
        try:
            self.cursor.execute("""
                INSERT INTO error_stats (error_type, query, message)
                VALUES (?, ?, ?)
            """, ('SECURITY_VIOLATION', query, error_msg))
            self.db_connection.commit()
        except Exception as e:
            logging.error(f"Log entry error: {str(e)}")

    @step
    async def generate_sql(self, ev: StartEvent) -> SQLGenerationEvent:
        pm.section_header("SQL Query Generation")
        prompt = ev.topic
        
        # Security checks
        pm.subsection("Security Analysis")
        is_safe, message = self.analyze_prompt_safety(prompt)
        if not is_safe:
            pm.security(f"Prompt is not safe: {message}", False)
            return SQLGenerationEvent(sql_query="SELECT 'Failed security check' as message")
        
        is_safe, message = await self.verify_prompt_with_llm(prompt)
        if not is_safe:
            pm.security(f"LLM security check failed: {message}", False)
            return SQLGenerationEvent(sql_query="SELECT 'LLM security check failed' as message")
        
        pm.security("Security checks passed", True)
        
        # Generate SQL
        pm.subsection("SQL Generation")
        query = self.sanitize_input(prompt)
        
        # Get schema information
        self.cursor.execute("SELECT table_name, columns_info FROM tables_info")
        schema_info = self.cursor.fetchall()
        
        # Learn from history
        learned_sql = self.learn_from_history(query)
        if learned_sql:
            pm.info(f"SQL learned from history: {learned_sql}")
            is_safe, message = self.validate_sql_safety(learned_sql)
            if not is_safe:
                learned_sql = None
                pm.warning(f"Learned query is not safe: {message}")
        
        # Generate SQL with LLM
        sql_prompt = f"""
        Database schema:
        {schema_info}
        
        Please translate the following natural language query into an SQL query:
        {query}
        
        IMPORTANT RULES:
        1. Only SELECT queries are allowed
        2. Only access the 'products' table
        3. Allowed columns: id, name, price, stock
        4. No multiple queries, comments, or special characters
        5. No complex queries like UNION, JOIN
        
        Only return the SQL query.
        """
        
        response = await self.llm.acomplete(sql_prompt)
        sql_query = str(response).strip().replace('```sql', '').replace('```', '').strip()
        
        # Validate SQL query
        is_safe, message = self.validate_sql_safety(sql_query)
        if not is_safe:
            pm.security(f"Generated SQL is not safe: {message}", False)
            return SQLGenerationEvent(sql_query="SELECT 'Security breach detected' as message")
        
        pm.success("SQL query generated successfully")
        pm.info(f"Original Query: {query}")
        pm.info(f"Generated SQL: {sql_query}")
        
        # Save SQL query
        self.cursor.execute(
            "INSERT INTO query_history (natural_query, generated_sql) VALUES (?, ?)",
            (query, sql_query)
        )
        self.db_connection.commit()
        
        return SQLGenerationEvent(sql_query=sql_query)

    @step
    async def execute_sql(self, ev: SQLGenerationEvent) -> SQLExecutionEvent:
        pm.section_header("Executing SQL Query")
        sql_query = ev.sql_query
        
        try:
            # Query plan
            pm.subsection("Query Plan Analysis")
            self.cursor.execute("EXPLAIN QUERY PLAN " + sql_query)
            query_plan = self.cursor.fetchall()
            for step in query_plan:
                pm.info(f"Plan Step: {step}")
            
            # Execute query
            pm.subsection("Executing Query")
            start_time = time.time()
            self.cursor.execute(sql_query)
            result = self.cursor.fetchall()
            execution_time = time.time() - start_time
            
            # Format results
            formatted_result = self.format_results(result, self.cursor.description)
            pm.query_result(formatted_result)
            
            # Performance metrics
            metrics = {
                "Execution Time": f"{execution_time:.4f} seconds",
                "Rows Returned": len(result),
                "Average Processing Time": f"{(execution_time/len(result) if len(result) > 0 else 0):.6f} seconds"
            }
            pm.performance(metrics)
            
            return SQLExecutionEvent(
                execution_result=formatted_result,
                execution_time=execution_time,
                row_count=len(result)
            )
            
        except Exception as e:
            error_message = f"Error executing SQL query: {str(e)}"
            pm.error(error_message)
            self.log_error(str(e), sql_query)
            return SQLExecutionEvent(
                execution_result=error_message,
                execution_time=0,
                row_count=0
            )

    def create_feedback_prompt(self, ev: SQLExecutionEvent) -> str:
        """Create a feedback prompt for the LLM"""
        return f"""
        Please briefly evaluate the result of this query:
        
        Query Metrics:
        - Execution Time: {ev.execution_time:.4f} seconds
        - Rows Returned: {ev.row_count}
        
        Query Result:
        {ev.execution_result}
        
        Please provide a SHORT and CLEAR evaluation based on the following criteria:
        1. Was the query successful? (Yes/No)
        2. Is the performance adequate? (Yes/No)
        3. If any, what are the improvement suggestions?
        """

    @step
    async def collect_feedback(self, ev: SQLExecutionEvent) -> StopEvent:
        pm.section_header("Query Evaluation")
        
        feedback = await self.llm.acomplete(self.create_feedback_prompt(ev))
        pm.subsection("LLM Evaluation")
        pm.info(str(feedback))
         
        
        return StopEvent(result=str(ev.execution_result))

    def __del__(self):
        """Cleanup operations"""
        try:
            if hasattr(self, 'cursor'):
                self.cursor.close()
            if hasattr(self, 'db_connection'):
                self.db_connection.close()
        except Exception as e:
            logging.error(f"Cleanup error: {str(e)}")

async def run_sql_agent(natural_query: str) -> str:
    """Run the SQL analysis agent"""
    # First, perform intent analysis
    intent_analyzer = IntentAnalyzer()
    result_dict = await intent_analyzer.run(topic=natural_query)  # Directly get the dictionary
    
    if result_dict["intent"] == "chat":
        pm.warning("I am an SQL assistant and can only help with database queries. I am not designed for chatting.")
        return "Please ask a question related to a database query."
    
    # If it's an SQL query, proceed with the normal flow
    agent = SQLAnalysisAgent()
    result = await agent.run(topic=natural_query)
    return str(result)

In [4]:
natural_query = "Give me the price of the most expensive product"
result = await run_sql_agent(natural_query)


== SQL Query Generation


----------------------------------------
-- Security Analysis
----------------------------------------

🔒 Security checks passed

----------------------------------------
-- SQL Generation
----------------------------------------

✓ SQL query generated successfully
ℹ Original Query: Give me the price of the most expensive product
ℹ Generated SQL: SELECT MAX(price) FROM products

== Executing SQL Query


----------------------------------------
-- Query Plan Analysis
----------------------------------------

ℹ Plan Step: (3, 0, 0, 'SEARCH products')

----------------------------------------
-- Executing Query
----------------------------------------


Query Results:
--------------------------------------------------------------------------------
MAX(price)     
--------------------------------------------------------------------------------
999.99         
--------------------------------------------------------------------------------


📊 Performance Metrics: