OpenAI assistant with access to SQL database

In [8]:
def read_SQL(file:str) -> list:
    """ read SQL code from file.
    Usage: read_SQL("filename").
    It returns list of transactions.
    """
    with open(file, "r") as f:
        text = f.read()
    return text.split(sep=";")

In [9]:
"""connect or create SQLite db then create a cursor object"""
import sqlite3
import os
conn = sqlite3.connect('Products.db')
cursor = conn.cursor()

# Create tables

for item in read_SQL("create_tables.SQL"):
    if len(item) > 1:
        cursor.execute(item)

# insert test data
cursor.execute("INSERT INTO products VALUES ('1','AWG812','1','Basic range washing machine with 8kg payload', '14', '777')")
cursor.execute("INSERT INTO products VALUES ('2','AWH914','1','Top range washing machine with 9kg payload', '12', '999')")
cursor.execute("INSERT INTO products VALUES ('3','AWZ8CD','2','Condensing dryer with 8kg payload', '8', '888')")
cursor.execute("INSERT INTO products VALUES ('4','AWZ8CD','2','Heat pump dryer with 9kg payload', '8', '977')")
cursor.execute("INSERT INTO products VALUES ('5','AWZ8CD','3','Steam ironer 2500 W', '8', '222')")
cursor.execute("INSERT INTO products VALUES ('6','AWZ8CD','3','Flatwork ironer 30kg/h', '8', '2599')")

cursor.execute("INSERT INTO product_groups VALUES ('1','W/M','Washing machines')")
cursor.execute("INSERT INTO product_groups VALUES ('2','DRY','Dryers')")
cursor.execute("INSERT INTO product_groups VALUES ('3','IRN','Ironers')")

# commit and close connection
conn.commit()
conn.close()

In [10]:
# check seeded data, read DB tables
import sqlite3
conn = sqlite3.connect('products.db')
cursor = conn.execute("SELECT * FROM products")
results = cursor.fetchall()
for item in results:
    print(item)

cursor = conn.execute("SELECT * FROM product_groups")
results = cursor.fetchall()
for item in results:
    print(item)
conn.commit()
conn.close()

(1, 'AWG812', 1, 'Basic range washing machine with 8kg payload', 14, 777.0)
(2, 'AWH914', 1, 'Top range washing machine with 9kg payload', 12, 999.0)
(3, 'AWZ8CD', 2, 'Condensing dryer with 8kg payload', 8, 888.0)
(4, 'AWZ8CD', 2, 'Heat pump dryer with 9kg payload', 8, 977.0)
(5, 'AWZ8CD', 3, 'Steam ironer 2500 W', 8, 222.0)
(6, 'AWZ8CD', 3, 'Flatwork ironer 30kg/h', 8, 2599.0)
(1, 'W/M', 'Washing machines')
(2, 'DRY', 'Dryers')
(3, 'IRN', 'Ironers')


In [15]:
user_query = "What is the model of the cheapest product from washing machines? "

In [16]:
"""
It will uses Azure AI, so do not forget implement Azure OpenAI deployment first
 and create .env file with credentials. .env file format:
AZURE_OPENAI_API_KEY="......"
AZURE_OPENAI_ENDPOINT="......."
AZURE_OPENAI_API_VERSION="......"
AZURE_OPENAI_DEPLOYMENT_NAME="........"
"""

import os
from dotenv import load_dotenv
from openai import AzureOpenAI

# Load credentials as environment variables from .env file
load_dotenv()

# Preparing prompt. Show to AI database tables structure from create_tables.SQL file
system_msg = "Given the following SQL tables, your job is to write queries given a user’s request. Return SQL query only. No other words. No explanations.\n"
for item in read_SQL("create_tables.SQL"):
    if len(item) > 1:
        system_msg = system_msg + "\n" + item

# Connect to AzureOpenAI
client = AzureOpenAI(
    api_key=os.getenv("AZURE_OPENAI_API_KEY"),  
    api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
    azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT"),
    )
deployment_name=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME")

# Send a completion call to generate an answer
response = client.chat.completions.create(
    model=deployment_name,
    #prompt=(prompt_pre + db_description + prompt_post + question),
    messages=[
        {
            "role": "system",
            "content":system_msg
        },
        {
            "role": "user",
            "content": user_query
        }
    ],
    max_tokens=64,
    temperature = 0)
generated_sql_query = str(response.choices[0].message.content).strip('`')
# Print generated SQL query
print(generated_sql_query)


SELECT model_code FROM products 
WHERE product_group_id = (SELECT id FROM product_groups WHERE group_code = 'W/M') 
ORDER BY price ASC LIMIT 1;


In [17]:
# Query database with generated SQL string and print the answer
print(f"Your query: {user_query}")
print("Answer:")
conn = sqlite3.connect('products.db')
cursor = conn.execute(generated_sql_query)
results = cursor.fetchall()
for item in results:
    print(item)
conn.commit()
conn.close()

Your query: What is the model of the cheapest product from washing machines? 
Answer:
('AWG812',)
