In [55]:
# Cell 1: Imports and Setup

import sqlite3
import json
import csv
from typing import List, Tuple, Dict, Union
import google.generativeai as genai
from dotenv import load_dotenv
import os
from pydantic import BaseModel, Field
from tqdm.autonotebook import tqdm

# Load environment variables from .env file
load_dotenv()

# Set Gemini API key from .env file
genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
MODEL_NAME = "gemini-pro"

In [56]:
# Cell 2: Pydantic Models

class RelevantData(BaseModel):
    relevant_tables: List[str] = Field(
        ..., description="List of relevant table names"
    )
    relevant_columns: List[str] = Field(
        ..., description="List of relevant column names in 'table.column' format"
    )


class SQLQuery(BaseModel):
    sql_query: str = Field(..., description="Generated SQL query")


class RelevantQueries(BaseModel):
    relevant_queries: List[Dict] = Field(
        ..., description="List of relevant past queries"
    )

In [57]:
# Cell 3: Base Agent Class

class BaseAgent:
    def __init__(self, model: str = MODEL_NAME):
        self.model = model
        self.gemini = genai.GenerativeModel(self.model)

    def call_llm(self, prompt: str, response_model: type[BaseModel]) -> BaseModel:
        """Call Gemini API and get structured output"""
        messages = [
            {
                "role": "user",
                "parts": ["You are an assistant for structured outputs."],
            },
            {
                "role": "model",
                "parts": [
                    "Understood. I will provide responses in the requested format."
                ],
            },
            {"role": "user", "parts": [prompt]},
        ]

        response = self.gemini.generate_content(messages)

        try:
            # Clean the response text
            cleaned_response_text = response.text.strip()
            if cleaned_response_text.startswith("```json"):
                cleaned_response_text = cleaned_response_text[7:]
            elif cleaned_response_text.startswith("```JSON"):
                cleaned_response_text = cleaned_response_text[7:]
            elif cleaned_response_text.startswith("```"):
                cleaned_response_text = cleaned_response_text[3:]

            if cleaned_response_text.endswith("```"):
                cleaned_response_text = cleaned_response_text[:-3]

            if cleaned_response_text.startswith("JSON"):
                cleaned_response_text = cleaned_response_text[4:]
            elif cleaned_response_text.startswith("json"):
                cleaned_response_text = cleaned_response_text[4:]

            cleaned_response_text = cleaned_response_text.strip()

            # Attempt to parse the cleaned response text as JSON
            response_json = json.loads(cleaned_response_text)
            return response_model(**response_json)
        except (json.JSONDecodeError, ValidationError) as e:
            print(f"Error parsing or validating response: {e}")
            print(f"Cleaned response text: {cleaned_response_text}")
            raise ValueError("Failed to parse and validate LLM response.")

In [58]:
# Cell 4: RAG Agent Class

class RAGAgent(BaseAgent):
    def __init__(self, db_file: str):
        super().__init__()
        self.db_file = db_file

    def extract_relevant_data(
        self, query: str, relevant_queries: List[Dict], memory_results: List[Tuple]
    ) -> Tuple[List[str], List[str]]:
        """Extract relevant tables and columns using LLM, database schema, and memory results."""
        schema = self.get_full_schema()
        prompt = (
            "You are a system that extracts schema information for SQL queries. **You MUST ONLY use the tables and columns provided in the schema below. Do NOT hallucinate or use any other tables or columns.**\n"
            f"Schema:\n{schema}\n"
            f"Relevant Queries: {relevant_queries}\n"
            f"Previous Results: {memory_results}\n"
            f"User Query: {query}\n"
            "Based on the user query, and considering the provided schema ONLY, provide a JSON object with:\n"
            "- 'relevant_tables': a list of the table names that are relevant to the query.\n"
            "- 'relevant_columns': a list of the column names (in 'table.column' format) that are relevant to the query.\n"
            "If no tables or columns are relevant, return an empty list for both.\n"
            "Example:\n"
            "Schema:\nTable: employees, Columns: employee_id (INTEGER), name (TEXT), department (TEXT), salary (REAL)\nTable: departments, Columns: department_id (INTEGER), department_name (TEXT)\n"
            "User Query: What is the average salary in the 'Sales' department?\n"
            "Response:\n"
            '{"relevant_tables": ["employees", "departments"], "relevant_columns": ["employees.salary", "departments.department_name"]}\n'
            "If the user asks a question that requires creating a new table or modifying data, return an empty list for both 'relevant_tables' and 'relevant_columns'."
        )
        structured_output: RelevantData = self.call_llm(prompt, RelevantData)
        return structured_output.relevant_tables, structured_output.relevant_columns

    def get_full_schema(self) -> str:
        """Retrieve full schema from the database."""
        schema_details = []
        with sqlite3.connect(self.db_file) as conn:
            cursor = conn.cursor()
            cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
            tables = [row[0] for row in cursor.fetchall()]
            for table in tables:
                cursor.execute(f"PRAGMA table_info({table});")
                columns = [f"{row[1]} ({row[2]})" for row in cursor.fetchall()]
                schema_details.append(f"Table: {table}, Columns: {', '.join(columns)}")
        return "\n".join(schema_details)

In [None]:
# Cell 5: SQL Agent Class

class SQLAgent(BaseAgent):
    def __init__(self, db_file: str):
        super().__init__()
        self.db_file = db_file
        self.rag_agent = RAGAgent(db_file)
        self.max_refinement_iterations = 3

    def generate_sql(
        self,
        tables: List[str],
        columns: List[str],
        query: str,
        relevant_queries: List[Dict],
        memory_results: List[Tuple],
    ) -> str:
        """Generate initial SQL query using extracted tables, columns, and history."""
        schema = self.rag_agent.get_full_schema()
        prompt = (
            "You are a SQL generation assistant. **You MUST ONLY use the tables and columns provided in the schema below. Do NOT hallucinate or use any other tables or columns.**\n"
            f"Schema:\n{schema}\n"
            f"Relevant Tables: {tables}\n"
            f"Relevant Columns: {columns}\n"
            f"Relevant Queries: {relevant_queries}\n"
            f"Previous Results: {memory_results}\n"
            f"User Query: {query}\n"
            "Based on the user query, relevant tables, relevant columns, and the provided schema ONLY, generate a single valid SQL query that answers the query.\n"
            "Provide a valid SQL query as a JSON object with the key 'sql_query' and the SQL query as the string value. Do not nest the SQL query within another object.\n"
            "Example: {\"sql_query\": \"SELECT * FROM table WHERE condition\"}\n"
            "If the user asks a question that requires creating a new table or modifying data, return 'NOT_APPLICABLE'.\n"
            "If the user asks a question that does not require any table, select from 'Table_does_not_exist'."
        )
        structured_output: SQLQuery = self.call_llm(prompt, SQLQuery)
        return structured_output.sql_query

    def refine_sql(
        self,
        previous_sql_query: str,
        query_results: List[Tuple],
        user_query: str,
        relevant_queries: List[Dict],
        memory_results: List[Tuple],
    ) -> str:
        """Refine the SQL query based on the results of the previous query."""
        schema = self.rag_agent.get_full_schema()
        prompt = (
            "You are a SQL refinement assistant. You improve SQL queries iteratively based on feedback.\n"
            f"Schema:\n{schema}\n"
            f"Previous SQL Query: {previous_sql_query}\n"
            f"Query Results: {query_results}\n"
            f"Relevant Queries: {relevant_queries}\n"
            f"Memory Results: {memory_results}\n"
            f"Original User Query: {user_query}\n"
            "Analyze the previous SQL query and its results. If the results are not satisfactory or if they do not fully address the original user query, generate a refined SQL query.\n"
            "Consider the following when refining:\n"
            "- Add or modify WHERE clauses to filter the results further.\n"
            "- Adjust JOIN conditions if necessary.\n"
            "- Add or change aggregate functions (e.g., COUNT, SUM, AVG).\n"
            "- Modify the selected columns.\n"
            "- Correct any errors in the previous query.\n"
            "If the results are satisfactory, return 'SATISFACTORY'.\n"
            "If you cannot determine how to refine the query further, return 'NO_FURTHER_REFINEMENT'.\n"
            "Provide the refined SQL query as a JSON object with the key 'sql_query'.\n"
            "Example:\n"
            "Previous SQL Query: SELECT department, AVG(salary) FROM employees GROUP BY department\n"
            "Query Results: [('Sales', 55000.0), ('HR', 60000.0)]\n"
            "Original User Query: What is the average salary in each department for employees hired after 2021?\n"
            "Response:\n"
            '{"sql_query": "SELECT department, AVG(salary) FROM employees WHERE hire_date > \'2021-12-31\' GROUP BY department"}\n'
            "Previous SQL Query: SELECT name FROM employees WHERE is_manager = 'yes'\n"
            "Query Results: [('John Doe',), ('Jane Smith',)]\n"
            "Original User Query: Who are the managers in the Sales department?\n"
            "Response:\n"
            '{"sql_query": "SELECT name FROM employees WHERE is_manager = \'yes\' AND department = \'Sales\'"}\n'
        )
        structured_output: SQLQuery = self.call_llm(prompt, SQLQuery)
        return structured_output.sql_query

    def generate_and_refine_sql(
        self,
        tables: List[str],
        columns: List[str],
        query: str,
        relevant_queries: List[Dict],
        memory_results: List[Tuple],
    ) -> str:
        """Generate and refine SQL queries iteratively until satisfactory or max iterations reached."""
        current_sql_query = self.generate_sql(
            tables, columns, query, relevant_queries, memory_results
        )
        print(f"    Initial SQL Query: {current_sql_query}")

        for iteration in range(self.max_refinement_iterations):
            if current_sql_query in [
                "NOT_APPLICABLE",
                "NO_FURTHER_REFINEMENT",
                "SATISFACTORY",
            ]:
                print(
                    f"    Final SQL Query (Iteration {iteration+1}): {current_sql_query}"
                )
                return current_sql_query

            with sqlite3.connect(self.db_file) as conn:
                cursor = conn.cursor()
                try:
                    cursor.execute(current_sql_query)
                    query_results = cursor.fetchall()
                except sqlite3.Error as e:
                    print(f"    SQL Execution Error (Iteration {iteration + 1}): {e}")
                    query_results = []

            print(f"    Query Results (Iteration {iteration+1}): {query_results}")

            new_sql_query = self.refine_sql(
                current_sql_query,
                query_results,
                query,
                relevant_queries,
                memory_results,
            )

            if new_sql_query == "SATISFACTORY":
                print(f"    SQL Refinement Satisfactory (Iteration {iteration+1})")
                return current_sql_query
            elif new_sql_query == "NO_FURTHER_REFINEMENT":
                print(
                    f"    SQL Refinement Could Not Be Improved Further (Iteration {iteration+1})"
                )
                return current_sql_query
            else:
                current_sql_query = new_sql_query
                print(f"    Refined SQL Query (Iteration {iteration+1}): {current_sql_query}")

        print(
            f"    Max Refinement Iterations Reached ({self.max_refinement_iterations})"
        )
        return current_sql_query

In [59]:
# Cell 6: Conversational Agent Class

class ConversationalAgent(BaseAgent):
    def __init__(self, memory_file: str):
        super().__init__()
        self.memory_file = memory_file

    def get_chat_history(self, current_input: str) -> List[Dict]:
        """
        Retrieve relevant queries from memory for the current input.
        Considers queries relevant if they provide context or information
        that would be helpful in answering the current input.
        """
        with open(self.memory_file, mode="r", newline="") as file:
            reader = csv.DictReader(file)
            memory_buffer = [row for row in reader]

        memory_json = json.dumps(memory_buffer)
        prompt = (
            "You are an assistant that retrieves relevant past queries from a memory buffer to provide context for answering the current input. **You MUST ONLY use the queries provided in the memory buffer. Do NOT hallucinate or use any other queries.**\n"
            f"Memory Buffer (JSON):\n{memory_json}\n"
            f"Current Input: {current_input}\n"
            "Analyze the current input and determine which past queries from the memory buffer are relevant. Consider the following when determining relevancy:\n"
            "- The past query provides definitions or context related to terms in the current input.\n"
            "- The past query asks a similar question to the current input, even if about a different entity.\n"
            "- The past query provides data that might be useful for comparison or contrast with the current input.\n"
            "- The past query was part of a sequence of questions leading up to the current input.\n"
            "A past query is considered relevant if it provides context or information that would be helpful in answering the current input.\n"
            "Return the relevant past queries as a JSON object under the key 'relevant_queries'.\n"
            "Example:\n"
            "Memory Buffer (JSON):\n"
            '[{"User Query": "What is the capital of France?", "Relevant Tables": [], "Relevant Columns": [], "Generated SQL Query": "NOT_APPLICABLE", "Execution Results": "[]"}, {"User Query": "How many departments are there?", "Relevant Tables": ["departments"], "Relevant Columns": ["departments.department_id"], "Generated SQL Query": "SELECT COUNT(department_id) FROM departments", "Execution Results": "[[5]]"}]\n'
            "Current Input: Which department has the most employees?\n"
            "Response:\n"
            '{"relevant_queries": [{"User Query": "How many departments are there?", "Relevant Tables": ["departments"], "Relevant Columns": ["departments.department_id"], "Generated SQL Query": "SELECT COUNT(department_id) FROM departments", "Execution Results": "[[5]]"}]}\n'
            "If no relevant queries are found, return an empty list.\n"
        )
        structured_output: RelevantQueries = self.call_llm(prompt, RelevantQueries)
        return structured_output.relevant_queries

In [60]:
# Cell 7: RAG Pipeline Class

class RAGPipeline:
    def __init__(self, db_file: str, user_query: str):
        self.db_file = db_file
        self.user_query = user_query
        self.memory_file = "memory_buffer.csv"
        self.conversational_agent = ConversationalAgent(self.memory_file)
        self.rag_agent = RAGAgent(db_file)
        self.sql_agent = SQLAgent(db_file)
        self.initialize_memory()

    def initialize_memory(self):
        """Initialize the memory buffer CSV."""
        if not os.path.exists(self.memory_file):
            with open(self.memory_file, mode="w", newline="") as file:
                writer = csv.writer(file)
                writer.writerow(
                    [
                        "User Query",
                        "Relevant Tables",
                        "Relevant Columns",
                        "Generated SQL Query",
                        "Execution Results",
                    ]
                )

    def update_memory(
        self,
        user_query: str,
        tables: List[str],
        columns: List[str],
        sql_query: str,
        results: List[Tuple],
    ):
        """Update the memory buffer with the latest context."""
        with open(self.memory_file, mode="a", newline="") as file:
            writer = csv.writer(file)
            writer.writerow(
                [
                    user_query,
                    json.dumps(tables),
                    json.dumps(columns),
                    sql_query,
                    json.dumps(results),
                ]
            )

    def fetch_relevant_from_memory(
        self, relevant_queries: List[Dict]
    ) -> Tuple[List[str], List[str], List[Tuple]]:
        """Fetch relevant tables, columns, and results from memory based on retrieved queries."""
        relevant_tables = set()
        relevant_columns = set()
        relevant_results = []
        with open(self.memory_file, mode="r", newline="") as file:
            reader = csv.DictReader(file)
            for row in reader:
                if any(
                    rq["User Query"] == row["User Query"] for rq in relevant_queries
                ):
                    relevant_tables.update(json.loads(row["Relevant Tables"]))
                    relevant_columns.update(json.loads(row["Relevant Columns"]))
                    relevant_results.append(json.loads(row["Execution Results"]))
        return list(relevant_tables), list(relevant_columns), relevant_results

    def run(self):
        while True:
            # Get relevant history for context
            relevant_queries = self.conversational_agent.get_chat_history(
                self.user_query
            )

            print("Processing Query:")
            print(f"  Relevant Queries for Context: {relevant_queries}")

            # Fetch relevant tables, columns, and results from memory
            memory_tables, memory_columns, memory_results = (
                self.fetch_relevant_from_memory(relevant_queries)
            )

            # Step 1: Extract relevant tables and columns
            rag_tables, rag_columns = self.rag_agent.extract_relevant_data(
                self.user_query, relevant_queries, memory_results
            )
            relevant_tables = list(set(memory_tables + rag_tables))
            relevant_columns = list(set(memory_columns + rag_columns))

            print(f"  Relevant Tables: {relevant_tables}")
            print(f"  Relevant Columns: {relevant_columns}")

            # Step 2: Generate and Refine SQL query
            final_sql_query = self.sql_agent.generate_and_refine_sql(
                relevant_tables,
                relevant_columns,
                self.user_query,
                relevant_queries,
                memory_results,
            )

            print(f"  Final SQL Query: {final_sql_query}")

            # Step 3: Execute SQL query
            results = self.execute_sql_with_validation(final_sql_query)
            print(f"  Query Results: {results}")

            # Update memory buffer
            self.update_memory(
                self.user_query,
                relevant_tables,
                relevant_columns,
                final_sql_query,
                results,
            )

            # Step 4: Take new user input
            self.user_query = input("Enter your next question: ").strip()

    def execute_sql_with_validation(self, sql_query: str) -> List[Tuple]:
        """Execute SQL query with retries if errors occur."""
        attempt = 0
        max_attempts = 3
        while attempt < max_attempts:
            try:
                with sqlite3.connect(self.db_file) as conn:
                    cursor = conn.cursor()
                    cursor.execute(sql_query)
                    return cursor.fetchall()
            except sqlite3.Error as e:
                print(f"SQL Execution Error (Attempt {attempt + 1}): {e}")
                attempt += 1
                sql_query = self.sql_agent.generate_sql(
                    [], [], self.user_query, [], []
                )  # Regenerate SQL on error
        return []

In [62]:
# Cell 8: Entry Point and Execution

if __name__ == "__main__":
    DB_FILE = "company_data.db"  # Replace with your database file
    USER_QUERY = "How many people have booked at least 50 hours before 01/05/2024?"  # Initial user query
    pipeline = RAGPipeline(DB_FILE, USER_QUERY)
    pipeline.run()

Processing Query:
  Relevant Queries for Context: [{'User Query': 'How many people have booked at least 50 hours before 01/05/2024?', 'Relevant Tables': '["holiday_balance"]', 'Relevant Columns': '["holiday_balance.booked"]', 'Generated SQL Query': "SELECT\nCOUNT(DISTINCT colleague_id)\nFROM holiday_balance\nWHERE booked >= 50\nAND report_date <= '2024-04-30'", 'Execution Results': '[[0]]'}, {'User Query': 'How many people have booked at least 50 hours before 01/05/2024?', 'Relevant Tables': '["holiday_balance"]', 'Relevant Columns': '["holiday_balance.booked"]', 'Generated SQL Query': "SELECT\nCOUNT(DISTINCT colleague_id)\nFROM holiday_balance\nWHERE booked >= 50\nAND report_date < '2024-05-01'\nAND using_workday = 1", 'Execution Results': '[[0]]'}]
  Relevant Tables: ['holiday_balance']
  Relevant Columns: ['holiday_balance.report_date', 'holiday_balance.booked']
    Initial SQL Query: SELECT
COUNT(DISTINCT colleague_id)
FROM holiday_balance
WHERE booked >= 50
AND report_date < '2024

KeyboardInterrupt: 