# Syntax Pattern Detection on Sentence Level (beyond word boundaries)

Imports

In [None]:
from collections import defaultdict
from json import JSONDecodeError
import stanza
import json

Setup WordNet & Stopword List

In [None]:
import nltk
nltk.download('wordnet')
nltk.download('stopwords')
from nltk.corpus import wordnet, stopwords

stopwords_en = set(stopwords.words('english'))

Singular/plural handling

In [None]:
import inflect

inflector = inflect.engine()

Graph libraries

In [None]:
# import igraph

Define placeholders

In [None]:
from enum import Enum

class ChangedBy(Enum):
    DB_TABLE = 1
    DB_COLUMN = 2
    DB_VAL = 3
    NER = 4
    NUM = 5
    SYNONYM = 6
    DB_PARTIAL_COLUMN = 7


TABLE_NAME = '{TABLE}'
COLUMN_NAME = '{COLUMN}'
COLUMN_PARTIAL_NAME = '{COLUMN_PART}'
DB_VALUE = '{VALUE}'
NUMERICAL = '{NUMBER}'
NAMED_ENTITY = '{NE}'

Specify Spider database

In [None]:
db_name = 'party_people'
spider_directory = 'spider_data'
db_directory = f'{spider_directory}/database/{db_name}'
patterns_directory = f'{spider_directory}/patterns_no_synonyms'
schema_file = f'{db_directory}/schema.sql'
db_file = f'{db_directory}/{db_name}.sqlite'
data_file = f'{spider_directory}/train_spider.json'
tables_file = f'{spider_directory}/tables.json'
export_trees_file = f'{spider_directory}/patterns.{db_name}.json'

Function: Load schema information from SQL file

In [None]:
def build_set_from_spider_element_list(entries):
    """
    Build a set of db element entries
    including transformation to lower text and singular/pluaral-inflection

    :param entries: original entries
    :type entries:
    :return:
    :rtype:
    """
    entry_set = set()
    for e in entries:
        entry_set.add(e)
        entry_set.add(e.lower())
        plural_text = inflector.plural(e)
        if plural_text and not plural_text.endswith("ss"):
            entry_set.add(plural_text)
            entry_set.add(plural_text.lower())
        singular_text = inflector.singular_noun(e)
        if singular_text:
            entry_set.add(singular_text)
            entry_set.add(singular_text.lower())
    return entry_set


def load_schema(db_id, tables_file_path):
    """
    Load schema information from tables.json

    :param db_id: database
    :type db_id: str
    :param tables_file_path: path to tables.json
    :type tables_file_path: str
    :return: db schema information, set of table names, set of column names
    :rtype: Dict[Str, Any], set[str], set[str]
    """
    with open(tables_file_path, "r") as tables_json:
        dbs = json.load(tables_json)
        for db_schema in dbs:
            if db_schema["db_id"] == db_id:
                column_names = build_set_from_spider_element_list(c for (_, c) in db_schema["column_names"])

                original_column_names = build_set_from_spider_element_list(c for (_, c) in db_schema["column_names_original"])
                column_names |= original_column_names

                column_partial_names = set()
                for column_name in column_names:
                    column_partial_names.update(t for t in column_name.split() if not t in stopwords_en)

                table_names = build_set_from_spider_element_list(db_schema["table_names"])

                table_names_list = list(table_names)
                table_names_list.sort(key=lambda x: len(x), reverse=True)

                column_names_list = list(column_names)
                column_names_list.sort(key=lambda x: len(x), reverse=True)
                
                column_partial_names_list = list(column_partial_names)
                column_partial_names_list.sort(key=lambda x: len(x), reverse=True)

                return db_schema, table_names_list, column_names_list, column_partial_names_list

Function: Load database values from file

In [None]:
problem_dbs = set()

def load_db_values(schema_file_path):
    """
    Load all database values from file that are not stopwords

    :param schema_file_path: data base schema definition file
    :type schema_file_path: str
    :return: All values of the selected database
    :rtype: Set[Any]
    """
    db_values = set()
    try:
        with open(schema_file_path, 'r') as schema_file:
            for line in schema_file.readlines():
                if line.startswith("INSERT INTO"):
                    # Get values part of insert string but without the surrounding blanks, brackets and the semicolon
                    values_raw = line.split("VALUES")[-1][2:-3]
                    try:
                        db_values.update(str(v) for v in json.loads(f'[{values_raw}]') if not v in stopwords_en)
                    except JSONDecodeError:
                        problem_dbs.add(schema_file_path)
    except FileNotFoundError:
        print(f"{schema_file_path} not found")
        problem_dbs.add(schema_file_path)
    return db_values

Load DB Schema & values

In [None]:
db_schema, db_tables, db_colums, db_partial_columns = load_schema(db_name, tables_file)
print(db_schema)
print(db_tables)
print(db_colums)
db_values = load_db_values(schema_file)
print(db_values)

db_info = {
    "schema": db_schema,
    "tables": db_tables,
    "columns": db_colums,
    "partial_columns": db_partial_columns,
    "values": db_values
}

Function: Load NL-SQL pairs for given Spider database

In [None]:
def load_train_data_for_db(db, db_data_file):
    """
    Load all data (NL-SQL) entries for the selected database

    :param db: database
    :type db: str
    :param db_data_file: path of file containing NL-SQL entries
    :type db_data_file: str
    :return: NL-SQL-Entries
    :rtype: List[Dict]
    """
    with open(db_data_file, "r") as train_json:
        return [entry for entry in json.load(train_json) if entry["db_id"] == db]

Load NL-SQL pairs for DB

In [None]:
data = load_train_data_for_db(db_name, data_file)

Setup Stanza

In [None]:
stanza.download('en') # download English model
nlp = stanza.Pipeline('en') # initialize English neural pipeline

Functions: Get start or end index of a given word in stanza parse format

In [None]:
import re
start_matcher = re.compile("start_char=(\d+)")
end_matcher = re.compile("end_char=(\d+)")

def get_word_start(misc_string):
    """
    Get start of a word from a stanza word misc attribute string

    :param misc_string: stanza misc attribute string
    :type misc_string: str
    :return: start character of the given parsed word
    :rtype: int
    """
    return int(start_matcher.search(misc_string).group(1))


def get_word_end(misc_string):
    """
    Get end of a word from a stanza word misc attribute string

    :param misc_string: stanza misc attribute string
    :type misc_string: str
    :return: end character of the given parsed word
    :rtype: int
    """
    return int(end_matcher.search(misc_string).group(1))

Data structures for parse tree container and processed token list

In [None]:
class ParseTree:
    def __init__(self, orginal_nl_query, original_sql_query, tokens, raw_parsed, id=-1):
        """
        Constructor

        :param orginal_nl_query: original NL query
        :type orginal_nl_query: str
        :param original_sql_query: original SQL query
        :type original_sql_query: str
        :param tokens: list of tokens
        :type tokens: List[Token]
        :param raw_parsed: raw stanza sentence parse
        :type raw_parsed: stanza.Sentence
        """
        self.orginal_nl_query = orginal_nl_query
        self.original_sql_query = original_sql_query
        self.tokens = tokens
        self.ents = raw_parsed.ents
        self.raw_parsed = raw_parsed
        self.id = id
        # self.log = []
        # self.graph = igraph.Graph()
        # self.graph.add_vertices(len(self.tokens))


    def to_tree(self):

        tree_rec = defaultdict(list)
        for token in self.tokens:
            tree_rec[token.head()].append(token)

        def _to_tree_rec(head_id):
            if head_id in tree_rec:
                tokens = tree_rec[head_id]
                if len(tokens) > 0:
                    return f"({' '.join(str(t))}"
            pass
        return " ".join(str(t) for t in tree_rec[0]), tree_rec

    def __str__(self):
        return f"{'|'.join(str(t) for t in self.tokens)}"


class Token:
    def __init__(self, text, parsed_tokens, changed=True, changed_by='', ent=None):
        self.text = text
        self.original_text = " ".join(p.text for p in parsed_tokens)
        self.parsed_tokens = parsed_tokens
        self.changed = changed
        self.changed_by = changed_by
        self.is_compound = len(self.parsed_tokens) > 1
        self.ent = ent

    def ids(self):
        """
        Get all ids of contained parsed words

        :return: list of all ids
        :rtype: List[int]
        """
        return [w.id for w in self.parsed_tokens]

    def head(self):
        """
        Get head id (highest in list)

        :return: head id
        :rtype: int
        """
        return min(w.head for w in self.parsed_tokens)

    def __str__(self):
        if self.changed:
            return f"{self.text} [{self.original_text}]"
        return self.text


def get_words_in_range(words_by_start, current_char, next_char):
    """
    Get words in given range

    :param words_by_start: dictionary of words (by start)
    :type words_by_start: OrderedDict[int, stanza.Word]
    :param current_char: start
    :type current_char: int
    :param next_char: end
    :type next_char: int
    :return: list of all words in given range
    :rtype: List[stanza.Word]
    """
    return [w for (k, w) in words_by_start.items() if k >= current_char and (k + len(w.text)) <= next_char]


def create_unchanged(current_char, question_remainder, words_by_start, ents_by_start, db_data):
    """
    Create token for unchanged word

    :param current_char:
    :type current_char:
    :param question_remainder:
    :type question_remainder:
    :param words_by_start:
    :type words_by_start:
    :param ents_by_start:
    :type ents_by_start:
    :param db_data:
    :type db_data:
    :return:
    :rtype:
    """
    current_word = words_by_start[current_char]
    return Token(current_word.text, [current_word], changed=False), current_char + len(current_word.text)


Functions: Replace table and partial column names as well as db values in the parse tree with a placeholder (and add provenance)

In [None]:
def replace_db_element(current_char, question_remainder, words_by_start, db_elements, placeholder, changed_by):
    """
    Replace db elements with placeholder

    :param current_char: current position in question string
    :type current_char: int
    :param question_remainder: unprocessed question string
    :type question_remainder: str
    :param words_by_start: dictionary of words by their start character
    :type words_by_start: Dict[int, stanza.Word]
    :param db_elements: elements of the database to search for
    :type db_elements: List[str]
    :param placeholder: placeholder to replace with
    :type placeholder: str
    :param changed_by: identifier for replacing pipeline step
    :type changed_by: Enum[ChangedBy]
    :return: newly created token (if any) and next character to start processing with
    :rtype: Token, int
    """
    for element in db_elements:
        qr_lower = question_remainder.lower()
        if question_remainder.startswith(element) or qr_lower.startswith(element):
            next_char = current_char +  len(element)
            affected_words = get_words_in_range(words_by_start, current_char, next_char)

            if len(affected_words) > 0 and get_word_end(affected_words[-1].misc) == next_char:
                return Token(placeholder, affected_words, changed_by=changed_by), next_char

    return None, current_char


def replace_column_names(current_char, question_remainder, words_by_start, ents_by_start, db_info):
    """
    Replace column names

    :param current_char: current position in question string
    :type current_char: int
    :param question_remainder: unprocessed question string
    :type question_remainder: str
    :param words_by_start: dictionary of words by their start character
    :type words_by_start: Dict[int, stanza.Word]
    :param ents_by_start: dictionary of entities by their start character
    :type ents_by_start: Dict[int, stanza.Entity]
    :param db_info: dictionary containing database information
    :type db_info: Dict[str, Any]
    :return: newly created token (if any) and next character to start processing with
    :rtype: Token, int
    """
    return replace_db_element(current_char, question_remainder, words_by_start, db_info["columns"], COLUMN_NAME, ChangedBy.DB_COLUMN)


def replace_table_names(current_char, question_remainder, words_by_start, ents_by_start, db_info):
    """
    Replace table names

    :param current_char: current position in question string
    :type current_char: int
    :param question_remainder: unprocessed question string
    :type question_remainder: str
    :param words_by_start: dictionary of words by their start character
    :type words_by_start: Dict[int, stanza.Word]
    :param ents_by_start: dictionary of entities by their start character
    :type ents_by_start: Dict[int, stanza.Entity]
    :param db_info: dictionary containing database information
    :type db_info: Dict[str, Any]
    :return: newly created token (if any) and next character to start processing with
    :rtype: Token, int
    """
    return replace_db_element(current_char, question_remainder, words_by_start, db_info["tables"], TABLE_NAME, ChangedBy.DB_TABLE)


def replace_column_partial_names(current_char, question_remainder, words_by_start, ents_by_start, db_info):
    """
    Replace partial column names

    :param current_char: current position in question string
    :type current_char: int
    :param question_remainder: unprocessed question string
    :type question_remainder: str
    :param words_by_start: dictionary of words by their start character
    :type words_by_start: Dict[int, stanza.Word]
    :param ents_by_start: dictionary of entities by their start character
    :type ents_by_start: Dict[int, stanza.Entity]
    :param db_info: dictionary containing database information
    :type db_info: Dict[str, Any]
    :return: newly created token (if any) and next character to start processing with
    :rtype: Token, int
    """
    return replace_db_element(current_char, question_remainder, words_by_start, db_info["partial_columns"], COLUMN_PARTIAL_NAME, ChangedBy.DB_PARTIAL_COLUMN)


def replace_db_values(current_char, question_remainder, words_by_start, ents_by_start, db_info):
    """
    Replace partial column names

    :param current_char: current position in question string
    :type current_char: int
    :param question_remainder: unprocessed question string
    :type question_remainder: str
    :param words_by_start: dictionary of words by their start character
    :type words_by_start: Dict[int, stanza.Word]
    :param ents_by_start: dictionary of entities by their start character
    :type ents_by_start: Dict[int, stanza.Entity]
    :param db_info: dictionary containing database information
    :type db_info: Dict[str, Any]
    :return: newly created token (if any) and next character to start processing with
    :rtype: Token, int
    """
    return replace_db_element(current_char, question_remainder, words_by_start, db_info["values"], DB_VALUE, ChangedBy.DB_VAL)

Function: Replace all Named Entities and numerical values in the parse tree with placeholders (and add provenance)

In [None]:
def replace_named_entities_and_numbers(current_char, question_remainder, words_by_start, ents_by_start, db_info):
    """
    Replace named entitites and numbers

    :param current_char: current position in question string
    :type current_char: int
    :param question_remainder: unprocessed question string
    :type question_remainder: str
    :param words_by_start: dictionary of words by their start character
    :type words_by_start: Dict[int, stanza.Word]
    :param ents_by_start: dictionary of entities by their start character
    :type ents_by_start: Dict[int, stanza.Entity]
    :param db_info: dictionary containing database information
    :type db_info: Dict[str, Any]
    :return: newly created token (if any) and next character to start processing with
    :rtype: Token, int
    """
    current_word = words_by_start[current_char]

    if current_char in ents_by_start:
        ent = ents_by_start[current_char]
        next_char = current_char + len(ent.text)
        if ent.type != "CARDINAL":
            return Token(NAMED_ENTITY, get_words_in_range(words_by_start, current_char, next_char), ent=ent, changed_by=ChangedBy.NER), next_char
        else:
            return Token(NUMERICAL, get_words_in_range(words_by_start, current_char, next_char), ent=ent, changed_by=ChangedBy.NUM), next_char
    elif current_word.upos == "NUM":
        next_char = current_char + len(current_word.text)
        return Token(NUMERICAL, get_words_in_range(words_by_start, current_char, next_char), changed_by=ChangedBy.NUM), next_char

    return None, current_char

Retrieve synset(s) for token

In [None]:
# Map between stanza and wordnet POS tag format
POS_TO_WN_POS = {"NOUN": wordnet.NOUN, "ADJ": wordnet.ADJ, "VERB": wordnet.VERB, "ADV": wordnet.ADV}

def retrieve_synset(word, pos_tag):
    """
    Get synonym for a given word using WordNet

    :param word: word to get synonym for
    :type word: str
    :param pos_tag: POS tag of given word
    :type pos_tag: str
    :return: synonym
    :rtype: str
    """
    if pos_tag in POS_TO_WN_POS:
        containing_synsets = [synset.name().split('.')[0] for synset in wordnet.synsets(word, pos=POS_TO_WN_POS[pos_tag])]
        if len(containing_synsets) > 0:
            fd = nltk.FreqDist(containing_synsets)
            return fd.max()
    return word

Function: Replace non-placeholder non-stopword tokens in parse tree with corresponding Synset name

In [None]:
def replace_synonyms(current_char, question_remainder, words_by_start, ents_by_start, db_info):
    """
    Replace words with synonyms (synset-based)

    :param current_char: current position in question string
    :type current_char: int
    :param question_remainder: unprocessed question string
    :type question_remainder: str
    :param words_by_start: dictionary of words by their start character
    :type words_by_start: Dict[int, stanza.Word]
    :param ents_by_start: dictionary of entities by their start character
    :type ents_by_start: Dict[int, stanza.Entity]
    :param db_info: dictionary containing database information
    :type db_info: Dict[str, Any]
    :return: newly created token (if any) and next character to start processing with
    :rtype: Token, int
    """
    current_word = words_by_start[current_char]

    if not current_word.text in stopwords_en or current_word.text.lower() in stopwords_en:
        synonym = retrieve_synset(current_word.text, current_word.upos)
        if synonym != current_word.text:
            return Token(synonym, [current_word], changed_by=ChangedBy.SYNONYM), current_char + len(current_word.text)

    return None, current_char

In [None]:
REPLACEMENT_PIPELINE = [
    replace_column_names,
    replace_table_names,
    replace_db_values,
    replace_column_partial_names,
    replace_named_entities_and_numbers,
    # replace_synonyms,
    create_unchanged
]


def parse_and_replace(nl_sql, db_info, id=-1):
    """
    Generated parsed form of given nl-sql entry

    :param nl_sql: data entry containing a NL query
    :type nl_sql: Dict[Str, Any]
    :return: parse tree
    :rtype: ParseTree
    """
    # Parse and return parsed form of first sentence (there should always be only one)
    question = nl_sql["question"]
    parsed_all = nlp(question)
    if len(parsed_all.sentences) > 1:
        raise ValueError
    parsed = parsed_all.sentences[0]

    ents_by_start = {e.start_char: e for e in parsed.ents}
    words_by_start = {get_word_start(w.misc): w for w in parsed.words}

    current_char = 0
    tokens = []

    while current_char < len(nl_sql["question"]):
        question_remainder = question[current_char:]

        if question_remainder.startswith(" "):
            current_char += 1
            continue

        if current_char in words_by_start:
            current_word = words_by_start[current_char]
            if current_word.upos == "PUNCT":
                current_char += len(current_word.text)
                continue
        else:
            print(words_by_start)
            print(question_remainder)

        for pipeline_step in REPLACEMENT_PIPELINE:
            token, next_char = pipeline_step(current_char, question_remainder, words_by_start, ents_by_start, db_info)
            if token is not  None:
                tokens.append(token)
                current_char = next_char
                break

    return ParseTree(nl_sql["question"], nl_sql["query"], tokens, parsed, id=id)

Apply parse and replace for some sample sentences

In [None]:
print(db_info["columns"])

sample_tree = parse_and_replace(data[7], db_info)
print(sample_tree)
print(*[f'id: {word.id}\tword: {word.text}\thead id: {word.head}\thead: {sample_tree.raw_parsed.words[word.head-1].text if word.head > 0 else "root"}\tdeprel: {word.deprel}' for word in sample_tree.raw_parsed.words], sep='\n')
print(sample_tree.to_tree())

sample_tree = parse_and_replace(data[27], db_info)
print(sample_tree)
sample_tree = parse_and_replace(data[42], db_info)
print(sample_tree)
sample_tree = parse_and_replace(data[15], db_info)
print(sample_tree)

Generate trees for every NL-SQL pair in this DB

In [None]:
db_parse_trees = []
for nl_sql_pair in data:
    db_parse_trees.append(parse_and_replace(nl_sql_pair, db_info))

Store string transcoding results in a file

In [None]:
patterns = [
    {
        "original_nl": pt.orginal_nl_query,
        "generalized_nl": str(pt),
        "generalized_tokens": [t.text for t in pt.tokens],
        "original_sql": pt.original_sql_query
    } for pt in db_parse_trees]

with open(export_trees_file, "w") as pattern_file:
    json.dump(patterns, pattern_file, indent=4)
    print(f"Generalized {len(patterns)} queries and stored them in '{export_trees_file}'")

Parse all the databases!

In [None]:
def export(last_db_name, db_parse_trees):
    patterns = [
    {
        "original_nl": pt.orginal_nl_query,
        "generalized_nl": str(pt),
        "generalized_tokens": [t.text for t in pt.tokens],
        "original_sql": pt.original_sql_query,
        "id": pt.id
    } for pt in db_parse_trees]

    export_trees_file = f'{patterns_directory}/{last_db_name}.json'

    with open(export_trees_file, "w") as pattern_file:
        json.dump(patterns, pattern_file, indent=4)
        print(f"Generalized {len(patterns)} queries and stored them in '{export_trees_file}'")


def parse_all_and_store_patterns(nl_sql_file):
    last_db_name = ""
    db_parse_trees = []
    error_counter = 0
    success_counter = 0

    with open(nl_sql_file, "r") as train_json:
        for i, entry in enumerate(json.load(train_json)):
            if entry["db_id"] != last_db_name:
                if last_db_name != "":
                    # Store previous patterns if necessary
                    export(last_db_name, db_parse_trees)

                last_db_name = entry["db_id"]
                print(last_db_name)

                # Reset/load new db info
                db_directory = f'{spider_directory}/database/{last_db_name}'
                schema_file = f'{db_directory}/schema.sql'

                db_schema, db_tables, db_colums, db_partial_columns = load_schema(last_db_name, tables_file)
                db_values = load_db_values(schema_file)

                db_info = {
                    "schema": db_schema,
                    "tables": db_tables,
                    "columns": db_colums,
                    "partial_columns": db_partial_columns,
                    "values": db_values
                }

                db_parse_trees = []

            # Parse and replace the current nnl_sql_pair
            # print(f'{i}: {entry["question"]}')
            try:
                db_parse_trees.append(parse_and_replace(entry, db_info))
                success_counter += 1
            except ValueError:
                print("Multiple Sentences, did not export")
                error_counter += 1

    export(last_db_name, db_parse_trees)

    print(f"Finished export. Exported {success_counter} entries, an eror occurred for {error_counter} entries.")


parse_all_and_store_patterns(data_file)
print("Problem DBs:", problem_dbs)