In [1]:
import argparse
from openai import AzureOpenAI
import openai
import os

current_directory = os.path.dirname(os.path.abspath("__file__"))
os.chdir(current_directory)

In [2]:
client = AzureOpenAI(
    api_version="2023-05-15",
    azure_endpoint="https://..", # replace with azure endpoint
    api_key="...", # replace with your API key
)

In [None]:
# input file names: those are replaced each time based on assessment and SQL database type.
prompt_file="./../prompt.md" # define prompt file to use
metadata_file="./../metadata.sql" # define metadata file to use
output_queries = '' # spacify output file name for AI-generated SQL queries

# for example for postgreSQL first assessment: prompt_file ="./postgres_assessment/post_prompt_baseline.md", metadata_file="./postgres_assessment/post_metadata_baseline.sql"
# output file = "./postgres_assessment/postgres_1_queries.sql"

In [3]:
def generate_prompt(question, prompt_file, metadata_file):
    """
    Generates a customized prompt by replacing placeholders in a template with actual data.

    Args:
        question (str): The user's question to be included in the prompt.
        prompt_file (str): Path to a file containing a template prompt with placeholders.
        metadata_file (str): Path to a file containing metadata (e.g., database table information).

    Returns:
        str: The generated prompt with placeholders replaced by actual data.
    """
    with open(prompt_file, "r") as f:
        prompt = f.read()
    
    with open(metadata_file, "r") as f:
        table_metadata_string = f.read()
        

    prompt = prompt.format(
        user_question=question, table_metadata_string=table_metadata_string)
    
    return prompt

In [4]:
def SQL_query_generation(question, prompt_file, metadata_file):
    """
    Generates an SQL query response based on a user question and template prompts.

    Args:
        question (str): The user's question or query.
        prompt_file (str): Path to a file containing template prompts for OpenAI.
        metadata_file (str): Path to a file containing relevant metadata (e.g., table information).
        temp (float): Temperature parameter for controlling response randomness.

    Returns:
        str: The generated SQL query response.
    """
    prompt = generate_prompt(question, prompt_file , metadata_file)
    try:
        sys_prompt = prompt.split("### Input:")[0]
        user_prompt = prompt.split("### Input:")[1].split("### Response:")[0]
        assistant_prompt = prompt.split("### Response:")[1]
    except:
        raise ValueError("Invalid prompt file. Please use prompt_openai.md")
  
    messages = []
    messages.append({"role": "system", "content": sys_prompt})
    messages.append({"role": "user", "content": user_prompt})
    messages.append({"role": "assistant", "content": assistant_prompt})

    response = client.chat.completions.create(
        model="chat",
        messages=messages,
    )

    return response.model_dump()["choices"][0]["message"]["content"]

In [5]:
with open('../questions.txt', 'r') as f:
    questions = f.read().split('\n')

with open(output_queries, "w") as file:
    for i,question in enumerate(questions):
        file.write(f"--{i+1} {question}\n")
        SQLcode = SQL_query_generation(question, prompt_file , metadata_file).split("```")
        #SQLcode = value.split("```")
        SQLcode_to_save = SQLcode[0].strip()
        if SQLcode_to_save[-1] == ';':
            file.write(f"{SQLcode_to_save}\n\n")
        else:
            file.write(f"{SQLcode_to_save};\n\n")