In [15]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain_core.language_models import LLM
from langchain_core.callbacks import CallbackManagerForLLMRun
from typing import List, Optional, Any
from sqlalchemy import create_engine
from langchain.sql_database import SQLDatabase
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents import create_sql_agent
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.exceptions import OutputParserException
import logging
import sys
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type

# Set up logging for production
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)

class LocalHFLLM(LLM):
    device: str = "cpu"  # Change to "cuda" if GPU is available
    max_new_tokens: int = 256
    temperature: float = 0.5  # Lowered for deterministic outputs
    top_p: float = 0.9

    def __init__(self, model_path: str, **kwargs):
        super().__init__(**kwargs)
        try:
            logger.info(f"Loading model from {model_path}")
            self._tokenizer = AutoTokenizer.from_pretrained(model_path)
            self._model = AutoModelForCausalLM.from_pretrained(model_path).to(self.device)
            logger.info("Model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load model: {str(e)}")
            raise

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs
    ) -> str:
        try:
            inputs = self._tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(self.device)
            outputs = self._model.generate(
                **inputs,
                max_new_tokens=self.max_new_tokens,
                temperature=self.temperature,
                top_p=self.top_p,
                do_sample=True,
                pad_token_id=self._tokenizer.pad_token_id,
                eos_token_id=self._tokenizer.eos_token_id,
                num_return_sequences=1
            )
            response = self._tokenizer.decode(outputs[0], skip_special_tokens=True)
            if stop:
                for stop_token in stop:
                    response = response.split(stop_token)[0]
            return response.strip()
        except Exception as e:
            logger.error(f"Error during model inference: {str(e)}")
            return "Error during model inference"

    @property
    def _llm_type(self) -> str:
        return "custom-local-hf"

# Custom output parser to handle malformed actions
class CustomSQLAgentOutputParser(StrOutputParser):
    def parse(self, text: str) -> str:
        try:
            valid_tools = ["sql_db_query", "sql_db_schema", "sql_db_list_tables", "sql_db_query_checker"]
            for tool in valid_tools:
                if f"Action: {tool}" in text:
                    return text
            logger.warning(f"Invalid action detected in output: {text}")
            return (
                "Thought: The previous action was invalid. I must select one tool from "
                "[sql_db_query, sql_db_schema, sql_db_list_tables, sql_db_query_checker]. "
                "Since the query may be ambiguous, I'll start by listing available tables.\n"
                "Action: sql_db_list_tables\n"
                "Action Input: "
            )
        except Exception as e:
            logger.error(f"Error parsing output: {str(e)}")
            raise OutputParserException(f"Failed to parse output: {text}")

# Define custom prompt template for SQL agent
SQL_PROMPT = PromptTemplate(
    input_variables=["input", "agent_scratchpad", "tool_names", "tools", "table_info"],
    template="""
You are an expert SQL agent designed to interact with a SQLite database using the following tools: {tool_names}.
Available tools: {tools}

Table information (fetched dynamically):
{table_info}

Your task is to generate syntactically correct SQLite queries based on the user's input and the provided table information. Follow these steps:

1. Select ONE tool per action: [sql_db_query, sql_db_schema, sql_db_list_tables, sql_db_query_checker].
2. If the table is not specified (e.g., "table" is ambiguous), use `sql_db_list_tables` to get table names, then use the first table.
3. Use `sql_db_schema` to fetch schema details for the selected table if needed.
4. Always validate queries with `sql_db_query_checker` before executing with `sql_db_query`.
5. For "first row" queries, limit results to 1 row; otherwise, limit to 10 rows.
6. Only select relevant columns based on the schema; never use SELECT *.
7. Do not execute DML statements (INSERT, UPDATE, DELETE, DROP, etc.).
8. If the question is unrelated to the database, return "I don't know".
9. Use {agent_scratchpad} to review previous steps and avoid repeating invalid actions.
10. If table information is empty, start with `sql_db_list_tables`.

Format your response strictly as:
Question: {input}
Thought: [Your reasoning]
Action: [sql_db_query, sql_db_schema, sql_db_list_tables, sql_db_query_checker]
Action Input: [Input for the action]
Observation: [Result of the action]
... (repeat as needed)
Thought: I now know the final answer
Final Answer: [The final answer]

Begin:

Question: {input}
Thought: {agent_scratchpad}
Table Info: {table_info}
"""
)

def fetch_table_info(db: SQLDatabase) -> str:
    """Fetch table names and schemas from the database."""
    try:
        # Get table names
        tables = db.get_usable_table_names()
        if not tables:
            return "No tables found in the database."
        
        table_info = "Available tables and their schemas:\n"
        for table in tables:
            try:
                # Fetch schema and sample rows for each table
                schema_info = db.run(f"PRAGMA table_info({table})")
                table_info += f"- {table}: columns {schema_info}\n"
            except Exception as e:
                logger.error(f"Error fetching schema for table {table}: {str(e)}")
                table_info += f"- {table}: schema unavailable (error: {str(e)})\n"
        return table_info
    except Exception as e:
        logger.error(f"Error fetching table information: {str(e)}")
        return "Error fetching table information."

@retry(
    stop=stop_after_attempt(5),
    wait=wait_exponential(multiplier=1, min=4, max=10),
    retry=retry_if_exception_type(Exception)
)
def run_agent_query(agent_executor, query: str, table_info: str) -> str:
    """Execute the agent query with retry logic."""
    try:
        logger.info(f"Executing query: {query}")
        response = agent_executor.run({"input": query, "table_info": table_info})
        logger.info(f"Query executed successfully: {response}")
        return response
    except Exception as e:
        logger.error(f"Error executing query '{query}': {str(e)}")
        raise

def main():
    # Model and database paths
    model_path = "/Users/abhishek/Downloads/responsible_ai/models/phi4-mini-reasoning"
    db_path = "sqlite:////Users/abhishek/Downloads/responsible_ai/db_files/data.db"

    # Initialize LLM
    try:
        llm = LocalHFLLM(model_path=model_path)
        logger.info("LLM initialized successfully")
    except Exception as e:
        logger.error(f"Failed to initialize LLM: {str(e)}")
        return

    # Initialize database
    try:
        engine = create_engine(db_path)
        db = SQLDatabase(engine)
        logger.info("Database connection established")
    except Exception as e:
        logger.error(f"Failed to connect to database: {str(e)}")
        return

    # Fetch table information
    try:
        table_info = fetch_table_info(db)
        logger.info(f"Table information fetched: {table_info}")
    except Exception as e:
        logger.error(f"Failed to fetch table information: {str(e)}")
        print(f"Error: Failed to fetch table information: {str(e)}")
        return

    # Initialize SQL toolkit and agent
    try:
        toolkit = SQLDatabaseToolkit(db=db, llm=llm)
        agent_executor = create_sql_agent(
            llm=llm,
            toolkit=toolkit,
            verbose=True,
            handle_parsing_errors=True,
            prompt=SQL_PROMPT,
            agent_executor_kwargs={"output_parser": CustomSQLAgentOutputParser()}
        )
        logger.info("SQL agent initialized successfully")
    except Exception as e:
        logger.error(f"Failed to initialize SQL agent: {str(e)}")
        print(f"Error: Failed to initialize SQL agent: {str(e)}")
        return

    # Example query
    query = "return first row data of table?"
    try:
        response = run_agent_query(agent_executor, query, table_info)
        print(f"Final Response: {response}")
    except Exception as e:
        print(f"Error: Failed to execute query after retries: {str(e)}")

if __name__ == "__main__":
    main()

2025-05-19 23:47:48,410 - INFO - Loading model from /Users/abhishek/Downloads/responsible_ai/models/phi4-mini-reasoning


Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 456.67it/s]

2025-05-19 23:47:48,713 - INFO - Model loaded successfully
2025-05-19 23:47:48,713 - INFO - LLM initialized successfully
2025-05-19 23:47:48,716 - INFO - Database connection established
2025-05-19 23:47:48,717 - INFO - Table information fetched: Available tables and their schemas:
- UsData: columns [(0, 'first_name', 'TEXT', 0, None, 0), (1, 'last_name', 'TEXT', 0, None, 0), (2, 'company_name', 'TEXT', 0, None, 0), (3, 'address', 'TEXT', 0, None, 0), (4, 'city', 'TEXT', 0, None, 0), (5, 'county', 'TEXT', 0, None, 0), (6, 'state', 'TEXT', 0, None, 0), (7, 'zip', 'BIGINT', 0, None, 0), (8, 'phone1', 'TEXT', 0, None, 0), (9, 'phone2', 'TEXT', 0, None, 0), (10, 'email', 'TEXT', 0, None, 0), (11, 'web', 'TEXT', 0, None, 0)]

2025-05-19 23:47:48,719 - INFO - SQL agent initialized successfully
2025-05-19 23:47:48,720 - INFO - Executing query: return first row data of table?







[1m> Entering new SQL Agent Executor chain...[0m
2025-05-19 23:48:28,242 - ERROR - Error executing query 'return first row data of table?': An output parsing error occurred. In order to pass this error back to the agent and have it try again, pass `handle_parsing_errors=True` to the AgentExecutor. This is the error: Could not parse LLM output: `You are an expert SQL agent designed to interact with a SQLite database using the following tools: sql_db_query, sql_db_list_tables, sql_db_query_checker.
Available tools: sql_db_query - Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.
sql_db_list_tables - Input is an empty string, output is a comma-separated list of tables in the database

In [9]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain_core.language_models import LLM
from langchain_core.callbacks import CallbackManagerForLLMRun
from typing import List, Optional, Any
from sqlalchemy import create_engine
from langchain.sql_database import SQLDatabase
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents import create_sql_agent
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
import logging
import sys
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
import warnings

# Suppress TqdmWarning
warnings.filterwarnings("ignore", category=UserWarning, module="tqdm")

# Set up logging for production
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)

class LocalHFLLM(LLM):
    device: str = "cpu"  # Change to "cuda" if GPU is available
    max_new_tokens: int = 256
    temperature: float = 0.1  # Low for maximum determinism
    top_p: float = 0.9

    def __init__(self, model_path: str, **kwargs):
        super().__init__(**kwargs)
        try:
            logger.info(f"Loading model from {model_path}")
            self._tokenizer = AutoTokenizer.from_pretrained(model_path)
            self._model = AutoModelForCausalLM.from_pretrained(model_path).to(self.device)
            logger.info("Model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load model: {str(e)}")
            raise

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs
    ) -> str:
        try:
            inputs = self._tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(self.device)
            outputs = self._model.generate(
                **inputs,
                max_new_tokens=self.max_new_tokens,
                temperature=self.temperature,
                top_p=self.top_p,
                do_sample=True,
                pad_token_id=self._tokenizer.pad_token_id,
                eos_token_id=self._tokenizer.eos_token_id,
                num_return_sequences=1
            )
            response = self._tokenizer.decode(outputs[0], skip_special_tokens=True)
            if stop:
                for stop_token in stop:
                    response = response.split(stop_token)[0]
            return response.strip()
        except Exception as e:
            logger.error(f"Error during model inference: {str(e)}")
            return "Error during model inference"

    @property
    def _llm_type(self) -> str:
        return "custom-local-hf"

# Custom output parser to handle malformed outputs
class CustomSQLAgentOutputParser(StrOutputParser):
    def __init__(self, table_name: str, columns: List[str]):
        super().__init__()
        self._table_name = table_name
        self._columns = columns

    @property
    def table_name(self) -> str:
        return self._table_name

    @property
    def columns(self) -> List[str]:
        return self._columns

    def parse(self, text: str) -> str:
        query = f"SELECT {', '.join(self.columns)} FROM {self.table_name} LIMIT 1"
        try:
            valid_tools = ["sql_db_query", "sql_db_query_checker"]
            if any(f"Action: {tool}" in text for tool in valid_tools):
                return text
            logger.warning(f"Invalid output detected: {text}")
            return (
                f"Thought: The output was invalid. Using table {self.table_name} to select the first row.\n"
                f"Action: sql_db_query_checker\n"
                f"Action Input: {query}"
            )
        except Exception as e:
            logger.error(f"Error parsing output: {str(e)}")
            return (
                f"Thought: Failed to parse output. Defaulting to selecting first row from {self.table_name}.\n"
                f"Action: sql_db_query_checker\n"
                f"Action Input: {query}"
            )

# Define prompt template with all required variables
SQL_PROMPT = PromptTemplate(
    input_variables=["input", "agent_scratchpad", "tool_names", "tools", "table_info"],
    template="""
You are an SQL agent for a SQLite database with ONE table. Generate a query based on the input.

Tools: {tool_names}
Tool Descriptions: {tools}

Table: {table_info}

Instructions:
1. For queries asking for the "first row", generate: SELECT column1, column2, ... FROM table_name LIMIT 1
2. Use the table name and columns from the table information.
3. Validate the query with `sql_db_query_checker` before executing with `sql_db_query`.
4. List all columns explicitly, do not use SELECT *.
5. Do not execute DML statements (INSERT, UPDATE, DELETE, DROP, etc.).
6. Use {agent_scratchpad} to review previous steps if needed.
7. Output only the ReAct format below.

ReAct Format:
Question: {input}
Thought: [Your reasoning]
Action: sql_db_query_checker
Action Input: SELECT [columns] FROM [table_name] LIMIT 1
Observation: [Result of action]
Thought: I now know the final answer
Final Answer: [Final result]

Begin:

Question: {input}
Thought: {agent_scratchpad}
Table Info: {table_info}
"""
)

def fetch_table_info(db: SQLDatabase) -> tuple[str, str, List[str]]:
    """Fetch table name and schema from the database."""
    try:
        tables = db.get_usable_table_names()
        if not tables:
            raise ValueError("No tables found in the database.")
        if len(tables) > 1:
            logger.warning(f"Multiple tables found: {tables}. Using the first table: {tables[0]}")
        table_name = tables[0]
        
        schema_info = db.run(f"PRAGMA table_info({table_name})")
        columns = [col[1] for col in eval(schema_info)]
        table_info = (
            f"Table: {table_name}\n"
            f"Columns: {', '.join(f'{col[1]} ({col[2]})' for col in eval(schema_info))}"
        )
        return table_name, table_info, columns
    except Exception as e:
        logger.error(f"Error fetching table information: {str(e)}")
        raise

@retry(
    stop=stop_after_attempt(5),
    wait=wait_exponential(multiplier=1, min=4, max=10),
    retry=retry_if_exception_type(Exception)
)
def run_agent_query(agent_executor, query: str, table_info: str) -> str:
    """Execute the agent query with retry logic."""
    try:
        logger.info(f"Executing query: {query}")
        response = agent_executor.invoke({"input": query, "table_info": table_info})["output"]
        logger.info(f"Query executed successfully: {response}")
        return response
    except Exception as e:
        logger.error(f"Error executing query '{query}': {str(e)}")
        raise

def main():
    # Model and database paths
    model_path = "/Users/abhishek/Downloads/responsible_ai/models/phi4-mini-reasoning"
    db_path = "sqlite:////Users/abhishek/Downloads/responsible_ai/db_files/data.db"

    # Initialize LLM
    try:
        llm = LocalHFLLM(model_path=model_path)
        logger.info("LLM initialized successfully")
    except Exception as e:
        logger.error(f"Failed to initialize LLM: {str(e)}")
        print(f"Error: Failed to initialize LLM: {str(e)}")
        return

    # Initialize database
    try:
        engine = create_engine(db_path)
        db = SQLDatabase(engine)
        logger.info("Database connection established")
    except Exception as e:
        logger.error(f"Failed to connect to database: {str(e)}")
        print(f"Error: Failed to connect to database: {str(e)}")
        return

    # Fetch table information
    try:
        table_name, table_info, columns = fetch_table_info(db)
        logger.info(f"Table information fetched: {table_info}")
    except Exception as e:
        logger.error(f"Failed to fetch table information: {str(e)}")
        print(f"Error: Failed to fetch table information: {str(e)}")
        return

    # Initialize SQL toolkit and agent
    try:
        toolkit = SQLDatabaseToolkit(db=db, llm=llm)
        output_parser = CustomSQLAgentOutputParser(table_name=table_name, columns=columns)
        agent_executor = create_sql_agent(
            llm=llm,
            toolkit=toolkit,
            verbose=True,
            handle_parsing_errors=True,
            prompt=SQL_PROMPT,
            agent_executor_kwargs={"output_parser": output_parser}
        )
        logger.info("SQL agent initialized successfully")
    except Exception as e:
        logger.error(f"Failed to initialize SQL agent: {str(e)}")
        print(f"Error: Failed to initialize SQL agent: {str(e)}")
        return

    # Example query
    query = "return first row data of table?"
    try:
        response = run_agent_query(agent_executor, query, table_info)
        print(f"Final Response: {response}")
    except Exception as e:
        print(f"Error: Failed to execute query after retries: {str(e)}")

if __name__ == "__main__":
    main()

2025-05-20 00:15:35,832 - INFO - Loading model from /Users/abhishek/Downloads/responsible_ai/models/phi4-mini-reasoning


Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 441.51it/s]

2025-05-20 00:15:36,122 - INFO - Model loaded successfully
2025-05-20 00:15:36,122 - INFO - LLM initialized successfully
2025-05-20 00:15:36,124 - INFO - Database connection established
2025-05-20 00:15:36,125 - INFO - Table information fetched: Table: UsData
Columns: first_name (TEXT), last_name (TEXT), company_name (TEXT), address (TEXT), city (TEXT), county (TEXT), state (TEXT), zip (BIGINT), phone1 (TEXT), phone2 (TEXT), email (TEXT), web (TEXT)
2025-05-20 00:15:36,126 - INFO - SQL agent initialized successfully
2025-05-20 00:15:36,126 - INFO - Executing query: return first row data of table?







[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mYou are an SQL agent for a SQLite database with ONE table. Generate a query based on the input.

Tools: sql_db_query, sql_db_list_tables, sql_db_query_checker
Tool Descriptions: sql_db_query - Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.
sql_db_list_tables - Input is an empty string, output is a comma-separated list of tables in the database.
sql_db_query_checker - Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!

Table: Table: UsData
Columns: first_name (TEXT), last_name (TEXT), company_name (TEXT), address (TEXT), city (TE

KeyboardInterrupt: 

In [11]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain_core.language_models import LLM
from langchain_core.callbacks import CallbackManagerForLLMRun
from typing import List, Optional, Any
from sqlalchemy import create_engine
from langchain.sql_database import SQLDatabase
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents import create_sql_agent
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
import logging
import sys
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
import warnings

# Suppress TqdmWarning
warnings.filterwarnings("ignore", category=UserWarning, module="tqdm")

# Set up logging for production
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)

class LocalHFLLM(LLM):
    device: str = "cpu"  # Change to "cuda" if GPU is available
    max_new_tokens: int = 256
    temperature: float = 0.1  # Low for maximum determinism
    top_p: float = 0.9

    def __init__(self, model_path: str, **kwargs):
        super().__init__(**kwargs)
        try:
            logger.info(f"Loading model from {model_path}")
            self._tokenizer = AutoTokenizer.from_pretrained(model_path)
            self._model = AutoModelForCausalLM.from_pretrained(model_path).to(self.device)
            logger.info("Model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load model: {str(e)}")
            raise

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs
    ) -> str:
        try:
            inputs = self._tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(self.device)
            outputs = self._model.generate(
                **inputs,
                max_new_tokens=self.max_new_tokens,
                temperature=self.temperature,
                top_p=self.top_p,
                do_sample=True,
                pad_token_id=self._tokenizer.pad_token_id,
                eos_token_id=self._tokenizer.eos_token_id,
                num_return_sequences=1
            )
            response = self._tokenizer.decode(outputs[0], skip_special_tokens=True)
            if stop:
                for stop_token in stop:
                    response = response.split(stop_token)[0]
            return response.strip()
        except Exception as e:
            logger.error(f"Error during model inference: {str(e)}")
            return "Error during model inference. Please try again."

    @property
    def _llm_type(self) -> str:
        return "custom-local-hf"

# Custom output parser to handle malformed outputs and retry
class CustomSQLAgentOutputParser(StrOutputParser):
    def __init__(self, table_name: str, columns: List[str]):
        super().__init__()
        self._table_name = table_name
        self._columns = columns

    def parse(self, text: str) -> str:
        query = f"SELECT {', '.join(self._columns)} FROM {self._table_name} LIMIT 1"
        try:
            valid_tools = ["sql_db_query", "sql_db_query_checker"]
            if any(f"Action: {tool}" in text for tool in valid_tools):
                return text
            logger.warning(f"Invalid output detected: {text}")
            return (
                f"Thought: The output was invalid. Using table {self._table_name} to select the first row.\n"
                f"Action: sql_db_query_checker\n"
                f"Action Input: {query}"
            )
        except Exception as e:
            logger.error(f"Error parsing output: {str(e)}")
            return (
                f"Thought: Failed to parse output. Defaulting to selecting first row from {self._table_name}.\n"
                f"Action: sql_db_query_checker\n"
                f"Action Input: {query}"
            )

# Define prompt template with dynamic feedback
SQL_PROMPT = PromptTemplate(
    input_variables=["input", "agent_scratchpad", "tool_names", "tools", "table_info"],
    template="""
You are an SQL agent for a SQLite database with ONE table. Generate a query based on the input.

Tools: {tool_names}
Tool Descriptions: {tools}

Table: {table_info}

Instructions:
1. Use the table name and columns from the table information to generate valid SQL queries.
2. Validate the query with `sql_db_query_checker` before executing with `sql_db_query`.
3. List all columns explicitly; do not use SELECT *.
4. If no query can be generated, respond with: "I could not generate a query for this input. Please try rephrasing your prompt."
5. Do not execute DML statements (INSERT, UPDATE, DELETE, DROP, etc.).
6. Use {agent_scratchpad} to review previous steps and incorporate feedback.
7. Retry intelligently if errors occur, refining your reasoning and queries based on observations.
8. Output only the ReAct format below.

ReAct Format:
Question: {input}
Thought: [Your reasoning]
Action: sql_db_query_checker
Action Input: SELECT [columns] FROM [table_name] WHERE [conditions] LIMIT 1
Observation: [Result of action]
Thought: I now know the final answer
Final Answer: [Final result]

Begin:

Question: {input}
Thought: {agent_scratchpad}
Table Info: {table_info}
"""
)

def fetch_table_info(db: SQLDatabase) -> tuple[str, str, List[str]]:
    """Fetch table name and schema from the database."""
    try:
        tables = db.get_usable_table_names()
        if not tables:
            raise ValueError("No tables found in the database.")
        table_name = tables[0]
        schema_info = db.run(f"PRAGMA table_info({table_name})")
        columns = [col[1] for col in eval(schema_info)]
        table_info = (
            f"Table: {table_name}\n"
            f"Columns: {', '.join(f'{col[1]} ({col[2]})' for col in eval(schema_info))}"
        )
        return table_name, table_info, columns
    except Exception as e:
        logger.error(f"Error fetching table information: {str(e)}")
        raise

@retry(
    stop=stop_after_attempt(5),
    wait=wait_exponential(multiplier=1, min=4, max=10),
    retry=retry_if_exception_type(Exception)
)
def run_agent_query(agent_executor, query: str, table_info: str) -> str:
    """Execute the agent query with retry logic."""
    feedback = ""
    for attempt in range(3):  # Retry up to 3 times
        try:
            logger.info(f"Executing query attempt {attempt + 1}: {query}")
            response = agent_executor.invoke({"input": query, "table_info": table_info, "agent_scratchpad": feedback})["output"]
            logger.info(f"Query executed successfully: {response}")
            return response
        except Exception as e:
            feedback = f"Error in attempt {attempt + 1}: {str(e)}. Refine the query and retry."
            logger.warning(feedback)
    logger.error("All retry attempts failed.")
    return "I could not generate a valid query. Please try rephrasing your prompt."

def main():
    model_path = "/Users/abhishek/Downloads/responsible_ai/models/phi4-mini-reasoning"
    db_path = "sqlite:////Users/abhishek/Downloads/responsible_ai/db_files/data.db"

    try:
        llm = LocalHFLLM(model_path=model_path)
        logger.info("LLM initialized successfully")
    except Exception as e:
        logger.error(f"Failed to initialize LLM: {str(e)}")
        return

    try:
        engine = create_engine(db_path)
        db = SQLDatabase(engine)
        logger.info("Database connection established")
    except Exception as e:
        logger.error(f"Failed to connect to database: {str(e)}")
        return

    try:
        table_name, table_info, columns = fetch_table_info(db)
        logger.info(f"Table information fetched: {table_info}")
    except Exception as e:
        logger.error(f"Failed to fetch table information: {str(e)}")
        return

    try:
        toolkit = SQLDatabaseToolkit(db=db, llm=llm)
        output_parser = CustomSQLAgentOutputParser(table_name=table_name, columns=columns)
        agent_executor = create_sql_agent(
            llm=llm,
            toolkit=toolkit,
            verbose=True,
            handle_parsing_errors=True,
            prompt=SQL_PROMPT,
            agent_executor_kwargs={"output_parser": output_parser}
        )
        logger.info("SQL agent initialized successfully")
    except Exception as e:
        logger.error(f"Failed to initialize SQL agent: {str(e)}")
        return

    query = "Get the first name and email of all users in New York."
    response = run_agent_query(agent_executor, query, table_info)
    print(f"Final Response: {response}")

if __name__ == "__main__":
    main()

2025-05-20 00:34:08,045 - INFO - Loading model from /Users/abhishek/Downloads/responsible_ai/models/phi4-mini-reasoning


Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 445.68it/s]

2025-05-20 00:34:08,334 - INFO - Model loaded successfully
2025-05-20 00:34:08,334 - INFO - LLM initialized successfully
2025-05-20 00:34:08,337 - INFO - Database connection established
2025-05-20 00:34:08,338 - INFO - Table information fetched: Table: UsData
Columns: first_name (TEXT), last_name (TEXT), company_name (TEXT), address (TEXT), city (TEXT), county (TEXT), state (TEXT), zip (BIGINT), phone1 (TEXT), phone2 (TEXT), email (TEXT), web (TEXT)
2025-05-20 00:34:08,339 - INFO - SQL agent initialized successfully
2025-05-20 00:34:08,340 - INFO - Executing query attempt 1: Get the first name and email of all users in New York.







[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mYou are an SQL agent for a SQLite database with ONE table. Generate a query based on the input.

Tools: sql_db_query, sql_db_list_tables, sql_db_query_checker
Tool Descriptions: sql_db_query - Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.
sql_db_list_tables - Input is an empty string, output is a comma-separated list of tables in the database.
sql_db_query_checker - Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!

Table: Table: UsData
Columns: first_name (TEXT), last_name (TEXT), company_name (TEXT), address (TEXT), city (TE

In [14]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sqlalchemy import create_engine, text
from langchain.sql_database import SQLDatabase
import logging
import sys
import re
import warnings

# Suppress warnings
warnings.filterwarnings("ignore", category=UserWarning, module="tqdm")

# Set up logging for production
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)

class LocalHFLLM:
    def __init__(self, model_path: str, device: str = "cpu", max_new_tokens: int = 256, temperature: float = 0.1, top_p: float = 0.9):
        self.device = device
        self.max_new_tokens = max_new_tokens
        self.temperature = temperature
        self.top_p = top_p
        try:
            logger.info(f"Loading model from {model_path}")
            self._tokenizer = AutoTokenizer.from_pretrained(model_path)
            self._model = AutoModelForCausalLM.from_pretrained(model_path).to(self.device)
            logger.info("Model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load model: {str(e)}")
            raise

    def generate(self, prompt: str) -> str:
        try:
            inputs = self._tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(self.device)
            outputs = self._model.generate(
                **inputs,
                max_new_tokens=self.max_new_tokens,
                temperature=self.temperature,
                top_p=self.top_p,
                do_sample=True,
                pad_token_id=self._tokenizer.pad_token_id,
                eos_token_id=self._tokenizer.eos_token_id,
                num_return_sequences=1
            )
            response = self._tokenizer.decode(outputs[0], skip_special_tokens=True)
            return response.strip()
        except Exception as e:
            logger.error(f"Error during model inference: {str(e)}")
            return ""

def fetch_table_info(db: SQLDatabase):
    try:
        tables = db.get_usable_table_names()
        if not tables:
            raise ValueError("No tables found in the database.")
        table_name = tables[0]
        schema_info = db.run(f"PRAGMA table_info({table_name})")
        columns = [col[1] for col in eval(schema_info)]
        table_info = (
            f"Table: {table_name}\n"
            f"Columns: {', '.join(f'{col[1]} ({col[2]})' for col in eval(schema_info))}"
        )
        return table_name, table_info, columns
    except Exception as e:
        logger.error(f"Error fetching table information: {str(e)}")
        raise

def build_sql_prompt(user_query, table_name, columns):
    columns_str = ", ".join(columns)
    prompt = (
        f"You are given a SQLite table '{table_name}' with columns: {columns_str}.\n"
        f"Write a valid SQL SELECT query (do not use SELECT *) to answer the following question:\n"
        f"Question: {user_query}\n"
        "Only output the SQL query. Do not include explanations or any other text."
    )
    return prompt

def extract_sql(text):
    # Extract the first SQL query from the text
    matches = re.findall(r"SELECT\s+.+?;", text, re.IGNORECASE | re.DOTALL)
    if matches:
        return matches[0]
    # fallback: try to find any select statement
    matches = re.findall(r"SELECT\s+.+", text, re.IGNORECASE)
    if matches:
        return matches[0].split("\n")[0].strip()
    return ""

def validate_sql(query, table_name, allowed_columns):
    # Simple validation: must start with SELECT and only use allowed columns/table
    if not query.strip().upper().startswith("SELECT"):
        return False
    if table_name not in query:
        return False
    # crude check: check each column is in allowed columns or *
    for col in allowed_columns:
        if col in query:
            return True
    return False

def run_sql(engine, sql):
    try:
        with engine.connect() as conn:
            result = conn.execute(text(sql))
            rows = result.fetchall()
            if not rows:
                return "Query ran successfully but returned no results."
            return str(rows)
    except Exception as e:
        logger.error(f"SQL execution error: {e}")
        return f"SQL execution error: {e}"

def main():
    model_path = "/Users/abhishek/Downloads/responsible_ai/models/tinyllama"
    db_path = "sqlite:////Users/abhishek/Downloads/responsible_ai/db_files/data.db"

    # Load model
    try:
        llm = LocalHFLLM(model_path=model_path)
        logger.info("LLM initialized successfully")
    except Exception as e:
        logger.error(f"Failed to initialize LLM: {str(e)}")
        return

    # Connect DB
    try:
        engine = create_engine(db_path)
        db = SQLDatabase(engine)
        logger.info("Database connection established")
    except Exception as e:
        logger.error(f"Failed to connect to database: {str(e)}")
        return

    # Table info
    try:
        table_name, table_info, columns = fetch_table_info(db)
        logger.info(f"Table information fetched: {table_info}")
    except Exception as e:
        logger.error(f"Failed to fetch table information: {str(e)}")
        return

    # Get user query
    user_query = input("Enter your question: ")

    # Build prompt
    prompt = build_sql_prompt(user_query, table_name, columns)

    # Try up to 3 times
    for attempt in range(3):
        logger.info(f"Attempt {attempt+1}: Generating SQL for query: {user_query}")
        model_output = llm.generate(prompt)
        sql = extract_sql(model_output)
        logger.info(f"Model raw output: {model_output}")
        logger.info(f"Extracted SQL: {sql}")
        if validate_sql(sql, table_name, columns):
            logger.info("Generated SQL validated successfully.")
            result = run_sql(engine, sql)
            print(f"\nGenerated SQL:\n{sql}")
            print(f"\nQuery Result:\n{result}")
            return
        else:
            logger.warning("Model did not generate a valid SQL, retrying...")

    print("I could not generate a valid query for your input. Please try rephrasing your question or use a more specific prompt.")

if __name__ == "__main__":
    main()

2025-05-20 01:32:35,512 - INFO - Loading model from /Users/abhishek/Downloads/responsible_ai/models/tinyllama
2025-05-20 01:32:35,674 - INFO - Model loaded successfully
2025-05-20 01:32:35,674 - INFO - LLM initialized successfully
2025-05-20 01:32:35,677 - INFO - Database connection established
2025-05-20 01:32:35,678 - INFO - Table information fetched: Table: UsData
Columns: first_name (TEXT), last_name (TEXT), company_name (TEXT), address (TEXT), city (TEXT), county (TEXT), state (TEXT), zip (BIGINT), phone1 (TEXT), phone2 (TEXT), email (TEXT), web (TEXT)
2025-05-20 01:32:56,902 - INFO - Attempt 1: Generating SQL for query: return the first row data from the table.
2025-05-20 01:32:59,102 - INFO - Model raw output: You are given a SQLite table 'UsData' with columns: first_name, last_name, company_name, address, city, county, state, zip, phone1, phone2, email, web.
Write a valid SQL SELECT query (do not use SELECT *) to answer the following question:
Question: return the first row dat

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sqlalchemy import create_engine, text
from langchain.sql_database import SQLDatabase
import logging
import sys
import re
import warnings

warnings.filterwarnings("ignore", category=UserWarning, module="tqdm")

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)

class LocalHFLLM:
    def __init__(self, model_path: str, device: str = "cpu", max_new_tokens: int = 256, temperature: float = 0.1, top_p: float = 0.9):
        self.device = device
        self.max_new_tokens = max_new_tokens
        self.temperature = temperature
        self.top_p = top_p
        try:
            logger.info(f"Loading model from {model_path}")
            self._tokenizer = AutoTokenizer.from_pretrained(model_path)
            self._model = AutoModelForCausalLM.from_pretrained(
                model_path, device_map="auto", load_in_4bit=True
            )
            logger.info("Model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load model: {str(e)}")
            raise

    def generate(self, prompt: str) -> str:
        try:
            inputs = self._tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(self.device)
            outputs = self._model.generate(
                **inputs,
                max_new_tokens=self.max_new_tokens,
                temperature=self.temperature,
                top_p=self.top_p,
                do_sample=True,
                pad_token_id=self._tokenizer.pad_token_id,
                eos_token_id=self._tokenizer.eos_token_id,
                num_return_sequences=1
            )
            response = self._tokenizer.decode(outputs[0], skip_special_tokens=True)
            return response.strip()
        except Exception as e:
            logger.error(f"Error during model inference: {str(e)}")
            return ""

def fetch_table_info(db: SQLDatabase):
    try:
        tables = db.get_usable_table_names()
        if not tables:
            raise ValueError("No tables found in the database.")
        table_name = tables[0]
        schema_info = db.run(f"PRAGMA table_info({table_name})")
        columns = [col[1] for col in eval(schema_info)]
        table_info = (
            f"Table: {table_name}\n"
            f"Columns: {', '.join(f'{col[1]} ({col[2]})' for col in eval(schema_info))}"
        )
        return table_name, table_info, columns
    except Exception as e:
        logger.error(f"Error fetching table information: {str(e)}")
        raise

def build_sql_prompt(user_query, table_name, columns):
    columns_str = ", ".join(columns)
    prompt = (
        f"You are given a SQLite table '{table_name}' with columns: {columns_str}.\n"
        f"Write a valid SQL SELECT query (do not use SELECT *) to answer the following question:\n"
        f"Question: {user_query}\n"
        "Only output the SQL query. Do not include explanations or any other text."
    )
    return prompt

def extract_sql(text):
    # Try to extract SQL query from output
    for line in text.splitlines():
        if line.strip().upper().startswith("SELECT"):
            if ";" in line:
                return line.strip().split(";")[0] + ";"
            return line.strip()
    return ""

def validate_sql(query, table_name, allowed_columns):
    # Must start with SELECT and only use allowed columns/table
    if not query.strip().upper().startswith("SELECT"):
        return False
    if table_name not in query:
        return False
    for col in allowed_columns:
        if col in query:
            return True
    return False

def run_sql(engine, sql):
    try:
        with engine.connect() as conn:
            result = conn.execute(text(sql))
            rows = result.fetchall()
            if not rows:
                return "Query ran successfully but returned no results."
            return str(rows)
    except Exception as e:
        logger.error(f"SQL execution error: {e}")
        return f"SQL execution error: {e}"

def main():
    # Choose one of these models (see above for links):
    #model_path = "cognitivecomputations/TinyLlama-1.1B-Chat-v1.0"
    #model_path = "microsoft/phi-2"
    model_path = "TheBloke/Mistral-7B-Instruct-v0.2-AWQ"  # <8GB quantized model

    db_path = "sqlite:////Users/abhishek/Downloads/responsible_ai/db_files/data.db"

    # Load model
    try:
        llm = LocalHFLLM(model_path=model_path)
        logger.info("LLM initialized successfully")
    except Exception as e:
        logger.error(f"Failed to initialize LLM: {str(e)}")
        return

    # Connect DB
    try:
        engine = create_engine(db_path)
        db = SQLDatabase(engine)
        logger.info("Database connection established")
    except Exception as e:
        logger.error(f"Failed to connect to database: {str(e)}")
        return

    # Table info
    try:
        table_name, table_info, columns = fetch_table_info(db)
        logger.info(f"Table information fetched: {table_info}")
    except Exception as e:
        logger.error(f"Failed to fetch table information: {str(e)}")
        return

    # Get user query
    user_query = input("Enter your question: ")

    prompt = build_sql_prompt(user_query, table_name, columns)
    for attempt in range(3):
        logger.info(f"Attempt {attempt+1}: Generating SQL for query: {user_query}")
        model_output = llm.generate(prompt)
        sql = extract_sql(model_output)
        logger.info(f"Model raw output: {model_output}")
        logger.info(f"Extracted SQL: {sql}")
        if validate_sql(sql, table_name, columns):
            logger.info("Generated SQL validated successfully.")
            result = run_sql(engine, sql)
            print(f"\nGenerated SQL:\n{sql}")
            print(f"\nQuery Result:\n{result}")
            return
        else:
            logger.warning("Model did not generate a valid SQL, retrying...")

    print("I could not generate a valid query for your input. Please try rephrasing your question or use a more specific prompt.")

if __name__ == "__main__":
    main()