In [6]:
import argparse
from openai import AzureOpenAI
import openai
import time
from func_timeout import func_timeout, FunctionTimedOut
import glob
import psycopg2
import os
import subprocess

current_directory = os.path.dirname(os.path.abspath("__file__"))
os.chdir(current_directory)

In [7]:
client = AzureOpenAI(
    api_version="2023-05-15",
    azure_endpoint=" ", # replace with azure_endpoint
    api_key=" ", # replace with API key
)

In [8]:
def generate_prompt(question, prompt_file, metadata_file):
    with open(prompt_file, "r") as f:
        prompt = f.read()
    
    with open(metadata_file, "r") as f:
        table_metadata_string = f.read()
        

    prompt = prompt.format(
        user_question=question, table_metadata_string=table_metadata_string)
    
    return prompt

def SQL_query_generation(question, prompt_file, metadata_file):
    
    prompt = generate_prompt(question, prompt_file , metadata_file)
    try:
        sys_prompt = prompt.split("### Input:")[0]
        user_prompt = prompt.split("### Input:")[1].split("### Response:")[0]
        assistant_prompt = prompt.split("### Response:")[1]
    except:
        raise ValueError("Invalid prompt file. Please use prompt_openai.md")
  
    messages = []
    messages.append({"role": "system", "content": sys_prompt})
    messages.append({"role": "user", "content": user_prompt})
    messages.append({"role": "assistant", "content": assistant_prompt})

    response = client.chat.completions.create(
        model="chat",
        messages=messages,
        temperature = 0
    )

    return response.model_dump()["choices"][0]["message"]["content"]

# functions for database query submission

def krbauth(username: str, password: str) -> bool:
    """Perform Kerberos authentication.

    Args:
    username: username for the account having access to auth
    password: password for user
    Returns:
    success: the exit code of subprocess.run() flipped, (because a successful run exits with 0)
    and produces a non-zero output if something is wrong.
    """
    cmd = ["kinit", username]
    try:
        command = subprocess.run(
            cmd, input=password.encode(), check=True, capture_output=True
        )
        success = command.returncode
    except subprocess.CalledProcessError as e:
        print(f"The error output for krbauth: {e.output}")
        print(f"The error code: {e.returncode}")
        print(f"The stderr: {e.stderr}")
        print(f"The stdout: {e.stdout}")
        raise (subprocess.CalledProcessError(success, command))
    # we flip the bool value of success, because we are returning the exit code of subprocess.run()
    # if it is 0, then it was a successful run, meaning we need to flip it, to make it truthy
    return not bool(success)


def read_sql_commands_from_file(file_path):
    """
    Reads SQL commands from a file and returns them as a list of lists (query, question).
    Args:
        file_path (str): Path to the SQL file.

    Returns:
        list: List of lists containing (query, question) pairs.
    """
    query_question_pairs = []
    try:
        with open(file_path, 'r') as sql_file:
            content = sql_file.read()  # Read the entire file content
            # Split the content by double newline (question separator)
            sections = content.split('\n\n')
            for section in sections:
                lines = section.strip().split('\n')
                if len(lines) >= 2:
                    question = lines[0].strip('--').strip()
                    query = ' '.join(lines[1:]).strip()
                    query_question_pairs.append([question, query])
    except FileNotFoundError:
        print(f"File '{file_path}' not found.")
    return query_question_pairs



def execute_query_and_measure_time(query_and_question_pair, output_file_path):
    hostname = '' # replace with hostname
    database = 'datalake' # replace with database name
    username = '' # replace with username
    port_id = 5434
    conn = None
    cur = None
    execution_time = 0.0  # Initialize execution_time outside the try block

    try:
        conn = psycopg2.connect(host=hostname, dbname=database, user=username, port=port_id)
        cur = conn.cursor()
        question, query = query_and_question_pair
        start_time = time.time()
        cur.execute(query)
        result = cur.fetchone()
        end_time = time.time()
        execution_time = end_time - start_time

        with open(output_file_path, 'a') as output_file:
            output_file.write(f"{question}\t{query}\t{result}\t{execution_time:.4f}\n")

    except psycopg2.Error as e:
        with open(output_file_path, 'a') as output_file:
            error_message = str(e).split("\n")[0]
            output_file.write(f"{question}\t{query}\tError executing query: {str(error_message)}\t{execution_time}\n")

    except Exception as error:
        with open(output_file_path, 'a') as output_file:
            output_file.write(f"{question}\t{query}\tError executing query: {error}\t{execution_time}\n")

    finally:
        if cur is not None:
            cur.close()
        if conn is not None:
            conn.close()


In [15]:
#define input and output files of the pipeline

questions_path = '../GPT_pipeline/postgreSQL_database_assessment/questions.txt'
prompt_file_path = "../GPT_pipeline/postgreSQL_database_assessment/1.baseline_first_attempt/post_prompt_baseline.md"
metadata_file_path = "../GPT_pipeline/postgreSQL_database_assessment/1.baseline_first_attempt/post_metadata_baseline.sql"
AIgenerated_SQL_queries_path = "./postgres_1_queries.sql"
output_results_path = './postgres_1_results.txt'

In [12]:
with open(questions_path, 'r') as f:
    questions = f.read().split('\n')

with open(AIgenerated_SQL_queries_path, "w") as file:
    for i,question in enumerate(questions):
        file.write(f"--{i+1} {question}\n")
        SQLcode = SQL_query_generation(question, prompt_file = prompt_file_path , metadata_file = metadata_file_path).split("```")
        SQLcode_to_save = SQLcode[0].strip()
        if SQLcode_to_save[-1] == ';':
            file.write(f"{SQLcode_to_save}\n\n")
        else:
            file.write(f"{SQLcode_to_save};\n\n")
            
queries_to_execute = read_sql_commands_from_file(AIgenerated_SQL_queries_path)
counter = 0
for query_and_question_pair in queries_to_execute:
    question, query = query_and_question_pair
    counter +=1
    try:
        doitReturnValue = func_timeout(4000, execute_query_and_measure_time, args=(query_and_question_pair, output_results_path))
        print(f"{counter}-Search completed successfully")
    except FunctionTimedOut:
        with open(output_results_path , 'a') as output_file:
            output_file.write(f"{question}\tError executing query: Search could not complete within 4000 seconds and was terminated.\n")
    except Exception as e:
        # Handle any exceptions that function might raise here
        print(f"An error occurred: {e}")

1-Search completed successfully
2-Search completed successfully
3-Search completed successfully
4-Search completed successfully
5-Search completed successfully
6-Search completed successfully
7-Search completed successfully
8-Search completed successfully
9-Search completed successfully
10-Search completed successfully
11-Search completed successfully
12-Search completed successfully
13-Search completed successfully
14-Search completed successfully
15-Search completed successfully
16-Search completed successfully
17-Search completed successfully
18-Search completed successfully
19-Search completed successfully
20-Search completed successfully


In [16]:
results = []

with open(output_results_path, "r") as file:
    for line in file:
        # Split the line by tab
        elements = line.strip().split("\t")
        if len(elements) >= 3:
            # Append the third element to the list
            results.append(elements[2])

In [17]:
correct_answers = ['(5.146803216038514,)',
 '(4.749735057239966,)',
 '(171,)',
 '(14389,)',
 '(96, 3)',
 "(Decimal('75.7743072598277632'),)",
 '(40.95298165067672,)',
 '(0.0, 0.69456)',
 "('Bacteroides vulgatus', 63250.523)",
 '(0.10249751144261383,)',
 "('Escherichia coli', 37212.05)",
 '(21.65521388303867,)',
 '(1023,)',
 '(1,)',
 "('Firmicutes', 535922.44)",
 '(0.11537600260030743,)',
 '(8022, 54)',
 '(2967,)',
 '(0.005528159512230005,)',
 "('Bacteroidaceae', 282596.5)"]

def compare_results(results_list, expected_results):
    # Initialize an empty list to store the comparison results
    comparison_results = []

    # Iterate over each result in the results_list
    for result, expected in zip(results_list, expected_results):
        # Compare the result with the expected result
        comparison_results.append(result == expected)

    return comparison_results


# Call the function to compare the results
comparison_results = compare_results(results, correct_answers)

count = 0
for i,el in enumerate(comparison_results):
    if el == False:
        count += 1
        print(i+1,'  ',el, results[i],'\n')
print(count,'/20 are incorrect!')

3    False Error executing query: column reference "species" is ambiguous 

6    False (Decimal('2055.0000000000000000'), 0.04859567151210226) 

7    False (0.057196901764991016,) 

10    False (None,) 

11    False None 

12    False (None,) 

13    False (0,) 

15    False ('Firmicutes', 535920.9) 

19    False (None,) 

9 /20 are incorrect!
