In [14]:

import openai
from openai import OpenAI
import os
from dotenv import load_dotenv
import json
import google.generativeai as genai
import os
import time
import sqlite3
import logging
import mysql.connector
from collections import Counter
import math
from sentence_transformers import SentenceTransformer, util

In [None]:

load_dotenv('.env')
api_key_gpt= os.getenv('CHATGPT_API_KEY')
client = OpenAI(
  api_key=api_key_gpt,
)

def create_batch_file(prompts, max_tokens=50):
    custom_requests = []
    for i, prompt in enumerate(prompts):
        custom_request = {
            "custom_id": f"request-{i+1}",
            "method": "POST",
            "url": "/v1/chat/completions",
            "body": {
                "model": "gpt-4o-mini",
                "messages": [
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": prompt}
                ],
                "max_tokens": max_tokens
            }
        }
        custom_requests.append(custom_request)

    input_file = "input_prompt_file.jsonl"

    with open(input_file, "w") as file:
        for item in custom_requests:
            file.write(json.dumps(item))
            file.write("\n")
    print("Batch input file created successfully.")

def process_batch_file(input_file):
    input_file_id = client.files.create(
        file=open(input_file, "rb"),
        purpose="batch"
    ).id

    batch = client.batches.create(
        input_file_id=input_file_id,
        endpoint="/v1/chat/completions",
        completion_window="24h",
        metadata={
            "description": "nightly eval job"
        }
    )

    batch_id = batch.id  # Capture the batch ID

    try:
        while True:
            batch_status = client.batches.retrieve(batch_id)
            print(f"Batch file processing status: {batch_status.status}")
            
            if batch_status.status in ["completed", "failed", "expired"]:
                break
            time.sleep(10)  # Wait for 10 seconds before checking again

        if batch_status.status == "completed":
            print("Batch file processing completed.")
            
            # Retrieve the file content
            file_response = client.files.content(batch_status.result_file_id)
            file_content = file_response.read().decode('utf-8')
            return file_content
        else:
            print(f"Batch file processing ended with status: {batch_status.status}")
            return None
    except openai.error.OpenAIError as e:
        print(f"An error occurred while retrieving the batch status: {e}")
        return None

# Example usage
prompts = ["What is the capital of France?", "What is the capital of Italy?"]
create_batch_file(prompts)
response_content = process_batch_file("input_prompt_file.jsonl")
if response_content:
    responses = response_content.splitlines()
    for response in responses:
        response_json = json.loads(response)
        print(json.dumps(response_json, indent=2))

In [2]:
#setting up LLMs

load_dotenv()

#gpt
api_key_gpt=os.getenv('CHATGPT_API_KEY')
client = OpenAI(
  api_key=api_key_gpt,
)

#gemini
genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
modelGemini=genai.GenerativeModel('gemini-1.5-pro-latest')

#llama
api_key_llama=os.getenv("LLAMA_API_KEY")
client = OpenAI(
    api_key = api_key_llama,
    base_url = "https://api.llama-api.com"
)

In [3]:
def printResponse(prompt,llm):
    # print(f"User:\n{prompt}\n\n")
    if llm=='gpt':
        chat_completion = client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {
                    "role": "user",
                    "content": prompt,
                }
            ]            
        )
        response=chat_completion.choices[0].message.content
        # print(f"LLM:\n{response}\n\n")
        return response
    elif llm=='gemini':
        time.sleep(18)      #gemini response limit mitigation,
        response = modelGemini.generate_content(prompt)
        # print(f"LLM:\n{response.text}\n\n")
        return response.text
    elif llm=='llama':
        response = client.chat.completions.create(
            model="llama-13b-chat",
            messages=[
                {"role": "system", "content": "Assistant is a large language model trained by OpenAI."},
                {"role": "user", "content": prompt}
            ]
        )
        response=response.choices[0].message.content
        # print(f"LLM:\n{response}\n\n")
        return response

In [4]:

load_dotenv()

#gpt
api_key_gpt=os.getenv('CHATGPT_API_KEY')
client = OpenAI(
  api_key=api_key_gpt,
)

def printResponse(prompt,llm):
    # print(f"User:\n{prompt}\n\n")
    if llm=='gpt':
        chat_completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {
                    "role": "user",
                    "content": prompt,
                }
            ]            
        )
        response=chat_completion.choices[0].message.content
        # print(f"LLM:\n{response}\n\n")
        return response
    elif llm=='gemini':
        time.sleep(18)      #gemini response limit mitigation,
        response = modelGemini.generate_content(prompt)
        # print(f"LLM:\n{response.text}\n\n")
        return response.text
    elif llm=='llama':
        response = client.chat.completions.create(
            model="llama-13b-chat",
            messages=[
                {"role": "system", "content": "Assistant is a large language model trained by OpenAI."},
                {"role": "user", "content": prompt}
            ]
        )
        response=response.choices[0].message.content
        # print(f"LLM:\n{response}\n\n")
        return response
    
def batch_response(batch_input_file):
    batch_input_file_id = batch_input_file.id

    client.batches.create(
        input_file_id=batch_input_file_id,
        endpoint="/v1/chat/completions",
        completion_window="24h",
        metadata={
        "description": "nightly eval job"
        }
    )

    batch_status = client.batches.retrieve("batch_abc123")

    # Check if the batch is completed
    if batch_status['status'] == 'completed':
        # Retrieve the file content
        file_response = client.files.content("file-xyz123")
        print(file_response.text)
    else:
        print("Batch is not yet completed. Please try again later.")

    file_response = client.files.content("file-xyz123")
    print(file_response.text)

    



In [5]:
def getDbSchemaMapping(dbFolderPath):
    count=0
    schema_array = {}
    for folder in os.listdir(dbFolderPath):
        count+=1
        folder_path = os.path.join('./spider/database', folder)
        if os.path.exists(folder_path):
            flag=0
            for file_name in os.listdir(folder_path):
                if file_name.endswith('.json'):
                    flag=1
                    json_file_path = os.path.join(folder_path, file_name)
                    with open(json_file_path, 'r', encoding='utf-8') as file:
                        schema_array[folder] = json.load(file)
            if flag==0:
                print(folder_path)
    print(count)
    return schema_array

In [6]:
nqlex_data_path="selectedTables/selected_tables.json"
with open(nqlex_data_path,"r") as f:
        nlqex_data=json.load(f)

In [None]:
nlqex_data['activity_1']['0']

In [8]:
def get_data(tableJsonPath,cleanDataPath):

    with open(tableJsonPath,"r") as f:
        tables_data=json.load(f)

    database = {}

    
    with open(cleanDataPath, 'r') as f:

        index=0
        for line in f:
            data = json.loads(line)
            db_id = data.get('db_id')
            query = data.get('query')
            question = data.get('question')
            query_toks=data.get('query_toks')

            if db_id not in nlqex_data: continue

            word_freq={}        
            for item in tables_data:
                if item['db_id']==db_id:
                    for table_name in item["table_names_original"]:
                        word_freq[table_name.lower()]=1

            if db_id not in database:
                index=0
                database[db_id] = {'query': [], 'question': [], 'query_toks': [], 'tables': [], 'nlqex_tables': []}

            interim_map={}
            table_list=[]

            for query_tok in query_toks:
                if query_tok.lower() in word_freq:
                    if query_tok.lower() not in interim_map:
                        interim_map[query_tok.lower()]=1
                        table_list.append(query_tok.lower())
            
            database[db_id]['query'].append(query)
            database[db_id]['question'].append(question)
            database[db_id]['query_toks'].append(query_toks)
            database[db_id]['tables'].append(table_list)
            print(db_id,index)
            database[db_id]['nlqex_tables'].append(nlqex_data[db_id][str(index)])
            index+=1
            

    return database

In [None]:
schema_array = getDbSchemaMapping('spider/databasee')
databases = get_data('spider/tables.json','spider/train_spider_main_data.json')

In [None]:
print(len(databases['activity_1']['tables'][0]))

In [None]:
print(len(databases))
print(len(nlqex_data))
print(len(schema_array))
# databases

In [25]:
# Database folder path
db_folder_path = 'spider/database'

def convert_to_list(results):
    final_list = []
    if not results: return final_list
    for row in results:
        final_list.append(tuple(sorted(list(row), key=str)))
    return final_list

def execute_query(database_path, query):
    try:
        conn = sqlite3.connect(database_path)
        cursor = conn.cursor()
        cursor.execute(query)
        results = cursor.fetchall()
        conn.close()
        return results
    except sqlite3.OperationalError as e:
        print(f"An error occurred: {e}")
        with open('./incorrectGeminiLog.txt', 'a') as f:
            f.write(f"\nError that occured:\n{e}\n\n")
        return None

def compare_queries(db_name, db_folder_path, generated_query, original_query):
    db_path = os.path.join(db_folder_path, db_name, db_name + '.sqlite')
    if not os.path.exists(db_path):
        print(f"Database {db_name} not found.")
        return False

    generated_results = execute_query(db_path, generated_query)
    original_results = execute_query(db_path, original_query)

    # print(generated_results)
    # print(original_results)

    gen_list = convert_to_list(generated_results)
    orig_list = convert_to_list(original_results)

    print(gen_list)
    print(orig_list)

    return Counter(gen_list) == Counter(orig_list)   

def extract_sql_query(response):
    response=response[::-1]
    start = response.find("```")
    end = response.find("etilqs```", start + 1)
    flag=0
    if (end==-1):
        flag=1
        end = response.find("lqs```")
    

    if start == -1 or end == -1:
        return ""
    sql_query=""
    if flag==1: sql_query = response[start + 3:end].strip()  # Extract the query between ```sql and ```
    else: sql_query = response[start + 3:end].strip()
    return sql_query[::-1]

In [19]:
def cosine_similarity(v1, v2):
    numerator=0
    for i in range(len(v1)):
        numerator+=(v1[i]*v2[i])
    if sum(v1)==0 or sum(v2)==0:
        return -1
    numerator/=math.sqrt(sum(v1))
    numerator/=math.sqrt(sum(v2))
    return numerator

In [None]:
db_info=[]

for dbName, dbSchema in schema_array.items():

    if dbName not in databases: continue
    if 'tables' not in dbSchema: continue
    num_questions=len(databases[dbName]['question'])
    num_tables=len(dbSchema['tables'])
    if (num_tables==0): continue
    if(num_questions==0): continue
    info=[num_tables,num_questions]
    db_info.append(info)

    #filtering out heavy databases
    # if num_tables<8 or num_questions<25: continue

sorted_list=sorted(db_info, key=lambda x: x[0])
for i in sorted_list: print(i)

In [31]:
#initialise cross prompt variiables (DANGER: DO NOT PRESS)
totalQueries=0
correctAns=0
notCorrectAns=0
accuracy_info={}
file_path = 'logs/incorrectGeminiLog.txt'
with open(file_path, 'w') as file:
    pass
file_path = 'logs/geminiLog.txt'
with open(file_path, 'w') as file:
    pass

for dbName, dbSchema in schema_array.items():

    if dbName not in databases: continue
    if 'tables' not in dbSchema: continue
    num_questions=len(databases[dbName]['question'])
    num_tables=len(dbSchema['tables'])
    if num_questions==0: continue
    if num_tables==0: continue

    db_folder_path='spider/database'

    errorLog=os.path.join(db_folder_path, dbName,'errors.txt')
    
    with open(errorLog, 'w') as f:
        pass

In [None]:
# print(accuracy_info)
totIncor=0
totLight=0
totHeavy=0
lightAcc=[]
heavyAcc=[]
avgl=0
avgh=0
for dbName, dbSchema in schema_array.items():
    if dbName not in databases: continue
    if 'tables' not in dbSchema: continue
    num_questions=len(databases[dbName]['question'])
    num_tables=len(dbSchema['tables'])
    if num_questions==0: continue
    if num_tables==0: continue

    if dbName in accuracy_info:
        print(f"dbName: {dbName}, Tables: {num_tables}, Correct: {accuracy_info[dbName][0]}, Incorrect: {accuracy_info[dbName][1]}")
        totIncor+=accuracy_info[dbName][1]
        if (num_tables<=4):
            totLight+=1
            lightAcc.append(accuracy_info[dbName][0]/(accuracy_info[dbName][0]+accuracy_info[dbName][1]))
            avgl+=accuracy_info[dbName][0]/(accuracy_info[dbName][0]+accuracy_info[dbName][1])
        else:
            totHeavy+=1
            heavyAcc.append(accuracy_info[dbName][0]/(accuracy_info[dbName][0]+accuracy_info[dbName][1]))
            avgh+=accuracy_info[dbName][0]/(accuracy_info[dbName][0]+accuracy_info[dbName][1])

print(totIncor)
print(totLight)
print(totHeavy)
print(lightAcc)
print(heavyAcc)
# print(avgl/len(lightAcc))
# print(avgh/len(heavyAcc))


In [None]:
model_used='gpt'
newInstTot=0
final_accuracy = {}

for dbName, dbSchema in schema_array.items():
    correct_single_table = 0
    correct_multi_table = 0
    single_table = 0
    multi_table = 0

    if dbName not in databases: continue
    if 'tables' not in dbSchema: continue
    num_questions=len(databases[dbName]['question'])
    num_tables=len(dbSchema['tables'])
    if num_questions==0: continue
    if num_tables==0: continue

    # #filtering out heavy databases
    # if num_tables<8 or num_questions<25: continue

    print(f"Database: {dbName}")

    #initialise entryof the database in accuracy info
    if dbName not in accuracy_info:
        accuracy_info[dbName]=[0,0]

    dbTables=[]
    for table in dbSchema["tables"]:
        dbTables.append(table["name"].lower())

    for i in range(len(databases[dbName]['question'])):
        # print(f"{totalQueries}, {newInstTot}")
        
            
        newInstTot+=1
        if (totalQueries>=(newInstTot)):
            continue

        # print(f"{totalQueries}, {newInstTot}")


        tables_used=""
        nlqex_tables_used=""


        # Iterate through dbNames

        for tname in databases[dbName]['tables'][i]:
            for tables in dbSchema["tables"]:
                if tables["name"].lower() ==tname.lower() :
                    tables_used += json.dumps(tables) + "\n"



        #nlq extracted tables

        for tname in databases[dbName]['nlqex_tables'][i]:
            for tables in dbSchema["tables"]:
                if tables["name"].lower() ==tname.lower() :
                    nlqex_tables_used += json.dumps(tables) + "\n"


        #shots chosen dynamically
        
        queryVector=[0]*len(dbTables)
        for tname in databases[dbName]['tables'][i]:
            if tname.lower() in dbTables:
                queryVector[dbTables.index(tname.lower())]=1
        
        shots=[]


        for j in range(len(databases[dbName]['question'])):
            if abs(j-i)>2:
                shotVector=[0]*len(dbTables)
                for tname in databases[dbName]['tables'][j]:
                    if tname.lower() in dbTables:
                        shotVector[dbTables.index(tname.lower())]=1
                shots.append([cosine_similarity(queryVector,shotVector),databases[dbName]['question'][j],databases[dbName]['query'][j]])

        shots.sort(reverse=True)

        dynamicPrompt="\n\nHere are some more examples:\n"

        numberOfShots=3
        minSimilarity=0.3
        maxSimilarity=1

        final_shots=[]

        count=0
        for j in range(len(shots)):
            if count==numberOfShots: break
            if shots[j][0]<minSimilarity: continue
            if shots[j][0]>maxSimilarity: continue
            count+=1
            final_shots.append(shots[j])
            dynamicPrompt+=f"Example {count}:\n\nQuestion:\n{shots[j][1]}\n\nSQL Query:\n{shots[j][2]}\n\n"


        for j in final_shots:
            print(j[0])

        #prompts used
        # tables_used=json.dumps(schema_array[dbName])
        initialPrompt=f"I will feed you schema of a database, store it and then use it to answer the question I will ask related to the database.\n"
        tablesPrompt = "Following is the schema of tables you can use to write the SQL Query." + "\n" + tables_used + "\n"
        nlqTablesPrompt = "Following are the tables that are highly likely to be used for the given query." + "\n" + nlqex_tables_used + "\n"
        infoPrompt = "we are using sqlite as the tool for running the queries. Keep syntax related to it in mind, and provide code accordingly. Dont forget the semicolons wherever needed."
        columnNamePrompt = "Be careful with the names of columns. Use precise column names, since changing them (for example making uppercase to lowercase) may lead to error. Also enclose the column names in double quotes to avoid any clashes with reserved keywords.\n"
        examplePrompt = '''
        \n\n\n
        Some guidelines to write SQL queries are as follows:

        Avoid using = if possible, in situations like in this query:
            SELECT
            m.Location,
            a.Aircraft
            FROM `match` AS m
            JOIN aircraft AS a
            ON m.Winning_Aircraft = a.Aircraft_ID
            WHERE
            m.Round = (
                SELECT
                `Round`
                FROM `match`
                ORDER BY
                Fastest_Qualifying
                LIMIT 1
            );
        Here since = is used, m.Round is limited to one row which is not correct.
        Also, always check datatypes to ensure that u are understanding the information a column represents correctly. For example, sometimes a column with say name pilot may actually represent its id and not its name.
        To help with this, first state your understanding of what the columns actually represent, then walk through the logic for how you are writing the particular query (a stepwise explanation), and then print the final query.
        
        You can also use joins to extract information. For example:
        Question: What are the last names and ages of the students who are allergic to milk and cat?
        Tables: Student, Has_Allergy
        SQL Query:
            SELECT s.LName, s.Age
            FROM Student s
            JOIN Has_Allergy ha1 ON s.StuID = ha1.StuID
            JOIN Has_Allergy ha2 ON s.StuID = ha2.StuID
            WHERE ha1.Allergy = 'Milk' AND ha2.Allergy = 'Cat';

        Notice how I have joined 3 tables, 2 of which are same. Essentially any table can be joined with any other under suitable constraints.

        Also take care of string type columns. Don't confuse singular with plural, for example, egg with eggs. Use the precise string in queries.

        When using AND and OR operators together, use parentheses to ensure the correct logical grouping of conditions. For example:
        Question: How many female students are allergic to milk or eggs?
        Tables: Student, Has_Allergy
        Incorrect SQL Query:
            SELECT count(*) 
            FROM has_allergy AS T1 
            JOIN Student AS T2 ON T1.StuID = T2.StuID 
            WHERE T2.sex = "F" AND T1.allergy = "Milk" OR T1.allergy = "Eggs";

        Correct SQL Query:
            SELECT count(*) 
            FROM has_allergy AS T1 
            JOIN Student AS T2 ON T1.StuID = T2.StuID 
            WHERE T2.sex = "F" AND (T1.allergy = "Milk" OR T1.allergy = "Eggs");

        Notice the use of parentheses to group the OR conditions properly. Always use parentheses to avoid logical errors when combining AND and OR operators.

        Also, when there is a column name, say X, present in more than one tables, say both T1 and T2, be very specific of the column you are using in the sql query. So, if u want to use the column from T1,
        be sure to use T1.X instead of simply X. This helps avoid the ambiguous column error.  

        For eg. 

        Incorrect SQL Query:
            SELECT "BlockFloor", COUNT(*) AS "NumberOfRooms" 
            FROM "Room" 
            JOIN "Block" ON "Room"."BlockFloor" = "Block"."BlockFloor" AND "Room"."BlockCode" = "Block"."BlockCode" 
            GROUP BY "BlockFloor";         
        Here "BlockFloor" column is ambiguous, thus giving error.

        Correct Query:
            SELECT count(*) ,  T1.blockfloor FROM BLOCK AS T1 JOIN room AS T2 ON T1.blockfloor  =  T2.blockfloor AND T1.blockcode  =  T2.blockcode GROUP BY T1.blockfloor;
        The above query corrected ambiguous error by using "T1.blockfloor" instead of simply blockfloor, which specified that we meant that the specified column belongs to T1.

        If one asks you, say, list people with some specific criteria, then you should try to print names , and print id of the name only if asked specifially.

        If, for example, someone asks list of people's names, then use distinct to make sure that a particular person is listed not more than once. For example:
        Question: Find the names of nurses who are on call.
        Correct Query:
            SELECT DISTINCT T1.name FROM nurse AS T1 JOIN on_call AS T2 ON T1.EmployeeID  =  T2.nurse
        Notice how using distinct here makes sure that a nurse is not included more than once, since we need only names of the nurses and not other information.

        End of guidelines.
        \n\n\n
        '''
        chainOfThoughtPrompt="Understand what each column and table mean. State what you understand about the tabel and their relatons. Also state logical steps in between as to how you are constructing the final SQL query."
        questionPrompt="\n\nHere is the question part, i.e., the query on the database above, explained in english, for which the corresponding SQL query code is needed.: \n\n\n" + databases[dbName]['question'][i] + "\n\n\n" +  "\n\n" + "Provide the SQL query at the end of the response.\n"

        #with nlqTables
        # prompt=initialPrompt+tablesPrompt+nlqTablesPrompt+infoPrompt+columnNamePrompt+examplePrompt+dynamicPrompt+chainOfThoughtPrompt+questionPrompt

        #without nlqTables
        prompt=initialPrompt+tablesPrompt+infoPrompt+columnNamePrompt+examplePrompt+dynamicPrompt+chainOfThoughtPrompt+questionPrompt

        #prompt without shots
        promptNoShots=initialPrompt+tablesPrompt+infoPrompt+columnNamePrompt+examplePrompt+chainOfThoughtPrompt+questionPrompt

        #get response from LLM
        response=printResponse(prompt,model_used)

        #get og and gen queries
        generated_query = extract_sql_query(response)
        original_query = databases[dbName]['query'][i]
        if not original_query.endswith(';'):
            original_query+=';'

        #compare og and gen queries
        isSame=compare_queries(dbName, db_folder_path, generated_query, original_query)
        is_single_table = False
        if (len(databases[dbName]['tables'][i])==1):
            is_single_table=True
            single_table+=1
        else:
            multi_table+=1
        if(isSame==True):
            if is_single_table==True:
                correct_single_table+=1
            else:
                correct_multi_table+=1

        #logging
        print(f"Q{totalQueries+1}:\n") 
        if isSame==False:
            print("The queries do not match.\n")
            notCorrectAns+=1
            accuracy_info[dbName][1]+=1
            #logging in incorrectGeminiLog.txt
            with open('logs/incorrectGeminiLog.txt', 'a') as f:
                f.write(f"Q{totalQueries+1}:\n")
                f.write(f"Prompt Tables:\n{tables_used}\n")
                f.write(f"LLM_Input:\n{prompt}\n")
                # f.write(f"\n\n\ncosine_similarity: {shots[0][0]} {shots[1][0]} {shots[2][0]}\n\n\n")
                f.write(f"LLM_response:\n{response}\n")
                f.write(f"Question:\n\n{databases[dbName]['question'][i]}\n\n")
                f.write(f"Original_query:\n\n{original_query}\n\n")
                f.write(f"generated_query:\n\n{generated_query}\n\n")
                f.write(f"\n\n\n")

            errorLog=os.path.join(db_folder_path, dbName,'errors.txt')
            with open(errorLog, 'a') as f:
                f.write(f"Q{totalQueries+1}:\n")
                f.write(f"Prompt Tables:\n{tables_used}\n")
                f.write(f"LLM_Input:\n{prompt}\n")
                # f.write(f"\n\n\ncosine_similarity: {shots[0][0]} {shots[1][0]} {shots[2][0]}\n\n\n")
                f.write(f"LLM_response:\n{response}\n")
                f.write(f"Question:\n\n{databases[dbName]['question'][i]}\n\n")
                f.write(f"Original_query:\n\n{original_query}\n\n")
                f.write(f"generated_query:\n\n{generated_query}\n\n")
                f.write(f"\n\n\n")

        else:
            print("The queries match.\n")
            correctAns+=1
            accuracy_info[dbName][0]+=1
        totalQueries+=1

        #logging in geminiLog.txt
        with open('logs/geminiLog.txt', 'a') as f:
            f.write(f"Q{totalQueries}:\n")
            f.write(f"Prompt Tables:\n{tables_used}\n")
            f.write(f"LLM_Input:\n{prompt}\n")
            # f.write(f"\n\n\ncosine_similarity: {shots[0][0]} {shots[1][0]} {shots[2][0]}\n\n\n")
            f.write(f"LLM_response:\n{response}\n")
            f.write(f"Question:\n\n{databases[dbName]['question'][i]}\n\n")
            f.write(f"Original_query:\n\n{original_query}\n\n")
            f.write(f"generated_query:\n\n{generated_query}\n\n")
            f.write(f"Total Queries: {totalQueries}, Correct Answers: {correctAns}, Incorrect Answers: {notCorrectAns}")
            f.write(f"\n\n\n")
        
    final_accuracy[dbName] = {
        'correct_single_table': correct_single_table,
        'correct_multi_table': correct_multi_table,
        'single_table': single_table,
        'multi_table': multi_table
    }


In [None]:
for dbName, values in final_accuracy.items():
    print(f"Database: {dbName}")
    print(f"Tables present: {len(schema_array[dbName]['tables'])}")
    print(f"Correct single table: {values['correct_single_table']}")
    print(f"Correct multi table: {values['correct_multi_table']}")
    print(f"Single table: {values['single_table']}")
    print(f"Multi table: {values['multi_table']}")
    print("\n\n")

In [None]:
print(schema_array['browser_web']['tables'][0]['name'])

In [None]:
totalQueriesNG=0
correctAnsNG=0
notCorrectAns=0

In [None]:
model_used='gemini'
newInstTot=0
for dbName, dbSchema in schema_array.items():
    
    if dbName not in databases:
        continue

    print(dbName)

    for i in range(len(databases[dbName]['question'])):
        # print(f"{totalQueriesNG}, {newInstTot}")
        newInstTot+=1
        if (totalQueriesNG>=(newInstTot)):
            continue

        # print(f"{totalQueriesNG}, {newInstTot}")
        tables_used=""

        # Iterate through dbNames

        for tname in databases[dbName]['tables'][i]:
            for tables in dbSchema["tables"]:
                if tables["name"].lower() ==tname.lower() :
                    tables_used += json.dumps(tables) + "\n"

        

        #prompts used
        tables_used=json.dumps(schema_array[dbName])
        initialPrompt=f"I will feed you schema of a database, store it and then use it to answer the question I will ask related to the database.\n"
        tablesPrompt = "Following is the schema of tables you can use to write the SQL Query." + "\n\n" +tables_used + "\n\n"
        infoPrompt = "we are using sqlite as the tool for running the queries. Keep syntax related to it in mind, and provide code accordingly. Dont forget the semicolons wherever needed."
        columnNamePrompt = "Be careful with the names of columns. Use precise column names, since changing them (for example making uppercase to lowercase) may lead to error. Also enclose the column names in double quotes to avoid any clashes with reserved keywords.\n"
        examplePrompt = '''
        Avoid using = if possible, in situations like in this query:
                            SELECT
                            m.Location,
                            a.Aircraft
                            FROM `match` AS m
                            JOIN aircraft AS a
                            ON m.Winning_Aircraft = a.Aircraft_ID
                            WHERE
                            m.Round = (
                                SELECT
                                `Round`
                                FROM `match`
                                ORDER BY
                                Fastest_Qualifying
                                LIMIT 1
                            );
                            Here since = is used, m.Round is limited to one row which is not correct.
                            Also, always check datatypes to ensure that u are understanding the information a column represents correctly. For example, sometimes a column with say name pilot may actually represent its id and not its name.
                            To help with this, first state your understanding of what the columns actually represent, then walk through the logic for how you are writing the particular query (a stepwise explanation), and then print the final query.
                            
                            You can also use joins to extract information. For example:
                            Question: What are the last names and ages of the students who are allergic to milk and cat?
                            Tables: Student, Has_Allergy
                            SQL Query:
                                SELECT s.LName, s.Age
                                FROM Student s
                                JOIN Has_Allergy ha1 ON s.StuID = ha1.StuID
                                JOIN Has_Allergy ha2 ON s.StuID = ha2.StuID
                                WHERE ha1.Allergy = 'Milk' AND ha2.Allergy = 'Cat';
                            Notice how I have joined 3 tables, 2 of which are same. Essentially any table can be joined with any other under suitable constraints.

        Also take care of string type columns. Don't confuse singular with plural, for example, egg with eggs. Use the precise string in queries.

        When using AND and OR operators together, use parentheses to ensure the correct logical grouping of conditions. For example:
        Question: How many female students are allergic to milk or eggs?
        Tables: Student, Has_Allergy
        Incorrect SQL Query:
            SELECT count(*) 
            FROM has_allergy AS T1 
            JOIN Student AS T2 ON T1.StuID = T2.StuID 
            WHERE T2.sex = "F" AND T1.allergy = "Milk" OR T1.allergy = "Eggs";

        Correct SQL Query:
            SELECT count(*) 
            FROM has_allergy AS T1 
            JOIN Student AS T2 ON T1.StuID = T2.StuID 
            WHERE T2.sex = "F" AND (T1.allergy = "Milk" OR T1.allergy = "Eggs");

        Notice the use of parentheses to group the OR conditions properly. Always use parentheses to avoid logical errors when combining AND and OR operators.

        Also, when there is a column name, say X, present in more than one tables, say both T1 and T2, be very specific of the column you are using in the sql query. So, if u want to use the column from T1,
        be sure to use T1.X instead of simply X. This helps avoid the ambiguous column error.   

        '''
        chainOfThoughtPrompt="Understand what each column and table mean. State what you understand about the tabel and their relatons. Also state logical steps in between as to how you are constructing the final SQL query."
        questionPrompt="\nHere is the question part, i.e., the query on the database above, explained in english, for which the corresponding SQL query code is needed.: \n\n" + databases[dbName]['question'][i] + "\n\n" +  "\n" + "Provide the SQL query at the end of the response.\n"

        prompt=initialPrompt+tablesPrompt+infoPrompt+columnNamePrompt+examplePrompt+chainOfThoughtPrompt+questionPrompt

        #get response from LLM
        response=printResponse(prompt,model_used)

        #get og and gen queries
        generated_query = extract_sql_query(response)
        original_query = databases[dbName]['query'][i]
        if not original_query.endswith(';'):
            original_query+=';'

        #compare og and gen queries
        isSame=compare_queries(dbName, db_folder_path, generated_query, original_query)

        #logging
        print(f"Q{totalQueriesNG+1}:\n") 
        if isSame==False:
            print("The queries do not match.\n")
            notCorrectAnsNG+=1

            #logging in incorrectGeminiLog.txt
            with open('logs/incorrectGeminiLogNoGolden.txt', 'a') as f:
                f.write(f"Q{totalQueriesNG+1}:\n")
                f.write(f"Prompt Tables:\n{tables_used}\n")
                f.write(f"LLM_response:\n{response}\n")
                f.write(f"Question:\n\n{databases[dbName]['question'][i]}\n\n")
                f.write(f"Original_query:\n\n{original_query}\n\n")
                f.write(f"generated_query:\n\n{generated_query}\n\n")
                f.write(f"\n\n\n")

        else:
            print("The queries match.\n")
            correctAnsNG+=1
        totalQueriesNG+=1

        #logging in geminiLog.txt
        with open('logs/geminiLogNoGolden.txt', 'a') as f:
            f.write(f"Q{totalQueriesNG}:\n")
            f.write(f"Prompt Tables:\n{tables_used}\n")
            f.write(f"LLM_response:\n{response}\n")
            f.write(f"Question:\n\n{databases[dbName]['question'][i]}\n\n")
            f.write(f"Original_query:\n\n{original_query}\n\n")
            f.write(f"generated_query:\n\n{generated_query}\n\n")
            f.write(f"Total Queries: {totalQueriesNG}, Correct Answers: {correctAnsNG}, Incorrect Answers: {notCorrectAnsNG}")
            f.write(f"\n\n\n")        


In [None]:
tables_list = ['physician', 'department', 'affiliated_with', 'procedures', 'trained_in', 'patient', 'nurse', 'appointment', 'medication', 'prescribes', 'block', 'room', 'on_call', 'stay', 'undergoes']
STprompt = "I will feed you the schema of a database, give me a description of the database using the tabluar information. Describe each table present in the schema in one or two lines. Here is the schema \n\n" + str(schema_array['hospital_1'])+'\n\n'
print(printResponse(STprompt,'gemini'))

In [None]:
# print(databases['hospital_1']['question'])

for q in databases['hospital_1']['question']:
    print(f"\"{q}\",")

In [None]:
table_extraction_prompt2 = "Give a detailed description of each table in the database. I have provided you the schema which contains the column names. I have also provided sample data, which is essentially the first row of the table, for better understanding. Use all the information to understand the table deeply, and provid information about it. The information should include what each column of the table contains, as well. We will be using the information you provide for extracting queries that use this table. Also, all information should be in the fomr of a single paragraph. The response should be # Table table_name table_description ! .\n"