Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
SQL Coverage script (#1750)
Browse files Browse the repository at this point in the history
Adding this now and tweaking a little of the pre-processing to break up the PR stream for these parsers - obviously this won't be the final grammar, because I need to modify it depending on the table, but it will be good to keep around a grammar which parses the raw dataset, because I _do not_ want to have to debug that thing twice.
  • Loading branch information
DeNeutoy committed Sep 18, 2018
1 parent 0b44bce commit 6039ac0
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 18 deletions.
4 changes: 2 additions & 2 deletions allennlp/data/dataset_readers/dataset_utils/text2sql_utils.py
Expand Up @@ -57,12 +57,12 @@ def replace_variables(sentence: List[str],

def clean_and_split_sql(sql: str) -> List[str]:
"""
Cleans up and unifies a SQL query. This involves removing unnecessary quotes
Cleans up and unifies a SQL query. This involves unifying quoted strings
and splitting brackets which aren't formatted consistently in the data.
"""
sql_tokens = []
for token in sql.strip().split():
token = token.replace('"', "").replace("'", "").replace("%", "")
token = token.replace('"', "'").replace("%", "")
if token.endswith("(") and len(token) > 1:
sql_tokens.append(token[:-1])
sql_tokens.append(token[-1])
Expand Down
Expand Up @@ -23,8 +23,8 @@ def test_process_sql_data_blob(self):
assert sql_data.text_with_variables == ['how', 'many', 'name0', 'are', 'there', 'in', 'city_name0', '?']
assert sql_data.sql == ['SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', 'AS', 'LOCATIONalias0', ',',
'RESTAURANT', 'AS', 'RESTAURANTalias0', 'WHERE', 'LOCATIONalias0.CITY_NAME', '=',
'city_name0', 'AND', 'RESTAURANTalias0.ID', '=', 'LOCATIONalias0.RESTAURANT_ID',
'AND', 'RESTAURANTalias0.NAME', '=', 'name0', ';']
'\'city_name0\'', 'AND', 'RESTAURANTalias0.ID', '=', 'LOCATIONalias0.RESTAURANT_ID',
'AND', 'RESTAURANTalias0.NAME', '=', '\'name0\'', ';']
assert sql_data.text_variables == {'city_name0': 'san francisco', 'name0': 'buttercup kitchen'}
assert sql_data.sql_variables == {'city_name0': 'san francisco', 'name0': 'buttercup kitchen'}

Expand All @@ -44,8 +44,8 @@ def test_process_sql_data_blob(self):
for i, sql_data in enumerate(dataset):
assert sql_data.sql == ['SELECT', 'COUNT', '(', '*', ')', 'FROM', 'GEOGRAPHIC', 'AS', 'GEOGRAPHICalias0',
',', 'RESTAURANT', 'AS', 'RESTAURANTalias0', 'WHERE', 'GEOGRAPHICalias0.REGION',
'=', 'region0', 'AND', 'RESTAURANTalias0.CITY_NAME', '=', 'GEOGRAPHICalias0.CITY_NAME',
'AND', 'RESTAURANTalias0.FOOD_TYPE', '=', 'food_type0', ';']
'=', '\'region0\'', 'AND', 'RESTAURANTalias0.CITY_NAME', '=', 'GEOGRAPHICalias0.CITY_NAME',
'AND', 'RESTAURANTalias0.FOOD_TYPE', '=', '\'food_type0\'', ';']
assert sql_data.text_variables == {'region0': 'bay area', 'food_type0': 'chinese'}
assert sql_data.sql_variables == {'region0': 'bay area', 'food_type0': 'chinese'}
assert sql_data.text == correct_text[i][0]
Expand All @@ -71,5 +71,5 @@ def test_clean_and_split_sql(self):
cleaned = text2sql_utils.clean_and_split_sql(sql)
assert cleaned == ['SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', 'AS', 'LOCATIONalias0', ',',
'RESTAURANT', 'AS', 'RESTAURANTalias0', 'WHERE', 'LOCATIONalias0.CITY_NAME', '=',
'city_name0', 'AND', 'RESTAURANTalias0.ID', '=', 'LOCATIONalias0.RESTAURANT_ID',
'AND', 'RESTAURANTalias0.NAME', '=', 'name0', ';']
'\'city_name0\'', 'AND', 'RESTAURANTalias0.ID', '=', 'LOCATIONalias0.RESTAURANT_ID',
'AND', 'RESTAURANTalias0.NAME', '=', '\'name0\'', ';']
Expand Up @@ -19,45 +19,45 @@ def test_reader(self):
assert tokens == ['how', 'many', 'buttercup', 'kitchen', 'are', 'there', 'in', 'san', 'francisco', '?']
assert tags == ['O', 'O', 'name0', 'name0', 'O', 'O', 'O', 'city_name0', 'city_name0', 'O']
assert fields["template"].label == "SELECT COUNT ( * ) FROM LOCATION AS LOCATIONalias0 , RESTAURANT " \
"AS RESTAURANTalias0 WHERE LOCATIONalias0.CITY_NAME = city_name0 " \
"AS RESTAURANTalias0 WHERE LOCATIONalias0.CITY_NAME = 'city_name0' " \
"AND RESTAURANTalias0.ID = LOCATIONalias0.RESTAURANT_ID AND " \
"RESTAURANTalias0.NAME = name0 ;"
"RESTAURANTalias0.NAME = 'name0' ;"

fields = instances[1].fields
tokens = [t.text for t in fields["tokens"].tokens]
tags = fields["slot_tags"].labels
assert tokens == ['how', 'many', 'chinese', 'restaurants', 'are', 'there', 'in', 'the', 'bay', 'area', '?']
assert tags == ['O', 'O', 'food_type0', 'O', 'O', 'O', 'O', 'O', 'region0', 'region0', 'O']
assert fields["template"].label == "SELECT COUNT ( * ) FROM GEOGRAPHIC AS GEOGRAPHICalias0 , RESTAURANT AS "\
"RESTAURANTalias0 WHERE GEOGRAPHICalias0.REGION = region0 AND "\
"RESTAURANTalias0 WHERE GEOGRAPHICalias0.REGION = 'region0' AND "\
"RESTAURANTalias0.CITY_NAME = GEOGRAPHICalias0.CITY_NAME AND "\
"RESTAURANTalias0.FOOD_TYPE = food_type0 ;"
"RESTAURANTalias0.FOOD_TYPE = 'food_type0' ;"

fields = instances[2].fields
tokens = [t.text for t in fields["tokens"].tokens]
tags = fields["slot_tags"].labels
assert tokens == ['how', 'many', 'places', 'for', 'chinese', 'food', 'are', 'there', 'in', 'the', 'bay', 'area', '?']
assert tags == ['O', 'O', 'O', 'O', 'food_type0', 'O', 'O', 'O', 'O', 'O', 'region0', 'region0', 'O']
assert fields["template"].label == "SELECT COUNT ( * ) FROM GEOGRAPHIC AS GEOGRAPHICalias0 , RESTAURANT AS "\
"RESTAURANTalias0 WHERE GEOGRAPHICalias0.REGION = region0 AND "\
"RESTAURANTalias0 WHERE GEOGRAPHICalias0.REGION = 'region0' AND "\
"RESTAURANTalias0.CITY_NAME = GEOGRAPHICalias0.CITY_NAME AND "\
"RESTAURANTalias0.FOOD_TYPE = food_type0 ;"
"RESTAURANTalias0.FOOD_TYPE = 'food_type0' ;"

fields = instances[3].fields
tokens = [t.text for t in fields["tokens"].tokens]
tags = fields["slot_tags"].labels
assert tokens == ['how', 'many', 'chinese', 'places', 'are', 'there', 'in', 'the', 'bay', 'area', '?']
assert tags == ['O', 'O', 'food_type0', 'O', 'O', 'O', 'O', 'O', 'region0', 'region0', 'O']
assert fields["template"].label == "SELECT COUNT ( * ) FROM GEOGRAPHIC AS GEOGRAPHICalias0 , RESTAURANT AS "\
"RESTAURANTalias0 WHERE GEOGRAPHICalias0.REGION = region0 AND "\
"RESTAURANTalias0 WHERE GEOGRAPHICalias0.REGION = 'region0' AND "\
"RESTAURANTalias0.CITY_NAME = GEOGRAPHICalias0.CITY_NAME AND "\
"RESTAURANTalias0.FOOD_TYPE = food_type0 ;"
"RESTAURANTalias0.FOOD_TYPE = 'food_type0' ;"
fields = instances[4].fields
tokens = [t.text for t in fields["tokens"].tokens]
tags = fields["slot_tags"].labels
assert tokens == ['how', 'many', 'places', 'for', 'chinese', 'are', 'there', 'in', 'the', 'bay', 'area', '?']
assert tags == ['O', 'O', 'O', 'O', 'food_type0', 'O', 'O', 'O', 'O', 'region0', 'region0', 'O']
assert fields["template"].label == "SELECT COUNT ( * ) FROM GEOGRAPHIC AS GEOGRAPHICalias0 , RESTAURANT AS "\
"RESTAURANTalias0 WHERE GEOGRAPHICalias0.REGION = region0 AND "\
"RESTAURANTalias0 WHERE GEOGRAPHICalias0.REGION = 'region0' AND "\
"RESTAURANTalias0.CITY_NAME = GEOGRAPHICalias0.CITY_NAME AND "\
"RESTAURANTalias0.FOOD_TYPE = food_type0 ;"
"RESTAURANTalias0.FOOD_TYPE = 'food_type0' ;"
214 changes: 214 additions & 0 deletions scripts/examine_sql_coverage.py
@@ -0,0 +1,214 @@
import json
import os
import sys
import argparse

sys.path.insert(0, os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir))))

from allennlp.data.dataset_readers.dataset_utils.text2sql_utils import process_sql_data, SqlData
from allennlp.semparse.contexts.sql_table_context import SqlVisitor
from parsimonious.grammar import Grammar

# still TODO:
# JOIN, seems hard.
# Added query to pos_value - check this, very unclear if it is the correct way to handle this.
# not all functions can take * as an argument.
# Check whether LIKE can take non string arguments (example in scholar dataset)

SQL_GRAMMAR2 = Grammar(
r"""
stmt = (query ws ";") / (query ws)
query = (ws select_core groupby_clause ws orderby_clause ws limit) /
(ws select_core groupby_clause ws orderby_clause) /
(ws select_core groupby_clause ws limit) /
(ws select_core orderby_clause ws limit) /
(ws select_core groupby_clause) /
(ws select_core orderby_clause) /
(ws select_core)
select_core = (select_with_distinct select_results from_clause where_clause) /
(select_with_distinct select_results from_clause) /
(select_with_distinct select_results where_clause) /
(select_with_distinct select_results)
select_with_distinct = (SELECT DISTINCT) / SELECT
select_results = (ws select_result ws "," ws select_results) / (ws select_result)
select_result = sel_res_all_star / sel_res_tab_star / sel_res_val / sel_res_col
sel_res_tab_star = name ".*"
sel_res_all_star = "*"
sel_res_val = (expr AS wsp name) / expr
sel_res_col = col_ref (AS wsp name)
from_clause = FROM source
source = (ws single_source ws "," ws source) / (ws single_source)
single_source = source_table / source_subq
source_table = table_name (AS wsp name)
source_subq = ("(" ws query ws ")" AS ws name) / ("(" ws query ws ")")
where_clause = (WHERE wsp expr where_conj) / (WHERE wsp expr)
where_conj = (AND wsp expr where_conj) / (AND wsp expr)
groupby_clause = (GROUP BY group_clause having_clause) / (GROUP BY group_clause)
group_clause = (ws expr ws "," group_clause) / (ws expr)
having_clause = HAVING ws expr
orderby_clause = ORDER BY order_clause
order_clause = (ordering_term ws "," order_clause) / ordering_term
ordering_term = (ws expr ordering) / (ws expr)
ordering = ASC / DESC
limit = LIMIT ws number
col_ref = (table_name "." column_name) / column_name
table_name = name
column_name = name
ws = ~"\s*"i
wsp = ~"\s+"i
name = ~"[a-zA-Z]\w*"i
expr = in_expr / like_expr / between_expr / binary_expr /
unary_expr / null_check_expr / source_subq / value
like_expr = value wsp LIKE ws string
in_expr = (value wsp NOT IN wsp string_set) /
(value wsp IN wsp string_set) /
(value wsp NOT IN wsp expr) /
(value wsp IN wsp expr)
between_expr = value BETWEEN wsp value AND wsp value
binary_expr = value ws binaryop wsp expr
unary_expr = unaryop expr
null_check_expr = (col_ref IS NOT NULL) / (col_ref IS NULL)
value = parenval / datetime / number / boolean / function / col_ref / string
datetime = "YEAR(CURDATE())"
parenval = "(" ws expr ws ")"
function = (fname ws "(" ws DISTINCT ws arg_list_or_star ws ")") /
(fname ws "(" ws arg_list_or_star ws ")")
arg_list_or_star = arg_list / "*"
arg_list = (expr ws "," ws arg_list) / expr
number = ~"\d*\.?\d+"i
string_set = ws "(" ws string_set_vals ws ")"
string_set_vals = (string ws "," ws string_set_vals) / string
string = ~"\'.*?\'"i
fname = "COUNT" / "SUM" / "MAX" / "MIN" / "AVG" / "ALL"
boolean = "true" / "false"
binaryop = "+" / "-" / "*" / "/" / "=" / "<>" / ">=" / "<=" / ">" / "<" / ">" / AND / OR / LIKE
binaryop_no_andor = "+" / "-" / "*" / "/" / "=" / "<>" / "<=" / ">" / "<" / ">"
unaryop = "+" / "-" / "not" / "NOT"
SELECT = ws "SELECT"
FROM = ws "FROM"
WHERE = ws "WHERE"
AS = ws "AS"
LIKE = ws "LIKE"
AND = ws "AND"
OR = ws "OR"
DISTINCT = ws "DISTINCT"
GROUP = ws "GROUP"
ORDER = ws "ORDER"
BY = ws "BY"
ASC = ws "ASC"
DESC = ws "DESC"
BETWEEN = ws "BETWEEN"
IN = ws "IN"
IS = ws "IS"
NOT = ws "NOT"
NULL = ws "NULL"
HAVING = ws "HAVING"
LIMIT = ws "LIMIT"
LIKE = ws "LIKE"
"""
)


def parse_dataset(filename: str, filter_by: str = None, verbose: bool = False):

filter_by = filter_by or "13754332dvmklfdsaf-3543543"
data = json.load(open(filename))
num_queries = 0
num_parsed = 0
filtered_errors = 0

for i, sql_data in enumerate(process_sql_data(data)):
sql_visitor = SqlVisitor(SQL_GRAMMAR2)

if any([x[:7] == "DERIVED"] for x in sql_data.sql):
# NOTE: DATA hack alert - the geography dataset doesn't alias derived tables consistently,
# so we fix the data a bit here instead of completely re-working the grammar.
sql_to_use = []
for j, token in enumerate(sql_data.sql):
if token[:7] == "DERIVED" and sql_data.sql[j-1] == ")":
sql_to_use.append("AS")
sql_to_use.append(token)

sql_string = " ".join(sql_to_use)
else:
sql_string = " ".join(sql_data.sql)
num_queries += 1
try:
prod_rules = sql_visitor.parse(sql_string)
num_parsed += 1
except Exception as e:

if filter_by in sql_string:
filtered_errors += 1

if verbose and filter_by not in sql_string:
print()
print(e)
print(" ".join(sql_data.text))
print(sql_data.sql)
try:
import sqlparse
print(sqlparse.format(sql_string, reindent=True))
except Exception:
print(sql_string)

if (i + 1) % 500 == 0:
print(f"\tProcessed {i + 1} queries.")

return num_parsed, num_queries, filtered_errors

def main(data_directory: int, dataset: str = None, filter_by: str = None, verbose: bool = False) -> None:
"""
Parameters
----------
data_directory : str, required.
The path to the data directory of https://github.com/jkkummerfeld/text2sql-data
which has been preprocessed using scripts/reformat_text2sql_data.py.
dataset : str, optional.
The dataset to parse. By default all are parsed.
filter_by : str, optional
Compute statistics about a particular error and only print errors which don't contain this string.
verbose : bool, optional.
Whether to print information about incorrectly parsed SQL.
"""
directory_dict = {path: files for path, names, files in os.walk(data_directory) if files}

for directory, data_files in directory_dict.items():
if "query_split" in directory or (dataset is not None and dataset not in directory):
continue

print(f"Parsing dataset at {directory}")
parsed = 0
total = 0
for json_file in data_files:
print(f"\tParsing split at {json_file}")
file_path = os.path.join(directory, json_file)
num_parsed, num_queries, filtered_errors = parse_dataset(file_path, filter_by, verbose)

parsed += num_parsed
total += num_queries

print(f"\tParsed {parsed} out of {total} queries, coverage {parsed/total}")
if filter_by is not None:
print(f"\tOf {total - parsed} errors, {filtered_errors/ (total - parsed + 1e-13)} contain {filter_by}")

if __name__ == "__main__":

parser = argparse.ArgumentParser(description="Check the coverage of a SQL Grammar on the text2sql datasets.")
parser.add_argument('--data', type=str, help='The path to the text2sql data directory.')
parser.add_argument('--dataset', type=str, default=None,
help='The dataset to check coverage for. Defaults to all datasets.')
parser.add_argument('--filter', type=str, default=None, help='A string to filter by.')
parser.add_argument('--verbose', help='Verbose output.', action='store_true')
args = parser.parse_args()
main(args.data, args.dataset, args.filter, args.verbose)

0 comments on commit 6039ac0

Please sign in to comment.