In [None]:

import openai
from openai import OpenAI
import os
from dotenv import load_dotenv
import json

import google.generativeai as genai
import os
import time
import sqlite3
import logging
import mysql.connector
from collections import Counter
import math

In [2]:
#setting up LLMs

load_dotenv('.env')

#gpt
api_key_gpt=os.getenv('CHATGPT_API_KEY')
client = OpenAI(
  api_key=api_key_gpt,
)

#gemini
genai.configure(api_key=os.environ["GOOGLE_API_KEY_FLASH"])
modelGemini=genai.GenerativeModel('gemini-1.5-flash')

#llama
api_key_llama=os.getenv("LLAMA_API_KEY")
client = OpenAI(
    api_key = api_key_llama,
    base_url = "https://api.llama-api.com"
)

In [5]:
def printResponse(prompt,llm):
    # print(f"User:\n{prompt}\n\n")
    if llm=='gpt':
        chat_completion = client.chat.completions.create(
            messages=[
                {
                    "role": "user",
                    "content": prompt,
                }
            ],
            model="gpt-3.5-turbo-0125",
        )
        response=chat_completion.choices[0].message.content
        # print(f"LLM:\n{response}\n\n")
        return response
    elif llm=='gemini':
         #gemini response limit mitigation,
        response = modelGemini.generate_content(prompt)
        # print(f"LLM:\n{response.text}\n\n")
        return response.text
    elif llm=='llama':
        response = client.chat.completions.create(
            model="llama-13b-chat",
            messages=[
                {"role": "system", "content": "Assistant is a large language model trained by OpenAI."},
                {"role": "user", "content": prompt}
            ]
        )
        response=response.choices[0].message.content
        # print(f"LLM:\n{response}\n\n")
        return response

In [None]:
response = printResponse("What is the capital of India?",'gemini')
print(response)

In [7]:

def execute_query(database_path, query):
    try:
        conn = sqlite3.connect(database_path)
        cursor = conn.cursor()
        cursor.execute(query)
        results = cursor.fetchall()
        conn.close()
        return results
    except sqlite3.OperationalError as e:
        print(f"An error occurred: {e}")
        with open('logs/incorrectGeminiLog.txt', 'a') as f:
            f.write(f"\nError that occured:\n{e}\n\n")
        return None

def get_first_row_with_columns(database_path, table_name):
    try:
        conn = sqlite3.connect(database_path)
        cursor = conn.cursor()
        
        # Get column names
        cursor.execute(f"PRAGMA table_info({table_name})")
        columns = [info[1] for info in cursor.fetchall()]
        
        # Get the first row
        cursor.execute(f"SELECT * FROM {table_name} LIMIT 1")
        first_row = cursor.fetchone()
        
        conn.close()
        
        return columns, first_row
    except sqlite3.OperationalError as e:
        print(f"An error occurred: {e}")
        with open('logs/incorrectGeminiLog.txt', 'a') as f:
            f.write(f"\nError that occurred:\n{e}\n\n")
        return None, None

In [8]:

def getDbSchemaMapping(dbFolderPath):
    count = 0
    schema_array = {}
    for folder in os.listdir(dbFolderPath):
        count += 1
        folder_path = os.path.join(dbFolderPath, folder)
        if os.path.exists(folder_path):
            json_data = None
            table_info = []
            for file_name in os.listdir(folder_path):
                if file_name.endswith('.json'):
                    json_file_path = os.path.join(folder_path, file_name)
                    with open(json_file_path, 'r', encoding='utf-8') as file:
                        try:
                            json_data = json.load(file)
                        except json.JSONDecodeError as e:
                            print(f"Error decoding JSON in file {json_file_path}: {e}")
                            continue
            
            db_file_path = os.path.join(folder_path, f"{folder}.sqlite")
            if json_data:
                if 'tables' in json_data:
                    final_table_info = {}
                    for table in json_data['tables']:
                        table_name = table['name']
                        columns, first_row = get_first_row_with_columns(db_file_path, table_name)
                        table_info.append((table_name, columns, first_row))
                    
                        
                    schema_array[folder] = {
                        "schema": json_data,
                        "table_info": table_info
                    }
                else:
                    print(f"'tables' key not found in JSON data for folder: {folder_path}")
                    schema_array[folder] = {
                        "schema": json_data,
                        "table_info": table_info
                    }
            else:
                print(f"JSON file not found for folder: {folder_path}")
    print(count)
    return schema_array

In [9]:
def get_data(tableJsonPath,cleanDataPath):

    with open(tableJsonPath,"r") as f:
        tables_data=json.load(f)

    database = {}

    with open(cleanDataPath, 'r') as f:

        for line in f:
            data = json.loads(line)
            db_id = data.get('db_id')
            query = data.get('query')
            question = data.get('question')
            query_toks=data.get('query_toks')

            word_freq={}        
            for item in tables_data:
                if item['db_id']==db_id:
                    for table_name in item["table_names_original"]:
                        word_freq[table_name.lower()]=1

            if db_id not in database:
                database[db_id] = {'query': [], 'question': [], 'query_toks': [], 'tables': []}

            interim_map={}
            table_list=[]
            for query_tok in query_toks:
                if query_tok.lower() in word_freq:
                    if query_tok.lower() not in interim_map:
                        interim_map[query_tok.lower()]=1
                        table_list.append(query_tok.lower())
            
            database[db_id]['query'].append(query)
            database[db_id]['question'].append(question)
            database[db_id]['query_toks'].append(query_toks)
            database[db_id]['tables'].append(table_list)
            

    return database

In [None]:
schema_array = getDbSchemaMapping('F:/OneDrive/Desktop/Study/NLP_ResearchProject/Project/spider/database')
databases = get_data('F:/OneDrive/Desktop/Study/NLP_ResearchProject/Project/spider/tables.json','F:/OneDrive/Desktop/Study/NLP_ResearchProject/Project/spider/train_spider_main_data.json')

In [11]:
hard_dataset = []
easy_dataset = []

for folder, data in schema_array.items():
    schema = data['schema']
    table_info = data['table_info']
    length = len(table_info)

    if length >= 6:
        hard_dataset.append(folder)
    else:
        easy_dataset.append(folder)



In [None]:
print(schema_array['school_bus']['table_info'])

In [None]:
for folder, data in schema_array.items():
    print(f"Folder: {folder}")
    
    schema = data.get('schema', {})
    print(f"Schema: {schema}")
    
    table_info = data.get('table_info', [])
    for table in table_info:
        table_name, columns, first_row = table
        print(f"  Table Name: {table_name}")
        print(f"  Columns: {columns}")
        print(f"  First Row: {first_row}")

In [None]:
print(len(databases))
print(len(schema_array))
databases

In [None]:
print(schema_array['insurance_fnol'])

In [12]:
# Database folder path
db_folder_path = 'F:/OneDrive/Desktop/Study/NLP_ResearchProject/Project/spider/database'
import re

def extract_table_descriptions(response):
    # Define the pattern to extract table_name and table_description
    pattern = r"# Table (\w+)\s*\n(.+?)\n"
    matches = re.findall(pattern, response, re.DOTALL)

    # Create a dictionary {table_name: table_description}
    table_extraction_dict = {match[0]: match[1].strip() for match in matches}
    
    return table_extraction_dict


def convert_to_list(results):
    final_list = []
    if not results: return final_list
    for row in results:
        final_list.append(tuple(sorted(list(row), key=str)))
    return final_list


def compare_queries(db_name, db_folder_path, generated_query, original_query):
    db_path = os.path.join(db_folder_path, db_name, db_name + '.sqlite')
    if not os.path.exists(db_path):
        print(f"Database {db_name} not found.")
        return False

    generated_results = execute_query(db_path, generated_query)
    original_results = execute_query(db_path, original_query)

    # print(generated_results)
    # print(original_results)

    gen_list = convert_to_list(generated_results)
    orig_list = convert_to_list(original_results)

    print(gen_list)
    print(orig_list)

    return Counter(gen_list) == Counter(orig_list)   

def extract_sql_query(response):
    start = response.find("```sqlite")
    flag=0
    if (start==-1):
        flag=1
        start = response.find("```sql")
    end = response.find("```", start + 1)
    if start == -1 or end == -1:
        return ""
    sql_query=""
    if flag==1: sql_query = response[start + 6:end].strip()  # Extract the query between ```sql and ```
    else: sql_query = response[start + 9:end].strip()
    return sql_query

def table_desc_extractor(table_info, table_description):
    table_extract = {}
    for table in table_info:
        table_name = table[0]
        table_col = str(table[1])
        table_row = str(table[2])
        # Check if table_name exists in table_description and is not None, else use a default value
        table_desc = table_description.get(table_name, "No description") if table_description.get(table_name) is not None else "No description"
        table_extract[table_name] ={
            "table_description": table_desc,
            "table_columns": table_col,
            "table_first_row": table_row
        }
    return table_extract




In [13]:
#initialise cross prompt variiables (DANGER: DO NOT PRESS)
totalQueries=0
correctAns=0
notCorrectAns=0
file_path = 'F:/OneDrive/Desktop/Study/NLP_ResearchProject/Project/incorrectGeminiLog.txt'

with open(file_path, 'w') as file:
    pass
file_path = 'F:/OneDrive/Desktop/Study/NLP_ResearchProject/Project/geminiLog.txt'
with open(file_path, 'w') as file:
    pass

In [24]:
#totalQueries setter
# totalQueries=24

In [None]:
print(totalQueries)

In [74]:
totalQueries=191
correctAns=123
notCorrectAns=68

In [87]:
def cosine_similarity(v1, v2):
    numerator=0
    for i in range(len(v1)):
        numerator+=(v1[i]*v2[i])
    if sum(v1)==0 or sum(v2)==0:
        return -1
    numerator/=math.sqrt(sum(v1))
    numerator/=math.sqrt(sum(v2))
    return numerator

In [None]:
total=0

for dbName, dbSchema in schema_array.items():
    if dbName not in databases: continue
    if 'tables' not in dbSchema: continue
    num_questions=len(databases[dbName]['question'])
    num_tables=len(dbSchema['tables'])
    if num_tables<8 or num_questions<25: continue
    total+=num_questions
    # print(dbName)
    # print(f"{len(dbSchema['tables'])}, {len(databases[dbName]['question'])}")

print(total)

In [None]:
database_table_extract = {}
count = 0
model_used = 'gemini'
database_table_descrition = {}
for folder, data in schema_array.items():
    if folder in database_table_extract:
        continue 
    # if folder != 'tracking_orders': continue
    if folder not in databases: continue
    if folder not in ['world_1', 'voter_1', 'twitter_1', 'soccer_1', 'small_bank_1', 'school_player', 'icfp', 'gymnast', 'flight_4', 'epinions_1', 'company_1']:
        continue
    # if 'tables' not in data: continue
    print(f"Folder: {folder}")
    num_questions=len(databases[folder]['question'])
    num_tables=len(data['table_info'])
    schema = data.get('schema', {})
    # print(f"Schema: {schema}")
    
    table_info = data.get('table_info', [])
    tables_used=json.dumps(schema)
    table_extraction_prompt = "Here is the schema of the database" + tables_used + "\n"
    table_extraction_prompt2 = '''
    Give a detailed description of each table in the database. I have provided you the schema which contains the column names. I have also provided sample data, which is essentially the first row of the table, for better understanding. Use all the information to understand the table deeply, and provid information about it. The information should include what each column of the table contains, as well. Describe in english what each column is storing according to you. We will be using the information you provide for extracting queries that use this table. Also, all information should be in the fomr of a single paragraph. The response should be # Table table_name table_description ! . A sample response may look like: 
    #Table employees
    Stores infomation regarding all the employees in the company. The column e_id is the employee ID of each employee, which is unique to each employee.
    !
    \n
    '''
    table_prompt = table_extraction_prompt + table_extraction_prompt2
    response = printResponse(table_prompt,model_used)
    with open('./geminiLog.txt', 'a') as f:
        f.write(f"\nResponse:\n{response}\n\n")
    table_description = extract_table_descriptions(response)
    print(table_description)

    database_table_descrition[folder] = table_description
    table_extract = table_desc_extractor(table_info, table_description)

    database_table_extract[folder] = table_extract

    # for table in table_info:
    #     table_name, columns, first_row = table
    #     print(f"  Table Name: {table_name}")
    #     print(f"  Columns: {columns}")
    #     print(f"  First Row: {first_row}")
    

In [None]:
print(len(database_table_extract))

In [None]:
with open('database_table_extract.json', 'a') as json_file:
    json_file.write(json.dumps(database_table_descrition, indent=4))


In [None]:
database_table_extract = {}
count = 0
model_used = 'gemini'
database_table_descrition = {}
for folder, data in schema_array.items():
    if folder == 'soccer_1':
        continue
    if folder in past_table_description:
        continue
    if folder in database_table_extract:
        continue 
    # if folder != 'tracking_orders': continue
    if folder not in databases: continue
    # if 'tables' not in data: continue
    print(f"Folder: {folder}")
    num_questions=len(databases[folder]['question'])
    num_tables=len(data['table_info'])
    schema = data.get('schema', {})
    # print(f"Schema: {schema}")
    
    table_info = data.get('table_info', [])
    tables_used=json.dumps(schema)
    table_extraction_prompt = "Here is the schema of the database" + tables_used + "\n"
    table_extraction_prompt2 = '''
    Give a detailed description of each table in the database. I have provided you the schema which contains the column names. I have also provided sample data, which is essentially the first row of the table, for better understanding. Use all the information to understand the table deeply, and provid information about it. The information should include what each column of the table contains, as well. Describe in english what each column is storing according to you. We will be using the information you provide for extracting queries that use this table. Also, all information should be in the fomr of a single paragraph. The response should be # Table table_name table_description ! . A sample response may look like: 
    #Table employees
    Stores infomation regarding all the employees in the company. The column e_id is the employee ID of each employee, which is unique to each employee.
    !
    \n
    '''
    table_prompt = table_extraction_prompt + table_extraction_prompt2
    response = printResponse(table_prompt,model_used)
    with open('./geminiLog.txt', 'a') as f:
        f.write(f"\nResponse:\n{response}\n\n")
    table_description = extract_table_descriptions(response)
    print(table_description)

    database_table_descrition[folder] = table_description
    table_extract = table_desc_extractor(table_info, table_description)

    database_table_extract[folder] = table_extract


In [47]:
with open('database_table_extract.json', 'a') as json_file:
    json_file.write(json.dumps(database_table_descrition, indent=4))

In [51]:
with open('database_table_extract.json', 'r') as json_file:
    past_table_description = json.load(json_file)

In [None]:
print(len(past_table_description))

In [53]:
database_table_extract = {}

for folder, data in schema_array.items():
    if folder not in past_table_description:
        continue
    table_description = past_table_description[folder]
    table_info = data.get('table_info', [])
    table_extract = table_desc_extractor(table_info, table_description)

    database_table_extract[folder] = table_extract
    

In [None]:
for folder, values in database_table_extract.items():
    print(f"Folder: {folder}")
    for table_name, table_info in values.items():
        print(f"  Table Name: {table_name}")
        print(f"  Table Info: {table_info}")
    print(len(database_table_extract))

In [None]:
count = 0
for folder, tables in database_table_extract.items():
    count += 1
print(count)

In [None]:
# Unsupervised learning
model_used = 'gemini'
threshold = 0.5

from sentence_transformers import SentenceTransformer, util
embedder = SentenceTransformer("all-MiniLM-L6-v2")

db_table_extraction_list = {}
total_num_correct = 0
total_num_incorrect = 0

for folder, tables in database_table_extract.items():
    print(f"Folder: {folder}")
    num_questions = len(databases[folder]['question'])
    num_correct = 0
    num_incorrect = 0

    for i in range(num_questions):
        question = databases[folder]['question'][i]
        question_embedding = embedder.encode(question, convert_to_tensor=True)
        table_similarities = {}
        for table_name, table_info in tables.items():
            table_description = table_info['table_description']
            table_description_embedding = embedder.encode(table_description, convert_to_tensor=True)
            similarity = util.pytorch_cos_sim(question_embedding, table_description_embedding)
            table_similarities[table_name.lower()] = similarity.item()
        
        print(f"Question: {question}")
        print(f"Table Similarities: {table_similarities}")
        

        num_elements = len(table_similarities)
        flag = True
        if num_elements <= 4:
            flag = False

        
        try:
            average_score = sum(table_similarities.values()) / len(table_similarities)
        except ZeroDivisionError:
            average_score = 0  # or any other default value or action
        selected_tables = []
        if flag:
            selected_tables = [table_name for table_name, score in table_similarities.items() if score > average_score]
        else :
            selected_tables = [table_name for table_name, score in table_similarities.items()]
        print(f"Selected Tables: {selected_tables}")
        
        table_list = databases[folder]['tables'][i]
        query = databases[folder]['query'][i]
        print(f"Table List: {table_list}")
        print(f"Query: {query}")
        
        # Check if all tables in table_list are present in selected_tables
        if all(table in selected_tables for table in table_list):
            print("Correct")
            num_correct += 1
            total_num_correct += 1
        else:
            print("Incorrect")
            num_incorrect += 1
            total_num_incorrect += 1
        print("\n")
    
    percent = (num_correct / num_questions) * 100
    print(f"{folder} has this much accuracy {percent}")
    db_table_extraction_list[folder] = {
        "Accuracy": percent,
        "Correct": num_correct,
        "Incorrect": num_incorrect
    }
    print("\n")

total_percent = (total_num_correct / (total_num_correct + total_num_incorrect)) * 100

print(f"Total Percent Accuracy: {total_percent}")

    
    
        
# Select the tables with similarity scores above the average score / 75% quartile
        

In [None]:
for folder, accuracy in db_table_extraction_list.items():
    print(f"Folder: {folder}")
    print(f"Accuracy: {accuracy['Accuracy']}")

In [59]:
with open("table_extraction_percentAccuracy.json", "w") as f:
    json.dump(db_table_extraction_list, f, indent=4)

In [65]:
hard_table_extraction_percent_accuracy = {}
easy_table_extraction_percent_accuracy = {}

for folder, accuracy in db_table_extraction_list.items():
    if folder in hard_dataset:
        hard_table_extraction_percent_accuracy[folder] = accuracy
    else:
        easy_table_extraction_percent_accuracy[folder] = accuracy

with open("hard_table_extraction_percent_accuracy.json", "w") as f:
    json.dump(hard_table_extraction_percent_accuracy, f, indent=4)
with open("easy_table_extraction_percent_accuracy.json", "w") as f:
    json.dump(easy_table_extraction_percent_accuracy, f, indent=4)

hard_num_correct = 0
hard_num_incorrect = 0
hard_percent = 0

for folder, accuracy in hard_table_extraction_percent_accuracy.items():
    hard_num_correct += accuracy['Correct']
    hard_num_incorrect += accuracy['Incorrect']
hard_percent = (hard_num_correct / (hard_num_correct + hard_num_incorrect)) * 100

easy_num_correct = 0
easy_num_incorrect = 0
easy_percent = 0

for folder, accuracy in easy_table_extraction_percent_accuracy.items():
    easy_num_correct += accuracy['Correct']
    easy_num_incorrect += accuracy['Incorrect']
easy_percent = (easy_num_correct / (easy_num_correct + easy_num_incorrect)) * 100
hard_size = len(hard_table_extraction_percent_accuracy)
easy_size = len(easy_table_extraction_percent_accuracy)
percent = {
    "Hard": hard_percent,
    "Easy": easy_percent,
    "total": total_percent,
    "Hard Size": hard_size,
    "Easy Size": easy_size,
    "hard_num_correct": hard_num_correct,
    "hard_num_incorrect": hard_num_incorrect,
    "easy_num_correct": easy_num_correct,
    "easy_num_incorrect": easy_num_incorrect
}

with open("table_extraction_percent.json", "w") as f:
    json.dump(percent, f, indent=4)

In [None]:
model_used='gemini'
newInstTot=0

from sentence_transformers import SentenceTransformer, util
embedder = SentenceTransformer("all-MiniLM-L6-v2")        

for dbName, dbSchema in schema_array.items():

    if dbName not in databases: continue
    if 'tables' not in dbSchema: continue
    num_questions=len(databases[dbName]['question'])
    num_tables=len(dbSchema['tables'])

    #filtering out heavy databases
    if num_tables<8 or num_questions<25: continue

    print(dbName)

    dbTables=[]
    for table in dbSchema["tables"]:
        dbTables.append(table["name"].lower())

    for i in range(len(databases[dbName]['question'])):
        # print(f"{totalQueries}, {newInstTot}")
        newInstTot+=1
        if (totalQueries>=(newInstTot)):
            continue

        # print(f"{totalQueries}, {newInstTot}")
        tables_used=""

        # Iterate through dbNames

        for tname in databases[dbName]['tables'][i]:
            for tables in dbSchema["tables"]:
                if tables["name"].lower() ==tname.lower() :
                    tables_used += json.dumps(tables) + "\n"


        #shots chosen dynamically
        queryVector=[0]*len(dbTables)
        for tname in databases[dbName]['tables'][i]:
            if tname.lower() in dbTables:
                queryVector[dbTables.index(tname.lower())]=1
        
        shots=[]

        print(dbTables)
        # print(databases[dbName]['question'][i])
        print(databases[dbName]['question'][i])
        print(queryVector)

        quer_embedding = embedder.encode(databases[dbName]['question'][i], convert_to_tensor=True)

        

        for j in range(len(databases[dbName]['question'])):
            if abs(j - i) > 1:
                shotVector=[0]*len(dbTables)
                for tname in databases[dbName]['tables'][j]:
                    if tname.lower() in dbTables:
                        shot_embedding = embedder.encode(databases[dbName]['question'][j], convert_to_tensor=True)
                        shotVector[dbTables.index(tname.lower())]=1
                shots.append([embedder.similarity(quer_embedding,shot_embedding)[0],databases[dbName]['question'][j],databases[dbName]['query'][j]])

        shots.sort(reverse=True)

        for items in shots:
            print(items[0])
            print(items[1])
            print(items[2])

        dynamicPrompt="Here are some examples:\n"

        for j in range(min(3,len(shots))):
            dynamicPrompt+=f"Example {j+1}:\n\nQuestion:\n{shots[j][1]}\n\nSQL Query:\n{shots[j][2]}\n\n"


        #prompts used
        tables_used=json.dumps(schema_array[dbName])
        initialPrompt=f"I will feed you schema of a database, store it and then use it to answer the question I will ask related to the database.\n"
        tablesPrompt = "Following is the schema of tables you can use to write the SQL Query." + "\n" + tables_used + "\n"
        table_extraction_prompt = "Give a one line description of each table in the database. Start the description with a '#', and end it with a '!'.\n"
        infoPrompt = "we are using sqlite as the tool for running the queries. Keep syntax related to it in mind, and provide code accordingly. Dont forget the semicolons wherever needed."
        columnNamePrompt = "Be careful with the names of columns. Use precise column names, since changing them (for example making uppercase to lowercase) may lead to error. Also enclose the column names in double quotes to avoid any clashes with reserved keywords.\n"
        examplePrompt = '''
        Avoid using = if possible, in situations like in this query:
                            SELECT
                            m.Location,
                            a.Aircraft
                            FROM `match` AS m
                            JOIN aircraft AS a
                            ON m.Winning_Aircraft = a.Aircraft_ID
                            WHERE
                            m.Round = (
                                SELECT
                                `Round`
                                FROM `match`
                                ORDER BY
                                Fastest_Qualifying
                                LIMIT 1
                            );
                            Here since = is used, m.Round is limited to one row which is not correct.
                            Also, always check datatypes to ensure that u are understanding the information a column represents correctly. For example, sometimes a column with say name pilot may actually represent its id and not its name.
                            To help with this, first state your understanding of what the columns actually represent, then walk through the logic for how you are writing the particular query (a stepwise explanation), and then print the final query.
                            
                            You can also use joins to extract information. For example:
                            Question: What are the last names and ages of the students who are allergic to milk and cat?
                            Tables: Student, Has_Allergy
                            SQL Query:
                                SELECT s.LName, s.Age
                                FROM Student s
                                JOIN Has_Allergy ha1 ON s.StuID = ha1.StuID
                                JOIN Has_Allergy ha2 ON s.StuID = ha2.StuID
                                WHERE ha1.Allergy = 'Milk' AND ha2.Allergy = 'Cat';
                            Notice how I have joined 3 tables, 2 of which are same. Essentially any table can be joined with any other under suitable constraints.

        Also take care of string type columns. Don't confuse singular with plural, for example, egg with eggs. Use the precise string in queries.

        When using AND and OR operators together, use parentheses to ensure the correct logical grouping of conditions. For example:
        Question: How many female students are allergic to milk or eggs?
        Tables: Student, Has_Allergy
        Incorrect SQL Query:
            SELECT count(*) 
            FROM has_allergy AS T1 
            JOIN Student AS T2 ON T1.StuID = T2.StuID 
            WHERE T2.sex = "F" AND T1.allergy = "Milk" OR T1.allergy = "Eggs";

        Correct SQL Query:
            SELECT count(*) 
            FROM has_allergy AS T1 
            JOIN Student AS T2 ON T1.StuID = T2.StuID 
            WHERE T2.sex = "F" AND (T1.allergy = "Milk" OR T1.allergy = "Eggs");

        Notice the use of parentheses to group the OR conditions properly. Always use parentheses to avoid logical errors when combining AND and OR operators.

        Also, when there is a column name, say X, present in more than one tables, say both T1 and T2, be very specific of the column you are using in the sql query. So, if u want to use the column from T1,
        be sure to use T1.X instead of simply X. This helps avoid the ambiguous column error.   

        '''
        chainOfThoughtPrompt="Understand what each column and table mean. State what you understand about the tabel and their relatons. Also state logical steps in between as to how you are constructing the final SQL query."
        questionPrompt="\nHere is the question part, i.e., the query on the database above, explained in english, for which the corresponding SQL query code is needed.: \n\n\n" + databases[dbName]['question'][i] + "\n\n\n" +  "\n" + "Provide the SQL query at the end of the response.\n"

        prompt=initialPrompt+tablesPrompt+table_extraction_prompt+infoPrompt+columnNamePrompt+examplePrompt+dynamicPrompt+chainOfThoughtPrompt+questionPrompt

        #get response from LLM
        response=printResponse(prompt,model_used)
        

        #get og and gen queries
        generated_query = extract_sql_query(response)
        original_query = databases[dbName]['query'][i]
        if not original_query.endswith(';'):
            original_query+=';'

        #compare og and gen queries
        isSame=compare_queries(dbName, db_folder_path, generated_query, original_query)

        #logging
        print(f"Q{totalQueries+1}:\n") 
        if isSame==False:
            print("The queries do not match.\n")
            notCorrectAns+=1

            #logging in incorrectGeminiLog.txt
            with open('./incorrectGeminiLog.txt', 'a') as f:
                f.write(f"Q{totalQueries+1}:\n")
                f.write(f"Prompt Tables:\n{tables_used}\n")
                # f.write(f"LLM_Input:\n{prompt}\n")
                f.write(f"\n\n\ncosine_similarity: {shots[0][0]} {shots[1][0]} {shots[2][0]}\n\n\n")
                f.write(f"LLM_response:\n{response}\n")
                f.write(f"Question:\n\n{databases[dbName]['question'][i]}\n\n")
                f.write(f"Original_query:\n\n{original_query}\n\n")
                f.write(f"generated_query:\n\n{generated_query}\n\n")
                f.write(f"\n\n\n")

        else:
            print("The queries match.\n")
            correctAns+=1
        totalQueries+=1

        #logging in geminiLog.txt
        with open('./geminiLog.txt', 'a') as f:
            f.write(f"Q{totalQueries}:\n")
            f.write(f"Prompt Tables:\n{tables_used}\n")
            # f.write(f"LLM_Input:\n{prompt}\n")
            f.write(f"\n\n\ncosine_similarity: {shots[0][0]} {shots[1][0]} {shots[2][0]}\n\n\n")
            f.write(f"LLM_response:\n{response}\n")
            f.write(f"Question:\n\n{databases[dbName]['question'][i]}\n\n")
            f.write(f"Original_query:\n\n{original_query}\n\n")
            f.write(f"generated_query:\n\n{generated_query}\n\n")
            f.write(f"Total Queries: {totalQueries}, Correct Answers: {correctAns}, Incorrect Answers: {notCorrectAns}")
            f.write(f"\n\n\n")


In [96]:
totalQueriesNG=0
correctAnsNG=0
notCorrectAnsNG=0

In [None]:
print(schema_array['browser_web']['tables'][0]['name'])

In [101]:
totalQueriesNG=0
correctAnsNG=0
notCorrectAns=0

In [None]:
model_used='gemini'
newInstTot=0
for dbName, dbSchema in schema_array.items():
    
    if dbName not in databases:
        continue

    print(dbName)

    for i in range(len(databases[dbName]['question'])):
        # print(f"{totalQueriesNG}, {newInstTot}")
        newInstTot+=1
        if (totalQueriesNG>=(newInstTot)):
            continue

        # print(f"{totalQueriesNG}, {newInstTot}")
        tables_used=""

        # Iterate through dbNames

        for tname in databases[dbName]['tables'][i]:
            for tables in dbSchema["tables"]:
                if tables["name"].lower() ==tname.lower() :
                    tables_used += json.dumps(tables) + "\n"

        

        #prompts used
        tables_used=json.dumps(schema_array[dbName])
        initialPrompt=f"I will feed you schema of a database, store it and then use it to answer the question I will ask related to the database.\n"
        tablesPrompt = "Following is the schema of tables you can use to write the SQL Query." + "\n\n" +tables_used + "\n\n"
        infoPrompt = "we are using sqlite as the tool for running the queries. Keep syntax related to it in mind, and provide code accordingly. Dont forget the semicolons wherever needed."
        columnNamePrompt = "Be careful with the names of columns. Use precise column names, since changing them (for example making uppercase to lowercase) may lead to error. Also enclose the column names in double quotes to avoid any clashes with reserved keywords.\n"
        examplePrompt = '''
        Avoid using = if possible, in situations like in this query:
                            SELECT
                            m.Location,
                            a.Aircraft
                            FROM `match` AS m
                            JOIN aircraft AS a
                            ON m.Winning_Aircraft = a.Aircraft_ID
                            WHERE
                            m.Round = (
                                SELECT
                                `Round`
                                FROM `match`
                                ORDER BY
                                Fastest_Qualifying
                                LIMIT 1
                            );
                            Here since = is used, m.Round is limited to one row which is not correct.
                            Also, always check datatypes to ensure that u are understanding the information a column represents correctly. For example, sometimes a column with say name pilot may actually represent its id and not its name.
                            To help with this, first state your understanding of what the columns actually represent, then walk through the logic for how you are writing the particular query (a stepwise explanation), and then print the final query.
                            
                            You can also use joins to extract information. For example:
                            Question: What are the last names and ages of the students who are allergic to milk and cat?
                            Tables: Student, Has_Allergy
                            SQL Query:
                                SELECT s.LName, s.Age
                                FROM Student s
                                JOIN Has_Allergy ha1 ON s.StuID = ha1.StuID
                                JOIN Has_Allergy ha2 ON s.StuID = ha2.StuID
                                WHERE ha1.Allergy = 'Milk' AND ha2.Allergy = 'Cat';
                            Notice how I have joined 3 tables, 2 of which are same. Essentially any table can be joined with any other under suitable constraints.

        Also take care of string type columns. Don't confuse singular with plural, for example, egg with eggs. Use the precise string in queries.

        When using AND and OR operators together, use parentheses to ensure the correct logical grouping of conditions. For example:
        Question: How many female students are allergic to milk or eggs?
        Tables: Student, Has_Allergy
        Incorrect SQL Query:
            SELECT count(*) 
            FROM has_allergy AS T1 
            JOIN Student AS T2 ON T1.StuID = T2.StuID 
            WHERE T2.sex = "F" AND T1.allergy = "Milk" OR T1.allergy = "Eggs";

        Correct SQL Query:
            SELECT count(*) 
            FROM has_allergy AS T1 
            JOIN Student AS T2 ON T1.StuID = T2.StuID 
            WHERE T2.sex = "F" AND (T1.allergy = "Milk" OR T1.allergy = "Eggs");

        Notice the use of parentheses to group the OR conditions properly. Always use parentheses to avoid logical errors when combining AND and OR operators.

        Also, when there is a column name, say X, present in more than one tables, say both T1 and T2, be very specific of the column you are using in the sql query. So, if u want to use the column from T1,
        be sure to use T1.X instead of simply X. This helps avoid the ambiguous column error.   

        '''
        chainOfThoughtPrompt="Understand what each column and table mean. State what you understand about the tabel and their relatons. Also state logical steps in between as to how you are constructing the final SQL query."
        questionPrompt="\nHere is the question part, i.e., the query on the database above, explained in english, for which the corresponding SQL query code is needed.: \n\n" + databases[dbName]['question'][i] + "\n\n" +  "\n" + "Provide the SQL query at the end of the response.\n"

        prompt=initialPrompt+tablesPrompt+infoPrompt+columnNamePrompt+examplePrompt+chainOfThoughtPrompt+questionPrompt

        #get response from LLM
        response=printResponse(prompt,model_used)

        #get og and gen queries
        generated_query = extract_sql_query(response)
        original_query = databases[dbName]['query'][i]
        if not original_query.endswith(';'):
            original_query+=';'

        #compare og and gen queries
        isSame=compare_queries(dbName, db_folder_path, generated_query, original_query)

        #logging
        print(f"Q{totalQueriesNG+1}:\n") 
        if isSame==False:
            print("The queries do not match.\n")
            notCorrectAnsNG+=1

            #logging in incorrectGeminiLog.txt
            with open('./incorrectGeminiLogNoGolden.txt', 'a') as f:
                f.write(f"Q{totalQueriesNG+1}:\n")
                f.write(f"Prompt Tables:\n{tables_used}\n")
                f.write(f"LLM_response:\n{response}\n")
                f.write(f"Question:\n\n{databases[dbName]['question'][i]}\n\n")
                f.write(f"Original_query:\n\n{original_query}\n\n")
                f.write(f"generated_query:\n\n{generated_query}\n\n")
                f.write(f"\n\n\n")

        else:
            print("The queries match.\n")
            correctAnsNG+=1
        totalQueriesNG+=1

        #logging in geminiLog.txt
        with open('./geminiLogNoGolden.txt', 'a') as f:
            f.write(f"Q{totalQueriesNG}:\n")
            f.write(f"Prompt Tables:\n{tables_used}\n")
            f.write(f"LLM_response:\n{response}\n")
            f.write(f"Question:\n\n{databases[dbName]['question'][i]}\n\n")
            f.write(f"Original_query:\n\n{original_query}\n\n")
            f.write(f"generated_query:\n\n{generated_query}\n\n")
            f.write(f"Total Queries: {totalQueriesNG}, Correct Answers: {correctAnsNG}, Incorrect Answers: {notCorrectAnsNG}")
            f.write(f"\n\n\n")        
