In [1]:
import sqlite3, json, warnings
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM

In [2]:
tokenizer_350M = AutoTokenizer.from_pretrained("../models/nsql-350M")
model_350M = AutoModelForCausalLM.from_pretrained("../models/nsql-350M")

# tokenizer_2B = AutoTokenizer.from_pretrained("../models/nsql-2B")
# model_2B = AutoModelForCausalLM.from_pretrained("../models/nsql-2B")

In [3]:
with open("../src/spider/table_database_map.json") as f:
    table_map_db = json.load(f)

In [4]:
df = pd.read_csv('../src/NSText2SQL/train_spider.csv')
df.head()

Unnamed: 0,Question,Table,SQL
0,"What are the first names, office locations of ...","CREATE TABLE course (\n crs_code text,\n ...","SELECT T2.emp_fname, T4.prof_office, T3.crs_de..."
1,Please show the songs that have result 'nomina...,"CREATE TABLE artist (\n artist_id number,\n...",SELECT T2.song FROM music_festival AS T1 JOIN ...
2,Which teams had more than 3 eliminations?,CREATE TABLE elimination (\n elimination_id...,SELECT team FROM elimination GROUP BY team HAV...
3,"Show the names of people, and dates and venues...","CREATE TABLE people (\n people_id number,\n...","SELECT T3.name, T2.date, T2.venue FROM debate_..."
4,Tell me the the date when the first claim was ...,CREATE TABLE settlements (\n settlement_id ...,SELECT date_claim_made FROM claims ORDER BY da...


In [5]:
def table_column_of_create_table(query):
    lines = query.splitlines()
    columns = []
    table_names = []

    # Look for "CREATE TABLE" and start capturing columns
    capture = False
    for line in lines:
        if "CREATE TABLE" in line:
            capture = True
            table_names.append(line.split()[-2].lower())
        elif line.strip().endswith(')') or line.strip().endswith(');'):
            capture = False
        elif capture:
            column_name = line.strip().split()[0]
            if column_name in ["CONSTRAINT", "PRIMARY"]: continue
            columns.append(column_name)
    return table_names, columns

In [6]:
def create_prompt(question, schema):
    full_prompt = ""
    full_prompt += f"{str(schema)}\n\n"
    full_prompt += "-- Using valid SQLite, answer the following questions for the tables provided above.\n\n"
    full_prompt += f"--{question}\n\nSELECT"
    return full_prompt


In [7]:
def pred_sql(prompt, model_name):

    if model_name == "nsql-350M":
        input_ids = tokenizer_350M(prompt, return_tensors="pt").input_ids
        generated_ids = model_350M.generate(input_ids, max_length=500)
        return tokenizer_350M.decode(generated_ids[0], skip_special_tokens=True).split('\n')[-1]
    # elif model_name == "nsql-2B":
    #     input_ids = tokenizer_2B(prompt, return_tensors="pt").input_ids
    #     generated_ids = model_2B.generate(input_ids, max_length=500)
    #     return tokenizer_2B.decode(generated_ids[0], skip_special_tokens=True).split('\n')[-1]

## Query by SQL

In [2]:
def query_db(sql_query, db_name):
    try:
        conn = sqlite3.connect(f'../src/spider/database/{db_name}/{db_name}.sqlite')
        cursor = conn.cursor()
    except:
        return "CANNOT CONNECT DATABASE"
    try:
        cursor.execute(sql_query)
        results = cursor.fetchall()
    except:
        return "CANNOT FETCHING DATA"
    conn.close()
    return results


In [None]:
df.shape

In [None]:
# expect_query_results = []
# nsql350M_query_results = []
# nsql350M_query = []
# nsql2B_query_results = []
# nsql2B_query = []

# with warnings.catch_warnings():
#     warnings.simplefilter("ignore")
#     for i, row in df.iterrows():
#         question = row['Question']
#         print(question)
#         schema = row['Table']
#         expect_query = row['SQL']
#         schema_tables = table_column_of_create_table(schema)[0]
#         try:
#             db_of_table = table_map_db[schema_tables[0].lower()]
#         except KeyError:
#             error_occur = "TABLE NOT MATCH"
#             expect_query_results.append(error_occur)
#             nsql350M_query.append(error_occur)
#             nsql350M_query_results.append(error_occur)
#             # nsql2B_query.append(error_occur)
#             # nsql2B_query_results.append(error_occur)
#             continue

#         expect_result = query_db(expect_query, db_of_table)
#         full_prompt = create_prompt(question, schema)
#         # try:
#         #     pred_sql_query_350M = pred_sql(full_prompt, "nsql-350M")
#         # except:
#         #     pred_sql_query_350M = "GEN QUERY ERROR"
#         pred_sql_query_350M = "test"
#         print(pred_sql_query_350M)
#         pred_result_350M = query_db(pred_sql_query_350M, db_of_table)
        
#         # try:
#         #     pred_sql_query_2B = pred_sql(full_prompt, "nsql-2B")
#         # except:
#         #     pred_sql_query_2B = "GEN QUERY ERROR"
#         # print(pred_sql_query_2B)
#         # pred_result_2B = query_db(pred_sql_query_2B, db_of_table)

#         expect_query_results.append(expect_result)
#         nsql350M_query.append(pred_sql_query_350M)
#         nsql350M_query_results.append(pred_result_350M)
#         # nsql2B_query.append(pred_sql_query_2B)
#         # nsql2B_query_results.append(pred_result_2B)
#         print(i)

In [None]:
# df['nsql-350M-query'] = nsql350M_query
# # df['nsql-2B-query'] = nsql2B_query
# df['expect_result'] = expect_query_results
# df['nsql-350M-result'] = nsql350M_query_results
# # df['nsql-2B-result'] = nsql2B_query_results

# df.to_csv("model-expirements.csv", index=False)
# df.head()

In [None]:
df = pd.read_csv("../src/temp-model-expirements.csv")
exclude_values = ["CANNOT FETCHING DATA", "TABLE NOT MATCH", None]
filtered_df = df[~df['Expect-result'].isin(exclude_values)]
print(filtered_df.shape)
filtered_df.head()

In [None]:
filtered_df.to_excel("model-exp.xlsx", index=False)

## Predict SQL result by unseen spider dataset

In [None]:
seen_questions = df['Question'].to_list()
print("SPIDER for training",df.shape)
df.head()

In [None]:
from datasets import load_dataset

dataset = load_dataset("spider")
spider_df = pd.concat([pd.DataFrame(dataset['train']), pd.DataFrame(dataset['validation'])], ignore_index=True)[['question','query' ,'db_id']]
print(spider_df.shape)
spider_df.head()

In [None]:
unseen_df = spider_df[~spider_df['question'].isin(seen_questions)]
unseen_df = unseen_df.reset_index(drop=True)
unseen_df.shape

In [None]:
unseen_df.head()

In [None]:
import os

folder_path = "../src/spider/database"

db = dict()

if os.path.exists(folder_path) and os.path.isdir(folder_path):
    files = os.listdir(folder_path)
    for db_id in files:
        db_path = os.path.join(folder_path, db_id)
        sqlite_db = [os.path.join(db_path, sql) for sql in os.listdir(db_path) if ".sqlite" in sql]
        assert len(sqlite_db) == 1
        db[db_id] = sqlite_db[0]


In [None]:
def get_schema(sqlite_db_path):
    connection = sqlite3.connect(sqlite_db_path)
    cursor = connection.cursor()

    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    full_sql = ""
    for table in tables:
        table_name = table[0]
        cursor.execute(f"PRAGMA table_info({table_name});")
        columns = cursor.fetchall()

        sql = f"CREATE TABLE {table_name} ("
        for column in columns:
            column_name = column[1]
            column_datatype = column[2].lower()
            sql += f"{column_name} {column_datatype}, "
        sql = sql[:-2] + ");"
        full_sql += sql
    
    cursor.close()
    connection.close()
    return full_sql

In [None]:
# for db_id, database_path in db.items():
#     print(get_schema(database_path))
#     # exists_table.append(table)
#     print('---------------------------------')

In [None]:
questions = []
expect_querys = []
expect_query_results = []
nsql350M_query_results = []
nsql350M_query = []
nsql2B_query_results = []
nsql2B_query = []
predict_result = []

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    for i, row in unseen_df.iterrows():
        question = row['question']
        schema = get_schema(db[row['db_id']])
        expect_query = row['query']
        print("EXPECT QUERY\n", expect_query)
        expect_result = query_db(expect_query, row['db_id'])
        full_prompt = create_prompt(question, schema)
        try:
            pred_sql_query_350M = pred_sql(full_prompt, "nsql-350M")
        except:
            pred_sql_query_350M = "GEN QUERY ERROR"
        print("NSQL-350M QUERY\n", pred_sql_query_350M)
        pred_result_350M = query_db(pred_sql_query_350M, row['db_id'])
        
        try:
            pred_sql_query_2B = pred_sql(full_prompt, "nsql-2B")
        except:
            pred_sql_query_2B = "GEN QUERY ERROR"
        print("NSQL-2B QUERY\n", pred_sql_query_2B)
        pred_result_2B = query_db(pred_sql_query_2B, row['db_id'])

        questions.append(question)
        expect_querys.append(expect_query)
        expect_query_results.append(expect_result)
        nsql350M_query.append(pred_sql_query_350M)
        nsql350M_query_results.append(pred_result_350M)
        nsql2B_query.append(pred_sql_query_2B)
        nsql2B_query_results.append(pred_result_2B)
        print(f"COMPLETE {i} FROM {unseen_df.shape[0]}")
        print()
        if i > 5: 
            data = {
                "Question": questions,
                "Expect-query": expect_querys,
                "NSQL-350M-query": nsql350M_query,
                "NSQL-2B-query": nsql2B_query,
                "Expect-result": expect_query_results,
                "NSQL-350M-result": nsql350M_query_results,
                "NSQL-2B-result": nsql2B_query_results
            }
            temp_df = pd.DataFrame(data)
            temp_df.to_csv("temp-model-expirements.csv", index=False)
            print("******** WRITED TEMP DATAFRAME ********")
            print()
            break
        if not i % 100:
            data = {
                "Question": questions,
                "Expect-query": expect_querys,
                "NSQL-350M-query": nsql350M_query,
                "NSQL-2B-query": nsql2B_query,
                "Expect-result": expect_query_results,
                "NSQL-350M-result": nsql350M_query_results,
                "NSQL-2B-result": nsql2B_query_results
            }
            temp_df = pd.DataFrame(data)
            temp_df.to_csv("temp-model-expirements.csv", index=False)
            print("******** WRITED TEMP DATAFRAME ********")
            print()


In [None]:
df['nsql-350M-query'] = nsql350M_query
df['nsql-2B-query'] = nsql2B_query
df['expect_result'] = expect_query_results
df['nsql-350M-result'] = nsql350M_query_results
df['nsql-2B-result'] = nsql2B_query_results

df.to_csv("model-unseen-expirements.csv", index=False)
df.head()

## PointX Keymatrix

In [18]:
import sqlparse, json
import pandas as pd

In [23]:
with open("../src/pointx/schemas/pointx_keymatrix_dly_columns_type.json") as f:
    pointx_keymatrix_schema = json.load(f)

pointx_keymatrix_columns = list(pointx_keymatrix_schema.keys())[1:]
pointx_keymatrix_columns[:5]

['month_id',
 'ntx_pointx_financial',
 'ntx_pointx_financial_out',
 'ncust_user',
 'ncust_pointx']

In [68]:
sql_extract_token_type = {
            sqlparse.sql.IdentifierList, sqlparse.sql.Where,
            sqlparse.sql.Having, sqlparse.sql.Comparison, sqlparse.sql.Function,
            sqlparse.sql.Parenthesis, sqlparse.sql.Operation, sqlparse.sql.Case
        }

def columns_from_query(sql_query:str, all_columns:list):
    # identifiers contain table name and column name
    if type(sql_query) == str:
        sql_query = sqlparse.parse(sql_query)[0]
    columns = []
    for token in sql_query:
        if isinstance(token, sqlparse.sql.Identifier) and token.get_real_name().lower() in all_columns:
            columns.append(token.get_real_name().lower())
        elif hasattr(token, "tokens"):
            columns.extend(columns_from_query(token.tokens, all_columns))
    return list(set(columns))

In [89]:
def query_pointx_db(sql_query):
    try:
        conn = sqlite3.connect(f'../src/pointx/database/pointx.db')
        cursor = conn.cursor()
    except:
        return "CANNOT CONNECT DATABASE"
    try:
        cursor.execute(sql_query)
        results = cursor.fetchall()
    except:
        return "CANNOT FETCHING DATA"
    conn.close()
    return results

In [69]:
pointx_keymatrix_pair = pd.read_excel("../src/pointx/PointX - NLQ training data set.xlsx", sheet_name="pointx_keymatrix")[['NLQ', 'NLQ with helper', 'SQL']].iloc[:39]
pointx_keymatrix_pair.head()

Unnamed: 0,NLQ,NLQ with helper,SQL
0,What is the total number of all financial tran...,What is the total number of ntx_pointx_financi...,"SELECT month_id, SUM(ntx_pointx_financial) FRO..."
1,What is the total amount of points generated b...,What is the total amount of amt_point_topup in...,SELECT SUM(amt_point_topup) FROM pointx_keymat...
2,What is the total amount of points generated b...,What is the total amount of amt_point_pay for ...,"SELECT month_id, SUM(amt_point_pay) FROM point..."
3,What is the average rate of released points fo...,What is the average rate_point_per_baht_pay?,SELECT AVG(rate_point_per_baht_pay) FROM point...
4,Can you determine the average number of custom...,Can you determine the average number of ncust_...,"SELECT month_id, AVG(ncust_visit) FROM pointx_..."


### Experiment of NSQL with F1 100%

In [91]:
def create_pointx_keymatrix_schema(list_cols:list):
    prompt = "CREATE TABLE pointx_keymatrix_dly ("
    for col in list_cols:
        prompt += f"{col} {pointx_keymatrix_schema[col]},"
    return prompt[:-1] + ")"

In [92]:
experiment_result = { 'Question' : [],
                      'Question with helper' : [],
                      'Expect SQL': [],
                      'Predict SQL': [],
                      'Predict SQL with helper': [],
                      'Expect result': [],
                      'Predict result': [],
                      'Predict result with helper': [],
    
}
for i, row in pointx_keymatrix_pair.iterrows():
    
    question = row['NLQ']
    question_helper = row['NLQ with helper']
    used_cols = columns_from_query(row['SQL'], pointx_keymatrix_columns)
    schema_prompt = create_pointx_keymatrix_schema(used_cols)
    full_prompt = create_prompt(question, schema_prompt)
    result = pred_sql(full_prompt, 'nsql-350M')

    full_prompt_helper = create_prompt(question_helper, schema_prompt)
    result_helper = pred_sql(full_prompt_helper, 'nsql-350M')

    experiment_result['Question'].append(question)
    experiment_result['Question with helper'].append(question_helper)
    experiment_result['Expect SQL'].append(row['SQL'])
    experiment_result['Predict SQL'].append(result)
    experiment_result['Predict SQL with helper'].append(result_helper)
    
    experiment_result['Expect result'].append(query_pointx_db(row['SQL']))
    experiment_result['Predict result'].append(query_pointx_db(result))
    experiment_result['Predict result with helper'].append(query_pointx_db(result_helper))

    print(question)
    print(used_cols)
    print("EXPECT SQL:",row['SQL'])
    print("RESULT:", result)
    print("RESULT WITH HELPER:", result_helper)
    print()

What is the total number of all financial transactions for each month?
['ntx_pointx_financial', 'month_id']
EXPECT SQL: SELECT month_id, SUM(ntx_pointx_financial) FROM pointx_keymatrix_dly  GROUP BY month_id; 
RESULT: SELECT COUNT(*), month_id FROM pointx_keymatrix_dly GROUP BY month_id;
RESULT WITH HELPER: SELECT COUNT(*) FROM pointx_keymatrix_dly;

What is the total amount of points generated by all top-up transactions in August 2022?
['month_id', 'amt_point_topup']
EXPECT SQL: SELECT SUM(amt_point_topup) FROM pointx_keymatrix_dly WHERE month_id = '2022-08'; 
RESULT: SELECT SUM(amt_point_topup) FROM pointx_keymatrix_dly WHERE month_id = "August 2022";
RESULT WITH HELPER: SELECT SUM(amt_point_topup) FROM pointx_keymatrix_dly WHERE month_id = "2022-08";

What is the total amount of points generated by all payment transactions for each month in 2022?
['amt_point_pay', 'month_id']
EXPECT SQL: SELECT month_id, SUM(amt_point_pay) FROM pointx_keymatrix_dly WHERE month_id LIKE '2022%' GROUP 

In [93]:
experiment_df = pd.DataFrame(experiment_result)
experiment_df.to_excel("experiment_pointx_keymatrix_f1_100.xlsx", index=False)
experiment_df.head()

Unnamed: 0,Question,Question with helper,Expect SQL,Predict SQL,Predict SQL with helper,Expect result,Predict result,Predict result with helper
0,What is the total number of all financial tran...,What is the total number of ntx_pointx_financi...,"SELECT month_id, SUM(ntx_pointx_financial) FRO...","SELECT COUNT(*), month_id FROM pointx_keymatri...",SELECT COUNT(*) FROM pointx_keymatrix_dly;,"[(2022-07, 447), (2022-08, 259)]","[(31, 2022-07), (19, 2022-08)]","[(50,)]"
1,What is the total amount of points generated b...,What is the total amount of amt_point_topup in...,SELECT SUM(amt_point_topup) FROM pointx_keymat...,SELECT SUM(amt_point_topup) FROM pointx_keymat...,SELECT SUM(amt_point_topup) FROM pointx_keymat...,"[(178992.0,)]","[(None,)]","[(178992.0,)]"
2,What is the total amount of points generated b...,What is the total amount of amt_point_pay for ...,"SELECT month_id, SUM(amt_point_pay) FROM point...",SELECT SUM(amt_point_pay) FROM pointx_keymatri...,SELECT COUNT(*) FROM pointx_keymatrix_dly WHER...,"[(2022-07, 30075.0), (2022-08, 30045.0)]","[(None,)]","[(0,)]"
3,What is the average rate of released points fo...,What is the average rate_point_per_baht_pay?,SELECT AVG(rate_point_per_baht_pay) FROM point...,SELECT AVG(rate_point_per_baht_pay) FROM point...,SELECT AVG(rate_point_per_baht_pay) FROM point...,"[(13.825421897546896,)]","[(13.825421897546896,)]","[(13.825421897546896,)]"
4,Can you determine the average number of custom...,Can you determine the average number of ncust_...,"SELECT month_id, AVG(ncust_visit) FROM pointx_...",SELECT AVG(ncust_visit) FROM pointx_keymatrix_...,"SELECT COUNT(*), month_id FROM pointx_keymatri...","[(2022-07, 47.61290322580645), (2022-08, 55.15...","[(50.48,)]","[(31, 2022-07), (19, 2022-08)]"


In [104]:
print(query_pointx_db("SELECT month_id, SUM(revenue_mobile_app) FROM pointx_keymatrix_dly GROUP BY month_id;"))

[('2022-07',), ('2022-08',)]


In [6]:
query_db("SELECT * FROM shop LIMIT 3", "coffee_shop")

[(1, '1200 Main Street', '13', 42.0, '2010'),
 (2, '1111 Main Street', '19', 38.0, '2008'),
 (3, '1330 Baltimore Street', '42', 36.0, '2010')]