In [1]:
from openai import OpenAI
from sqlglot import parse_one
from APIKEY import APIKEY
from systemidentity import identity
import sqlite3

In [2]:
client = OpenAI(api_key=APIKEY)

In [3]:
command = "Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity"

In [4]:
# db_name = input('Enter the db_name or path to the db')

### Step 1 : Extract column names from database

In [5]:
def get_column_mapping(db_file):
    conn = sqlite3.connect(db_file)
    cursor = conn.cursor()

    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()

    column_mapping = {}

    for table in tables:
        table_name = table[0]
        cursor.execute(f"PRAGMA table_info({table_name});")
        columns = cursor.fetchall()
        column_mapping[table_name] = [column[1] for column in columns]

    conn.close()
    return column_mapping

In [6]:
# get_column_mapping('movie_platform.sqlite')

### Step 2 : Generate the SQL Query using OpenAI

In [7]:
def text_tosql(command: str, column_mapping) -> str:
    schema_context = ""
    for table, columns in column_mapping.items():
        schema_context += f"Table: {table}, Columns: {', '.join(columns)}\n"
    
    prompt = f"""
    Based on the following database schema: {schema_context}
    Translate this into an SQL query: {command}
    """

    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {"role": "system", "content": identity },
            {"role": "user", "content": prompt}
        ]
    )
    sql_query = (response.choices[0].message.content).strip()
    return sql_query

In [8]:
# column_mapping = get_column_mapping(db_file='movie_platform.sqlite')  # Step 1
# sql_query = text_tosql(command, column_mapping)  # Step 2
# print(f"Generated Query: {sql_query}")

### Step 3 : Validate with sqlglot

In [9]:
def validate_sql(sql_query):
    try:
        parsed = parse_one(sql_query)
        return (f"VALID QUERY --> {parsed.sql()}"), parsed.sql()
    except Exception as e:
        return (f'Error  {str(e)}')

### Step 4 : Execute the generated query with the DB

In [10]:
def sqlite_execute(db_name, sql_query):
    conn=sqlite3.connect(db_name)
    cursor = conn.cursor()
    try:
        cursor.execute(sql_query)
        rows = cursor.fetchall()
        for row in rows:
            print(row)
    except Exception as e:
        print(f"Error executing query: {e}")
    return conn,cursor

In [12]:
column_mapping = get_column_mapping(db_file='movie_platform.sqlite')  # Step 1
sql_query = text_tosql(command, column_mapping)  # Step 2
validation_result, validsql = validate_sql(sql_query=sql_query)
print(validation_result)

if "VALID QUERY" in validation_result:
    sqlite_execute(db_name='movie_platform.sqlite', sql_query=validsql)

VALID QUERY --> SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC
('Brief Encounter',)
('Children of Paradise',)
('Rome, Open City',)
('Scarlet Street',)
('The Lost Weekend',)
('Spellbound',)
('Detour',)
('Mildred Pierce',)
('I Know Where I’m Going!',)
('Leave Her to Heaven',)
('Les dames du Bois de Boulogne',)
('Dead of Night',)
('The Southerner',)
('Le vampire',)
('A Tree Grows in Brooklyn',)
('The Clock',)
('The Picture of Dorian Gray',)
('The Body Snatcher',)
('They Were Expendable',)
("The Men Who Tread on the Tiger's Tail",)
('And Then There Were None',)
('Fallen Angel',)
('Isle of the Dead',)
('Objective, Burma!',)
('Odor-Able Kitty',)
("Brewster's Millions",)
('Nazi Concentration Camps',)
('Blithe Spirit',)
('Mouse in Manhattan',)
('A Diary for Timothy',)
('My Name Is Julia Ross',)
('Paris Frills',)
('Anchors Aweigh',)
('Hare Tonic',)
('Two People',)
('Sanshiro Sugata, Part Two',)
('The Seventh Veil',)
('Christmas in Connecticut',)
('