In [1]:
from SchemaLinking import SchemaLinking
import pandas as pd
import os, json, re

In [2]:
q_pair_df = pd.read_csv("../src/pointx/PointX_questionpair.csv")[['Table', 'Question', 'Actual SQL']]
print(q_pair_df.shape)
q_pair_df.head()

(124, 3)


Unnamed: 0,Table,Question,Actual SQL
0,pointx_keymatrix_dly,What is the total amout of all financial trans...,"SELECT month_id, SUM(ntx_pointx_financial) FRO..."
1,pointx_keymatrix_dly,What is the total amount of points generated b...,SELECT SUM(amt_point_topup) FROM pointx_keymat...
2,pointx_keymatrix_dly,What is the total amount of points generated b...,"SELECT month_id, SUM(amt_point_pay) FROM point..."
3,pointx_keymatrix_dly,What is the average rate of released points fo...,SELECT AVG(rate_point_per_baht_pay) FROM point...
4,pointx_keymatrix_dly,Can you determine the average number of custom...,"SELECT month_id, AVG(ncust_visit) FROM pointx_..."


In [5]:
with open(f"../src/src_dev/pointx/embedded_data.json", "r") as f:
    domain = json.load(f)

schema_link = SchemaLinking(domain)
question = "Display the user id and revenue of user who has the highest total transactions id"
tables = ['pointx_fbs_rpt_dly']

In [14]:
def table_col_of_sql(sql_query:str ,schema_link=schema_link) -> dict:
        """
        Extract tables and their corresponding columns from the given SQL query.

        Parameters:
        sql_query (str): The SQL query from which tables and columns need to be extracted.

        Returns:
        dict: A dictionary containing tables as keys and lists of columns as values.

        Example:
        SchemaLinking.table_col_of_sql("SELECT column1, column2 FROM table1 WHERE column3 = 'value'")
        {'table1': ['column1', 'column2', 'column3']}
        """
        
        selected_schema = {}
        query_split = re.split(schema_link.split_pattern, sql_query)
        for table in schema_link.schema_datatypes.keys():
            if table in query_split:
                selected_col = []
                for col in schema_link.schema_datatypes[table]['COLUMNS'].keys():
                    if col in query_split: selected_col.append(col)
                selected_schema[table] = selected_col

        return selected_schema

In [21]:
def safe_divide(numerator, denominator):
    try:
        result = round(numerator / denominator,2)
    except :
        result = "ZeroDivisionError"
    return result

def classification_accuracy(actual_cols:list, selected_cols:list):

    actual_cols = set(col.lower() for col in actual_cols)
    selected_cols = set(col.lower() for col in selected_cols)
    col_TP = len(actual_cols & selected_cols)
    col_FP = len(selected_cols - actual_cols)
    col_FN = len(actual_cols - selected_cols)

    precision = safe_divide(col_TP, col_TP + col_FP)
    recall = safe_divide(col_TP, col_TP + col_FN)
    f1_score = safe_divide(2 * precision * recall, precision + recall)
    
    return precision, recall, f1_score

In [15]:
table_col_of_sql(q_pair_df.iloc[0, 2])

{'pointx_keymatrix_dly': ['month_id', 'ntx_pointx_financial']}

In [8]:
list(schema_link.filter_schema(question, tables, max_n=10)[tables[0]].keys())

String matching ['id', 'id']


['user_id',
 'user_ltv_revenue',
 'user_ltv_currency',
 'traffic_source_medium',
 'ecommerce',
 'ecoupon_rank',
 'id',
 'merchant_id',
 'order_id',
 'place_id',
 'product_id',
 'total_amount',
 'transaction_id',
 'transaction_status']

In [9]:
list(schema_link.old_filter_schema(question, tables, max_n=10)[tables[0]].keys())

Column string match  ----> id
Column string match  ----> id


['id',
 'transaction_id',
 'user_ltv_revenue',
 'merchant_id',
 'total_amount',
 'transaction_status',
 'user_ltv_currency',
 'order_id',
 'ecommerce',
 'traffic_source_medium']

In [35]:
def schema_link_exp(max_n:int):
    df_data = {
        "Question" : [],
        f"Old {max_n} Selected cols" : [],
        f"Old {max_n} # correct" : [],
        f"Old {max_n} recall" : [],
        f"Old {max_n} precision" : [],
        f"Old {max_n} F1" : [],
        f"New {max_n} Selected cols" : [],
        f"New {max_n} # correct" : [],
        f"New {max_n} recall" : [],
        f"New {max_n} precision" : [],
        f"New {max_n} F1" : []
    }
    
    for i, row in q_pair_df.iterrows():
        
        table = row['Table']
        question = row['Question']
        actual_cols = table_col_of_sql(row['Actual SQL'])[table]

        old_selected_cols = list(schema_link.old_filter_schema(question, [table], max_n=max_n)[table].keys())
        new_selected_cols = list(schema_link.filter_schema(question, [table], max_n=max_n)[table].keys())

        precision, recall, f1 = classification_accuracy(actual_cols, old_selected_cols)

        df_data['Question'].append(question)

        df_data[f'Old {max_n} Selected cols'].append(old_selected_cols)
        df_data[f'Old {max_n} # correct'].append(len(set(actual_cols).intersection(set(old_selected_cols))))
        df_data[f'Old {max_n} recall'].append(recall)
        df_data[f'Old {max_n} precision'].append(precision)
        df_data[f'Old {max_n} F1'].append(f1)

        precision, recall, f1 = classification_accuracy(actual_cols, new_selected_cols)

        df_data[f'New {max_n} Selected cols'].append(new_selected_cols)
        df_data[f'New {max_n} # correct'].append(len(set(actual_cols).intersection(set(new_selected_cols))))
        df_data[f'New {max_n} recall'].append(recall)
        df_data[f'New {max_n} precision'].append(precision)
        df_data[f'New {max_n} F1'].append(f1)

    result_df = pd.DataFrame(df_data)
    return result_df

In [51]:
list_max_n = [200]

full_df = pd.DataFrame(q_pair_df['Question'])
full_df['Actual cols'] = q_pair_df['Actual Cols'] = q_pair_df.apply(lambda row: table_col_of_sql(row['Actual SQL'])[row['Table']], axis=1)
full_df['# Actual cols']= full_df['Actual cols'].apply(len)

for n in list_max_n:
    result_df = schema_link_exp(n)
    full_df = pd.merge(full_df, result_df, on='Question', how='outer')

full_df.to_excel("results/schemalink_pointx_exp.xlsx", index=False)

full_df.head()

Column string match  ----> points
String matching ['points']
Column string match  ----> points
String matching ['points']
Column string match  ----> points
String matching ['points']
Column string match  ----> points
String matching ['points']
Column string match  ----> points
String matching ['points']
Column string match  ----> points
String matching ['points']
Column string match  ----> points
String matching ['points']
Column string match  ----> points
String matching ['points']
Column string match  ----> Platform
String matching ['Platform']
Column string match  ----> points
String matching ['points']
Column string match  ----> points
String matching ['points']
Column string match  ----> points
String matching ['points']
Column string match  ----> engagement_time_msec
String matching ['engagement_time_msec']
Column string match  ----> id
Column string match  ----> id
String matching ['id', 'id']
Column string match  ----> platform
String matching ['platform']
Column string match  

Unnamed: 0,Question,Actual cols,# Actual cols,Old 200 Selected cols,Old 200 # correct,Old 200 recall,Old 200 precision,Old 200 F1,New 200 Selected cols,New 200 # correct,New 200 recall,New 200 precision,New 200 F1
0,What is the total amout of all financial trans...,"[month_id, ntx_pointx_financial]",2,"[month_id, ntx_pointx_financial, ntx_pointx_fi...",2,1.0,0.01,0.02,"[month_id, ntx_pointx_financial, ntx_pointx_fi...",2,1.0,0.02,0.04
1,What is the total amount of points generated b...,"[month_id, amt_point_topup]",2,"[month_id, ntx_pointx_financial, ntx_pointx_fi...",2,1.0,0.01,0.02,"[month_id, ntx_pointx_financial, ncust_user, n...",2,1.0,0.02,0.04
2,What is the total amount of points generated b...,"[month_id, amt_point_pay]",2,"[month_id, ntx_pointx_financial, ntx_pointx_fi...",2,1.0,0.01,0.02,"[month_id, ntx_pointx_financial_out, ncust_use...",2,1.0,0.02,0.04
3,What is the average rate of released points fo...,[rate_point_per_baht_pay],1,"[month_id, ntx_pointx_financial, ntx_pointx_fi...",1,1.0,0.01,0.02,"[month_id, ntx_pointx_financial_out, ncust_use...",1,1.0,0.01,0.02
4,Can you determine the average number of custom...,"[month_id, ncust_visit]",2,"[month_id, ntx_pointx_financial, ntx_pointx_fi...",2,1.0,0.01,0.02,"[month_id, ncust_user, ncust_pointx, ncust_vis...",2,1.0,0.02,0.04


In [55]:
full_df['New 200 F1'].mean()

0.03096774193548387

In [56]:
full_df['Old 200 F1'].mean()

0.03919354838709677