In [35]:
from typing import Dict, List, Tuple, Optional, Union, Any
import sqlite3
import pandas as pd
import numpy as np
import os
import json
import re
import logging
import threading
from pathlib import Path
from dataclasses import dataclass
from contextlib import contextmanager

In [36]:
from llama_cpp import Llama
os.environ["LLAMA_CPP_LOG_LEVEL"] = "ERROR"

In [37]:
@dataclass
class TableInfo:
    name: str
    columns: List[Dict[str, Any]]
    relationships: List[Dict[str, str]]

In [38]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [39]:
class DatabaseSchemaManager:
    """
    Manages database schema operations and caching.
    
    Attributes:
        db_path (str): Path to the SQLite database file
        connection (Optional[sqlite3.Connection]): Database connection object
        schema_cache (Optional[Dict]): Cached schema information
    """
    
    def __init__(self, db_path: str):
        self.db_path = db_path
        self.connection: Optional[sqlite3.Connection] = None
        self.schema_cache: Optional[Dict] = None
        self._lock = threading.Lock()  # Thread safety for schema operations

    @contextmanager
    def get_connection(self):
        """Context manager for database connections."""
        try:
            if not self.connection:
                self.connect()
            yield self.connection
        except sqlite3.Error as e:
            logger.error(f"Database connection error: {e}")
            raise
        finally:
            if self.connection:
                self.connection.close()
                self.connection = None

    def connect(self) -> None:
        """Establishes a connection to the database."""
        try:
            self.connection = sqlite3.connect(self.db_path)
            self.connection.row_factory = sqlite3.Row  # Enable row factory for better access
            logger.info(f"Connected to database at {self.db_path}")
        except sqlite3.Error as e:
            logger.error(f"Error connecting to database: {e}")
            raise

    def close(self) -> None:
        """Closes the database connection."""
        if self.connection:
            self.connection.close()
            self.connection = None
            logger.info("Database connection closed")

    def get_schema(self, refresh: bool = False) -> Dict:
        """
        Retrieves the database schema with caching.
        
        Args:
            refresh (bool): Force refresh of schema cache
            
        Returns:
            Dict: Database schema information
        """
        with self._lock:
            if self.schema_cache is not None and not refresh:
                return self.schema_cache

            if not self.connection:
                self.connect()

            schema = {"tables": [], "relationships": []}
            
            try:
                with self.get_connection() as conn:
                    cursor = conn.cursor()
                    cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';")
                    tables = cursor.fetchall()

                    for table in tables:
                        table_name = table[0]
                        table_info = {"name": table_name, "columns": []}
                        
                        # Get column information
                        cursor.execute(f"PRAGMA table_info('{table_name}');")
                        columns = cursor.fetchall()
                        for col in columns:
                            table_info["columns"].append({
                                "name": col["name"],
                                "type": col["type"],
                                "is_primary_key": bool(col["pk"]),
                                "not_null": bool(col["notnull"]),
                                "default": col["dflt_value"]
                            })
                        schema["tables"].append(table_info)

                        # Get foreign key information
                        cursor.execute(f"PRAGMA foreign_key_list('{table_name}');")
                        foreign_keys = cursor.fetchall()
                        for fk in foreign_keys:
                            schema["relationships"].append({
                                "table": table_name,
                                "column": fk["from"],
                                "references_table": fk["table"],
                                "references_column": fk["to"]
                            })

                self.schema_cache = schema
                return schema
                
            except sqlite3.Error as e:
                logger.error(f"Error retrieving schema: {e}")
                raise

    def format_schema_for_llm(self) -> str:
        """
        Formats the database schema for LLM consumption.
        
        Returns:
            str: Formatted schema string
        """
        schema = self.get_schema()
        formatted = "### DATABASE SCHEMA\n\n"
        
        for table in schema["tables"]:
            column_defs = []
            for col in table["columns"]:
                flags = []
                if col["is_primary_key"]:
                    flags.append("PK")
                if col["not_null"]:
                    flags.append("NN")
                flag_str = " ".join(flags)
                column_defs.append(f"{col['name']} {col['type']} {flag_str}".strip())
            formatted += f"{table['name']}({', '.join(column_defs)});\n"
            
        return formatted

    def filter_relevant_tables(self, user_query: str) -> str:
        """
        Filters and returns schema information for tables relevant to the query.
        
        Args:
            user_query (str): User's natural language query
            
        Returns:
            str: Filtered schema information
        """
        schema = self.get_schema()
        query_lower = user_query.lower()

        # Step 1: Match tables directly mentioned in the query
        relevant_tables = set()
        for table in schema["tables"]:
            if table["name"].lower() in query_lower:
                relevant_tables.add(table["name"])

        # Step 2: Fallback to frequently used tables if no matches found
        if not relevant_tables:
            fallback_tables = ["customer", "countries", "prospect", "visit"]
            relevant_tables.update([t for t in fallback_tables if t in [tbl["name"] for tbl in schema["tables"]]])

        # Step 3: Build Compact Schema for relevant tables only
        formatted = "### DATABASE SCHEMA\n\n"
        for table in schema["tables"]:
            if table["name"] in relevant_tables:
                column_defs = []
                for col in table["columns"]:
                    flags = []
                    if col["is_primary_key"]:
                        flags.append("PK")
                    if col["not_null"]:
                        flags.append("NN")
                    flag_str = " ".join(flags)
                    column_defs.append(f"{col['name']} {col['type']} {flag_str}".strip())
                formatted += f"{table['name']}({', '.join(column_defs)});\n"
                
        return formatted

    def get_sample_data(self, limit: int = 3) -> Dict[str, pd.DataFrame]:
        """
        Retrieves sample data from all tables.
        
        Args:
            limit (int): Number of rows to retrieve per table
            
        Returns:
            Dict[str, pd.DataFrame]: Dictionary of table names to sample data
        """
        schema = self.get_schema()
        samples = {}
        
        try:
            with self.get_connection() as conn:
                for table in schema["tables"]:
                    try:
                        query = f"SELECT * FROM {table['name']} LIMIT {limit};"
                        samples[table['name']] = pd.read_sql_query(query, conn)
                    except sqlite3.Error as e:
                        logger.warning(f"Error retrieving sample data for table {table['name']}: {e}")
                        samples[table['name']] = pd.DataFrame()
                        
            return samples
            
        except sqlite3.Error as e:
            logger.error(f"Error retrieving sample data: {e}")
            raise

    def format_sample_data_for_llm(self, limit: int = 3) -> str:
        """
        Formats sample data for LLM consumption.
        
        Args:
            limit (int): Number of rows to retrieve per table
            
        Returns:
            str: Formatted sample data
        """
        samples = self.get_sample_data(limit)
        formatted = "SAMPLE DATA:\n\n"
        
        for table, df in samples.items():
            formatted += f"Table: {table}\n"
            if df.empty:
                formatted += "  (No data available)\n\n"
            else:
                table_str = df.to_string(index=False)
                formatted += "\n".join("  " + line for line in table_str.split("\n")) + "\n\n"
                
        return formatted


    def get_sample_data(self, limit: int = 3):
        if not self.connection:
            self.connect()
        schema = self.get_schema()
        samples = {}
        for table in schema["tables"]:
            try:
                query = f"SELECT * FROM {table['name']} LIMIT {limit};"
                samples[table['name']] = pd.read_sql_query(query, self.connection)
            except sqlite3.Error:
                samples[table['name']] = pd.DataFrame()
        return samples

    def format_sample_data_for_llm(self, limit: int = 3) -> str:
        samples = self.get_sample_data(limit)
        formatted = "SAMPLE DATA:\n\n"
        for table, df in samples.items():
            formatted += f"Table: {table}\n"
            if df.empty:
                formatted += "  (No data available)\n\n"
            else:
                table_str = df.to_string(index=False)
                formatted += "\n".join("  " + line for line in table_str.split("\n")) + "\n\n"
        return formatted

In [40]:
class LocalLLMClient:
    def __init__(self, model_path: str):
        self.llm = None
        self.is_loaded = False
        self.load_thread = threading.Thread(target=self._load_model, args=(model_path,), daemon=True)
        self.load_thread.start()

    def _load_model(self, model_path):
        try:
            self.llm = Llama(model_path=model_path, n_ctx=16384, n_gpu_layers=0)
            self.is_loaded = True
        except Exception as e:
            logger.error(f"Model load error: {e}")

    def wait_for_model(self, timeout=30):
        self.load_thread.join(timeout)
        return self.is_loaded

    def generate_sql(self, prompt: str) -> str:
        if not self.is_loaded and not self.wait_for_model():
            return "Model not ready."

        # Wrap prompt with instruction template
        wrapped_prompt = f"""You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.
    ### Instruction:
    {prompt}
    ### Response:"""

        # Tokenize and truncate if needed
        encoded = self.llm.tokenize(wrapped_prompt.encode("utf-8"))
        print(f"Encoded prompt: {len(encoded)}")
        max_context_tokens = 16384
        max_response_tokens = 1024
        max_prompt_tokens = max_context_tokens - max_response_tokens

        if len(encoded) > max_prompt_tokens:
            logger.warning(f"Prompt too long ({len(encoded)} tokens), truncating to {max_prompt_tokens} tokens.")
            encoded = encoded[:max_prompt_tokens]
            wrapped_prompt = self.llm.detokenize(encoded).decode("utf-8", errors="ignore")

        try:
            output = self.llm(wrapped_prompt, max_tokens=max_response_tokens, stop=["### Instruction:", "### Response:"], echo=False)
            return output["choices"][0]["text"].strip()
        except Exception as e:
            logger.error(f"LLM error: {e}")
            return "Error generating SQL."


In [41]:
class QueryProcessor:
    def __init__(self, llm_client, schema_manager):
        self.llm_client = llm_client
        self.schema_manager = schema_manager

    def process_query(self, user_query: str):
        # Get schema chunks
        schema_info = self.schema_manager.filter_relevant_tables(user_query)
        # sample_data = self.schema_manager.format_sample_data_for_llm(limit=1)
        
        prompt = self._create_sql_generation_prompt(user_query, schema_info)
        # logger.info(f"Prompt size estimate: {len(prompt.split())} tokens")
        # n_tokens = len(self.llm_client.llm.tokenize(prompt.encode('utf-8')))
        # print(f"Prompt size : {n_tokens} tokens")
        print(schema_info)
        
        response = self.llm_client.generate_sql(prompt)
        return {"sql": self._parse_sql(response)}

    def _create_sql_generation_prompt(self, user_query: str, schema_info: str) -> str:
        prompt = f"""
    ### USER QUESTION START
    {user_query}
    ### USER QUESTION END

    {schema_info}

    ### INSTRUCTIONS
    Provide only the SQL query inside triple backticks (```). Don't include anything else in your response.
    Strictly use the table names provided in the schema.
    If the required tables are not in this schema, respond with "TABLES_NOT_IN_CHUNK".
        """
        return prompt.strip()

    def _parse_sql(self, response: str) -> str:
        match = re.search(r"```sql\s*(.*?)\s*```", response, re.DOTALL)
        return match.group(1).strip() if match else response.strip()

    def validate_sql(self, sql: str):
        if not self.schema_manager.connection:
            self.schema_manager.connect()
        try:
            self.schema_manager.connection.execute(f"EXPLAIN QUERY PLAN {sql}")
            return True, ""
        except sqlite3.Error as e:
            return False, str(e)

In [42]:
class QueryExecutor:
    def __init__(self, schema_manager):
        self.schema_manager = schema_manager

    def execute_query(self, sql: str):
        if not self.schema_manager.connection:
            self.schema_manager.connect()
        try:
            df = pd.read_sql_query(sql, self.schema_manager.connection)
            return {"status": "success", "results": df, "row_count": len(df)}
        except sqlite3.Error as e:
            return {"status": "error", "error_message": str(e), "results": None}

    def format_results(self, execution_result: Dict) -> Dict:
        if execution_result["status"] == "error":
            return {
                "status": "error",
                "message": f"Query execution failed: {execution_result['error_message']}",
                "data": None,
                "row_count": 0,
                "columns": [],
                "summary": {}
            }

        df = execution_result["results"]
        records = df.to_dict(orient="records")
        summary = {
            col: {
                "min": float(df[col].min()),
                "max": float(df[col].max()),
                "mean": float(df[col].mean()),
                "median": float(df[col].median())
            } for col in df.select_dtypes(include=np.number).columns
        }

        return {
            "status": "success",
            "message": f"Query returned {len(records)} rows",
            "data": records,
            "row_count": len(records),
            "columns": list(df.columns),
            "summary": summary
        }

In [43]:
class OfflineRAGAgent:
    def __init__(self, db_path: str, model_path: str):
        self.schema_manager = DatabaseSchemaManager(db_path)
        self.llm_client = LocalLLMClient(model_path)
        self.query_processor = QueryProcessor(self.llm_client, self.schema_manager)
        self.query_executor = QueryExecutor(self.schema_manager)
        self.schema_manager.connect()
        self.schema_manager.get_schema()

    def process_query(self, user_query: str) -> Dict:
        if not self.llm_client.is_loaded and not self.llm_client.wait_for_model(timeout=5):
            return {"status": "pending", "message": "Model loading...", "user_query": user_query}

        query_result = self.query_processor.process_query(user_query)
        sql = query_result["sql"]
        print(f"SQL: {sql}")  
        is_valid, error = self.query_processor.validate_sql(sql)

        if not is_valid:
            return {
                "status": "error",
                "message": f"Invalid SQL: {error}",
                "user_query": user_query,
                "sql": sql,
                "results": None
            }

        execution_result = self.query_executor.execute_query(sql)
        formatted_results = self.query_executor.format_results(execution_result)

        return {
            "status": formatted_results["status"],
            "message": formatted_results["message"],
            "user_query": user_query,
            "sql": sql,
            "results": formatted_results["data"],
            "columns": formatted_results.get("columns", []),
            "summary": formatted_results.get("summary", {})
        }

    def close(self):
        self.schema_manager.close()

In [47]:
if __name__ == "__main__":
    db_path = "cargill.db"
    model_path = "deepseek-coder-1.3b-instruct.Q4_K_M.gguf"
    agent = OfflineRAGAgent(db_path, model_path)
    result = agent.process_query("Find the total number of customers as total_customers in the United States")
    if result["status"] == "success":
        print("✅ Generated SQL:\n", result["sql"])
        print("📊 SQL Results:\n", pd.DataFrame(result["results"]))
    elif result["status"] == "pending":
        print("⏳", result["message"])
    else:
        print("❌ Error:", result["message"])
    agent.close()

2025-05-14 14:29:36,054 - INFO - Connected to database at cargill.db
llama_model_loader: loaded meta data with 22 key-value pairs and 219 tensors from deepseek-coder-1.3b-instruct.Q4_K_M.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = deepseek-ai_deepseek-coder-1.3b-instruct
llama_model_loader: - kv   2:                       llama.context_length u32              = 16384
llama_model_loader: - kv   3:                     llama.embedding_length u32              = 2048
llama_model_loader: - kv   4:                          llama.block_count u32              = 24
llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 5504
llama_model_loader: - kv   6:                 llama.rope.dimensio

⏳ Model loading...


load_tensors:  CPU_AARCH64 model buffer size =   479.25 MiB
load_tensors:   CPU_Mapped model buffer size =   831.88 MiB
repack: repack tensor blk.0.attn_q.weight with q4_K_8x8
repack: repack tensor blk.0.attn_k.weight with q4_K_8x8
repack: repack tensor blk.0.attn_output.weight with q4_K_8x8
repack: repack tensor blk.0.ffn_gate.weight with q4_K_8x8
.repack: repack tensor blk.0.ffn_up.weight with q4_K_8x8
.repack: repack tensor blk.1.attn_q.weight with q4_K_8x8
repack: repack tensor blk.1.attn_k.weight with q4_K_8x8
repack: repack tensor blk.1.attn_output.weight with q4_K_8x8
.repack: repack tensor blk.1.ffn_gate.weight with q4_K_8x8
repack: repack tensor blk.1.ffn_up.weight with q4_K_8x8
.repack: repack tensor blk.2.attn_q.weight with q4_K_8x8
repack: repack tensor blk.2.attn_k.weight with q4_K_8x8
.repack: repack tensor blk.2.attn_output.weight with q4_K_8x8
repack: repack tensor blk.2.ffn_gate.weight with q4_K_8x8
.repack: repack tensor blk.2.ffn_up.weight with q4_K_8x8
repack: repac