In [1]:
import sqlite3, json, warnings
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM

In [2]:
tokenizer_350M = AutoTokenizer.from_pretrained("../models/nsql-350M")
model_350M = AutoModelForCausalLM.from_pretrained("../models/nsql-350M")

tokenizer_2B = AutoTokenizer.from_pretrained("../models/nsql-2B")
model_2B = AutoModelForCausalLM.from_pretrained("../models/nsql-2B")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
with open("../src/spider/table_database_map.json") as f:
    table_map_db = json.load(f)

In [4]:
df = pd.read_csv('../src/NSText2SQL/train_spider.csv')
df.head()

Unnamed: 0,Question,Table,SQL
0,"What are the first names, office locations of ...","CREATE TABLE course (\n crs_code text,\n ...","SELECT T2.emp_fname, T4.prof_office, T3.crs_de..."
1,Please show the songs that have result 'nomina...,"CREATE TABLE artist (\n artist_id number,\n...",SELECT T2.song FROM music_festival AS T1 JOIN ...
2,Which teams had more than 3 eliminations?,CREATE TABLE elimination (\n elimination_id...,SELECT team FROM elimination GROUP BY team HAV...
3,"Show the names of people, and dates and venues...","CREATE TABLE people (\n people_id number,\n...","SELECT T3.name, T2.date, T2.venue FROM debate_..."
4,Tell me the the date when the first claim was ...,CREATE TABLE settlements (\n settlement_id ...,SELECT date_claim_made FROM claims ORDER BY da...


In [5]:
def table_column_of_create_table(query):
    lines = query.splitlines()
    columns = []
    table_names = []

    # Look for "CREATE TABLE" and start capturing columns
    capture = False
    for line in lines:
        if "CREATE TABLE" in line:
            capture = True
            table_names.append(line.split()[-2].lower())
        elif line.strip().endswith(')') or line.strip().endswith(');'):
            capture = False
        elif capture:
            column_name = line.strip().split()[0]
            if column_name in ["CONSTRAINT", "PRIMARY"]: continue
            columns.append(column_name)
    return table_names, columns

In [6]:
def create_prompt(question, schema):
    full_prompt = ""
    full_prompt += f"{str(schema)}\n\n"
    full_prompt += "-- Using valid SQLite, answer the following questions for the tables provided above.\n\n"
    full_prompt += f"--{question}\n\nSELECT"
    return full_prompt


In [7]:
def pred_sql(prompt, model_name):

    if model_name == "nsql-350M":
        input_ids = tokenizer_350M(prompt, return_tensors="pt").input_ids
        generated_ids = model_350M.generate(input_ids, max_length=500)
        return tokenizer_350M.decode(generated_ids[0], skip_special_tokens=True).split('\n')[-1]
    elif model_name == "nsql-2B":
        input_ids = tokenizer_2B(prompt, return_tensors="pt").input_ids
        generated_ids = model_2B.generate(input_ids, max_length=500)
        return tokenizer_2B.decode(generated_ids[0], skip_special_tokens=True).split('\n')[-1]

## Query by SQL

In [8]:
def query_db(sql_query, db_name):
    try:
        conn = sqlite3.connect(f'../src/spider/database/{db_name}/{db_name}.sqlite')
        cursor = conn.cursor()
    except:
        return "CANNOT CONNECT DATABASE"
    try:
        cursor.execute(sql_query)
        results = cursor.fetchall()
    except:
        return "CANNOT FETCHING DATA"
    conn.close()
    return results


In [9]:
df.shape

(6994, 3)

In [10]:
# expect_query_results = []
# nsql350M_query_results = []
# nsql350M_query = []
# nsql2B_query_results = []
# nsql2B_query = []

# with warnings.catch_warnings():
#     warnings.simplefilter("ignore")
#     for i, row in df.iterrows():
#         question = row['Question']
#         print(question)
#         schema = row['Table']
#         expect_query = row['SQL']
#         schema_tables = table_column_of_create_table(schema)[0]
#         try:
#             db_of_table = table_map_db[schema_tables[0].lower()]
#         except KeyError:
#             error_occur = "TABLE NOT MATCH"
#             expect_query_results.append(error_occur)
#             nsql350M_query.append(error_occur)
#             nsql350M_query_results.append(error_occur)
#             # nsql2B_query.append(error_occur)
#             # nsql2B_query_results.append(error_occur)
#             continue

#         expect_result = query_db(expect_query, db_of_table)
#         full_prompt = create_prompt(question, schema)
#         # try:
#         #     pred_sql_query_350M = pred_sql(full_prompt, "nsql-350M")
#         # except:
#         #     pred_sql_query_350M = "GEN QUERY ERROR"
#         pred_sql_query_350M = "test"
#         print(pred_sql_query_350M)
#         pred_result_350M = query_db(pred_sql_query_350M, db_of_table)
        
#         # try:
#         #     pred_sql_query_2B = pred_sql(full_prompt, "nsql-2B")
#         # except:
#         #     pred_sql_query_2B = "GEN QUERY ERROR"
#         # print(pred_sql_query_2B)
#         # pred_result_2B = query_db(pred_sql_query_2B, db_of_table)

#         expect_query_results.append(expect_result)
#         nsql350M_query.append(pred_sql_query_350M)
#         nsql350M_query_results.append(pred_result_350M)
#         # nsql2B_query.append(pred_sql_query_2B)
#         # nsql2B_query_results.append(pred_result_2B)
#         print(i)

In [11]:
# df['nsql-350M-query'] = nsql350M_query
# # df['nsql-2B-query'] = nsql2B_query
# df['expect_result'] = expect_query_results
# df['nsql-350M-result'] = nsql350M_query_results
# # df['nsql-2B-result'] = nsql2B_query_results

# df.to_csv("model-expirements.csv", index=False)
# df.head()

In [12]:
df = pd.read_csv("../src/temp-model-expirements.csv")
exclude_values = ["CANNOT FETCHING DATA", "TABLE NOT MATCH", None]
filtered_df = df[~df['Expect-result'].isin(exclude_values)]
print(filtered_df.shape)
filtered_df.head()

(2695, 7)


Unnamed: 0,Question,Expect-query,NSQL-350M-query,NSQL-2B-query,Expect-result,NSQL-350M-result,NSQL-2B-result
2,Which teams had more than 3 eliminations?,SELECT team FROM elimination GROUP BY team HAV...,SELECT team FROM elimination GROUP BY team HAV...,SELECT team FROM elimination GROUP BY team HAV...,"[('Team Batista',)]","[('Team Batista',)]","[('Team Batista',)]"
7,What are the reigns and days held of all wrest...,"SELECT reign, days_held FROM wrestler","SELECT reign, days_held FROM wrestler;","SELECT reign, days_held FROM wrestler;","[('1', '344'), ('1', '113'), ('1', '1285'), ('...","[('1', '344'), ('1', '113'), ('1', '1285'), ('...","[('1', '344'), ('1', '113'), ('1', '1285'), ('..."
10,List roles that have more than one employee. L...,"SELECT roles.role_description, COUNT(employees...","SELECT T1.role_description, COUNT(*) FROM empl...","SELECT T1.role_description, COUNT(*) FROM role...","[('Editor', 2), ('Photo', 2)]",CANNOT FETCHING DATA,"[('Editor', 2), ('Photo', 2)]"
11,What are the vocal types used in song 'Le Pop'?,SELECT type FROM vocals AS T1 JOIN songs AS T2...,SELECT type FROM vocals AS T1 JOIN songs AS T2...,SELECT type FROM vocals AS T1 JOIN songs AS T2...,[],[],[]
13,How many faculty members did the university th...,SELECT T2.faculty FROM campuses AS T1 JOIN fac...,SELECT T2.faculty FROM campuses AS T1 JOIN fac...,SELECT T1.faculty FROM faculty AS T1 JOIN degr...,"[(1555.7,)]","[(1555.7,)]","[(1555.7,)]"


In [13]:
filtered_df.to_excel("model-exp.xlsx", index=False)

## Predict SQL result by unseen spider dataset

In [14]:
seen_questions = df['Question'].to_list()
print("SPIDER for training",df.shape)
df.head()

SPIDER for training (5201, 7)


Unnamed: 0,Question,Expect-query,NSQL-350M-query,NSQL-2B-query,Expect-result,NSQL-350M-result,NSQL-2B-result
0,"What are the first names, office locations of ...","SELECT T2.emp_fname, T4.prof_office, T3.crs_de...","SELECT T1.emp_fname, T2.prof_office FROM emplo...","SELECT T2.emp_fname, T1.prof_office FROM profe...",CANNOT FETCHING DATA,CANNOT FETCHING DATA,CANNOT FETCHING DATA
1,Please show the songs that have result 'nomina...,SELECT T2.song FROM music_festival AS T1 JOIN ...,SELECT song FROM music_festival WHERE result =...,SELECT T2.song FROM music_festival AS T1 JOIN ...,CANNOT FETCHING DATA,CANNOT FETCHING DATA,CANNOT FETCHING DATA
2,Which teams had more than 3 eliminations?,SELECT team FROM elimination GROUP BY team HAV...,SELECT team FROM elimination GROUP BY team HAV...,SELECT team FROM elimination GROUP BY team HAV...,"[('Team Batista',)]","[('Team Batista',)]","[('Team Batista',)]"
3,"Show the names of people, and dates and venues...","SELECT T3.name, T2.date, T2.venue FROM debate_...","SELECT T1.name, T2.date, T3.venue FROM debate_...","SELECT T2.name, T2.date, T2.venue FROM debate_...",CANNOT FETCHING DATA,CANNOT FETCHING DATA,CANNOT FETCHING DATA
4,Tell me the the date when the first claim was ...,SELECT date_claim_made FROM claims ORDER BY da...,SELECT date_claim_made FROM claims ORDER BY da...,SELECT date_claim_made FROM claims ORDER BY da...,CANNOT FETCHING DATA,CANNOT FETCHING DATA,CANNOT FETCHING DATA


In [15]:
from datasets import load_dataset

dataset = load_dataset("spider")
spider_df = pd.concat([pd.DataFrame(dataset['train']), pd.DataFrame(dataset['validation'])], ignore_index=True)[['question','query' ,'db_id']]
print(spider_df.shape)
spider_df.head()

Found cached dataset spider (/Users/thanawatthongpia/.cache/huggingface/datasets/spider/spider/1.0.0/4e5143d825a3895451569c8b9b55432b91a4bc2d04d390376c950837f4680daa)


  0%|          | 0/2 [00:00<?, ?it/s]

(8034, 3)


Unnamed: 0,question,query,db_id
0,How many heads of the departments are older th...,SELECT count(*) FROM head WHERE age > 56,department_management
1,"List the name, born state and age of the heads...","SELECT name , born_state , age FROM head ORD...",department_management
2,"List the creation year, name and budget of eac...","SELECT creation , name , budget_in_billions ...",department_management
3,What are the maximum and minimum budget of the...,"SELECT max(budget_in_billions) , min(budget_i...",department_management
4,What is the average number of employees of the...,SELECT avg(num_employees) FROM department WHER...,department_management


In [16]:
unseen_df = spider_df[~spider_df['question'].isin(seen_questions)]
unseen_df = unseen_df.reset_index(drop=True)
unseen_df.shape

(3308, 3)

In [17]:
unseen_df.head()

Unnamed: 0,question,query,db_id
0,"List the name, born state and age of the heads...","SELECT name , born_state , age FROM head ORD...",department_management
1,"List the creation year, name and budget of eac...","SELECT creation , name , budget_in_billions ...",department_management
2,What are the distinct creation years of the de...,SELECT DISTINCT T1.creation FROM department AS...,department_management
3,In which year were most departments established?,SELECT creation FROM department GROUP BY creat...,department_management
4,Show the name and number of employees for the ...,"SELECT T1.name , T1.num_employees FROM depart...",department_management


In [18]:
import os

folder_path = "../src/spider/database"

db = dict()

if os.path.exists(folder_path) and os.path.isdir(folder_path):
    files = os.listdir(folder_path)
    for db_id in files:
        db_path = os.path.join(folder_path, db_id)
        sqlite_db = [os.path.join(db_path, sql) for sql in os.listdir(db_path) if ".sqlite" in sql]
        assert len(sqlite_db) == 1
        db[db_id] = sqlite_db[0]


In [19]:
def get_schema(sqlite_db_path):
    connection = sqlite3.connect(sqlite_db_path)
    cursor = connection.cursor()

    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    full_sql = ""
    for table in tables:
        table_name = table[0]
        cursor.execute(f"PRAGMA table_info({table_name});")
        columns = cursor.fetchall()

        sql = f"CREATE TABLE {table_name} ("
        for column in columns:
            column_name = column[1]
            column_datatype = column[2].lower()
            sql += f"{column_name} {column_datatype}, "
        sql = sql[:-2] + ");"
        full_sql += sql
    
    cursor.close()
    connection.close()
    return full_sql

In [20]:
# for db_id, database_path in db.items():
#     print(get_schema(database_path))
#     # exists_table.append(table)
#     print('---------------------------------')

In [21]:
questions = []
expect_querys = []
expect_query_results = []
nsql350M_query_results = []
nsql350M_query = []
nsql2B_query_results = []
nsql2B_query = []
predict_result = []

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    for i, row in unseen_df.iterrows():
        question = row['question']
        schema = get_schema(db[row['db_id']])
        expect_query = row['query']
        print("EXPECT QUERY\n", expect_query)
        expect_result = query_db(expect_query, row['db_id'])
        full_prompt = create_prompt(question, schema)
        try:
            pred_sql_query_350M = pred_sql(full_prompt, "nsql-350M")
        except:
            pred_sql_query_350M = "GEN QUERY ERROR"
        print("NSQL-350M QUERY\n", pred_sql_query_350M)
        pred_result_350M = query_db(pred_sql_query_350M, row['db_id'])
        
        try:
            pred_sql_query_2B = pred_sql(full_prompt, "nsql-2B")
        except:
            pred_sql_query_2B = "GEN QUERY ERROR"
        print("NSQL-2B QUERY\n", pred_sql_query_2B)
        pred_result_2B = query_db(pred_sql_query_2B, row['db_id'])

        questions.append(question)
        expect_querys.append(expect_query)
        expect_query_results.append(expect_result)
        nsql350M_query.append(pred_sql_query_350M)
        nsql350M_query_results.append(pred_result_350M)
        nsql2B_query.append(pred_sql_query_2B)
        nsql2B_query_results.append(pred_result_2B)
        print(f"COMPLETE {i} FROM {unseen_df.shape[0]}")
        print()
        if i > 5: 
            data = {
                "Question": questions,
                "Expect-query": expect_querys,
                "NSQL-350M-query": nsql350M_query,
                "NSQL-2B-query": nsql2B_query,
                "Expect-result": expect_query_results,
                "NSQL-350M-result": nsql350M_query_results,
                "NSQL-2B-result": nsql2B_query_results
            }
            temp_df = pd.DataFrame(data)
            temp_df.to_csv("temp-model-expirements.csv", index=False)
            print("******** WRITED TEMP DATAFRAME ********")
            print()
            break
        if not i % 100:
            data = {
                "Question": questions,
                "Expect-query": expect_querys,
                "NSQL-350M-query": nsql350M_query,
                "NSQL-2B-query": nsql2B_query,
                "Expect-result": expect_query_results,
                "NSQL-350M-result": nsql350M_query_results,
                "NSQL-2B-result": nsql2B_query_results
            }
            temp_df = pd.DataFrame(data)
            temp_df.to_csv("temp-model-expirements.csv", index=False)
            print("******** WRITED TEMP DATAFRAME ********")
            print()


EXPECT QUERY
 SELECT name ,  born_state ,  age FROM head ORDER BY age
NSQL-350M QUERY
 SELECT name, born_state, age FROM head ORDER BY age;
NSQL-2B QUERY
 GEN QUERY ERROR
COMPLETE 0 FROM 3308

******** WRITED TEMP DATAFRAME ********

EXPECT QUERY
 SELECT creation ,  name ,  budget_in_billions FROM department
NSQL-350M QUERY
 SELECT Creation, Name, Budget_in_Billions FROM department;


In [None]:
df['nsql-350M-query'] = nsql350M_query
df['nsql-2B-query'] = nsql2B_query
df['expect_result'] = expect_query_results
df['nsql-350M-result'] = nsql350M_query_results
df['nsql-2B-result'] = nsql2B_query_results

df.to_csv("model-unseen-expirements.csv", index=False)
df.head()