# Syntax Pattern Detection on Sentence Level

Imports

In [None]:
import stanza
import json

Setup WordNet & Stopword List

In [None]:
import nltk
nltk.download('wordnet')
nltk.download('stopwords')
from nltk.corpus import wordnet, stopwords

stopwords_en = set(stopwords.words('english'))

Define placeholders

In [None]:
from enum import Enum

class ChangedBy(Enum):
    DB_TABLE = 1
    DB_COLUMN = 2
    DB_VAL = 3
    NER = 4
    NUM = 5
    SYNONYM = 6


TABLE_NAME = '{TABLE}'
COLUMN_NAME = '{COLUMN}'
COLUMN_PARTIAL_NAME = '{COLUMN_PART}'
DB_VALUE = '{VALUE}'
NUMERICAL = '{NUMBER}'
NAMED_ENTITY = '{NE}'

Specify Spider database

In [None]:
db_name = 'party_people'
spider_directory = 'spider_data'
db_directory = f'{spider_directory}/database/{db_name}'
schema_file = f'{db_directory}/schema.sql'
db_file = f'{db_directory}/{db_name}.sqlite'
data_file = f'{spider_directory}/train_spider.json'
tables_file = f'{spider_directory}/tables.json'
export_trees_file = f'{spider_directory}/patterns.{db_name}.json'

Function: Load schema information from SQL file

In [None]:
def load_schema(db_id, tables_file_path):
    """
    Load schema information from tables.json

    :param db_id: database
    :type db_id: str
    :param tables_file_path: path to tables.json
    :type tables_file_path: str
    :return: db schema information, set of table names, set of column names
    :rtype: Dict[Str, Any], set[str], set[str]
    """
    with open(tables_file_path, "r") as tables_json:
        dbs = json.load(tables_json)
        for db_schema in dbs:
            if db_schema["db_id"] == db_id:
                column_names = set(t[1].lower() for t in db_schema["column_names"])
                column_names.update(t[1].lower() for t in db_schema["column_names_original"])
                column_partial_names = set()
                for column_name in column_names:
                    column_partial_names.update(t for t in column_name.split() if not t in stopwords_en)
                return db_schema, set(db_schema['table_names']), column_names, column_partial_names

Function: Load database values from file

In [None]:
def load_db_values(schema_file_path):
    """
    Load all database values from file that are not stopwords

    :param schema_file_path: data base schema definition file
    :type schema_file_path: str
    :return: All values of the selected database
    :rtype: Set[Any]
    """
    db_values = set()
    with open(schema_file_path, 'r') as schema_file:
        for line in schema_file.readlines():
            if line.startswith("INSERT INTO"):
                # Get values part of insert string but without the surrounding blanks, brackets and the semicolon
                values_raw = line.split("VALUES")[-1][2:-3]
                db_values.update(v for v in json.loads(f'[{values_raw}]') if not v in stopwords_en)

    return db_values

Load DB Schema & values

In [None]:
db_schema, db_tables, db_colums, db_partial_columns = load_schema(db_name, tables_file)
print(db_schema)
print(db_tables)
print(db_colums)
db_values = load_db_values(schema_file)
print(db_values)

Function: Load NL-SQL pairs for given Spider database

In [None]:
def load_db_data(db, db_data_file):
    """
    Load all data (NL-SQL) entries for the selected database

    :param db: database
    :type db: str
    :param db_data_file: path of file containing NL-SQL entries
    :type db_data_file: str
    :return: NL-SQL-Entries
    :rtype: List[Dict]
    """
    with open(db_data_file, "r") as train_json:
        return [entry for entry in json.load(train_json) if entry["db_id"] == db]

Load NL-SQL pairs for DB

In [None]:
data = load_db_data(db_name, data_file)

Setup Stanza

In [None]:
stanza.download('en') # download English model
nlp = stanza.Pipeline('en') # initialize English neural pipeline

Data structures for parse tree container and processed token list

In [None]:
class ParseTree:
    def __init__(self, orginal_nl_query, original_sql_query, tokens, raw_parsed):
        """
        Constructor

        :param orginal_nl_query: original NL query
        :type orginal_nl_query: str
        :param original_sql_query: original SQL query
        :type original_sql_query: str
        :param tokens: list of tokens
        :type tokens: List[Token]
        :param raw_parsed: raw stanza sentence parse
        :type raw_parsed: stanza.Sentence
        """
        self.orginal_nl_query = orginal_nl_query
        self.original_sql_query = original_sql_query
        self.tokens = tokens
        self.ents = raw_parsed.ents
        self.raw_parsed = raw_parsed
        self.log = []

    def unchanged(self):
        """
        Get all unchanged tokens

        :return: all not-yet changed tokens of this tree
        :rtype: Generator[Token]
        """
        return (token for token in self.tokens if not token.changed)

    def __str__(self):
        return f"{'|'.join(str(t) for t in self.tokens)}"


class Token:
    NUM_UNKNOWN = 0
    NUM_SINGULAR = 1
    NUM_PLURAL = 2

    def __init__(self, parsed_token, is_compound=False, ent=None):
        self.text = parsed_token.text
        self.original_text = "ERROR. TEXT WAS CHANGED WITHOUT USING CHANGE METHOD"
        self.text_lower = self.text.lower()
        self.changed = False
        self.changed_by = ''
        self.properties = parsed_token
        self.is_compound = is_compound
        self.properties_list = [self.properties]
        self.ent = ent
        if parsed_token.feats is not None:
            self.features = parsed_token.feats.split("|")
        else:
            self.features = []
        if "Number=Sing" in self.features:
            self.num = self.NUM_SINGULAR
        elif "Number=Plur" in self.features:
            self.num = self.NUM_PLURAL
        else:
            self.num = self.NUM_UNKNOWN

    def change(self, new_text, by):
        """
        Change token

        :param new_text: what to replace with
        :type new_text: str
        :param by: what pipeline part changed
        :type by: str
        """
        self.original_text = self.text
        self.text = new_text
        self.text_lower = self.text.lower()
        self.changed = True
        self.changed_by = by

    def add_compound_component(self, parsed_token):
        self.text += f" {parsed_token.text}"
        self.text_lower = self.text.lower()
        self.properties_list.append(parsed_token)

    def __str__(self):
        if self.changed:
            return f"{self.text} [{self.original_text}]"
        return self.text


Function: Generate Parse Tree incl. Provenance for given NL-SQL pair

In [None]:

def generate_parse_tree(nl_sql):
    """
    Generated parsed form of given nl-sql entry

    :param nl_sql: data entry containing a NL query
    :type nl_sql: Dict[Str, Any]
    :return: parse tree
    :rtype: ParseTree
    """
    # Parse and return parsed form of first sentence (there should always be only one)
    parsed = nlp(nl_sql["question"]).sentences[0]

    ents_by_start = {e.start_char: e for e in parsed.ents}

    # Create a list of words (and make sure compounds are recombined)
    current_char = 0
    current_ent = None
    tokens = []
    current_token = None
    for pw in parsed.words:
        # Start of an entity?
        if current_char in ents_by_start:
            current_ent = ents_by_start[current_char]
            current_token = Token(pw, is_compound=True, ent=current_ent)
            tokens.append(current_token)
        # Inside of entity?
        elif current_ent is not None and current_char < current_ent.end_char:
            current_token.add_compound_component(pw)
        # Normal word
        else:
            current_ent = None
            if pw.upos != 'PUNCT':
                tokens.append(Token(pw))
        current_char = current_char + len(pw.text) + 1

    return ParseTree(nl_sql["question"], nl_sql["query"], tokens, parsed)

Generate Parse Tree for sample sentence (and display)

In [None]:
sample_tree = generate_parse_tree(data[27])
# sample_tree = generate_parse_tree(data[42])
print(sample_tree)

Singular/plural handling

In [None]:
import inflect

inflector = inflect.engine()

Function: Replace all table/column names in the parse tree with a placeholder (and add provenance)

In [None]:
def replace_table_column_names(db_schema, db_tables, db_columns, db_partial_columns, parse_tree):
    """
    Replace all table/column names in the parse tree with a placeholder (and add provenance)

    :param db_schema: db schema
    :type db_schema: Dict[str, Any]
    :param db_tables: set of table names
    :type db_tables: set[str]
    :param db_columns: set of column names
    :type db_columns: set[str]
    :param parse_tree: parse tree
    :type parse_tree: ParseTree
    """
    for word in parse_tree.unchanged():
        texts = [word.text]
        texts_lower = [word.text_lower]

        if word.num == Token.NUM_SINGULAR:
            plural_text = inflector.plural(word.text)
            if plural_text:
                texts.append(plural_text)
                texts_lower.append(plural_text.lower())
        elif word.num == Token.NUM_PLURAL:
            singular_text = inflector.singular_noun(word.text)
            if singular_text:
                texts.append(singular_text)
                texts_lower.append(singular_text.lower())

        for text, text_lower in zip(texts, texts_lower):
            if text != '':
                if text in db_columns or text_lower in db_columns:
                    word.change(COLUMN_NAME, ChangedBy.DB_COLUMN)
                    break
                elif text in db_tables or text_lower in db_tables:
                    word.change(TABLE_NAME, ChangedBy.DB_TABLE)
                    break
                elif text in db_partial_columns or text_lower in db_partial_columns:
                    word.change(COLUMN_PARTIAL_NAME, ChangedBy.DB_COLUMN)
                    break

Function: Replace all values in the DB in the parse tree with placeholders (and add provenance)

In [None]:
def replace_db_values(db_values, parse_tree):
   """
   Replace all values in the DB in the parse tree with placeholders

   :param db_values: set of all values of the database
   :type db_values: Set[Any]
   :param parse_tree: parse_tree
   :type parse_tree: ParseTree
   """
   for word in parse_tree.unchanged():
       if word.text in db_values:
           word.change(DB_VALUE, ChangedBy.DB_VAL)

Function: Replace all Named Entities in the parse tree with placeholders (and add provenance)

In [None]:
def replace_named_entities(parse_tree):
   """

   :param parse_tree:
   :type parse_tree: ParseTree
   """
   for word in parse_tree.unchanged():
       if word.ent is not None and word.ent.type != "CARDINAL":
           word.change(NAMED_ENTITY, ChangedBy.NER)

Function: Replace all numerical values in the parse tree with placeholders (and add provenance)

In [None]:
def replace_numerical_values(parse_tree):
   """
   Replace all numerical values in the parse tree with placeholders (and add provenance)

   :param parse_tree: parse tree
   :type parse_tree: ParseTree
   """
   for word in parse_tree.unchanged():
       if word.properties.upos == "NUM":
           word.change(NUMERICAL, ChangedBy.NUM)

Insert placeholders in sample parse tree (and display)

In [None]:
print(sample_tree)
replace_table_column_names(db_schema, db_tables, db_colums, db_partial_columns, sample_tree)
replace_db_values(db_values, sample_tree)
replace_named_entities(sample_tree)
replace_numerical_values(sample_tree)
print(sample_tree)
print(sample_tree.original_sql_query)

Retrieve synset(s) for token

In [None]:
# Map between stanza and wordnet POS tag format
POS_TO_WN_POS = {"NOUN": wordnet.NOUN, "ADJ": wordnet.ADJ, "VERB": wordnet.VERB, "ADV": wordnet.ADV}

def retrieve_synset(word, pos_tag):
    """
    Get synonym for a given word using WordNet

    :param word: word to get synonym for
    :type word: str
    :param pos_tag: POS tag of given word
    :type pos_tag: str
    :return: synonym
    :rtype: str
    """
    if pos_tag in POS_TO_WN_POS:
        containing_synsets = [synset.name().split('.')[0] for synset in wordnet.synsets(word, pos=POS_TO_WN_POS[pos_tag])]
        if len(containing_synsets) > 0:
            return containing_synsets[0]
    return word

Function: Replace non-placeholder non-stopword tokens in parse tree with corresponding Synset name

In [None]:
def apply_synset_names(parse_tree):
    """

    :param parse_tree:
    :type parse_tree: ParseTree
    :return:
    :rtype:
    """
    for token in parse_tree.unchanged():
        if not token.text.lower() in stopwords_en:
            synonym = retrieve_synset(token.text, token.properties.upos)
            if synonym != token.text:
                token.change(synonym, ChangedBy.SYNONYM)
    return parse_tree

Apply synonym grouping to sample parse tree (and display)

In [None]:
apply_synset_names(sample_tree)
print(sample_tree)

Function: Create and process parse tree (incl. provenance) for given NL-SQL pair

In [None]:
def parse_nl_sql_pair(nl_sql):
   """
   Create and process parse tree (incl. provenance) for given NL-SQL pair

   :param nl_sql: training data entry
   :type nl_sql: Dict[Str, Any]
   :return: parsed and processed tree
   :rtype: ParseTree
   """
   tree = generate_parse_tree(nl_sql)
   replace_table_column_names(db_schema, db_tables, db_colums, db_partial_columns, tree)
   replace_db_values(db_values, tree)
   replace_named_entities(tree)
   replace_numerical_values(tree)
   apply_synset_names(tree)
   return tree

Generate trees for every NL-SQL pair in this DB

In [None]:
db_parse_trees = []
for nl_sql_pair in data:
   db_parse_trees.append(parse_nl_sql_pair(nl_sql_pair))

Store string transcoding results in a file

In [None]:
patterns = [
    {
        "original_nl": pt.orginal_nl_query,
        "generalized_nl": str(pt),
        "original_sql": pt.original_sql_query
    } for pt in db_parse_trees]

with open(export_trees_file, "w") as pattern_file:
    json.dump(patterns, pattern_file, indent=4)
    print(f"Generalized {len(patterns)} queries and stored them in '{export_trees_file}'")

Function: Compute Parse Tree similarity

In [None]:
def parse_tree_similarity(source_tree, target_tree):
   """

   :param source_tree:
   :type source_tree:
   :param target_tree:
   :type target_tree:
   :return:
   :rtype:
   """
   pass

Function: Cluster parse trees using parse tree similarity

In [None]:
def cluster_parse_tree(parse_trees):
   """

   :param parse_trees:
   :type parse_trees:
   :return:
   :rtype:
   """
   pass

Cluster all Parse Trees for this DB

In [None]:
clusters = cluster_parse_tree(db_parse_trees)

Function: Display generic tree and NL-SQL pair for given cluster

In [None]:
def print_template(tree_cluster):
   """

   :param tree_cluster:
   :type tree_cluster:
   :return:
   :rtype:
   """
   pass

Display generic tree and NL-SQL pair for all clusters for this DB

In [None]:
for cluster in clusters:
   print_template(cluster)