In [1]:
from openai import OpenAI
from sqlglot import parse_one

from APIKEY import APIKEY
from systemidentity import identity

import sqlite3
import csv

In [2]:
client = OpenAI(api_key=APIKEY)

In [3]:
db_names = ['sqlite_db/books.sqlite',
            'sqlite_db/disney.sqlite',
            'sqlite_db/genes.sqlite',
            'sqlite_db/movie_platform.sqlite',
            'sqlite_db/shipping.sqlite']

### Step 1 : Extract column names from database

In [4]:
def get_column_mapping(db_file):
    # print(f"Executing get_column_mapping function for file {db_file} ")
    conn = sqlite3.connect(db_file)
    cursor = conn.cursor()

    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()

    column_mapping = {}

    for table in tables:
        table_name = table[0]
        cursor.execute(f'PRAGMA table_info("{table_name}");')
        columns = cursor.fetchall()
        column_mapping[table_name] = [column[1] for column in columns]

    conn.close()
    return column_mapping


### Step 2 : Generate the SQL Query using OpenAI

In [5]:
def text_tosql(question: str, column_mapping) -> str:

    # print(f"Executing text_tosql function for {question} ")

    schema_context = ""
    for table, columns in column_mapping.items():
        schema_context += f"Table: {table}\nColumns: {', '.join(columns)}\n"

    prompt = f"""
    Based on the following database schema: {schema_context}
    Translate this into an SQL query: {question}
    """

    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": identity },
            {"role": "user", "content": prompt}
        ]
    )
    sql_query = (response.choices[0].message.content).strip()
    return sql_query

### Step 3 : Validate with sqlglot

In [6]:
def validate_sql(sql_query):
    try:
        parsed = parse_one(sql_query)
        return (f"VALID QUERY --> {parsed.sql()}"), parsed.sql()
    except Exception as e:
        return (f'Error  {str(e)}'), None

Iteration Operations

In [7]:
# Function to extract db_id from the file path
def get_db_id(file_path):
    return file_path.split('/')[-1].split('.')[0]

In [8]:
## Function to add and update the LLM_SQL column in the CSV

def update_csv_with_llm_sql(output_file, question, generated_sql):
    updated_rows = []

    with open(output_file, mode='r', newline='') as file:
        reader = csv.DictReader(file)
        fieldnames = reader.fieldnames
        
        if 'LLM_SQL' not in fieldnames:
            fieldnames.append('LLM_SQL')
        
        for row in reader:
            if row['question'].strip() == question.strip():
                row['LLM_SQL'] = generated_sql
            updated_rows.append(row)

    with open(output_file, mode='w', newline='') as file:
        writer = csv.DictWriter(file, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(updated_rows)

In [9]:
# Read the CSV data and create a dictionary of questions for each db_id
questions_by_db = {}
file_csv = 'output_30.csv' #30 random questions for testing

with open(file_csv, 'r') as csvfile:
    reader = csv.DictReader(csvfile)
    for row in reader:
        db_id = row['db_id']
        question = row['question']
        if db_id not in questions_by_db:
            questions_by_db[db_id] = []
        questions_by_db[db_id].append(question)

In [10]:
# Iterate over the database files
for dbfile in db_names:
    print(f"For {dbfile}")
    column_mapping = get_column_mapping(db_file=dbfile) #Step 1
    db_id = get_db_id(dbfile)

    if db_id in questions_by_db:
        print(f"For {db_id}")
        questions = questions_by_db[db_id]
        for question in questions:
            print(f"For {question}")
            sql_query = text_tosql(question, column_mapping)  # Step 2
            validation_result, validsql = validate_sql(sql_query=sql_query)
            if "VALID QUERY" in validation_result:
                update_csv_with_llm_sql(file_csv, question, validsql)

For sqlite_db/books.sqlite
For books
For Among the books published by publisher ID 1929, how many of them have over 500 pages?
For Which customer has made the most orders? Show his/her full name.
For What is the average number of pages of David Coward's books?
For Among Daisey Lamball's orders, how many were shipped via International shipping?
For What is the full name of the customer who ordered the most books of all time?
For In books published by Ace Book, what is the percentage of English books published?
For sqlite_db/disney.sqlite
For disney
For How much more total box office gross did the Walt Disney Company have in revenue in 1998 than in 1997?
For Which song is associated with the most popular Disney movie in 1970s?
For Name the first movie released by Disney.
For How many movies were released by Disney between 2010 and 2016?
For Name the top 5 highest-grossing Disney movies adjusted for inflation. Identify the percentage they contributed to the total of Disney's current gross