In [1]:
import sqlite3, json, warnings
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer_350M = AutoTokenizer.from_pretrained("../models/nsql-350M")
model_350M = AutoModelForCausalLM.from_pretrained("../models/nsql-350M")

# tokenizer_2B = AutoTokenizer.from_pretrained("../models/nsql-2B")
# model_2B = AutoModelForCausalLM.from_pretrained("../models/nsql-2B")

In [2]:
with open("../src/spider/table_database_map.json") as f:
    table_map_db = json.load(f)

In [3]:
df = pd.read_csv('../src/NSText2SQL/train_spider.csv')
df.head()

Unnamed: 0,Question,Table,SQL
0,"What are the first names, office locations of ...","CREATE TABLE course (\n crs_code text,\n ...","SELECT T2.emp_fname, T4.prof_office, T3.crs_de..."
1,Please show the songs that have result 'nomina...,"CREATE TABLE artist (\n artist_id number,\n...",SELECT T2.song FROM music_festival AS T1 JOIN ...
2,Which teams had more than 3 eliminations?,CREATE TABLE elimination (\n elimination_id...,SELECT team FROM elimination GROUP BY team HAV...
3,"Show the names of people, and dates and venues...","CREATE TABLE people (\n people_id number,\n...","SELECT T3.name, T2.date, T2.venue FROM debate_..."
4,Tell me the the date when the first claim was ...,CREATE TABLE settlements (\n settlement_id ...,SELECT date_claim_made FROM claims ORDER BY da...


In [4]:
def table_column_of_create_table(query):
    lines = query.splitlines()
    columns = []
    table_names = []

    # Look for "CREATE TABLE" and start capturing columns
    capture = False
    for line in lines:
        if "CREATE TABLE" in line:
            capture = True
            table_names.append(line.split()[-2].lower())
        elif line.strip().endswith(')') or line.strip().endswith(');'):
            capture = False
        elif capture:
            column_name = line.strip().split()[0]
            if column_name in ["CONSTRAINT", "PRIMARY"]: continue
            columns.append(column_name)
    return table_names, columns

In [5]:
def create_prompt(question, schema):
    full_prompt = ""
    full_prompt += f"{str(schema)}\n\n"
    full_prompt += "-- Using valid SQLite, answer the following questions for the tables provided above.\n\n"
    full_prompt += f"--{question}\n\nSELECT"
    return full_prompt


In [6]:
def pred_sql(prompt, model_name):

    if model_name == "nsql-350M":
        input_ids = tokenizer_350M(prompt, return_tensors="pt").input_ids
        generated_ids = model_350M.generate(input_ids, max_length=500)
        return tokenizer_350M.decode(generated_ids[0], skip_special_tokens=True).split('\n')[-1]
    # elif model_name == "nsql-2B":
    #     input_ids = tokenizer_2B(prompt, return_tensors="pt").input_ids
    #     generated_ids = model_2B.generate(input_ids, max_length=500)
    #     return tokenizer_2B.decode(generated_ids[0], skip_special_tokens=True).split('\n')[-1]

## Query by SQL

In [7]:
def query_db(sql_query, db_name):
    try:
        conn = sqlite3.connect(f'../src/spider/database/{db_name}/{db_name}.sqlite')
        cursor = conn.cursor()
    except:
        return "CANNOT CONNECT DATABASE"
    try:
        cursor.execute(sql_query)
        results = cursor.fetchall()
    except:
        return "CANNOT FETCHING DATA"
    conn.close()
    return results


In [9]:
df.shape

(6994, 3)

In [12]:
expect_query_results = []
nsql350M_query_results = []
nsql350M_query = []
nsql2B_query_results = []
nsql2B_query = []

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    for i, row in df.iterrows():
        question = row['Question']
        print(question)
        schema = row['Table']
        expect_query = row['SQL']
        schema_tables = table_column_of_create_table(schema)[0]
        try:
            db_of_table = table_map_db[schema_tables[0].lower()]
        except KeyError:
            error_occur = "TABLE NOT MATCH"
            expect_query_results.append(error_occur)
            nsql350M_query.append(error_occur)
            nsql350M_query_results.append(error_occur)
            # nsql2B_query.append(error_occur)
            # nsql2B_query_results.append(error_occur)
            continue

        expect_result = query_db(expect_query, db_of_table)
        full_prompt = create_prompt(question, schema)
        # try:
        #     pred_sql_query_350M = pred_sql(full_prompt, "nsql-350M")
        # except:
        #     pred_sql_query_350M = "GEN QUERY ERROR"
        pred_sql_query_350M = "test"
        print(pred_sql_query_350M)
        pred_result_350M = query_db(pred_sql_query_350M, db_of_table)
        
        # try:
        #     pred_sql_query_2B = pred_sql(full_prompt, "nsql-2B")
        # except:
        #     pred_sql_query_2B = "GEN QUERY ERROR"
        # print(pred_sql_query_2B)
        # pred_result_2B = query_db(pred_sql_query_2B, db_of_table)

        expect_query_results.append(expect_result)
        nsql350M_query.append(pred_sql_query_350M)
        nsql350M_query_results.append(pred_result_350M)
        # nsql2B_query.append(pred_sql_query_2B)
        # nsql2B_query_results.append(pred_result_2B)
        print(i)

What are the first names, office locations of all lecturers who have taught some course?
test
0
Please show the songs that have result 'nominated' at music festivals.
test
1
Which teams had more than 3 eliminations?
test
2
Show the names of people, and dates and venues of debates they are on the negative side, ordered in ascending alphabetical order of name.
test
3
Tell me the the date when the first claim was made.
test
4
What are the speeds of the longest roller coaster?
test
5
Give me the maximum low temperature and average precipitation at the Amersham station.
test
6
What are the reigns and days held of all wrestlers?
test
7
What are the different names of the colleges involved in the tryout in alphabetical order?
Find the title and star rating of the movie that got the least rating star for each reviewer.
List roles that have more than one employee. List the role description and number of employees.
test
10
What are the vocal types used in song 'Le Pop'?
test
11
What are the name

In [13]:
df['nsql-350M-query'] = nsql350M_query
# df['nsql-2B-query'] = nsql2B_query
df['expect_result'] = expect_query_results
df['nsql-350M-result'] = nsql350M_query_results
# df['nsql-2B-result'] = nsql2B_query_results

df.to_csv("model-expirements.csv", index=False)
df.head()

Unnamed: 0,Question,Table,SQL,nsql-350M-query,expect_result,nsql-350M-result
0,"What are the first names, office locations of ...","CREATE TABLE course (\n crs_code text,\n ...","SELECT T2.emp_fname, T4.prof_office, T3.crs_de...",test,CANNOT FETCHING DATA,CANNOT FETCHING DATA
1,Please show the songs that have result 'nomina...,"CREATE TABLE artist (\n artist_id number,\n...",SELECT T2.song FROM music_festival AS T1 JOIN ...,test,CANNOT FETCHING DATA,CANNOT FETCHING DATA
2,Which teams had more than 3 eliminations?,CREATE TABLE elimination (\n elimination_id...,SELECT team FROM elimination GROUP BY team HAV...,test,"[(Team Batista,)]",CANNOT FETCHING DATA
3,"Show the names of people, and dates and venues...","CREATE TABLE people (\n people_id number,\n...","SELECT T3.name, T2.date, T2.venue FROM debate_...",test,CANNOT FETCHING DATA,CANNOT FETCHING DATA
4,Tell me the the date when the first claim was ...,CREATE TABLE settlements (\n settlement_id ...,SELECT date_claim_made FROM claims ORDER BY da...,test,CANNOT FETCHING DATA,CANNOT FETCHING DATA
