In [1]:
import sqlite3
import numpy as np
import pandas as pd
import tqdm
import glob
import os

In [None]:
path = '..'
os.chdir(path)
print(os.getcwd())

In [3]:
def find_db_schema_filename(db_id):
    filename = glob.glob(f'spider_data/spider_data/database/{db_id}/*.sql')[0]
    return filename

In [5]:
def extract_schema(db_id):
    with open(find_db_schema_filename(db_id), 'r', encoding="utf-8") as file:
        clauses = file.read().split(';')
        clauses = [clause[clause.lower().find("create table"):] for clause in clauses if "create table" in clause.lower()]
    schema = "\n".join(clauses).strip()
    return schema

In [7]:
# model_name = 'base'
# model_name = 'LR1e4'
model_name = 'LR5e5'

In [9]:
answers = np.load(f'answers/llama3-8b-{model_name}-T2S-answers.npy')

In [11]:
answers[0]

'SELECT count(*) FROM singer'

In [13]:
names = ['db_id', 'query', 'question']
test_df = pd.read_json('spider_data/spider_data/dev.json')
test_df = test_df[names]
test_df.head()

Unnamed: 0,db_id,query,question
0,concert_singer,SELECT count(*) FROM singer,How many singers do we have?
1,concert_singer,SELECT count(*) FROM singer,What is the total number of singers?
2,concert_singer,"SELECT name , country , age FROM singer ORDE...","Show name, country, age for all singers ordere..."
3,concert_singer,"SELECT name , country , age FROM singer ORDE...","What are the names, countries, and ages for ev..."
4,concert_singer,"SELECT avg(age) , min(age) , max(age) FROM s...","What is the average, minimum, and maximum age ..."


In [15]:
test_missing_schema = []
for i in tqdm.tqdm(range(len(test_df))):
    db_id = test_df.iloc[i, 0]
    query = test_df.iloc[i, 1]
    question = test_df.iloc[i, 2]
    
    try:
        schema = extract_schema(db_id)
    except IndexError:
        if(db_id not in test_missing_schema):
            test_missing_schema.append(db_id)

100%|████████████████████████████████████████████████████████████████████████████| 1034/1034 [00:00<00:00, 2842.52it/s]


In [16]:
test_missing_schema

['voter_1', 'world_1']

In [19]:
test_df = test_df[~test_df['db_id'].isin(test_missing_schema)]
test_df.head()

Unnamed: 0,db_id,query,question
0,concert_singer,SELECT count(*) FROM singer,How many singers do we have?
1,concert_singer,SELECT count(*) FROM singer,What is the total number of singers?
2,concert_singer,"SELECT name , country , age FROM singer ORDE...","Show name, country, age for all singers ordere..."
3,concert_singer,"SELECT name , country , age FROM singer ORDE...","What are the names, countries, and ages for ev..."
4,concert_singer,"SELECT avg(age) , min(age) , max(age) FROM s...","What is the average, minimum, and maximum age ..."


In [21]:
test_df = test_df.iloc[:len(answers),:]
test_df['predicted_query'] = answers

In [23]:
test_df.head()

Unnamed: 0,db_id,query,question,predicted_query
0,concert_singer,SELECT count(*) FROM singer,How many singers do we have?,SELECT count(*) FROM singer
1,concert_singer,SELECT count(*) FROM singer,What is the total number of singers?,SELECT COUNT(*) FROM singer
2,concert_singer,"SELECT name , country , age FROM singer ORDE...","Show name, country, age for all singers ordere...","SELECT singer_Name, singer_Country, singer_Age..."
3,concert_singer,"SELECT name , country , age FROM singer ORDE...","What are the names, countries, and ages for ev...","SELECT Singer_Name, Country, Age \nFROM singer..."
4,concert_singer,"SELECT avg(age) , min(age) , max(age) FROM s...","What is the average, minimum, and maximum age ...","SELECT AVG(T1.Age), MIN(T1.Age), MAX(T1.Age)\n..."


In [25]:
test_df.to_csv(f'answers/dev_queries_LLaMa-3-8b-{model_name}.csv', index=False)

In [128]:
test_df = pd.read_csv(f'answers/dev_queries_LLaMa-3-8b-{model_name}-clean.csv', quotechar='"')
test_df.head()

Unnamed: 0,db_id,query,question,predicted_query
0,concert_singer,SELECT count(*) FROM singer,How many singers do we have?,SELECT COUNT(*) FROM singer;
1,concert_singer,SELECT count(*) FROM singer,What is the total number of singers?,SELECT COUNT(*) FROM singer;
2,concert_singer,"SELECT name , country , age FROM singer ORDE...","Show name, country, age for all singers ordere...","\r\n\r\nSELECT s.Name, s.Country, s.Age\r\nFRO..."
3,concert_singer,"SELECT name , country , age FROM singer ORDE...","What are the names, countries, and ages for ev...","\r\n\r\nSELECT s.Name, s.Country, s.Age\r\nFRO..."
4,concert_singer,"SELECT avg(age) , min(age) , max(age) FROM s...","What is the average, minimum, and maximum age ...","\r\n\r\nSELECT AVG(Age) AS Average_Age, MIN(Ag..."


In [130]:
success_query = 0

for i in range(len(test_df)):
    db_id = test_df.iloc[i, 0]
    predicted_query = test_df.iloc[i, 3]
    
    conn = sqlite3.connect(f"spider_data/spider_data/database/{db_id}/{db_id}.sqlite")
    cursor = conn.cursor()

    try:
        cursor.execute(predicted_query)
        success_query += 1
    except:
        # print(predicted_query)
        pass
        
    rows = cursor.fetchall()

    conn.close()

print(f'Number of successful queries: {success_query}\nNumber of failed queries: {len(test_df) - success_query}')

Number of successful queries: 785
Number of failed queries: 113
