In [28]:
import os, sqlite3, json, time, warnings, re
from openai import OpenAI
from dotenv import load_dotenv
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
from filtering_schema.Description_base_linking import SchemaLinking

In [29]:
load_dotenv()
# Set environment variables
base_dir = "filtering_schema"
os.environ['nsql_model_path'] = os.path.join(base_dir, 'models', 'nsql-350M')
os.environ['sentence_emb_model_path'] = os.path.join(base_dir, 'models', 'all-MiniLM-L6-v2')
os.environ['schema_description_folder_path'] = os.path.join(base_dir, 'src', 'schemas', 'coffeeshop-descriptions')
os.environ['schema_data_types_folder_path'] = os.path.join(base_dir, 'src', 'schemas', 'coffeeshop-datatypes')
os.environ['column_threshold'] = '0.2'
os.environ['table_threshold'] = '0.2'
os.environ['max_select_column'] = '10'
os.environ['filter_table'] = 'False'
os.environ['verbose'] = 'False'


schema_link = SchemaLinking()
schema_link.selected_domain(schema_description_folder_path=os.environ.get('schema_description_folder_path'),
                            schema_data_types_folder_path=os.environ.get('schema_data_types_folder_path'))

tokenizer = AutoTokenizer.from_pretrained(os.environ.get('nsql_model_path'))
model = AutoModelForCausalLM.from_pretrained(os.environ.get('nsql_model_path'))

## LLM example

In [30]:
client = OpenAI()
llm_model_name = 'gpt-3.5-turbo'
stop = ['\n\n']
prompt = """You are a SQL query assistant.
I have some SQL where the [MASK] column and the conditional value of [MASK] are formatted. And I want you to respond to the output populating the [MASK] column and the SQL input conditional values, followed by the question, the schema description (name - description) and the 5 row example values as provided.
If you don't know which column or value to enter, Don't include columns you created yourself. And only columns and values defined from the schema and sample values must be used.
Don't use columns from another table or schema. It must also be used from the same table defined in the input.


table :     cat - this table contain cat information
column :    id - number for identify cat
            name - name of cat
            age - age of cat
            gender - gender of cat
example 5 rows:
|   id | name   |   age | gender   |
|-----:|:-------|------:|:---------|
|    5 | Pussy  |     2 | Male     |
|    1 | Sam    |     1 | Male     |
|    9 | Peter  |     2 | Female   |
|    3 | Jack   |     4 | Male     |
|    4 | Ponica |     3 | Female   |

question: show me the name of cat.
input: SELECT [MASK] FROM cat;
output: SELECT name FROM cat;

question: show me the name and age of cat.
input: SELECT [MASK], [MASK] FROM cat;
output: SELECT name, age FROM cat;

question: Count number of cate each gender.
input: SELECT [MASK], COUNT([MASK]) FROM cat GROUP BY [MASK];
output: SELECT gender, COUNT(*) FROM cat GROUP BY gender;

question: show me the name and age of woman cat.
input: SELECT [MASK], [MASK] FROM cat WHERE [MASK] = [MASK];
output:
"""

response = client.chat.completions.create(
    model=llm_model_name,
    messages=[
            {"role": "system",
                "content": "I will give you some x-y examples followed by a x, you need to give me the y, and no other content."},
            {"role": "user", "content": prompt},
            ],
    stop=stop
)
response.choices[0].message.content

"SELECT name, age FROM cat WHERE gender = 'Female';"

### Schema link example

In [31]:
schema_link.filter_schema("Which shop opening from 2022")

Table string match  ----> shop


{'happy_hour': {},
 'happy_hour_member': {},
 'member': {},
 'shop': {'Shop_ID': 0.394,
  'Address': 0.605,
  'Num_of_staff': 0.408,
  'Score': 0.357,
  'Open_Year': 0.643}}

In [32]:
client = OpenAI()
stop = ['\n\n']
temperature = 0

def llm_response(prompt, model='gpt-3.5-turbo'):
    response = client.chat.completions.create(
        model=model,
        messages=[
                {"role": "system",
                    "content": "I will give you some x-y examples followed by a x, you need to give me the y, and no other content."},
                {"role": "user", "content": prompt},
                ],
        stop=stop,
        temperature=temperature
    )
    return response.choices[0].message.content

In [33]:
def create_prompt(question:str, used_schema):
    full_sql = ""
    for table, columns in used_schema.items():
        if not len(columns): continue       # pass this table when no column
        primary_keys = schema_link.schema_datatypes[table]["JOIN_KEY"]["PK"]
        foreign_keys = list(schema_link.schema_datatypes[table]["JOIN_KEY"]["FK"].keys())
        join_table_key = primary_keys + foreign_keys
        
        sql = f"CREATE TABLE {table} ("
        for column in columns:
            if column in join_table_key and len(join_table_key): join_table_key.remove(column)
            try:
                sql += f' {column} {schema_link.schema_datatypes[table]["COLUMNS"][column]},'
            except KeyError: 
                print(f"KeyError :{column}")
                
        if len(join_table_key): # key for join of table are remaining
            for column in join_table_key:
                sql += f' {column} {schema_link.schema_datatypes[table]["COLUMNS"][column]},'

        # A lot of tables contain PK
        if len(primary_keys):
            sql = sql[:-1] + ' PRIMARY KEY ('
            for pk_type in primary_keys: sql += f'"{pk_type}" ,'
            sql = sql[:-1] + "),"
        if len(foreign_keys):
            for fk, ref_table in schema_link.schema_datatypes[table]["JOIN_KEY"]["FK"].items():
                sql = sql[:-1] + f' FOREIGN KEY ("{fk}") REFERENCES "{ref_table}" ("{fk}"),'

        sql = sql[:-1] + " )\n\n"
        full_sql += sql
    prompt = full_sql + "-- Using valid SQLite, answer the following questions for the tables provided above."
    prompt = prompt + '\n' + '-- ' + question
    prompt = prompt + '\n' + "SELECT"
    return prompt

def gen_sql(prompt:str):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids
        generated_ids = model.generate(input_ids, max_length=1000)
        sql = tokenizer.decode(generated_ids[0], skip_special_tokens=True).split('\n')[-1]
    return sql

In [34]:
def query_n_rows(table_name:str, columns:list,  n_rows:int = 5, src_dir="filtering_schema/src/data"):
    for file_path in os.listdir(src_dir):
        if file_path.startswith(table_name): break
    query_df = pd.read_csv(os.path.join(src_dir, file_path))[columns].iloc[:n_rows]
    return query_df


In [35]:
query_n_rows("pointx_keymatrix_dly", ['_date', "month_id"])

Unnamed: 0,_date,month_id
0,2022-07-06,2022-07
1,2022-07-01,2022-07
2,2022-07-08,2022-07
3,2022-07-05,2022-07
4,2022-07-14,2022-07


In [36]:
def query_pointx_db(sql_query):
    conn = sqlite3.connect(f'src/pointx/database/pointx.db')
    cursor = conn.cursor()
    try:
        cursor.execute(sql_query)
        results = cursor.fetchall()
    except:
        return "CANNOT FETCHING DATA"
    conn.close()
    return results

In [37]:
keymatrix_df = pd.read_excel("src/pointx/PointX - NLQ training data set.xlsx", sheet_name="pointx_keymatrix")[['NLQ', 'NLQ with helper', 'SQL']]
keymatrix_df.head()

Unnamed: 0,NLQ,NLQ with helper,SQL
0,What is the total number of all financial tran...,"SELECT month_id, SUM(ntx_pointx_financial) FRO...","SELECT month_id, SUM(ntx_pointx_financial) FRO..."
1,What is the total amount of points generated b...,SELECT SUM(amt_point_topup) FROM pointx_keymat...,SELECT SUM(amt_point_topup) FROM pointx_keymat...
2,What is the total amount of points generated b...,"SELECT month_id, SUM(amt_point_pay) FROM point...","SELECT month_id, SUM(amt_point_pay) FROM point..."
3,What is the average rate of released points fo...,SELECT AVG(rate_point_per_baht_pay) FROM point...,SELECT AVG(rate_point_per_baht_pay) FROM point...
4,Can you determine the average number of custom...,"SELECT month_id, AVG(ncust_visit) FROM pointx_...","SELECT month_id, AVG(ncust_visit) FROM pointx_..."


In [38]:
with open("src/pointx/schemas/pointx_keymatrix_dly_schema_description.json") as f:
    keymatrix_schema = json.load(f)
    keymatrix_columns_desc = keymatrix_schema['columns']

In [39]:
schema_link = SchemaLinking()
schema_link.selected_domain(schema_description_folder_path="src/pointx/schemas/descriptions",
                            schema_data_types_folder_path="src/pointx/schemas/datatype")

#### Test NSQL predict SQL

In [40]:
question = "How many unique tractions occur?"
temp_schema = schema_link.filter_schema(question, column_threshold=0.1)
print(gen_sql(create_prompt(question, temp_schema)))

SELECT COUNT(*) FROM pointx_keymatrix_dly;


In [41]:
temp_schema

{'pointx_keymatrix_dly': {'mtd1_n_topup_point_extnl': 0.246,
  'mtd1_n_topup_point_auto_wealth': 0.241,
  'n_topup_point': 0.232,
  'amt_point_topup': 0.226,
  'amt_point_topup_onetime': 0.223}}

## Experiment

In [None]:
df_data = {"Question": [],
           "Actual SQL": [],
           "Predict MASK SQL": [],
           "Actual result": [],
           "Predict result": []}

short_learning_prompt = """You are a SQL query assistant.
I have some SQL where the [MASK] column and the conditional value of [MASK] are formatted. And I want you to respond to the output populating the [MASK] column and the SQL input conditional values, followed by the question, the schema description (name - description) and the 5 row example values as provided.
If you don't know which column or value to enter, Don't include columns you created yourself. And only columns and values defined from the schema and sample values must be used.
Don't use columns from another table or schema. It must also be used from the same table defined in the input.

#################

table :     cat - this table contain cat information
column :    id - number for identify cat
            name - name of cat
            age - age of cat
            gender - gender of cat
example 5 rows:
|   id | name   |   age | gender   |
|-----:|:-------|------:|:---------|
|    5 | Pussy  |     2 | Male     |
|    1 | Sam    |     1 | Male     |
|    9 | Peter  |     2 | Female   |
|    3 | Jack   |     4 | Male     |
|    4 | Ponica |     3 | Female   |
question: Count number of cate each gender.
input: SELECT [MASK], COUNT([MASK]) FROM cat GROUP BY [MASK];
output: SELECT gender, COUNT(*) FROM cat GROUP BY gender;

#################

table :     cat - this table contain cat information
column :    id - number for identify cat
            name - name of cat
            age - age of cat
            gender - gender of cat
example 5 rows:
|   id | name   |   age | gender   |
|-----:|:-------|------:|:---------|
|    5 | Pussy  |     2 | Male     |
|    1 | Sam    |     1 | Male     |
|    9 | Peter  |     2 | Female   |
|    3 | Jack   |     4 | Male     |
|    4 | Ponica |     3 | Female   |
question: show me the name and age of woman cat.
input: SELECT [MASK], [MASK] FROM cat WHERE [MASK] = [MASK];
output: SELECT name, age FROM cat WHERE gender = 'Female';

#################
"""

table_name = "pointx_keymatrix_dly"

for i,row in keymatrix_df.iterrows():
    strat_time = time.time()
    full_prompt = short_learning_prompt 
    question = row['NLQ']
    actual_sql = row['SQL']
    used_schema = schema_link.filter_schema(question, column_threshold=0.2, max_select_columns=10)
    sql = gen_sql(create_prompt(question, used_schema))
    if '*' in sql: sql = sql.replace('*', "[MASK]")

    full_prompt += f"\ntable: {table_name} - {keymatrix_schema['description']}\ncolumn: "

    for col in used_schema[table_name].keys():
        full_prompt += f"{col} - {keymatrix_columns_desc[col]}\n"
        if col in sql: 
            sql = sql.replace(col, "[MASK]")
    
    pattern = r'(["\'])(.*?)\1|\b[-+]?\d+\.\d+\b|\b[-+]?\d+\b'
    matches = re.findall(pattern, sql)
    if matches:
        value_to_replace = matches[0]  # Assuming there is one value to replace
        sql = re.sub(pattern, "[MASK]", sql)
    
    full_prompt += f"example 5 rows:\n{query_n_rows(table_name, list(used_schema[table_name].keys())).to_markdown(index=False)}\n"
    full_prompt += f"question: {question}\n"
    full_prompt += f"input: {sql}\noutput:"
    print(full_prompt)
    response = llm_response(full_prompt)
    if '\n' in response: response = response.replace('\n',' ')

    df_data['Question'].append(question)
    df_data['Actual SQL'].append(actual_sql)
    df_data['Predict MASK SQL'].append(response)
    df_data['Actual result'].append(query_pointx_db(actual_sql))
    df_data['Predict result'].append(query_pointx_db(response))

    print("\nQuestion:",question)
    print("ACTUAL SQL:",actual_sql)
    print("MASKED SQL:",sql)
    print("Response:",response)
    print(f"Time taken: {time.time() - strat_time} seconds")
    print()

In [None]:
result_df = pd.DataFrame(df_data)
result_df.to_excel("Experiments/NSQL-LLM-predict-MASK-keymatrix.xlsx", index=False)
result_df.head()