In [None]:
import pandas as pd
from googletrans import Translator
from tqdm import tqdm
import json, os

translator = Translator()
folder_path = "src/pointx"
change_type = { "string" : "text",
                "int" : "number",
                "bigint": "number",
                "decimal(27,2)" : "number",
                "double" : "number",
                "timestamp" : "text",
                "date" : "text"
}
schema_desc_path = os.path.join(folder_path,"ETL Mapping & Data Dict - PointX (1).xlsx")
# dotenv_path = Path('.env')
# load_dotenv(dotenv_path=dotenv_path)

# Preparing data

## pointx_keymatrix_dly	Table

In [None]:
df = pd.read_excel(schema_desc_path, sheet_name='14')
df.columns = df.iloc[17,:]
df = df.iloc[18:,:].reset_index(drop=True)
df.columns.name = None
df.head()

In [None]:
col_types = {}
col_descs = {}
table_name = df['Table'].unique().tolist()[0]
table_desc = """The Key Matrix Dashboard Design table provides a detailed overview of dashboard-related database columns, 
including data types, status indicators, descriptions, conditions, business logic, and sample data, 
enabling a comprehensive understanding of the data structure for effective dashboard design."""

for i, row in tqdm(df.iterrows()):
    col_name = row['Column']
    data_type = change_type[row['Data Type'].lower()]
    desc = translator.translate(row['Description'], dest='en').text

    col_types[col_name] = data_type
    col_descs[col_name] = desc


In [None]:
schema_desc = {
    "table": table_name,
    "description": table_desc,
    "columns": col_descs
}

# with open(os.path.join(folder_path, "pointx_keymatrix_dly_schema_description.json"),'w') as f:
#     json.dump(schema_desc, f, indent=4)

# with open(os.path.join(folder_path, "pointx_keymatrix_dly_columns_type.json"),'w') as f:
#     json.dump(col_types, f, indent=4)

In [None]:
# with open(os.path.join(folder_path, "pointx_keymatrix_dly_schema_description.json"),'r') as f:
#     col_descs = json.load(f)

## pointx_cust_mly Table

In [None]:
with open("src/pointx/schema/pointx_cust_mly_type.json") as f:
    col_type = json.load(f)
col_names = set(col_type.keys())

In [None]:
df = pd.read_excel("src/pointx/Business Glossary 1.xlsx")
df = df[['col_name', 'descriptions']]
df = df[df.applymap(lambda x: isinstance(x, str) and x.strip() != '')].dropna()
df['descriptions'] = df['descriptions'].apply(lambda desc : translator.translate(desc, dest='en').text)


In [None]:
table_name = "pointx_cust_mly"
table_desc = """The table provides a comprehensive monthly overview of customer engagement within the app, 
capturing data related to accumulated points, usage patterns, and relevant metrics, 
facilitating in-depth analysis of user behavior and app performance."""

In [None]:
col_descs = df.set_index('col_name')['descriptions'].to_dict()
for col in col_descs:
    if col not in col_descs:
        del col_descs[col]

schema_desc = {
    "table": table_name,
    "description": table_desc,
    "columns": col_descs
}
with open(os.path.join(folder_path, "schema/pointx_cust_mly_schema_description.json"),'w') as f:
    json.dump(schema_desc, f, indent=4)

## pointx_fbs_rpt_dly Table

In [None]:
table_name = "pointx_fbs_rpt_dly"
table_desc = """Table records user interactions with the PointX app daily, capturing events such as app opens and deletions, 
providing key insights into user behavior, app version usage, and device characteristics """

df = pd.read_csv("src/pointx/pointx_fbs_rpt_dly_description.csv")
col_descs = df.set_index('Column')['Description'].to_dict()

In [None]:
schema_desc = {
    "table": table_name,
    "description": table_desc,
    "columns": col_descs
}

# with open(os.path.join(folder_path, "pointx_fbs_rpt_dly_description.json"),'w') as f:
#     json.dump(schema_desc, f, indent=4)

# Let's filtering

In [None]:
import json, warnings, time
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForCausalLM

sentence_emb_model = SentenceTransformer('models/all-MiniLM-L6-v2')
tokenizer = AutoTokenizer.from_pretrained("models/nsql-350M")
model = AutoModelForCausalLM.from_pretrained("models/nsql-350M")

table_desc_vectors = {}     # { table1: vector , ...}
schema_desc_vectors = {}    # { table1: { column1: vector, ...}}
schema_datatypes = {}       # { table1: { column1: datatype, ...}}

In [None]:
def join_schema(schema_description_path:str, schema_datatype_path:str):
    with open(schema_description_path) as jsonfile:
        new_schema_description = json.load(jsonfile)
    with open(schema_datatype_path) as jsonfile:
        new_schema_datatype = json.load(jsonfile)
    
    table_name = new_schema_description['table']
    table_vector = sentence_emb_model.encode(new_schema_description['description'])
    table_desc_vectors[table_name] = table_vector

    schema_datatypes[table_name] = new_schema_datatype
    column_vectors = {}
    for col, desc in new_schema_description["columns"].items():
        column_vectors[col] = sentence_emb_model.encode(desc)
    schema_desc_vectors[table_name] = column_vectors

In [None]:
def remove_table(table_name):
    del table_desc_vectors[table_name]
    del schema_desc_vectors[table_name]
    del schema_datatypes[table_name]
    return True

In [None]:
def filter_schema(question:str, column_threshold:float = 0.4, table_threshold:float = 0.2, 
                  max_select_columns:int = 5, filter_tables:bool = True):
    question_emb = sentence_emb_model.encode(question)
    used_schemas = {}
    found_table = []

    # string matching with table, coumn and question tokens
    for token in question.split():
        found_columns = []
        if token in schema_desc_vectors.keys():
            print("Table string match  ---->", token)
            found_table.append(token)
        for table, column in schema_desc_vectors.items():
            if token in column.keys(): 
                found_columns.append(token)
                print("Column matching  --->",token)
    
    if filter_tables:       #filter table before
        used_tables = []
        for table_name, table_vector in table_desc_vectors.items():
            if util.cos_sim(table_vector, question_emb) >= table_threshold: 
                used_tables.append(table_name)
    else: used_tables = list(table_desc_vectors.keys())     # filtering schema with all columns

    for table in used_tables:
        if table in found_table: table_offset = 0.1         # offset score for selected column in this table
        else: table_offset = 0
        used_schemas[table] = {}
        for column, column_vector in schema_desc_vectors[table].items():
            sim_score = util.cos_sim(column_vector, question_emb)
            if (sim_score >= (column_threshold - table_offset)
                or column in found_columns):
                used_schemas[table][column] = round(float(sim_score),3)
        if max_select_columns and len(used_schemas[table]) > max_select_columns:
            # Select the top k largest values from the dictionary
            used_schemas[table] = dict(sorted(used_schemas[table].items(), key=lambda item: item[1], reverse=True)[:max_select_columns])
    
    return used_schemas

In [None]:
def create_prompt(question, used_schema):
    full_sql = ""
    for table, columns in used_schema.items():
        if not len(columns): continue       # pass this table when no column
        primary_keys = schema_datatypes[table]["JOIN_KEY"]["PK"]
        foreign_keys = list(schema_datatypes[table]["JOIN_KEY"]["FK"].keys())
        join_table_key = primary_keys + foreign_keys
        
        sql = f"CREATE TABLE {table} ("
        for column in columns:
            if column in join_table_key and len(join_table_key): join_table_key.remove(column)
            try:
                sql += f' {column} {schema_datatypes[table][column]},'
            except KeyError: 
                print(f"KeyError :{column}")
                
        if len(join_table_key): # key for join of table are not selected
            for column in join_table_key:
                sql += f' {column} {schema_datatypes[table][column]},'

        # All table contain PK (maybe)
        if len(primary_keys):
            sql += 'PRIMARY KEY ('
            for pk in primary_keys: sql += f'"{pk}" ,'
            sql = sql[:-1] + ")"
        if len(foreign_keys):
            for fk, ref_table in schema_datatypes[table]["JOIN_KEY"]["FK"].items():
                sql += f', FOREIGN KEY ("{fk}") REFERENCES "{ref_table}" ("{fk}"),'

        sql = sql[:-1] + " )\n\n"
        full_sql += sql
    promp = full_sql + "-- Using valid SQLite, answer the following questions for the tables provided above."
    promp = promp + '\n' + '-- ' + question
    promp = promp + '\n' + "SELECT"

    return promp

In [None]:
# join_schema("src/pointx/schemas/pointx_fbs_rpt_dly_schema_description.json",
#             "src/pointx/schemas/pointx_fbs_rpt_dly_columns_type.json")

# join_schema("src/pointx/schemas/pointx_cust_mly_schema_description.json",
#             "src/pointx/schemas/pointx_cust_mly_columns_type.json")

# join_schema("src/pointx/schemas/pointx_keymatrix_dly_schema_description.json",
#             "src/pointx/schemas/pointx_keymatrix_dly_columns_type.json")

In [None]:
# remove_table("pointx_keymatrix_dly")

### Spider dataset

In [None]:
table_desc_vectors = {}     # { table1: vector , ...}
schema_desc_vectors = {}    # { table1: { column1: vector, ...}}
schema_datatypes = {}       # { table1: { column1: datatype, ...}}

join_schema("src/spider/cofee_shop/happy_hour_desc.json",
            "src/spider/cofee_shop/happy_hour_datatype.json")

join_schema("src/spider/cofee_shop/happy_hour_member_desc.json",
            "src/spider/cofee_shop/happy_hour_member_datatype.json")

join_schema("src/spider/cofee_shop/member_desc.json",
            "src/spider/cofee_shop/member_datatype.json")

join_schema("src/spider/cofee_shop/shop_desc.json",
            "src/spider/cofee_shop/shop_datatype.json")

## Let's ask question!

In [None]:
question = "shop address with happy hour in April"

In [None]:
# question = "How many unique user use this app"

In [None]:
result = filter_schema(question, column_threshold=0.3, table_threshold=0.2, filter_tables=False, max_select_columns=False)
prompt = create_prompt(question, result)
# for table, columns in result.items():
#     print(f"Table : {table}")
#     print(f"Selected columns : {columns}")
print(prompt)
start_time = time.time()
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    generated_ids = model.generate(input_ids, max_length=1000)
    sql = tokenizer.decode(generated_ids[0], skip_special_tokens=True).split('\n')[-1]
    
    # print("QUESTION :",question)
    print()
    print("SQL :",sql)
    print(f"TAKE {time.time()-start_time} seconds")

# Test query

In [None]:
import sqlite3


conn = sqlite3.connect('src/spider/cofee_shop/coffee_shop.sqlite')
cursor = conn.cursor()
cursor.execute(sql)
results = cursor.fetchall()
conn.close()

print(sql)
print()
for row in results:
    print(row)