In [None]:
import json, sqlparse, httpx
from tqdm import tqdm
import pandas as pd

In [None]:
sql_extract_token_type = {
            sqlparse.sql.IdentifierList, sqlparse.sql.Where,
            sqlparse.sql.Having, sqlparse.sql.Comparison, sqlparse.sql.Function,
            sqlparse.sql.Parenthesis, sqlparse.sql.Operation, sqlparse.sql.Case
        }

def columns_from_query(sql_query):
    # identifiers contain table name and column name
    if type(sql_query) == str:
        sql_query = sqlparse.parse(sql_query)[0]
    columns = []
    for token in sql_query:
        if isinstance(token, sqlparse.sql.Identifier):
            columns.append(token.get_real_name().lower())
        elif hasattr(token, "tokens"):
            columns.extend(columns_from_query(token.tokens))
    return columns

def columns_by_split(sql_query:str, all_columns:list):
    columns = []
    for token in sql_query.split():
        if token[-1] == ",": token = token[:-1]
        if token in all_columns:
            columns.append(token)
    return columns

In [None]:
with open("../filtering-schema/src/schemas/column-datatypes/pointx_fbs_rpt_dly_datatype.json") as f:
    all_columns = set(json.load(f)['COLUMNS'].keys())
exp_df = pd.read_excel("../src/pointx/PointX - text2sql pair.xlsx")[['Question', 'SQL']]
exp_df.head()

In [None]:
used_cols = []

for i, row in exp_df.iterrows():
    try:
        used_cols.extend([c for c in columns_from_query(row['SQL']) if c in all_columns])
        used_cols.extend(columns_by_split(row['SQL'], all_columns))
    except: pass

used_cols = list(set(used_cols))
sorted(used_cols)

In [None]:
pointx_rpt_dly_df = pd.read_csv("../filtering-schema/src/data/pointx_fbs_rpt_dly.csv")[used_cols]
pointx_rpt_dly_df

In [None]:
def predict_sql(question, timeout=60):
    url = "http://0.0.0.0:8000/nlq"
    payload = {
                "input": {
                    "text": question
                }
            }
    # response = httpx.post(url, json=payload, timeout=120)
    # return response.json()['data'][0]['text']
    try:
        response = httpx.post(url, json=payload, timeout=timeout)
        response.raise_for_status()  # Raise an HTTPError for bad responses
        return response.json()['data'][0]['text']
    except httpx.RequestError as e:
        print(f"Request error: {e}")
    except httpx.TimeoutException as e:
        print(f"Request timed out: {e}")
    except httpx.HTTPStatusError as e:
        print(f"HTTP error: {e}")
    return None

In [None]:
for question in tqdm(exp_df['Question']):
    print()
    print(question)
    print(predict_sql(question))
    print('--------------')
    