In [1]:
# Imports
from sql_ast_dataset.ast_processing.factory import Factory
from sql_ast_dataset.ast_processing.ast_diff_types import ASTDiffInput
from typing import List, Any, Dict

In [2]:
# Helper function to visualize the output strings
def print_training_sample(query:str, query_subword_indices: List[Any], query_subword_labels: List[Any]):
    index_to_char_map = {idx: char for idx, char in enumerate(query)}
    for label, indices in zip(query_subword_labels, query_subword_indices):
        color = round(label) # Either 0, or 1

        for index in indices:
            index_to_char_map[index] = f"\x1b[0;{31 + color}m{query[index]}\x1b[0m"
    ret: str = ""
    for key in index_to_char_map.keys():
        ret += index_to_char_map[key]
    return ret

In [3]:
# Sample data
utterance = "How many ships ended up being 'Captured'?"
gold_query = "SELECT COUNT(*) FROM ship WHERE disposition_of_ship = 'Captured'"

wrong_query_list = [
    "SELECT COUNT(*) FROM ship WHERE location = 'captured'",
    "/* 1 */ SELECT COUNT(*) FROM ship WHERE lost_in_battle IS NULL",
    "SELECT name FROM ship WHERE disposition_of_ship = 'Captured'",
    "SELECT COUNT(*) FROM ship WHERE disposition_of_ship = 'captured'",
    "SELECT ships.id, name FROM ship, death",
    "SELECT s.name FROM ship AS s WHERE EXISTS(SELECT * FROM death AS d WHERE s.id = d.kill) AND s.Disposition_of_ship = 'unstable'"
]

# Generate SQL AST Classification Data
query_processor = Factory().build("QueryProcessor", {})
assert query_processor is not None

ast_diff_list: List[ASTDiffInput] = []
# The gold query:
ast_diff_list.append(query_processor.process(
    sql_query_1=gold_query,
    sql_query_2=gold_query,
    label=1
))

for wrong_query in wrong_query_list:
    ast_diff_list.append(query_processor.process(
        sql_query_1 = wrong_query, # The current query
        sql_query_2 = gold_query, # The query to compare against
        label = 0
    ))

In [4]:
# Plain SQL queries
print("  ", utterance)
for sample in ast_diff_list:
    # Prints the label and the highlited SQL
    print(sample.label, "", sample.query)
    

   How many ships ended up being 'Captured'?
1  SELECT COUNT(*) FROM ship WHERE disposition_of_ship = 'Captured'
0  SELECT COUNT(*) FROM ship WHERE location = 'captured'
0  /* 1 */ SELECT COUNT(*) FROM ship WHERE lost_in_battle IS NULL
0  SELECT name FROM ship WHERE disposition_of_ship = 'Captured'
0  SELECT COUNT(*) FROM ship WHERE disposition_of_ship = 'captured'
0  SELECT ships.id, name FROM ship, death
0  SELECT s.name FROM ship AS s WHERE EXISTS(SELECT * FROM death AS d WHERE s.id = d.kill) AND s.Disposition_of_ship = 'unstable'


In [5]:
# Simple visualization
print("  ", utterance)
for sample in ast_diff_list:
    # Prints the label and the highlighted SQL
    print(f"{sample.label}", print_training_sample(sample.query, sample.query_subword_indices_as_list(), sample.get_labels()))
    

   How many ships ended up being 'Captured'?
1 [0;32mS[0m[0;32mE[0m[0;32mL[0m[0;32mE[0m[0;32mC[0m[0;32mT[0m[0;32m [0m[0;32mC[0m[0;32mO[0m[0;32mU[0m[0;32mN[0m[0;32mT[0m[0;32m([0m[0;32m*[0m[0;32m)[0m[0;32m [0m[0;32mF[0m[0;32mR[0m[0;32mO[0m[0;32mM[0m[0;32m [0m[0;32ms[0m[0;32mh[0m[0;32mi[0m[0;32mp[0m[0;32m [0m[0;32mW[0m[0;32mH[0m[0;32mE[0m[0;32mR[0m[0;32mE[0m[0;32m [0m[0;32md[0m[0;32mi[0m[0;32ms[0m[0;32mp[0m[0;32mo[0m[0;32ms[0m[0;32mi[0m[0;32mt[0m[0;32mi[0m[0;32mo[0m[0;32mn[0m[0;32m_[0m[0;32mo[0m[0;32mf[0m[0;32m_[0m[0;32ms[0m[0;32mh[0m[0;32mi[0m[0;32mp[0m[0;32m [0m[0;32m=[0m[0;32m [0m[0;32m'[0m[0;32mC[0m[0;32ma[0m[0;32mp[0m[0;32mt[0m[0;32mu[0m[0;32mr[0m[0;32me[0m[0;32md[0m[0;32m'[0m
0 [0;32mS[0m[0;32mE[0m[0;32mL[0m[0;32mE[0m[0;32mC[0m[0;32mT[0m[0;32m [0m[0;32mC[0m[0;32mO[0m[0;32mU[0m[0;32mN[0m[0;32mT[0m[0;32m([0m[0;32m*[0m[0;32m)[0m[