In [166]:
import sqlite3, os, json, sqlparse, re
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer, util
from nltk.stem import WordNetLemmatizer

In [171]:
folder_path = "src/spider/database"
select_db = ['musical',
             'farm', 
             'hospital_1', 
             'tvshow', 
             'cinema', 
             'restaurants', 
             'company_employee', 
             'company_1', 
             'company_offic', 
             'singer', 
             'coffee_shop']

db = []

if os.path.exists(folder_path) and os.path.isdir(folder_path):
    files = os.listdir(folder_path)
    for file in files:
        # if file in select_db:
        db_path = os.path.join(folder_path, file)
        sqlite_db = [os.path.join(db_path, sql) for sql in os.listdir(db_path) if ".sqlite" in sql]
        db.append(*sqlite_db)

In [173]:
def get_schema(sqlite_db):
    connection = sqlite3.connect(sqlite_db)
    cursor = connection.cursor()

    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()

    for table in tables:
        table_name = table[0]
        print(f"Table: {table_name}")

        cursor.execute(f"PRAGMA table_info({table_name});")
        columns = cursor.fetchall()

        for column in columns:
            column_name = column[1]
            print(f"  Column: {column_name}")

        print()
    
    cursor.close()
    connection.close()

In [179]:
for table in db:
    # if table in exists_table : continue
    get_schema(table)
    # exists_table.append(table)
    print('---------------------------------')

Table: Web_client_accelerator
  Column: id
  Column: name
  Column: Operating_system
  Column: Client
  Column: Connection

Table: browser
  Column: id
  Column: name
  Column: market_share

Table: accelerator_compatible_browser
  Column: accelerator_id
  Column: browser_id
  Column: compatible_since_year

---------------------------------
Table: musical
  Column: Musical_ID
  Column: Name
  Column: Year
  Column: Award
  Column: Category
  Column: Nominee
  Column: Result

Table: actor
  Column: Actor_ID
  Column: Name
  Column: Musical_ID
  Column: Character
  Column: Duration
  Column: age

---------------------------------
Table: city
  Column: City_ID
  Column: Official_Name
  Column: Status
  Column: Area_km_2
  Column: Population
  Column: Census_Ranking

Table: farm
  Column: Farm_ID
  Column: Year
  Column: Total_Horses
  Column: Working_Horses
  Column: Total_Cattle
  Column: Oxen
  Column: Bulls
  Column: Cows
  Column: Pigs
  Column: Sheep_and_Goats

Table: farm_competition

In [175]:
src_folder = "src"
schema_description_file = "mockup_schema_description.json"
with open(os.path.join(src_folder, schema_description_file)) as f:
    dbs = json.load(f)
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
lemmanizer = WordNetLemmatizer()

In [56]:
# description_emb = []

# for db in dbs:
#     schema_emb = {}
#     table_name = db['table']
#     table_description = db['description']
#     schema_emb[table_name] = model.encode(table_description).tolist()
#     columns = list(db['columns'].keys())
#     for col in columns:
#         column_description = db['columns'][col]
#         schema_emb[col] = model.encode(column_description).tolist()
#     description_emb.append(schema_emb)

# schema_vector_file = "mockup_schema_description_vector.json"
# with open(os.path.join(src_folder, schema_vector_file), "w") as f:
#     json.dump(description_emb,f)

In [180]:
schema_vector_file = "mockup_schema_description_vector.json"
with open(os.path.join(src_folder, schema_vector_file)) as f:
    schema_vector = json.load(f)

In [371]:
# string mathing , max_score = weight of string match

def column_from_question(question,used_table_col = {}, default_score=0.6):
    # question_tokens = [token.lower() for token in tokenizer.tokenize(question)]
    question_tokens = [lemmanizer.lemmatize(token.lower()) for token in question.split()]
    print(question_tokens)
    for table in schema_vector:
        max_score = default_score
        for token in question_tokens:
            cols = [ key.lower() for key in table.keys()]
            table_name = cols.pop(0)
            if token == table_name: 
                max_score = 1.0
                # plus the score of columns in exact match with table name
                if used_table_col.get(token) is not None: 
                    for key in used_table_col[token]: 
                        used_table_col[token].update({key: used_table_col[token][key] + 0.1})
            # exact match table and column
            if token in cols: 
                used_table_col.setdefault(table_name, {}).update({token : max_score})

            

    return used_table_col

In [372]:
# filter_table : filter by table before 
def filter_tables_by_description(question, column_threshold = 0.4, table_threshold = 0.2, filter_tables = True):
    question_emb = model.encode(question)
    used_schema = {}
    for i in range(len(schema_vector)):
        table_name = list(schema_vector[i].keys())[0]

        table_description_vector = schema_vector[i][table_name]
        if filter_tables and util.cos_sim(table_description_vector, question_emb) < table_threshold: continue
        
        used_col = {}
        for col, vec in schema_vector[i].items():
            if col == table_name: continue
            score = round(float(util.cos_sim(vec, question_emb)),2)
            if score > column_threshold:
                # column_description = [dbs[i]['columns'][col] for i in range(len(dbs)) if dbs[i]['table'] == table_name][0]
                # print(f"{table_name} - {col} : {score}\nDescription : {column_description}\n")
                used_col.update({col: score})
        if len(used_col) > 0: used_schema[table_name] = used_col
    return used_schema

In [383]:
question = "Who direct the mickey mouse"

In [384]:
selected_table_column = filter_tables_by_description(question, column_threshold = 0.3, filter_tables = False)
selected_table_column

{'Cartoon': {'id': 0.31,
  'Title': 0.35,
  'Directed_by': 0.54,
  'Written_by': 0.44,
  'Original_air_date': 0.32,
  'Production_code': 0.39},
 'film': {'Directed_by': 0.36}}

In [385]:
selected_table_column = column_from_question(question, used_table_col=selected_table_column)
selected_table_column

['who', 'direct', 'the', 'mickey', 'mouse']


{'Cartoon': {'id': 0.31,
  'Title': 0.35,
  'Directed_by': 0.54,
  'Written_by': 0.44,
  'Original_air_date': 0.32,
  'Production_code': 0.39},
 'film': {'Directed_by': 0.36}}

In [366]:
full_sql = ""
for table in selected_table_column:
    sql = f"CREATE TABLE {table} ("
    for column in selected_table_column.get(table).keys():
        sql += f'"{column}" datatype, '

    sql = sql[:-2] + " )\n\n"
    full_sql += sql
print(full_sql)

CREATE TABLE people ("Nationality" datatype )

CREATE TABLE singer ("Singer_ID" datatype, "Name" datatype, "Birth_Year" datatype, "Net_Worth_Millions" datatype, "Citizenship" datatype )

CREATE TABLE song ("Singer_ID" datatype )




In [27]:
spider_sql = []
df_data = {
    'Question' : [],
    'Table' : [],
    'SQL' : []
}

with open("src/NSText2SQL/train.jsonl") as f:
    for line in f:
        data = json.loads(line)
        if data['source'] == 'spider': 
            spider_sql.append(data)
            df_data['Question'].append(data['instruction'].split('--')[-1].strip())
            df_data['Table'].append(data['instruction'].split('--')[0].strip())
            df_data['SQL'].append(data['output'])

df = pd.DataFrame(df_data)
df.to_csv('src/NSText2SQL/train_spider.csv', index=False)
df.head()

Unnamed: 0,Question,Table,SQL
0,"What are the first names, office locations of ...","CREATE TABLE course (\n crs_code text,\n ...","SELECT T2.emp_fname, T4.prof_office, T3.crs_de..."
1,Please show the songs that have result 'nomina...,"CREATE TABLE artist (\n artist_id number,\n...",SELECT T2.song FROM music_festival AS T1 JOIN ...
2,Which teams had more than 3 eliminations?,CREATE TABLE elimination (\n elimination_id...,SELECT team FROM elimination GROUP BY team HAV...
3,"Show the names of people, and dates and venues...","CREATE TABLE people (\n people_id number,\n...","SELECT T3.name, T2.date, T2.venue FROM debate_..."
4,Tell me the the date when the first claim was ...,CREATE TABLE settlements (\n settlement_id ...,SELECT date_claim_made FROM claims ORDER BY da...


In [151]:
# Split the SQL query into lines
def table_column_of_create_table(query):
    lines = query.splitlines()

    # Initialize a list to store column names
    columns = []
    table_names = []

    # Look for "CREATE TABLE" and start capturing columns
    capture = False

    for line in lines:
        if "CREATE TABLE" in line:
            capture = True
            table_names.append(line.split()[-2])
        elif line.strip().endswith(')') or line.strip().endswith(');'):
            capture = False
        elif capture:
            column_name = line.strip().split()[0]
            if column_name.lower() in ["constraint", "primary"]: continue
            columns.append(column_name)

    # print("Table Name:", table_names)
    # print("Columns:", columns)
    return [''.join(filter(str.isalpha, s)) for s in table_names], [''.join(filter(str.isalpha, s)) for s in columns]


In [157]:
spider_path = 'src/spider/database'
map_table_db = {}

for folder in os.listdir(spider_path):
    schema_path = os.path.join(spider_path, folder, 'schema.sql')
    if os.path.exists(schema_path):
        with open(schema_path, 'r') as sql_file:
            sql_script = sql_file.read()
            table_names = table_column_of_create_table(sql_script)[0]
            for table in table_names: map_table_db[table] = folder

with open("src/spider/table_database_map.json", "w") as f:
    json.dump(map_table_db, f, indent=4)

In [267]:
sql_extract_token_type = {
            sqlparse.sql.IdentifierList, sqlparse.sql.Where,
            sqlparse.sql.Having, sqlparse.sql.Comparison, sqlparse.sql.Function,
            sqlparse.sql.Parenthesis, sqlparse.sql.Operation, sqlparse.sql.Case
        }

def columns_from_query(sql_query):
    # identifiers contain table name and column name
    if type(sql_query) == str:
        sql_query = sqlparse.parse(sql_query)[0]
    identifiers = []
    for token in sql_query:
        if isinstance(token, sqlparse.sql.Identifier):
            identifiers.append(token.get_real_name())
        elif hasattr(token, "tokens"):
            identifiers.extend(columns_from_query(token.tokens))
    return identifiers

In [273]:
exists_table = [i['table'] for i in dbs]
for i,row in df.iterrows():
    table_of_query = row['Table']
    tables, all_columns = table_column_of_create_table(table_of_query)
    if tables[0] in exists_table:
        try:
            # db_path = os.path.join(spider_path, map_table_db[tables[0]])
            # print(db_path)
            identifiers = columns_from_query(row['SQL'])
            expect_column = [identify for identify in identifiers if identify in all_columns]
            expect_table = [identify for identify in identifiers if identify in tables]
            print(row['SQL'])
            print(expect_table)
            print(expect_column)
            print()
        except KeyError: pass

SELECT T3.name, T2.date, T2.venue FROM debate_people AS T1 JOIN debate AS T2 ON T1.debate_id = T2.debate_id JOIN people AS T3 ON T1.negative = T3.people_id ORDER BY T3.name
['debate', 'people']
['name', 'date', 'venue', 'negative', 'name']

SELECT COUNT(*), address FROM member GROUP BY address
['member']
['address', 'address']

SELECT COUNT(*) FROM cinema
['cinema']
[]

SELECT AVG(money_requested) FROM entrepreneur
['entrepreneur']
[]

SELECT T1.name FROM category AS T1 JOIN film_category AS T2 ON T1.category_id = T2.category_id JOIN film AS T3 ON T2.film_id = T3.film_id WHERE T3.title = 'HUNGER ROOF'
['category', 'film']
['name', 'title']

SELECT COUNT(DISTINCT dept_address), school_code FROM department GROUP BY school_code
['department']
[]

SELECT location FROM cinema WHERE openning_year >= 2010 GROUP BY location ORDER BY COUNT(*) DESC LIMIT 1
['cinema']
[]

SELECT first_name FROM people ORDER BY first_name
['people']
[]

SELECT COUNT(*) FROM professor WHERE prof_high_degree = 'Ph.D

In [206]:
test = ['SELECT T3.name, T2.date, T2.venue FROM debate_people AS T1 JOIN debate AS T2 ON T1.debate_id = T2.debate_id JOIN people AS T3 ON T1.negative = T3.people_id ORDER BY T3.name',
'SELECT COUNT(*), address FROM member GROUP BY address',
'SELECT COUNT(*) FROM cinema',
'SELECT AVG(money_requested) FROM entrepreneur',]