In [1]:
import sqlite3, os, json, sqlparse, re
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer, util
from nltk.stem import WordNetLemmatizer
from sql_metadata import Parser
import matplotlib.pyplot as plt

## Get the SQLite Database

In [37]:
folder_path = "src/spider/database"
select_db = ['musical',
             'farm', 
             'hospital_1', 
             'tvshow', 
             'cinema', 
             'restaurants', 
             'company_employee', 
             'company_1', 
             'company_offic', 
             'singer', 
             'coffee_shop']

db = []

if os.path.exists(folder_path) and os.path.isdir(folder_path):
    files = os.listdir(folder_path)
    for file in files:
        # if file in select_db:
        db_path = os.path.join(folder_path, file)
        sqlite_db = [os.path.join(db_path, sql) for sql in os.listdir(db_path) if ".sqlite" in sql]
        db.append(*sqlite_db)

db[:5]

['src/spider/database/browser_web/browser_web.sqlite',
 'src/spider/database/musical/musical.sqlite',
 'src/spider/database/farm/farm.sqlite',
 'src/spider/database/voter_1/voter_1.sqlite',
 'src/spider/database/game_injury/game_injury.sqlite']

In [38]:
def get_schema(sqlite_db):
    connection = sqlite3.connect(sqlite_db)
    cursor = connection.cursor()

    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()

    for table in tables:
        table_name = table[0]
        print(f"Table: {table_name}")

        cursor.execute(f"PRAGMA table_info({table_name});")
        columns = cursor.fetchall()

        for column in columns:
            column_name = column[1]
            print(f"  Column: {column_name}")

        print()
    
    cursor.close()
    connection.close()

In [39]:
for table in db[:2]:
    # if table in exists_table : continue
    get_schema(table)
    # exists_table.append(table)
    print('---------------------------------')

Table: Web_client_accelerator
  Column: id
  Column: name
  Column: Operating_system
  Column: Client
  Column: Connection

Table: browser
  Column: id
  Column: name
  Column: market_share

Table: accelerator_compatible_browser
  Column: accelerator_id
  Column: browser_id
  Column: compatible_since_year

---------------------------------
Table: musical
  Column: Musical_ID
  Column: Name
  Column: Year
  Column: Award
  Column: Category
  Column: Nominee
  Column: Result

Table: actor
  Column: Actor_ID
  Column: Name
  Column: Musical_ID
  Column: Character
  Column: Duration
  Column: age

---------------------------------


## Embedding description of tables and columns

In [2]:
src_folder = "src"
schema_description_file = "mockup_schema_description.json"
with open(os.path.join(src_folder, schema_description_file)) as f:
    dbs = json.load(f)
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
lemmanizer = WordNetLemmatizer()

In [9]:
# description_emb = []

# for db in dbs:
#     schema_emb = {}
#     table_name = db['table']
#     table_description = db['description']
#     schema_emb[table_name] = model.encode(table_description).tolist()
#     columns = list(db['columns'].keys())
#     for col in columns:
#         column_description = db['columns'][col]
#         schema_emb[col] = model.encode(column_description).tolist()
#     description_emb.append(schema_emb)

# schema_vector_file = "mockup_schema_description_vector.json"
# with open(os.path.join(src_folder, schema_vector_file), "w") as f:
#     json.dump(description_emb,f)

In [3]:
schema_vector_file = "mockup_schema_description_vector.json"
with open(os.path.join(src_folder, schema_vector_file)) as f:
    schema_vector = json.load(f)

## Filtering Columns by Question

Question --> Similarity description-base --> String matching (score weight up to match condition) --> Column

### String Matching

In [4]:
# string mathing , max_score = weight of string match

def column_from_question(question,used_table_col = {}, default_score=0.6):
    # question_tokens = [token.lower() for token in tokenizer.tokenize(question)]
    question_tokens = [lemmanizer.lemmatize(token.lower()) for token in question.split()]
    # print(question_tokens)
    for table in schema_vector:
        max_score = default_score
        for token in question_tokens:
            cols = [ key.lower() for key in table.keys()]
            table_name = cols.pop(0)
            if token == table_name: 
                max_score = 1.0
                # plus the score of columns in exact match with table name
                if used_table_col.get(token) is not None: 
                    for key in used_table_col[token]: 
                        used_table_col[token].update({key: used_table_col[token][key] + 0.1})
            # exact match table and column
            if token in cols: 
                used_table_col.setdefault(table_name, {}).update({token : max_score})

            

    return used_table_col

### Similarity description-base score

In [5]:
# filter_table : filter by table before 
def filter_tables_by_description(question, column_threshold = 0.4, table_threshold = 0.2, filter_tables = True):
    question_emb = model.encode(question)
    used_schema = {}
    for i in range(len(schema_vector)):
        table_name = list(schema_vector[i].keys())[0]

        table_description_vector = schema_vector[i][table_name]
        if filter_tables and util.cos_sim(table_description_vector, question_emb) < table_threshold: continue
        
        used_col = {}
        for col, vec in schema_vector[i].items():
            if col == table_name: continue
            score = round(float(util.cos_sim(vec, question_emb)),2)
            if score > column_threshold:
                # column_description = [dbs[i]['columns'][col] for i in range(len(dbs)) if dbs[i]['table'] == table_name][0]
                # print(f"{table_name} - {col} : {score}\nDescription : {column_description}\n")
                used_col.update({col: score})
        if len(used_col) > 0: used_schema[table_name] = used_col
    return used_schema

### Question --> Similarity description-base --> String matching (score weight up to match condition) --> Column

In [6]:
question = "Who direct the mickey mouse"

In [7]:
selected_table_column = filter_tables_by_description(question, column_threshold = 0.3, filter_tables = False)
selected_table_column = column_from_question(question, used_table_col=selected_table_column)
selected_table_column

{'Cartoon': {'id': 0.31,
  'Title': 0.35,
  'Directed_by': 0.54,
  'Written_by': 0.44,
  'Original_air_date': 0.32,
  'Production_code': 0.39},
 'film': {'Directed_by': 0.36}}

### Generate full prompt (many tables)

In [8]:
full_sql = ""
for table in selected_table_column:
    sql = f"CREATE TABLE {table} ("
    for column in selected_table_column.get(table).keys():
        sql += f'"{column}" datatype, '

    sql = sql[:-2] + " )\n\n"
    full_sql += sql

print(full_sql)

CREATE TABLE Cartoon ("id" datatype, "Title" datatype, "Directed_by" datatype, "Written_by" datatype, "Original_air_date" datatype, "Production_code" datatype )

CREATE TABLE film ("Directed_by" datatype )




## Filter Spider train dataset for expirement testing

In [20]:
# spider_sql = []
# df_data = {
#     'Question' : [],
#     'Table' : [],
#     'SQL' : []
# }

# with open("src/NSText2SQL/train.jsonl") as f:
#     for line in f:
#         data = json.loads(line)
#         if data['source'] == 'spider': 
#             spider_sql.append(data)
#             df_data['Question'].append(data['instruction'].split('--')[-1].strip())
#             df_data['Table'].append(data['instruction'].split('--')[0].strip())
#             df_data['SQL'].append(data['output'])

# df = pd.DataFrame(df_data)
# df.to_csv('src/NSText2SQL/train_spider.csv', index=False)
# df.head()

In [9]:
df = pd.read_csv('src/NSText2SQL/train_spider.csv')
print(df.shape)
df.head()

(6994, 3)


Unnamed: 0,Question,Table,SQL
0,"What are the first names, office locations of ...","CREATE TABLE course (\n crs_code text,\n ...","SELECT T2.emp_fname, T4.prof_office, T3.crs_de..."
1,Please show the songs that have result 'nomina...,"CREATE TABLE artist (\n artist_id number,\n...",SELECT T2.song FROM music_festival AS T1 JOIN ...
2,Which teams had more than 3 eliminations?,CREATE TABLE elimination (\n elimination_id...,SELECT team FROM elimination GROUP BY team HAV...
3,"Show the names of people, and dates and venues...","CREATE TABLE people (\n people_id number,\n...","SELECT T3.name, T2.date, T2.venue FROM debate_..."
4,Tell me the the date when the first claim was ...,CREATE TABLE settlements (\n settlement_id ...,SELECT date_claim_made FROM claims ORDER BY da...


## Get the column and table name from SQL query

by String matching, SQLParse library

In [10]:
sql_extract_token_type = {
            sqlparse.sql.IdentifierList, sqlparse.sql.Where,
            sqlparse.sql.Having, sqlparse.sql.Comparison, sqlparse.sql.Function,
            sqlparse.sql.Parenthesis, sqlparse.sql.Operation, sqlparse.sql.Case
        }

def columns_from_query(sql_query):
    # identifiers contain table name and column name
    if type(sql_query) == str:
        sql_query = sqlparse.parse(sql_query)[0]
    columns = []
    for token in sql_query:
        if isinstance(token, sqlparse.sql.Identifier):
            columns.append(token.get_real_name())
        elif hasattr(token, "tokens"):
            columns.extend(columns_from_query(token.tokens))
    return columns

def columns_by_split(sql_query:str, all_columns:list):
    columns = []
    for token in sql_query.split():
        if token[-1] == ",": token = token[:-1]
        if token in all_columns:
            columns.append(token)
    return columns

# Split the SQL query into lines
def table_column_of_create_table(query):
    lines = query.splitlines()

    # Initialize a list to store column names
    columns = []
    table_names = []

    # Look for "CREATE TABLE" and start capturing columns
    capture = False

    for line in lines:
        if "CREATE TABLE" in line:
            capture = True
            table_names.append(line.split()[-2].lower())
        elif line.strip().endswith(')') or line.strip().endswith(');'):
            capture = False
        elif capture:
            column_name = line.strip().split()[0]
            if column_name in ["CONSTRAINT", "PRIMARY"]: continue
            columns.append(column_name)

    # print("Table Name:", table_names)
    # print("Columns:", columns)
    return table_names, columns

### Map table to db

In [11]:
spider_path = 'src/spider/database'
map_table_db = {}

for folder in os.listdir(spider_path):
    schema_path = os.path.join(spider_path, folder, 'schema.sql')
    if os.path.exists(schema_path):
        with open(schema_path, 'r') as sql_file:
            sql_script = sql_file.read()
            table_names = table_column_of_create_table(sql_script)[0]
            for table in table_names: 
                table = re.sub(r'[^a-zA-Z_]', '', table)
                map_table_db[table] = folder

# with open("src/spider/table_database_map.json", "w") as f:
#     json.dump(map_table_db, f, indent=4)

In [16]:
def safe_divide(numerator, denominator):
    try:
        result = round(numerator / denominator,2)
    except :
        result = "ZeroDivisionError"
    return result

### Map db to table

In [13]:
db_to_table_map = {}
for table, db in map_table_db.items():
    if db in db_to_table_map:
        db_to_table_map[db].append(table)
    else:
        db_to_table_map[db] = [table]


## Expirement test by threshold score to dataframe
automate join column (each treshold score)

In [29]:
def expirement_test(threshold_score:list, dbs=dbs, verbose=False):

    full_result_df = pd.DataFrame()
    exists_table = [i['table'].lower() for i in dbs]
    for i,row in df.iterrows():
        table_of_query = row['Table']
        tables, all_columns = table_column_of_create_table(table_of_query)
        is_present = np.all(np.isin(np.array(tables), np.array(exists_table)))
        # all table from db is in exist tables (same database)
        if is_present:
            try:
                question = row['Question']

                # Got the expect column and table with string matching, parser and SQLParse lib
                columns = Parser(row['SQL']).columns
                expect_cols = []
                expect_table = []
                for col in columns:
                    # found join function (Table1.column1)
                    if "." in col:
                        table_name, column_name = col.split('.') 
                        expect_cols.append(column_name)
                        expect_table.append(table_name)
                    elif col in all_columns:
                        expect_cols.append(col)
                
                expect_cols.extend([c for c in columns_from_query(row['SQL']) if c in all_columns])
                expect_cols.extend(columns_by_split(row['SQL'], all_columns))
                expect_table.extend(Parser(row['SQL']).tables)

                expect_table = list(set(expect_table))
                expect_cols = list(set(expect_cols))

                # filtering table name in columns list
                expect_cols = [c for c in expect_cols if c not in expect_table]

                db = map_table_db[expect_table[0]]
                table_in_db = db_to_table_map[db]

                if verbose:
                    print(question)
                    print(row['SQL'])
                    print("DATABASE:", db)
                    print("EXPECT TABLE:", expect_table)
                    print("EXPECT COLUMNS:",expect_cols)
                    print()

                # dataframe for merge to threshold dataframe
                threshold_result_df = pd.DataFrame({'query': [question],'actual col' : [expect_cols]})

                for score in threshold_score:

                    name_col_selected   = f"T{score} selected col"
                    name_col_correct    = f"T{score} # correct"
                    name_col_recall     = f"T{score} recall"
                    name_col_precision  = f"T{score} precision"
                    name_col_f1         = f"T{score} F1"


                    # filtering schema from question (similarity description-base score --> string matching)
                    result = filter_tables_by_description(row['Question'], column_threshold = score, filter_tables = False)
                    result = column_from_question(row['Question'], used_table_col=result)
                    
                    result_tables = []
                    result_columns = []
                    for t in result:
                        if t in table_in_db: 
                            result_tables.append(t)
                            result_columns.extend(list(result[t].keys()))

                    result_columns = list(set([c.lower() for c in result_columns]))

                    # calculate accuracy

                    table_TP = len(set(expect_table) & set(result_tables))
                    table_FP = len(set(result_tables) - set(expect_table))
                    table_FN = len(set(expect_table) - set(result_tables))
                    col_TP = len(set(expect_cols) & set(result_columns))
                    col_FP = len(set(result_columns) - set(expect_cols))
                    col_FN = len(set(expect_cols) - set(result_columns))
                    
                    table_recall, table_precision, col_recall, col_precision, table_f1, col_f1 = (None, ) * 6

                    # Calculate table_recall, table_precision, col_recall, and col_precision using safe_divide
                    if table_TP is not None and table_FN is not None and table_FP is not None:
                        table_recall = safe_divide(table_TP, table_TP + table_FN)
                        table_precision = safe_divide(table_TP, table_TP + table_FP)

                    if col_TP is not None and col_FN is not None and col_FP is not None:
                        col_recall = safe_divide(col_TP, col_TP + col_FN)
                        col_precision = safe_divide(col_TP, col_TP + col_FP)

                    # Calculate table_f1 and col_f1 using safe_divide
                    if table_precision is not None and table_recall is not None and table_recall != "ZeroDivisionError" and table_precision != "ZeroDivisionError":
                        table_f1 = 2 * safe_divide(table_precision * table_recall, table_precision + table_recall)

                    if col_precision is not None and col_recall is not None and col_recall != "ZeroDivisionError" and col_precision != "ZeroDivisionError":
                        col_f1 = 2 * safe_divide(col_precision * col_recall, col_precision + col_recall)

                    # Check for "division error" and set appropriate values
                    if table_f1 == "ZeroDivisionErrorZeroDivisionError":table_f1 = "ZeroDivisionError"
                    if col_f1 == "ZeroDivisionErrorZeroDivisionError": col_f1 = "ZeroDivisionError"
                        

                    if verbose:
                        print("THRESHOLD:", score)
                        print("PREDICT TABLE:", result_tables)
                        print("PREDICT COLUMNS:", result_columns)
                        print("TABLE RECALL:", table_recall, "\tCOLUMNS RECALL:", col_recall)
                        print("TABLE PRECISION:", table_precision, "\tCOLUMNS PRECISION:", col_precision)
                        print("TABLE F1 SCORE:", table_f1, "\tCOLUMNS F1 SCORE:", col_f1)
                        print()

                    result_data = {
                        'query' : [question],
                        name_col_selected : [result_columns],
                        name_col_correct : [np.sum(np.isin(np.array(result_columns), np.array(expect_cols)))],
                        name_col_recall : [col_recall],
                        name_col_precision : [col_precision],
                        name_col_f1 : [col_f1]
                    }
                    
                    # dataframe for merge
                    result_df = pd.DataFrame(result_data)
                    threshold_result_df = pd.merge(threshold_result_df, result_df, on='query', how='outer')
                
                # append (concat) the row of full dataframe
                full_result_df = pd.concat([full_result_df, threshold_result_df], ignore_index=True)

            except KeyError: pass
            print('------------------------------------------')
    return full_result_df

In [30]:
result_df = expirement_test([0.6, 0.4, 0.2], verbose=True)

Find the number of members living in each address.
SELECT COUNT(*), address FROM member GROUP BY address
DATABASE: coffee_shop
EXPECT TABLE: ['member']
EXPECT COLUMNS: ['address']
THRESHOLD: 0.6
PREDICT TABLE: []
PREDICT COLUMNS: []
TABLE RECALL: 0.0 	COLUMNS RECALL: 0.0
TABLE PRECISION: ZeroDivisionError 	COLUMNS PRECISION: ZeroDivisionError
TABLE F1 SCORE: None 	COLUMNS F1 SCORE: None

THRESHOLD: 0.4
PREDICT TABLE: ['member']
PREDICT COLUMNS: ['address']
TABLE RECALL: 1.0 	COLUMNS RECALL: 1.0
TABLE PRECISION: 1.0 	COLUMNS PRECISION: 1.0
TABLE F1 SCORE: 1.0 	COLUMNS F1 SCORE: 1.0

THRESHOLD: 0.2
PREDICT TABLE: ['shop', 'member', 'happy_hour', 'happy_hour_member']
PREDICT COLUMNS: ['num_of_staff', 'num_of_staff_in_charge', 'address', 'member_id', 'name', 'age', 'level_of_membership', 'membership_card', 'total_amount']
TABLE RECALL: 1.0 	COLUMNS RECALL: 1.0
TABLE PRECISION: 0.25 	COLUMNS PRECISION: 0.11
TABLE F1 SCORE: 0.4 	COLUMNS F1 SCORE: 0.2

----------------------------------------

In [31]:
result_df.head()

Unnamed: 0,query,actual col,T0.6 selected col,T0.6 # correct,T0.6 recall,T0.6 precision,T0.6 F1,T0.4 selected col,T0.4 # correct,T0.4 recall,T0.4 precision,T0.4 F1,T0.2 selected col,T0.2 # correct,T0.2 recall,T0.2 precision,T0.2 F1
0,Find the number of members living in each addr...,[address],[],0,0.0,ZeroDivisionError,,[address],1,1.0,1.0,1.0,"[num_of_staff, num_of_staff_in_charge, address...",1,1.0,0.11,0.2
1,Count the number of cinemas.,[],"[cinema_id, capacity]",0,ZeroDivisionError,0.0,,"[show_times_per_day, name, cinema_id, capacity...",0,ZeroDivisionError,0.0,,"[show_times_per_day, name, cinema_id, capacity...",0,ZeroDivisionError,0.0,
2,How many rooms does each block floor have?,"[blockfloor, blockcode]",[room],0,0.0,0.0,ZeroDivisionError,[room],0,0.0,0.0,ZeroDivisionError,[room],0,0.0,0.0,ZeroDivisionError
3,What procedures cost less than 5000 and have J...,"[name, employeeid, cost, treatment, code]",[cost],1,0.2,1.0,0.34,[cost],1,0.2,1.0,0.34,[cost],1,0.2,1.0,0.34
4,What is the location with the most cinemas ope...,"[openning_year, location]","[openning_year, location]",2,1.0,1.0,1.0,"[show_times_per_day, name, cinema_id, capacity...",2,1.0,0.33,0.5,"[show_times_per_day, name, date, cinema_id, ca...",2,1.0,0.22,0.36


In [34]:
# result_df.to_csv("expirement_filtering_columns.csv", index=False)

In [36]:
# error_values = ["ZeroDivisionError"]
# result_df.replace(error_values, float('nan'), inplace=True)


# thresholds = ['T0.6', 'T0.4', 'T0.2']
# metrics = ['recall', 'precision', 'F1']

# plt.figure(figsize=(10, 6))

# for metric in metrics:
#     for threshold in thresholds:
#         column_name = f"{threshold} {metric}"
#         plt.scatter(result_df.index, result_df[column_name], label=f"{metric} ({threshold})")
# # plt.scatter(result_df.index, result_df['T0.6 recll'])
# plt.title('Recall, Precision, and F1 Score over Similarity Score Thresholds')
# plt.xlabel('Queries')
# plt.ylabel('Score')
# plt.legend()
# plt.show()