In [25]:
import sqlite3, os, json
from sentence_transformers import SentenceTransformer, util

In [2]:
folder_path = "src/spider/database"
select_db = ['musical',
             'farm', 
             'hospital_1', 
             'tvshow', 
             'cinema', 
             'restaurants', 
             'company_employee', 
             'company_1', 
             'company_offic', 
             'singer', 
             'coffee_shop']

db = []

if os.path.exists(folder_path) and os.path.isdir(folder_path):
    files = os.listdir(folder_path)
    for file in files:
        if file in select_db:
            db_path = os.path.join(folder_path, file)
            sqlite_db = [os.path.join(db_path, sql) for sql in os.listdir(db_path) if ".sqlite" in sql]
            db.append(*sqlite_db)

In [3]:
# db

['spider/database/musical/musical.sqlite',
 'spider/database/farm/farm.sqlite',
 'spider/database/hospital_1/hospital_1.sqlite',
 'spider/database/tvshow/tvshow.sqlite',
 'spider/database/cinema/cinema.sqlite',
 'spider/database/restaurants/restaurants.sqlite',
 'spider/database/company_employee/company_employee.sqlite',
 'spider/database/company_1/company_1.sqlite',
 'spider/database/coffee_shop/coffee_shop.sqlite',
 'spider/database/singer/singer.sqlite']

In [4]:
def get_schema(sqlite_db):
    connection = sqlite3.connect(sqlite_db)
    cursor = connection.cursor()

    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()

    for table in tables:
        table_name = table[0]
        print(f"Table: {table_name}")

        cursor.execute(f"PRAGMA table_info({table_name});")
        columns = cursor.fetchall()

        for column in columns:
            column_name = column[1]
            print(f"  Column: {column_name}")

        print()
    
    cursor.close()
    connection.close()

In [8]:
# for table in db:
#     get_schema(table)
#     print('---------------------------------')

In [55]:
src_folder = "src"
schema_description_file = "mockup_schema_description.json"
with open(os.path.join(src_folder, schema_description_file)) as f:
    dbs = json.load(f)
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

In [56]:
# description_emb = []

# for db in dbs:
#     schema_emb = {}
#     table_name = db['table']
#     table_description = db['description']
#     schema_emb[table_name] = model.encode(table_description).tolist()
#     columns = list(db['columns'].keys())
#     for col in columns:
#         column_description = db['columns'][col]
#         schema_emb[col] = model.encode(column_description).tolist()
#     description_emb.append(schema_emb)

# schema_vector_file = "mockup_schema_description_vector.json"
# with open(os.path.join(src_folder, schema_vector_file), "w") as f:
#     json.dump(description_emb,f)

In [99]:
schema_vector_file = "mockup_schema_description_vector.json"
with open(os.path.join(src_folder, schema_vector_file)) as f:
    schema_vector = json.load(f)

question = "Count singers play concert"
question_emb = model.encode(question)

In [116]:
# mostsim = None
used_schema = {}
threshold = 0.4
print("-----  Table - Column  -----")
for i in range(len(schema_vector)):
    table_name = list(schema_vector[i].keys())[0]
    table_description_vector = schema_vector[i][table_name]
    used_col = []
    for col, vec in schema_vector[i].items():
        if col == table_name: continue
        if util.cos_sim(vec, question_emb) > threshold:
            column_description = [dbs[i]['columns'][col] for i in range(len(dbs)) if dbs[i]['table'] == table_name][0]
            print(f"{table_name} - {col} \nDescription : {column_description}\n")
            used_col.append(col)
    if len(used_col) > 0: used_schema[table_name] = used_col

-----  Table - Column  -----
musical - Name 
Description : Name of the musical

musical - Year 
Description : Year the musical was produced

musical - Nominee 
Description : Name of the nominee associated with the musical

singer - Name 
Description : Name of the singer

song - Singer_ID 
Description : Identifier of the singer who performed the song



In [117]:
used_schema

{'musical': ['Name', 'Year', 'Nominee'],
 'singer': ['Name'],
 'song': ['Singer_ID']}

In [124]:
def filter_tables(question, column_threshold = 0.4, table_threshold = 0.3):
    question_emb = model.encode(question)
    used_schema = {}
    for i in range(len(schema_vector)):
        table_name = list(schema_vector[i].keys())[0]

        table_description_vector = schema_vector[i][table_name]
        if util.cos_sim(table_description_vector, question_emb) < table_threshold: continue
        
        used_col = []
        for col, vec in schema_vector[i].items():
            if col == table_name: continue
            if util.cos_sim(vec, question_emb) > column_threshold:
                column_description = [dbs[i]['columns'][col] for i in range(len(dbs)) if dbs[i]['table'] == table_name][0]
                print(f"{table_name} - {col} \nDescription : {column_description}\n")
                used_col.append(col)
        if len(used_col) > 0: used_schema[table_name] = used_col
    return used_schema

In [125]:
filter_tables(question)

musical - Name 
Description : Name of the musical

musical - Year 
Description : Year the musical was produced

musical - Nominee 
Description : Name of the nominee associated with the musical

singer - Name 
Description : Name of the singer

song - Singer_ID 
Description : Identifier of the singer who performed the song



{'musical': ['Name', 'Year', 'Nominee'],
 'singer': ['Name'],
 'song': ['Singer_ID']}