In [None]:
import json
import os
import sqlite3

from tqdm import tqdm
from openai import OpenAI
from dotenv import load_dotenv
from pathlib import Path

from utils.spider import load_tables

load_dotenv()

In [4]:
class Arguments:
    input = "datasets/spider_data/dev.json"
    output = "results/predicted.sql"
    dir = "datasets/spider_data"
    tables = "tables.json"
    db = "database"
    model = "gpt-4o"
    temperature = 0.0
    max_tokens = 1000

args = Arguments()

In [None]:
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "<your OpenAI API key if not set as env var>"))

system_prompt = """You are an expert SQL assistant specialized in converting natural language queries into accurate SQL statements.
When given a question, you will convert it to a valid SQL query based on the provided database schema.
Only output the raw SQL query in one line without any markdown formatting, code blocks, or additional text."""

# Load questions from JSON file
print ("Reading questions from ", args.input)
with open(args.input, 'r') as f:
    examples = json.load(f)
print ("Total number of questions: ", len(examples))

# load schemas for all the DBs
print ("Loading schemas...")
schemas, _ = load_tables([os.path.join(args.dir, args.tables)])

# Backup in-memory copies of all the DBs and create the live connections
print ("Loading DB connections...")
for db_id, schema in tqdm(schemas.items(), desc="DB connections"):
    sqlite_path = Path(args.dir) / args.db / db_id / f"{db_id}.sqlite"
    source: sqlite3.Connection
    with sqlite3.connect(str(sqlite_path)) as source:
        dest = sqlite3.connect(':memory:')
        dest.row_factory = sqlite3.Row
        source.backup(dest)
    schema.connection = dest

# Get all the CREATE statements for all the DBs
db_schemas = {}
for db_id, _ in schemas.items():
    connection = schemas[db_id].connection
    cursor = connection.cursor()
    # Query sqlite_master table to get all CREATE statements
    cursor.execute("""
        SELECT sql 
        FROM sqlite_master 
        WHERE type='table' AND sql IS NOT NULL
    """)
    # Convert list of tuples to a single string of CREATE statements
    create_statements = '\n'.join(row[0] for row in cursor.fetchall())
    db_schemas[db_id] = create_statements

In [None]:
print (db_schemas["academic"])

In [None]:
example = examples[0]
question = example["question"]
db_id = example["db_id"]
code_representation = db_schemas[db_id]

user_prompt = f"""### Database Schema:
{code_representation}
### Question:
{question}
### SQL Query:
"""

print (user_prompt)

In [None]:
response = client.chat.completions.create(
    model=args.model,
    messages=[
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ],
    temperature=args.temperature,
    max_tokens=args.max_tokens,
)

predicted_sql = response.choices[0].message.content
print (predicted_sql)