In [83]:
import sqlite3
import pandas as pd
import numpy as np
import os
import json
import re
from typing import Dict, List, Tuple, Optional, Union
import logging

In [84]:
# For LLM interaction - using Anthropic's Claude API
# For Kaggle testing, we'll use an API-based approach
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry

In [85]:
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [86]:
#####################################################
# Database Schema Management
#####################################################

class DatabaseSchemaManager:
    """Manages the extraction and representation of SQLite database schema."""
    
    def __init__(self, db_path: str):
        """
        Initialize with path to SQLite database.
        
        Args:
            db_path: Path to the SQLite database file
        """
        self.db_path = db_path
        self.connection = None
        self.schema_cache = None
    
    def connect(self) -> None:
        """Establish connection to the database."""
        try:
            self.connection = sqlite3.connect(self.db_path)
            logger.info(f"Connected to database at {self.db_path}")
        except sqlite3.Error as e:
            logger.error(f"Error connecting to database: {e}")
            raise
    
    def close(self) -> None:
        """Close the database connection."""
        if self.connection:
            self.connection.close()
            self.connection = None
            logger.info("Database connection closed")
    
    def get_schema(self, refresh: bool = False) -> Dict:
        """
        Extract database schema information.
        
        Args:
            refresh: Whether to refresh the schema cache
            
        Returns:
            Dictionary containing the database schema
        """
        if self.schema_cache is not None and not refresh:
            return self.schema_cache
        
        if not self.connection:
            self.connect()
        
        schema = {
            "tables": [],
            "relationships": []
        }
        
        # Get all tables
        try:
            cursor = self.connection.cursor()
            cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';")
            tables = cursor.fetchall()
            
            for table in tables:
                table_name = table[0]
                table_info = {
                    "name": table_name,
                    "columns": []
                }
                
                # Get column information
                cursor.execute(f"PRAGMA table_info('{table_name}');")
                columns = cursor.fetchall()
                
                for column in columns:
                    col_id, col_name, col_type, not_null, default_val, is_pk = column
                    table_info["columns"].append({
                        "name": col_name,
                        "type": col_type,
                        "is_primary_key": bool(is_pk),
                        "not_null": bool(not_null),
                        "default": default_val
                    })
                
                schema["tables"].append(table_info)
            
            # Get foreign key relationships
            for table in tables:
                table_name = table[0]
                cursor.execute(f"PRAGMA foreign_key_list('{table_name}');")
                foreign_keys = cursor.fetchall()
                
                for fk in foreign_keys:
                    id, seq, ref_table, from_col, to_col, on_update, on_delete, match = fk
                    schema["relationships"].append({
                        "table": table_name,
                        "column": from_col,
                        "references_table": ref_table,
                        "references_column": to_col
                    })
            
            self.schema_cache = schema
            logger.info(f"Extracted schema with {len(schema['tables'])} tables")
            return schema
            
        except sqlite3.Error as e:
            logger.error(f"Error extracting schema: {e}")
            raise
    
    def format_schema_for_llm(self) -> str:
        """
        Format the database schema in a way that's optimal for LLM understanding.
        
        Returns:
            String representation of the schema formatted for the LLM
        """
        schema = self.get_schema()
        formatted_schema = "DATABASE SCHEMA:\n\n"
        
        # Format tables and columns
        for table in schema["tables"]:
            formatted_schema += f"Table: {table['name']}\n"
            formatted_schema += "Columns:\n"
            
            for column in table["columns"]:
                pk_marker = " (PRIMARY KEY)" if column["is_primary_key"] else ""
                null_marker = " NOT NULL" if column["not_null"] else ""
                formatted_schema += f"  - {column['name']} ({column['type']}){pk_marker}{null_marker}\n"
            
            formatted_schema += "\n"
        
        # Format relationships
        if schema["relationships"]:
            formatted_schema += "Relationships:\n"
            for rel in schema["relationships"]:
                formatted_schema += f"  - {rel['table']}.{rel['column']} -> {rel['references_table']}.{rel['references_column']}\n"
        
        return formatted_schema

    def get_sample_data(self, limit: int = 3) -> Dict[str, pd.DataFrame]:
        """
        Get sample data from each table for better LLM understanding.
        
        Args:
            limit: Number of sample rows to fetch
            
        Returns:
            Dictionary mapping table names to DataFrames with sample data
        """
        if not self.connection:
            self.connect()
        
        schema = self.get_schema()
        samples = {}
        
        for table in schema["tables"]:
            table_name = table["name"]
            try:
                query = f"SELECT * FROM {table_name} LIMIT {limit};"
                samples[table_name] = pd.read_sql_query(query, self.connection)
                logger.info(f"Fetched {len(samples[table_name])} sample rows from {table_name}")
            except sqlite3.Error as e:
                logger.warning(f"Could not fetch sample data from {table_name}: {e}")
                samples[table_name] = pd.DataFrame()
        
        return samples
    
    def format_sample_data_for_llm(self, limit: int = 3) -> str:
        """
        Format sample data in a way that's optimal for LLM understanding.
        
        Args:
            limit: Number of sample rows to fetch
            
        Returns:
            String representation of sample data formatted for the LLM
        """
        samples = self.get_sample_data(limit)
        formatted_samples = "SAMPLE DATA:\n\n"
        
        for table_name, data in samples.items():
            formatted_samples += f"Table: {table_name}\n"
            
            if data.empty:
                formatted_samples += "  (No data available)\n\n"
                continue
            
            # Convert DataFrame to string representation with proper formatting
            table_str = data.to_string(index=False)
            # Add indentation to each line
            table_str = "\n".join("  " + line for line in table_str.split("\n"))
            formatted_samples += f"{table_str}\n\n"
        
        return formatted_samples

In [87]:
#####################################################
# Query Processing
#####################################################

class QueryProcessor:
    """Processes natural language queries and converts them to SQL."""
    
    def __init__(self, llm_client, schema_manager: DatabaseSchemaManager):
        """
        Initialize with an LLM client and schema manager.
        
        Args:
            llm_client: Client for LLM API interaction
            schema_manager: Database schema manager instance
        """
        self.llm_client = llm_client
        self.schema_manager = schema_manager
    
    def process_query(self, user_query: str) -> Dict:
        """
        Process a natural language query to generate SQL.
        
        Args:
            user_query: Natural language query from the user
            
        Returns:
            Dictionary containing SQL query and explanation
        """
        # Get formatted schema for the LLM
        schema_info = self.schema_manager.format_schema_for_llm()
        sample_data = self.schema_manager.format_sample_data_for_llm(limit=2)
        
        # Create prompt for the LLM
        prompt = self._create_sql_generation_prompt(user_query, schema_info, sample_data)
        
        # Get response from LLM
        response = self.llm_client.generate_sql(prompt)
        
        # Extract SQL and explanation from LLM response
        result = self._parse_llm_response(response)
        logger.info(f"Generated SQL: {result['sql']}")
        
        return result
    
    def _create_sql_generation_prompt(self, user_query: str, schema_info: str, sample_data: str) -> str:
        """
        Create a prompt for SQL generation.
        
        Args:
            user_query: User's natural language query
            schema_info: Formatted database schema information
            sample_data: Formatted sample data
            
        Returns:
            Formatted prompt for the LLM
        """
        prompt = f"""
        You are an AI assistant that converts natural language queries into SQLite SQL queries.
        Given the following database schema and sample data, generate a SQL query that answers the user's question.
        
        {schema_info}
        
        {sample_data}
        
        USER QUESTION: {user_query}
        
        Please respond in the following format:
        ```sql
        -- Your SQL query here
        ```
        
        EXPLANATION:
        Explain your approach and how your SQL query answers the user's question.
        
        Make sure your SQL query:
        1. Is valid SQLite syntax
        2. References only tables and columns that exist in the schema
        3. Uses proper joins when needed
        4. Uses appropriate filtering and aggregation
        5. Is efficient and follows best practices
        
        SQL QUERY: """
        return prompt
    
    def _parse_llm_response(self, response: str) -> Dict:
        """
        Parse the LLM response to extract SQL and explanation.
        
        Args:
            response: Raw response from the LLM
            
        Returns:
            Dictionary with SQL query and explanation
        """
        # Extract SQL from code blocks or SQL prefixed lines
        sql_pattern = r"```sql\s*(.*?)\s*```"
        sql_matches = re.findall(sql_pattern, response, re.DOTALL)
        
        if sql_matches:
            sql = sql_matches[0].strip()
        else:
            # Fallback: look for lines that might be SQL
            sql_lines = []
            in_sql = False
            
            for line in response.split("\n"):
                if line.strip().upper().startswith("SELECT") or in_sql:
                    in_sql = True
                    if line.strip().endswith(";"):
                        in_sql = False
                    sql_lines.append(line)
            
            sql = "\n".join(sql_lines).strip()
        
        # Extract explanation (text after "EXPLANATION:" if it exists)
        explanation_pattern = r"EXPLANATION:(.*?)(?:$|SQL QUERY:)"
        explanation_matches = re.findall(explanation_pattern, response, re.DOTALL)
        
        if explanation_matches:
            explanation = explanation_matches[0].strip()
        else:
            # If no explicit explanation section, use everything except the SQL
            if sql:
                explanation = response.replace(sql, "").strip()
            else:
                explanation = "No explanation provided."
        
        return {
            "sql": sql,
            "explanation": explanation
        }
    
    def validate_sql(self, sql: str) -> Tuple[bool, str]:
        """
        Validate SQL query syntax without executing it.
        
        Args:
            sql: SQL query string
            
        Returns:
            Tuple of (is_valid, error_message)
        """
        if not self.schema_manager.connection:
            self.schema_manager.connect()
        
        try:
            # Create a cursor and parse the SQL without executing
            cursor = self.schema_manager.connection.cursor()
            cursor.execute(f"EXPLAIN QUERY PLAN {sql}")
            cursor.fetchall()  # Fetch results but don't use them
            return True, ""
        except sqlite3.Error as e:
            return False, str(e)

In [88]:
#####################################################
# Query Execution
#####################################################

class QueryExecutor:
    """Executes SQL queries against the SQLite database and formats results."""
    
    def __init__(self, schema_manager: DatabaseSchemaManager):
        """
        Initialize with a schema manager.
        
        Args:
            schema_manager: Database schema manager instance
        """
        self.schema_manager = schema_manager
    
    def execute_query(self, sql: str) -> Dict:
        """
        Execute a SQL query and return results.
        
        Args:
            sql: SQL query to execute
            
        Returns:
            Dictionary with execution status, results, and metadata
        """
        if not self.schema_manager.connection:
            self.schema_manager.connect()
        
        try:
            # Execute query and get results as DataFrame
            result_df = pd.read_sql_query(sql, self.schema_manager.connection)
            
            # Get query execution metadata
            cursor = self.schema_manager.connection.cursor()
            cursor.execute("EXPLAIN QUERY PLAN " + sql)
            query_plan = cursor.fetchall()
            
            return {
                "status": "success",
                "row_count": len(result_df),
                "results": result_df,
                "query_plan": query_plan
            }
        except sqlite3.Error as e:
            logger.error(f"Error executing query: {e}")
            return {
                "status": "error",
                "error_message": str(e),
                "results": None
            }
    
    def format_results(self, execution_result: Dict) -> Dict:
        """
        Format query execution results for presentation.
        
        Args:
            execution_result: Result from execute_query
            
        Returns:
            Dictionary with formatted results
        """
        if execution_result["status"] == "error":
            return {
                "status": "error",
                "message": f"Query execution failed: {execution_result['error_message']}",
                "data": None,
                "row_count": 0,
                "columns": [],
                "summary": {}
            }
    
        result_df = execution_result["results"]
    
        # Convert DataFrame to dict records
        records = result_df.to_dict(orient="records")
    
        # Generate summary statistics for numeric columns
        summary = {}
        for column in result_df.columns:
            if np.issubdtype(result_df[column].dtype, np.number):
                summary[column] = {
                    "min": float(result_df[column].min()),
                    "max": float(result_df[column].max()),
                    "mean": float(result_df[column].mean()),
                    "median": float(result_df[column].median())
                }
    
        return {
            "status": "success",
            "message": f"Query returned {len(records)} rows",
            "data": records,
            "row_count": len(records),
            "columns": list(result_df.columns),
            "summary": summary
        }

In [89]:
#####################################################
# LLM Client
#####################################################

class LLMClient:
    """Client for interacting with LLM APIs."""
    
    def __init__(self, api_key: str = None, api_base: str = None):
        """
        Initialize LLM client.
        
        Args:
            api_key: API key for the LLM service
            api_base: Base URL for the LLM API
        """
        self.api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
        self.api_base = api_base or "https://api.anthropic.com/v1/messages"
        
        # Set up session with retries
        self.session = requests.Session()
        retries = Retry(total=3, backoff_factor=0.5, status_forcelist=[502, 503, 504])
        self.session.mount('https://', HTTPAdapter(max_retries=retries))
    
    def generate_sql(self, prompt: str) -> str:
        """
        Generate SQL using the LLM.
        
        Args:
            prompt: Prompt for SQL generation
            
        Returns:
            Generated SQL and explanation
        """
        try:
            # Check if API key is available
            if not self.api_key:
                logger.warning("No API key provided. Using mock response for testing.")
                return self._mock_response(prompt)
            
            headers = {
                "Content-Type": "application/json",
                "X-API-Key": self.api_key,
                "anthropic-version": "2023-06-01"
            }
            
            data = {
                "model": "claude-3-opus-20240229",
                "max_tokens": 2000,
                "messages": [{"role": "user", "content": prompt}]
            }
            
            response = self.session.post(
                self.api_base,
                headers=headers,
                json=data,
                timeout=30
            )
            
            response.raise_for_status()
            result = response.json()
            
            # Extract the text response
            return result['content'][0]['text']
            
        except requests.exceptions.RequestException as e:
            logger.error(f"Error calling LLM API: {e}")
            return self._mock_response(prompt)
    
    def _mock_response(self, prompt: str) -> str:
        """
        Generate a mock response for testing when API is not available.
        
        Args:
            prompt: Input prompt
            
        Returns:
            Mock response
        """
        # Extract the question from the prompt
        question_match = re.search(r"USER QUESTION: (.*?)(?:\n|$)", prompt)
        if not question_match:
            return "Could not parse the question."
        
        question = question_match.group(1).lower()
        
        # Extract table names from the schema section
        table_matches = re.findall(r"Table: (\w+)", prompt)
        
        # Generate a simple mock SQL response based on the question
        if "count" in question:
            table = table_matches[0] if table_matches else "users"
            mock_sql = f"SELECT COUNT(*) FROM {table};"
        elif "average" in question or "avg" in question:
            table = table_matches[0] if table_matches else "data"
            mock_sql = f"SELECT AVG(value) FROM {table};"
        else:
            table = table_matches[0] if table_matches else "items"
            mock_sql = f"SELECT * FROM {table} LIMIT 10;"
        
        return f"""```sql
{mock_sql}
```

EXPLANATION:
This is a mock SQL query generated for testing purposes when no API key is provided.
The query is based on simple pattern matching from your question.
In production, this would be replaced with actual LLM-generated SQL."""

In [90]:
#####################################################
# Result Explainer
#####################################################

class ResultExplainer:
    """Explains query results in natural language."""
    
    def __init__(self, llm_client):
        """
        Initialize with an LLM client.
        
        Args:
            llm_client: Client for LLM API interaction
        """
        self.llm_client = llm_client
    
    def explain_results(self, user_query: str, sql: str, formatted_results: Dict) -> str:
        """
        Generate natural language explanation of query results.
        
        Args:
            user_query: Original user query
            sql: SQL query that was executed
            formatted_results: Formatted query results
            
        Returns:
            Natural language explanation of results
        """
        if formatted_results["status"] == "error":
            return f"The query failed with the following error: {formatted_results['message']}"
        
        # Create prompt for result explanation
        prompt = self._create_explanation_prompt(user_query, sql, formatted_results)
        
        # Get explanation from LLM
        try:
            explanation = self.llm_client.generate_sql(prompt)
            return explanation
        except Exception as e:
            logger.error(f"Error generating explanation: {e}")
            # Fallback to basic explanation
            return self._generate_basic_explanation(formatted_results)
    
    def _create_explanation_prompt(self, user_query: str, sql: str, formatted_results: Dict) -> str:
        """
        Create a prompt for result explanation.
        
        Args:
            user_query: Original user query
            sql: SQL query that was executed
            formatted_results: Formatted query results
            
        Returns:
            Formatted prompt for the LLM
        """
        max_rows = 5
        data_sample = formatted_results.get("data", [])[:max_rows]
    
        prompt = f"""You are an AI assistant that explains database query results in natural language.
    Please explain the following query results based on the user's original question.
    
    USER QUESTION: {user_query}
    
    SQL QUERY:
    ```sql
    {sql}
    ```

    QUERY RESULTS:
    The query returned {formatted_results.get("row_count", 0)} rows.
    
    Sample of the results (first {min(max_rows, len(data_sample))} rows):
    {json.dumps(data_sample, indent=2)}
    
    Column statistics:
    {json.dumps(formatted_results.get("summary", {}), indent=2)}
    
    Please provide a clear, concise explanation of these results in relation to the user's question.
    Focus on key insights, patterns, and directly answering the user's question.
    If there are interesting statistics or trends in the data, highlight them.
    """
        return prompt
    
    def _generate_basic_explanation(self, formatted_results: Dict) -> str:
        """
        Generate a basic explanation when LLM is not available.
        
        Args:
            formatted_results: Formatted query results
            
        Returns:
            Basic explanation of results
        """
        explanation = f"The query returned {formatted_results['row_count']} rows with the following columns: "
        explanation += ", ".join(formatted_results["columns"])
        
        if formatted_results["summary"]:
            explanation += "\n\nHere are some statistics from the numeric columns:\n"
            for col, stats in formatted_results["summary"].items():
                explanation += f"\n{col}:"
                explanation += f"\n  - Minimum: {stats['min']}"
                explanation += f"\n  - Maximum: {stats['max']}"
                explanation += f"\n  - Mean: {stats['mean']:.2f}"
                explanation += f"\n  - Median: {stats['median']:.2f}"
        
        return explanation

In [91]:
#####################################################
# RAG Agent
#####################################################

class RAGAgent:
    """Main RAG agent that coordinates all components."""
    
    def __init__(self, db_path: str, api_key: str = None):
        """
        Initialize the RAG agent.
        
        Args:
            db_path: Path to the SQLite database
            api_key: API key for the LLM service
        """
        # Initialize components
        self.schema_manager = DatabaseSchemaManager(db_path)
        self.llm_client = LLMClient(api_key)
        self.query_processor = QueryProcessor(self.llm_client, self.schema_manager)
        self.query_executor = QueryExecutor(self.schema_manager)
        self.result_explainer = ResultExplainer(self.llm_client)
        
        # Connect to database
        self.schema_manager.connect()
        
        # Cache the schema
        self.schema_manager.get_schema()
        
        logger.info("RAG Agent initialized successfully")
    
    def process_query(self, user_query: str) -> Dict:
        """
        Process a user query from natural language to results.
        
        Args:
            user_query: Natural language query from the user
            
        Returns:
            Dictionary with all processing results and explanations
        """
        logger.info(f"Processing user query: {user_query}")
        
        # Step 1: Generate SQL from natural language
        query_result = self.query_processor.process_query(user_query)
        sql = query_result["sql"]
        sql_explanation = query_result["explanation"]
        
        # Step 2: Validate SQL
        is_valid, error = self.query_processor.validate_sql(sql)
        
        if not is_valid:
            logger.warning(f"Invalid SQL: {error}")
            return {
                "status": "error",
                "message": f"Generated SQL is invalid: {error}",
                "user_query": user_query,
                "sql": sql,
                "sql_explanation": sql_explanation,
                "results": None,
                "result_explanation": None
            }
        
        # Step 3: Execute SQL
        execution_result = self.query_executor.execute_query(sql)
        
        # Step 4: Format results
        formatted_results = self.query_executor.format_results(execution_result)
        
        # Step 5: Explain results
        if formatted_results["status"] == "success":
            result_explanation = self.result_explainer.explain_results(
                user_query, sql, formatted_results
            )
        else:
            result_explanation = formatted_results["message"]
        
        return {
            "status": formatted_results["status"],
            "message": formatted_results["message"],
            "user_query": user_query,
            "sql": sql,
            "sql_explanation": sql_explanation,
            "results": formatted_results["data"] if formatted_results["status"] == "success" else None,
            "columns": formatted_results.get("columns", []),
            "summary": formatted_results.get("summary", {}),
            "result_explanation": result_explanation
        }
    
    def close(self):
        """Clean up resources."""
        self.schema_manager.close()
        logger.info("RAG Agent resources cleaned up")

In [92]:
#####################################################
# Example Usage
#####################################################

def example_usage():
    """Example usage of the RAG agent."""
    
    # Path to your SQLite database (adjust as needed)
    db_path = "/kaggle/input/cargill/cargill.db"
    
    # Initialize the RAG agent
    agent = RAGAgent(db_path,api_key="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")
    
    # Example query
    user_query = "Which countries have the most customers?"
    
    # Process the query
    result = agent.process_query(user_query)
    
    # Print the results
    print(f"User Query: {result['user_query']}\n")
    print(f"Generated SQL:\n{result['sql']}\n")
    print(f"SQL Explanation:\n{result['sql_explanation']}\n")
    
    if result['status'] == 'success':
        print(f"Results:\n{pd.DataFrame(result['results'])}\n")
        print(f"Result Explanation:\n{result['result_explanation']}\n")
    else:
        print(f"Error: {result['message']}\n")
    
    # Clean up
    agent.close()

In [93]:
if __name__ == "__main__":
    example_usage()

  has_large_values = (abs_vals > 1e6).any()
  has_small_values = ((abs_vals < 10 ** (-self.digits)) & (abs_vals > 0)).any()
  has_small_values = ((abs_vals < 10 ** (-self.digits)) & (abs_vals > 0)).any()


User Query: Which countries have the most customers?

Generated SQL:
SELECT c.country, COUNT(*) AS num_customers
FROM customer c
GROUP BY c.country
ORDER BY num_customers DESC
LIMIT 3;

SQL Explanation:
To find the countries with the most customers, we need to:

1. Select from the `customer` table since that contains the country for each customer
2. Group the results by the `country` column to get a count for each distinct country value 
3. SELECT the `country` and use COUNT(*) to count the number of customers for each country
4. ORDER BY the counted `num_customers` in DESCending order to put the countries with the most customers first
5. Use LIMIT 3 to return just the top 3 countries with the most customers

This query efficiently aggregates the data by country, counts the customers per country, orders it to put the highest counts first, and limits the result to the top 3 countries.

The key aspects are:
- Selecting from the correct `customer` table that has the country data
- Groupin