In [37]:
from google import genai
from dotenv import load_dotenv
import os
import sqlite3

load_dotenv()

True

In [115]:
# ------------------------------------------------------------------
# 1. SQL-ONLY PROMPT
# ------------------------------------------------------------------
SQL_PROMPT = """
You are a SQL-only assistant for the HDB data mart (Singapore context).
Your ONLY task is to emit **exactly one valid SQLite statement** (or a single CTE) that answers the user’s question by reading from the tables below.

Schema
------
bto_prices(
    _id INTEGER PRIMARY KEY AUTOINCREMENT,
    financial_year TEXT,
    room_type TEXT,
    town TEXT,
    min_selling_price REAL,
    max_selling_price REAL,
    min_selling_price_less_ahg_shg REAL,
    max_selling_price_less_ahg_shg REAL
)

Rules
1. Return ONLY the SQL statement—no explanations, no markdown fences.  
2. Use standard SQLite syntax (CTEs allowed).  
3. Parameterise with $town, $room_type, etc. when possible.  
4. Aggregate or filter as needed to answer the question; do NOT predict future prices.  
5. If the question is ambiguous, choose the most reasonable interpretation and proceed.  
6. The SQLite query must always be executable.

Examples
--------
Q: List all towns with BTO launches in 2018.
A:
SELECT DISTINCT town
FROM bto_prices
WHERE financial_year = '2018';

Q: <user query>
"""


ANALYST_PROMPT = """
You will be given the following data and context:  
1. A user query  
2. An SQL query that extracts relevant information from the database  
3. The corresponding SQL output  

Your task is to response to the user’s query using the information provided. 
Ensure you always aim to respond to the user's query, but only use information provided.

Schema
------
bto_prices(
    _id INTEGER PRIMARY KEY AUTOINCREMENT,
    financial_year TEXT,
    room_type TEXT,
    town TEXT,
    min_selling_price REAL,
    max_selling_price REAL,
    min_selling_price_less_ahg_shg REAL,
    max_selling_price_less_ahg_shg REAL
)

Q: <user query>

SQL: <SQL query>

SQL: <Output>
"""

In [116]:
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
    raise ValueError("GOOGLE_API_KEY not found in environment variables")

client = genai.Client(api_key=api_key)

In [117]:
def _sql_is_valid(sql, conn):
    try:
        conn.execute(f"EXPLAIN {sql}")
        return True
    except sqlite3.Error as e:
        print("SQLite says:", e)
        return False
    

def generate_sql_based_on_query(query):
    prompt = SQL_PROMPT.replace("<user query>", query)

    response = client.models.generate_content(
        model="gemini-2.5-flash-lite",
        contents=prompt,
    )

    sql = response.text.strip()

    return sql

def generate_valid_sql(query, db_path):
    """
    Generates a valid SQLite query from a natural language query.
    """
    with sqlite3.connect(db_path) as conn:
        for attempt in range(3):
            sql = generate_sql_based_on_query(query)
            print(sql)
            if _sql_is_valid(sql, conn):
                return sql
            print(f"[Attempt {attempt}] Generated SQL was invalid; retrying…")
    raise RuntimeError("Failed to generate a valid SQL query after 3 attempts.")

def execute_sql_query(sql, conn):
    try:
        cursor = conn.execute(sql)
        results = cursor.fetchall()
        return results
    except sqlite3.Error as e:
        print(f"Error executing SQL query: {e}")
        raise

def execute_and_display_query(sql, conn):
    try:
        cursor = conn.execute(sql)
        columns = [description[0] for description in cursor.description] if cursor.description else []
        
        print("Columns:", columns)
        print("-" * 50)
        
        results = cursor.fetchall()
        if not results:
            print("No results found.")
        else:
            # Print header
            print(" | ".join(f"{col:<15}" for col in columns))
            print("-" * (len(" | ".join(f"{col:<15}" for col in columns))))
            
            # Print rows
            for row in results:
                print(" | ".join(f"{str(val):<15}" for val in row))
                
        return results
        
    except sqlite3.Error as e:
        print(f"Error executing SQL query: {e}")
        raise

def explain_query(user_query: str, sql: str, results, model="gemini-2.5-flash-lite"):
    result_str = str(results) if results else "No rows returned."

    prompt = (
        ANALYST_PROMPT
        .replace("<user_query>", user_query)
        .replace("<sql_query>", sql.strip())
        .replace("<output>", result_str)
    )

    response = client.models.generate_content(model=model, contents=prompt)
    return response.text.strip()


def ask_and_explain(user_query: str, db_path: str):
    sql = generate_valid_sql(user_query, db_path)
    with sqlite3.connect(db_path) as conn:
        results = execute_and_display_query(sql, conn)
        explanation = explain_query(user_query, sql, results)
        print("\n=== Analyst Explanation ===")
        print(explanation)
    return sql, results, explanation

In [118]:
if __name__ == "__main__":
    ask_and_explain(
        "Which town had the cheapest 4-room flats in 2022?",
        "data/hdb_prices.db"
    )

SELECT town
FROM bto_prices
WHERE financial_year = '2022' AND room_type = '4-room'
ORDER BY min_selling_price
LIMIT 1;
Columns: ['town']
--------------------------------------------------
town           
---------------
Yishun         

=== Analyst Explanation ===
Here are the median selling prices for 3-room HDB flats in Woodlands for the financial year 2023:

*   **Minimum Selling Price:** $326,000
*   **Maximum Selling Price:** $431,000


In [None]:
import sqlite3
import pandas as pd

# Connect to your SQLite database
conn = sqlite3.connect("data/hdb_prices.db")

# Option 1: Get a quick snapshot of first 10 rows
df = pd.read_sql_query("SELECT * FROM bto_prices LIMIT 2;", conn)
print(df)

# Option 2: Get distinct values for overview
years = pd.read_sql_query("SELECT DISTINCT financial_year FROM bto_prices;", conn)
room_types = pd.read_sql_query("SELECT DISTINCT room_type FROM bto_prices;", conn)
towns = pd.read_sql_query("SELECT DISTINCT town FROM bto_prices;", conn)

print("Years:\n", years)
print("Room types:\n", room_types)
print("Towns:\n", towns)

conn.close()


   _id financial_year room_type     town  min_selling_price  \
0    1           2008    2-room  Punggol            82000.0   
1    2           2008    3-room  Punggol           135000.0   

   max_selling_price  min_selling_price_less_ahg_shg  \
0           107000.0                             0.0   
1           211000.0                             0.0   

   max_selling_price_less_ahg_shg  
0                             0.0  
1                             0.0  
Years:
    financial_year
0            2008
1            2009
2            2010
3            2011
4            2012
5            2013
6            2014
7            2015
8            2016
9            2017
10           2018
11           2019
12           2020
13           2021
14           2022
15           2023
Room types:
   room_type
0    2-room
1    3-room
2    4-room
3    5-room
4   2-room 
Towns:
              town
0         Punggol
1     Jurong West
2   Bukit Panjang
3       Woodlands
4        Sengkang
5   Choa Chu Kang
