In [10]:
import json
import os
import sqlite3

from tqdm import tqdm
from openai import OpenAI
from dotenv import load_dotenv
from pathlib import Path

from utils.spider import load_tables

load_dotenv()

True

In [11]:
class Arguments:
    input = "datasets/spider_data/dev.json"
    output = "results/predicted.sql"
    dir = "datasets/spider_data"
    tables = "tables.json"
    db = "database"
    model = "gpt-4o-mini"
    temperature = 0.0
    max_tokens = 1000

args = Arguments()

In [17]:
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "<your OpenAI API key if not set as env var>"))

system_prompt = """You are an expert SQL assistant specialized in converting natural language queries into accurate SQL statements. You:
1. Understand database schemas and relationships
2. Generate standard SQL queries that follow best practices
3. Consider edge cases and data validation
4. Can handle complex joins, aggregations, and nested queries
When given a question, you will convert it to a valid SQL query based on the provided database schema.
Only output the raw SQL query without any markdown formatting, code blocks, or additional text. 
For example, output should look like: SELECT * FROM table;"""

# Load questions from JSON file
print ("Reading questions from ", args.input)
with open(args.input, 'r') as f:
    examples = json.load(f)
print ("Total number of questions: ", len(examples))

# load schemas for all the DBs
print ("Loading schemas...")
schemas, _ = load_tables([os.path.join(args.dir, args.tables)])

# Backup in-memory copies of all the DBs and create the live connections
print ("Loading DB connections...")
for db_id, schema in tqdm(schemas.items(), desc="DB connections"):
    sqlite_path = Path(args.dir) / args.db / db_id / f"{db_id}.sqlite"
    source: sqlite3.Connection
    with sqlite3.connect(str(sqlite_path)) as source:
        dest = sqlite3.connect(':memory:')
        dest.row_factory = sqlite3.Row
        source.backup(dest)
    schema.connection = dest

# Get all the CREATE statements for all the DBs
db_schemas = {}
for db_id, _ in schemas.items():
    connection = schemas[db_id].connection
    cursor = connection.cursor()
    # Query sqlite_master table to get all CREATE statements
    cursor.execute("""
        SELECT sql 
        FROM sqlite_master 
        WHERE type='table' AND sql IS NOT NULL
    """)
    # Convert list of tuples to a single string of CREATE statements
    create_statements = '\n'.join(row[0] for row in cursor.fetchall())
    db_schemas[db_id] = create_statements

Reading questions from  datasets/spider_data/dev.json
Total number of questions:  1034
Loading schemas...
Loading DB connections...


DB connections: 100%|██████████| 166/166 [00:01<00:00, 142.31it/s]


In [13]:
print (db_schemas["academic"])

CREATE TABLE "author" (
"aid" int,
"homepage" text,
"name" text,
"oid" int,
primary key("aid")
)
CREATE TABLE "conference" (
"cid" int,
"homepage" text,
"name" text,
primary key ("cid")
)
CREATE TABLE "domain" (
"did" int,
"name" text,
primary key ("did")
)
CREATE TABLE "domain_author" (
"aid" int, 
"did" int,
primary key ("did", "aid"),
foreign key("aid") references `author`("aid"),
foreign key("did") references `domain`("did")
)
CREATE TABLE "domain_conference" (
"cid" int,
"did" int,
primary key ("did", "cid"),
foreign key("cid") references `conference`("cid"),
foreign key("did") references `domain`("did")
)
CREATE TABLE "journal" (
"homepage" text,
"jid" int,
"name" text,
primary key("jid")
)
CREATE TABLE "domain_journal" (
"did" int,
"jid" int,
primary key ("did", "jid"),
foreign key("jid") references "journal"("jid"),
foreign key("did") references "domain"("did")
)
CREATE TABLE "keyword" (
"keyword" text,
"kid" int,
primary key("kid")
)
CREATE TABLE "domain_keyword" (
"did" int,


In [14]:
example = examples[0]
question = example["question"]
db_id = example["db_id"]
code_representation = db_schemas[db_id]

user_prompt = f"""### Database Schema:
{code_representation}
### Question:
{question}
### SQL Query:
"""

print (user_prompt)

### Database Schema:
CREATE TABLE "stadium" (
"Stadium_ID" int,
"Location" text,
"Name" text,
"Capacity" int,
"Highest" int,
"Lowest" int,
"Average" int,
PRIMARY KEY ("Stadium_ID")
)
CREATE TABLE "singer" (
"Singer_ID" int,
"Name" text,
"Country" text,
"Song_Name" text,
"Song_release_year" text,
"Age" int,
"Is_male" bool,
PRIMARY KEY ("Singer_ID")
)
CREATE TABLE "concert" (
"concert_ID" int,
"concert_Name" text,
"Theme" text,
"Stadium_ID" text,
"Year" text,
PRIMARY KEY ("concert_ID"),
FOREIGN KEY ("Stadium_ID") REFERENCES "stadium"("Stadium_ID")
)
CREATE TABLE "singer_in_concert" (
"concert_ID" int,
"Singer_ID" text,
PRIMARY KEY ("concert_ID","Singer_ID"),
FOREIGN KEY ("concert_ID") REFERENCES "concert"("concert_ID"),
FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID")
)
### Question:
How many singers do we have?
### SQL Query:



In [18]:
response = client.chat.completions.create(
    model=args.model,
    messages=[
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ],
    temperature=args.temperature,
    max_tokens=args.max_tokens,
)

predicted_sql = response.choices[0].message.content
print (predicted_sql)

SELECT COUNT(*) FROM singer;
