In [1]:
import pandas as pd
import numpy as np
import sqlparse, time
from sentence_transformers import SentenceTransformer, util

description_df = pd.read_excel('src/New_query_Description.xlsx',header=1)
description_df = description_df[['Column','Description']].dropna().reset_index(drop=True)

In [2]:
descriptions = description_df['Description'].tolist()

In [3]:
pointx_cols = pd.read_csv('src/pointx_fbs_rpt_dly.csv').columns.to_list()

In [4]:
def get_col(sql):
    col = []
    ignore = {
        'over', 'extract', 'desc', 'datediff', 'dayofweek', 'cnt', 'dateadd',
        'max', 'min', 'sum', 'count', 'getdate', 'timestampdiff', 'weekday',
        'having', 'month', 'year', 'day', 'date', 'avg', 'team_tds.tds_intern.pointx_fbs_txn_rpt_dly'
    }
    include_types = {
        sqlparse.sql.IdentifierList, sqlparse.sql.Where,
        sqlparse.sql.Having, sqlparse.sql.Comparison, sqlparse.sql.Function,
        sqlparse.sql.Parenthesis, sqlparse.sql.Operation, sqlparse.sql.Case
    }
    
    for token in sql.tokens:
        if str(token).lower() in ignore :
            continue
        
        if isinstance(token, sqlparse.sql.Identifier):
            if len(str(token).lower().split('as')) > 1:
                col.extend(get_col(token))
            elif str(token).lower() in pointx_cols:
                col.append(str(token))
                
        elif isinstance(token, tuple(include_types)):
            col.extend(get_col(token))
    return col

In [5]:
compare_df = pd.read_csv('src/compare_result.csv')
questions = compare_df['Question'].to_list()
sql_queries = compare_df['SQL'].to_list()

In [6]:
model = SentenceTransformer('all-MiniLM-L6-v2')
description_embs = [model.encode(des) for des in descriptions]
result_df = pd.DataFrame(columns=['Question','SQL','Expect columns','Minimum Threshold','Choose columns','Time'])

def get_col_max_score(question,col_labels,description_embs=description_embs,description_df=description_df):
    q_emb = model.encode(question)
    scores = np.array([float(util.cos_sim(q_emb, des)) for des in description_embs])
    col_labels_index = description_df[description_df['Column'].isin(col_labels)].index.tolist()
    col_labels_score = scores[col_labels_index]
    min_threshold = np.min(col_labels_score)

    print("QUESTION:\t",question)
    print("EXPECT COLUMNS:\t",col_labels)
    print("MIN THRESHOLD:\t",min_threshold)
    # print("MAX THRESHOLD:\t",np.max(col_labels_score))
    # print("MAX SCORE COLUMN:\t",description_df.iloc[np.argmax(scores)]['Column'])
    # print("DESCIPTION:\t",description_df.iloc[np.argmax(scores)]['Description'])
    # print("SCORE:\t",np.max(scores))

    n_columns = len(description_df.iloc[np.where(scores >= min_threshold)])
    print(f"CHOOSE RELATE COLUMN WITHIN THRESHOLD (FROM {len(scores)} COLUMNS):",n_columns)
    print()
    
    return n_columns

In [7]:
n_cols = 0
start_time = time.time()
for i, q in enumerate(questions):
    sql_parse = sqlparse.parse(sql_queries[i])[0]
    col_labels = list(set(get_col(sql_parse)))
    print(col_labels)
    n_cols += get_col_max_score(q,col_labels)

print("AVERAGE NUMBER OF COLUMNS:\t",n_cols/len(questions))
print("AVG TIME PER QUESTION:\t",(time.time()-start_time)/len(questions))

['event_date', 'ga_session_id', 'engagement_time_msec', 'user_pseudo_id']
QUESTION:	 How many daily active users each day?
EXPECT COLUMNS:	 ['event_date', 'ga_session_id', 'engagement_time_msec', 'user_pseudo_id']
MIN THRESHOLD:	 0.15984684228897095
CHOOSE RELATE COLUMN WITHIN THRESHOLD (FROM 182 COLUMNS): 49

['ga_session_id', 'engagement_time_msec', 'event_month', 'user_pseudo_id']
QUESTION:	 How many monthly active users each month?
EXPECT COLUMNS:	 ['ga_session_id', 'engagement_time_msec', 'event_month', 'user_pseudo_id']
MIN THRESHOLD:	 0.1179906576871872
CHOOSE RELATE COLUMN WITHIN THRESHOLD (FROM 182 COLUMNS): 68

['event_date', 'ga_session_id', 'engagement_time_msec', 'user_pseudo_id']
QUESTION:	 What is the average number of daily active users last 7 days?
EXPECT COLUMNS:	 ['event_date', 'ga_session_id', 'engagement_time_msec', 'user_pseudo_id']
MIN THRESHOLD:	 0.11249055713415146
CHOOSE RELATE COLUMN WITHIN THRESHOLD (FROM 182 COLUMNS): 76

['event_date', 'ga_session_id', 'en

In [8]:
ques = ["How many users have not used the app for more than a month?"]
for i, q in enumerate(ques):
    sql_parse = sqlparse.parse(sql_queries[i])[0]
    col_labels = list(set(get_col(sql_parse)))
    n_cols += get_col_max_score(q,col_labels)

QUESTION:	 How many users have not used the app for more than a month?
EXPECT COLUMNS:	 ['event_date', 'ga_session_id', 'engagement_time_msec', 'user_pseudo_id']
MIN THRESHOLD:	 0.03984004259109497
CHOOSE RELATE COLUMN WITHIN THRESHOLD (FROM 182 COLUMNS): 143



In [None]:
#TESTING

In [121]:
def test_get_col(sql):
    col = []
    ignore = {
        'over', 'extract', 'desc', 'datediff', 'dayofweek', 'cnt', 'dateadd',
        'max', 'min', 'sum', 'count', 'getdate', 'timestampdiff', 'weekday',
        'having', 'month', 'year', 'day', 'date', 'avg', 'team_tds.tds_intern.pointx_fbs_txn_rpt_dly'
    }
    include_types = {
        sqlparse.sql.IdentifierList, sqlparse.sql.Where,
        sqlparse.sql.Having, sqlparse.sql.Comparison, sqlparse.sql.Function,
        sqlparse.sql.Parenthesis, sqlparse.sql.Operation, sqlparse.sql.Case
    }
    
    for token in sql.tokens:
        
        print(token,type(token))
        if str(token).lower() in ignore :
            continue
        
        if isinstance(token, sqlparse.sql.Identifier):
            if len(str(token).lower().split('as')) > 1:
                col.extend(test_get_col(token))
            elif str(token).lower() in pointx_cols:
                col.append(str(token))
                
        elif isinstance(token, tuple(include_types)):
            col.extend(test_get_col(token))
    return col

In [128]:
sql_parse = sqlparse.parse("""WITH users AS (
  SELECT user_pseudo_id, event_timestamp 
  FROM team_tds.tds_intern.pointx_fbs_txn_rpt_dly
  WHERE event_name = "session_start"
  AND event_timestamp >= DATEADD(DAY, -60, GETDATE())
),
users_start AS (
SELECT COUNT(DISTINCT user_pseudo_id) as user FROM users
),
users_end AS (
SELECT COUNT(DISTINCT user_pseudo_id) as user FROM users
WHERE event_timestamp >= DATEADD(DAY, -30, GETDATE())
)""")[0]
test_get_col(sql_parse)

KeyboardInterrupt: 

In [104]:
sql_parse.tokens[0].tokens

[<Parenthesis '(COUNT...' at 0x2A4BEE2D0>,
 <Whitespace ' ' at 0x2A4AD9360>,
 <Operator '*' at 0x2A4ADBBE0>,
 <Whitespace ' ' at 0x2A4ADB5E0>,
 <Integer '100' at 0x2A4ADBA00>]