In [132]:
import en_core_web_md
nlp = en_core_web_md.load()

In [133]:
import json
import os
import pickle

In [134]:
raw_queries = None
raw_tables = None
with open("Data/spider/train_spider.json") as fp:
    raw_queries = json.load(fp)

with open("Data/spider/tables.json") as fp:
    raw_tables = json.load(fp)

In [135]:
db_map = {_["db_id"] : _ for _ in raw_tables}

In [136]:
import random
import typing
import re
import numpy as np
from enum import IntEnum
from dataclasses import dataclass

In [137]:
class AggregateFunction(IntEnum):
    NONE = 0
    MAX = 1
    MIN = 2
    COUNT = 3
    SUM = 4    
    AVG = 5

class WhereOperator(IntEnum):
    EQ = 2
    GT = 3
    LT = 4
    LTE = 5
    GTE = 6
    NEQ = 7

class ValueType(IntEnum):
    NUM = 0
    STR = 1

class QueryComplexity(IntEnum):
    SELECT_ONLY = 0
    SELECT_WHERE = 1
    
@dataclass
class Value:
    str_value: str
    value_type: ValueType
    def __str__(self):
        return f"{str(self.str_value)}"

@dataclass
class Column:
    col_name_actual: str
    col_name_simplified: str
    value_type: ValueType
    def __str__(self):
        return f"{str(self.col_name_actual)}"

@dataclass       
class SelectColumn:
    function: AggregateFunction
    column: Column
    def __str__(self):
        return f"{str(self.function)}({self.column})"
        
@dataclass
class Condition:
    column: Column
    op: WhereOperator
    value: Value
    def __str__(self):
        return f"{self.column} {str(self.op)} {self.value}"

@dataclass
class SelectClause:
    columns: typing.List[SelectColumn]
    def __str__(self):
        return "SELECT "+ ", ".join([str(_) for _ in self.columns])

@dataclass
class WhereClause:
    condition: Condition
    def __str__(self):
        return f"WHERE {self.condition}"

@dataclass
class SimpleQuery:
    complexity: QueryComplexity
    select: SelectClause
    where: WhereClause
    def __str__(self):
        if self.where is not None:
            return f"{self.select} FROM <table-name> {self.where}"
        else:
            return f"{self.select} FROM <table-name>"

@dataclass
class QuerySample:
    query: str    
    question: str
    question_toks: typing.List[str]
    columns: typing.List[Column]
    simple_query: SimpleQuery
    def __str__(self):
        return f"({len(self.columns)}) : {self.simple_query}"

In [138]:
def get_op(op_code):
    if op_code == 2:
        return WhereOperator.EQ
    elif op_code == 3:
        return WhereOperator.GT
    elif op_code == 4:
        return WhereOperator.LT
    elif op_code == 5:
        return WhereOperator.LTE
    elif op_code == 6:
        return WhereOperator.GTE
    elif op_code == 7:
        return WhereOperator.NEQ
    else:
        raise Exception("unknown operator: "+str(op_code))

def get_func(op_code):
    if op_code == 0:
        return AggregateFunction.NONE
    elif op_code == 1:
        return AggregateFunction.MAX
    elif op_code == 2:
        return AggregateFunction.MIN    
    elif op_code == 3:
        return AggregateFunction.COUNT
    elif op_code == 4:
        return AggregateFunction.SUM
    elif op_code == 5:
        return AggregateFunction.AVG
    else:
        raise Exception("unknown function: "+str(op_code))

def parse_where(condition_part, actual_columns, column_dict):
    value_train = condition_part[3]
    value = None
    if isinstance(value_train, str):
        value = Value(str(value_train).strip('"'), ValueType.STR)
    else:
        value = Value(str(value_train), ValueType.NUM)
    column_num = condition_part[2][1][1]
    raw_col_name = actual_columns[column_num]
    op_code = condition_part[1]
    column = column_dict[raw_col_name]
    condition = Condition(column, get_op(op_code), value)
    where = WhereClause(condition)
    return where

def parse_select(select_parts, actual_columns, column_dict):
    real_columns = []
    for select_part in select_parts:
        column_num = select_part[1][1][1]
        raw_col_name = actual_columns[column_num]        
        col_func = select_part[0]
        column = column_dict[raw_col_name]
        select_column = SelectColumn(get_func(col_func), column)
        real_columns.append(select_column)
    return SelectClause(real_columns)
        
not_required_parts = ['limit', 'intersect', 'union', 'except']
non_empty_parts = ['groupBy', 'having', 'orderBy']
def sample_simple_queries(count=10):
    idx = 0
    cnt = 0
    while idx < len(raw_queries) and cnt < count:
        query = raw_queries[idx]
        sql_part = query['sql']
        select_parts = sql_part['select'][1]
        where_part = sql_part['where']
        has_join =len(sql_part['from']['conds']) != 0
        has_simple_where_part = all([not op[0] for op in where_part])
        has_non_essential_parts = all([sql_part[op] is None for op in not_required_parts])
        has_non_empty_parts = all([len(sql_part[op]) == 0 for op in non_empty_parts])
        has_distinct = sql_part['select'][0]
        if has_simple_where_part and has_non_essential_parts and has_non_empty_parts and not has_distinct and not has_join:
            raw_column_names = db_map[query['db_id']]['column_names_original']
            column_names = db_map[query['db_id']]['column_names']
            value_types = [(ValueType.STR if t == "text" else ValueType.NUM)for t in db_map[query['db_id']]['column_types']]
            column_dict = {}
            columns = []
            actual_columns = [raw_column_names[i][1] for i in range(len(column_names))]
            for i in range(len(column_names)):
                if raw_column_names[i][1] not in column_dict:
                    col = Column(raw_column_names[i][1], column_names[i][1], value_types[i])
                    column_dict[raw_column_names[i][1]] = col
                    columns.append(col)
            actual_query = query['query']
            actual_question = query['question']
            question_toks = query['question_toks']
            if len(where_part) <= 1:
                query_sample = None
                if len(where_part) > 0:
                    if (all([conds[2][2] is None for conds in where_part])):
                        if all([not isinstance(conds[3], dict) for conds in where_part]): 
                            if not any([conds[1] == 9 or conds[1] == 1 for conds in where_part]):#For simple query
                                where = parse_where(where_part[0], actual_columns, column_dict)
                                select = parse_select(select_parts, actual_columns, column_dict)
                                simple_query = SimpleQuery(QueryComplexity.SELECT_WHERE, select, where)
                                query_sample = QuerySample(actual_query, actual_question, question_toks, columns, simple_query)
                else:
                    select = parse_select(select_parts, actual_columns, column_dict)
                    simple_query = SimpleQuery(QueryComplexity.SELECT_ONLY, select, None)
                    query_sample = QuerySample(actual_query, actual_question, question_toks, columns, simple_query)
                if query_sample is not None:
                    yield query_sample                
                cnt += 1      
        idx += 1

In [139]:
vector_map = {}
def init_vectors():
    global vector_map
    if len(vector_map) > 0:
        return
    if os.path.isfile('vector_map.pickle'):
        with open('vector_map.pickle', 'rb') as fp:
            vector_map = pickle.load(fp)
    else:
        d = vector_map
        for db_name in db_map:
            col_names = db_map[db_name]['column_names']
            for col in col_names:
                if col[1] not in d:
                    d[col[1]] = nlp(col[1]).vector
        for query in raw_queries:
            tokens = query['question_toks']
            for token in tokens:
                if token not in d:
                    d[token] = nlp(token).vector
        d["*"] = nlp("*").vector
        with open('vector_map.pickle', "wb") as fp:
            pickle.dump(vector_map, fp)

In [140]:
init_vectors()

In [141]:
all_data = list(sample_simple_queries(20000))
num = int(0.98 * len(all_data))
simple_queries_train = all_data[:num]
simple_queries_test = all_data[num:]

In [142]:
def encode_one_hot_func(op_code):
    encode = []
    for i in range(6):
        if i == op_code:
            encode.append(1.0)
        else:
            encode.append(0.0)
    return encode

def encode_one_hot_op(op_code):
    encode = []
    for i in range(6):
        if i == op_code - 2:
            encode.append(1.0)
        else:
            encode.append(0.0)
    return encode

def get_select_transform(select_match):
    if select_match is None:
        return (np.array([0.0]).reshape(1, 1), np.array(encode_one_hot_func(-1)).reshape(1, 6))
    else:
        return (np.array([1.0]).reshape(1, 1), np.array(encode_one_hot_func(select_match.function)).reshape(1, 6))
    
def get_where_transform(where_match):
    if where_match is None:
        return (np.array([0.0]).reshape(1, 1), np.array(encode_one_hot_op(-1)).reshape(1, 6))
    else:
        return (np.array([1.0]).reshape(1, 1), np.array(encode_one_hot_op(where_match.condition.op)).reshape(1, 6))

def get_select_reverse_transform(select_prob, function_one_hot):
    if select_prob >= 0.5:
        return True, AggregateFunction(tf.math.argmax(function_one_hot).numpy())
    else:
        return False, AggregateFunction.NONE

def get_where_reverse_transform(where_prob, operator_one_hot):
    if where_prob >= 0.5:
        return True, WhereOperator(tf.math.argmax(operator_one_hot).numpy() + 2)
    else:
        return False, WhereOperator.EQ    

def get_select_match(col, simple_query):
    for c in simple_query.select.columns:
        if col == c.column:
            return c

def get_where_match(col, simple_query):
    if simple_query.where is None:
        return None
    if col == simple_query.where.condition.column:
        return simple_query.where
    else:
        return None

In [143]:
class TrainingType(IntEnum):
    SELECT_FILTER_MODEL = 1
    WHERE_FILTER_MODEL = 2
    WHERE_COND_MODEL = 3
    SELECT_FUNC_MODEL = 4

def transform_input(samples: typing.List[QuerySample], sample_index, training_type, class_dist):
    p = 0
    for sample in samples:
        tokens = sample.question_toks
        x1 = np.array([vector_map[t] for t in tokens])
        l = len(tokens)
        x1 = x1.reshape(1, l, 300)        
        for col in sample.columns:
            column_name = col.col_name_simplified
            value_type = int(col.value_type) * 1.0
            column_name_vector = vector_map[column_name]
            select_match = get_select_match(col, sample.simple_query)
            where_match = get_where_match(col, sample.simple_query)
            x2 = np.append([[value_type]], column_name_vector).reshape(1, 301)
            select_y, select_func = get_select_transform(select_match)
            where_y, where_op = get_where_transform(where_match)
            out = None
            if training_type == TrainingType.SELECT_FILTER_MODEL:
                out = select_y
            elif training_type == TrainingType.WHERE_FILTER_MODEL:
                out = where_y
            elif training_type == TrainingType.WHERE_COND_MODEL:
                if where_y == 1.0:
                    out = where_op #Train only when a where clause is identified
                else:
                    out = None
            elif training_type == TrainingType.SELECT_FUNC_MODEL:
                if select_y == 1.0:
                    out = select_func #Train only when a where clause is identified
                else:
                    out = None               
            if out is None:
                continue # Skip if output was not generated
            ser = str(out)
            if ser in class_dist:
                class_dist[ser] += 1
            else:
                class_dist[ser] = 1
            yield [x1, x2], out
            sample_index.append((sample, col, out))               
            p += 1

def resample_scarce_class(queries, training_examples, current_count, new_queries, new_class_dist, factor=None):
    total = len(training_examples)
    class_count_map = {_ : current_count[_] for _ in current_count}
    assert sum([class_count_map[_] for _ in class_count_map]) == total
    if factor is None:
        factor = {_ : 1.0/len(class_count_map) for _ in class_count_map}
    else:
        assert sum((factor[f] for f in factor)) <= 1.0
    class_count_ordered_map = [(class_count_map[_], _) for _ in class_count_map]
    class_count_ordered_map.sort()
    class_count_ordered_map.reverse()
    sample_count_map = {}
    equivalence = None
    for cnt, class_ in class_count_ordered_map:
        ratio = factor[class_]
        if equivalence is None:
            equivalence = cnt/ratio
        new_cnt = int(equivalence * ratio)
        sample_count_map[class_] = (new_cnt, int(new_cnt/cnt))
    single_sample = {}
    for q, ex in zip(queries, training_examples):
        out = ex[1]
        ser = str(out)
        new_cnt, cnt_per_sample = sample_count_map[ser]
        for _ in range(cnt_per_sample):
            yield ex
            new_queries.append(q)
        if ser not in single_sample:
            single_sample[ser] = (q, ex)
        if ser in new_class_dist:
            new_class_dist[ser] += cnt_per_sample
        else:
            new_class_dist[ser] = cnt_per_sample
        new_cnt -= cnt_per_sample
        sample_count_map[ser] = (new_cnt, cnt_per_sample)
    # for ser in sample_count_map:
    #     cnt_remaining, _ = sample_count_map[ser]
    #     q, sample = single_sample[ser]
    #     for _ in range(cnt_remaining):
    #         yield sample
    #         new_queries.append(q)
    #     if ser in new_class_dist:
    #         new_class_dist[ser] += cnt_remaining
    #     else:
    #         new_class_dist[ser] = cnt_remaining

def transform_with_resampling(actual_query_samples, training_type: TrainingType):
    train_queries = []
    class_dist = {}
    new_class_dist = {}
    transformed_input = list(transform_input(actual_query_samples, train_queries, training_type, class_dist))
    balanced_queries = []
    balanced_input = list(resample_scarce_class(train_queries, transformed_input, class_dist, balanced_queries, new_class_dist))
    train_queries.clear()
    transformed_input.clear()
    train_queries = balanced_queries
    transformed_input = balanced_input
    print(f"{str(training_type)} Old class_dist", class_dist)
    print(f"{str(training_type)} New class_dist", new_class_dist)
    random.seed(0xf00)
    random.shuffle(train_queries)
    random.seed(0xf00)
    random.shuffle(transformed_input)
    return train_queries, transformed_input   

In [144]:
import random
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [145]:
class CellType(IntEnum):
    GRU=1
    LSTM=2

def create_model(training_type: TrainingType, cell_type = CellType.GRU, dropout = 0, learning_rate = 1.5*10-4, dense_layer = [20, 7], stack_depth = 6, encoding_size = 128, activation_lstm = 'tanh', activation_dense = 'relu'):
    model_prefix = None
    input_dim = 300
    assert stack_depth >= 1
    
    if training_type == TrainingType.SELECT_FILTER_MODEL:
        model_prefix = "select"
    elif training_type == TrainingType.WHERE_FILTER_MODEL:
        model_prefix = "where"
    elif training_type == TrainingType.WHERE_COND_MODEL:
        model_prefix = "whereCondition"
    elif training_type == TrainingType.SELECT_FUNC_MODEL:
        model_prefix = "selectFunc"

    # input
    model_input1 = keras.Input(shape=(None, input_dim), name='sequentialSentence')
    model_input2 = keras.Input(shape=input_dim + 1, name='columnName')
    
    # encoder
    encoders = []
    if cell_type == CellType.GRU:
        encoders.append(layers.Bidirectional(layers.GRU(encoding_size, dropout=dropout, activation=activation_lstm, name=f'sequentialSentenceLstm0', return_sequences=stack_depth != 1))(model_input1))
    elif cell_type == CellType.LSTM:
        encoders.append(layers.Bidirectional(layers.LSTM(encoding_size, dropout=dropout, activation=activation_lstm, name=f'sequentialSentenceLstm0', return_sequences=stack_depth != 1))(model_input1))
    for _ in range(stack_depth - 1):
        if cell_type == CellType.GRU:
            encoders.append(layers.GRU(encoding_size, dropout=dropout, activation=activation_lstm, name=f'sequentialSentenceLstm{_ + 1}', return_sequences=_ != stack_depth - 2)(encoders[-1]))
        elif cell_type == CellType.LSTM:
            encoders.append(layers.LSTM(encoding_size, dropout=dropout, activation=activation_lstm, name=f'sequentialSentenceLstm{_ + 1}', return_sequences=_ != stack_depth - 2)(encoders[-1]))
    concat_layer = layers.Concatenate(axis=1, name='contactLstmOutAndColumnName')([encoders[-1], model_input2])
    
    # decoder
    decode_hidden1 = layers.Dense(
        dense_layer[0], 
        activation=activation_dense,
        name=f'sql{model_prefix}LayerHidden1'
    )(concat_layer)
    decode_hidden2 = layers.Dense(
        dense_layer[1], 
        activation=activation_dense,
        name=f'sql{model_prefix}LayerHidden2'
    )(decode_hidden1)    
    metrics = []
    loss = []
    model = None
    if training_type == TrainingType.SELECT_FILTER_MODEL or training_type == TrainingType.WHERE_FILTER_MODEL:
        metrics = [
            keras.metrics.TruePositives(name='tp'),
            keras.metrics.FalsePositives(name='fp'),
            keras.metrics.TrueNegatives(name='tn'),
            keras.metrics.FalseNegatives(name='fn'), 
            keras.metrics.BinaryAccuracy(name='accuracy'),
            keras.metrics.Precision(name='precision'),
            keras.metrics.Recall(name='recall'),
            keras.metrics.AUC(name='auc'),
            keras.metrics.AUC(name='prc', curve='PR'), # precision-recall curve
        ]
        loss = [
            keras.losses.BinaryCrossentropy(from_logits=False)
        ]
        binary_filter = layers.Dense(
            1, 
            activation='sigmoid',
            name=f'sql{model_prefix}Layer'
            )(decode_hidden2)
        model = keras.Model(inputs = [model_input1, model_input2], outputs = binary_filter)  
    elif training_type == TrainingType.WHERE_COND_MODEL or training_type == TrainingType.SELECT_FUNC_MODEL:
        metrics = [
            tf.keras.metrics.CategoricalAccuracy(
                name='categorical accuracy'),
            tf.keras.metrics.Precision(name='precision 1', top_k=1),
            tf.keras.metrics.Precision(name='precision 3', top_k=3),
            tf.keras.metrics.Recall(name='recall 1', top_k=1),
            tf.keras.metrics.Recall(name='recall 2', top_k=3)
        ]
        loss = [
            keras.losses.CategoricalCrossentropy(from_logits=False)
        ]
        softmax_categorical = layers.Dense(
            6, 
            activation='softmax',
            name=f'sql{model_prefix}Layer'
            )(decode_hidden2)
        model = keras.Model(inputs = [model_input1, model_input2], outputs = softmax_categorical)
    #model.summary()
    model.compile(
    loss=loss, 
    optimizer=keras.optimizers.Adam(learning_rate=learning_rate), 
    metrics=metrics)
    return model    

In [146]:
def print_some_queries(queries, cnt = 10):
    for q, c, l in queries[:cnt]:
        print(f"({l} - {c.col_name_actual}) {q.query}")    

In [147]:
train_where_cond_queries, transformed_where_cond_input = transform_with_resampling(simple_queries_train, TrainingType.WHERE_COND_MODEL)
print_some_queries(train_where_cond_queries)

TrainingType.WHERE_COND_MODEL Old class_dist {'[[0. 1. 0. 0. 0. 0.]]': 65, '[[0. 0. 0. 0. 0. 1.]]': 58, '[[1. 0. 0. 0. 0. 0.]]': 499, '[[0. 0. 1. 0. 0. 0.]]': 35, '[[0. 0. 0. 1. 0. 0.]]': 12, '[[0. 0. 0. 0. 1. 0.]]': 2}
TrainingType.WHERE_COND_MODEL New class_dist {'[[0. 1. 0. 0. 0. 0.]]': 455, '[[0. 0. 0. 0. 0. 1.]]': 464, '[[1. 0. 0. 0. 0. 0.]]': 499, '[[0. 0. 1. 0. 0. 0.]]': 490, '[[0. 0. 0. 1. 0. 0.]]': 492, '[[0. 0. 0. 0. 1. 0.]]': 498}
([[0. 0. 1. 0. 0. 0.]] - HIRE_DATE) SELECT * FROM employees WHERE hire_date  <  '2002-06-21'
([[0. 0. 0. 1. 0. 0.]] - Price) SELECT count(*) FROM products WHERE price >= 180
([[0. 1. 0. 0. 0. 0.]] - followers) SELECT name ,  email FROM user_profiles WHERE followers  >  1000
([[1. 0. 0. 0. 0. 0.]] - formats) SELECT f_id FROM files WHERE formats  =  "mp3"
([[0. 0. 1. 0. 0. 0.]] - Opening_year) SELECT count(DISTINCT city) FROM stadium WHERE opening_year  <  2006
([[0. 0. 0. 0. 0. 1.]] - Country) SELECT name ,  year_join FROM artist WHERE country != 'U

In [148]:
train_select_func_queries, transformed_select_func_input = transform_with_resampling(simple_queries_train, TrainingType.SELECT_FUNC_MODEL)
print_some_queries(train_select_func_queries)

TrainingType.SELECT_FUNC_MODEL Old class_dist {'[[0. 0. 0. 1. 0. 0.]]': 556, '[[1. 0. 0. 0. 0. 0.]]': 1011, '[[0. 1. 0. 0. 0. 0.]]': 63, '[[0. 0. 0. 0. 0. 1.]]': 131, '[[0. 0. 0. 0. 1. 0.]]': 48, '[[0. 0. 1. 0. 0. 0.]]': 24}
TrainingType.SELECT_FUNC_MODEL New class_dist {'[[0. 0. 0. 1. 0. 0.]]': 556, '[[1. 0. 0. 0. 0. 0.]]': 1011, '[[0. 1. 0. 0. 0. 0.]]': 1008, '[[0. 0. 0. 0. 0. 1.]]': 917, '[[0. 0. 0. 0. 1. 0.]]': 1008, '[[0. 0. 1. 0. 0. 0.]]': 1008}
([[0. 0. 0. 0. 0. 1.]] - Floors) SELECT avg(floors) ,  max(floors) ,  min(floors) FROM building
([[0. 0. 1. 0. 0. 0.]] - Age) SELECT min(age) ,  avg(age) ,  max(age) FROM Student
([[0. 0. 1. 0. 0. 0.]] - Age) SELECT min(age) ,  avg(age) ,  max(age) FROM Student
([[0. 0. 1. 0. 0. 0.]] - Crime_rate) SELECT min(Crime_rate) ,  max(Crime_rate) FROM county_public_safety
([[0. 0. 0. 1. 0. 0.]] - *) SELECT count(*) FROM CHARACTERISTICS
([[0. 0. 0. 0. 1. 0.]] - student_capacity) SELECT sum(student_capacity) FROM dorm
([[0. 0. 0. 0. 0. 1.]] - price

In [149]:
train_where_filter_queries, transformed_where_filter_input = transform_with_resampling(simple_queries_train, TrainingType.WHERE_FILTER_MODEL)
print_some_queries(train_where_filter_queries)

TrainingType.WHERE_FILTER_MODEL Old class_dist {'[[0.]]': 31641, '[[1.]]': 671}
TrainingType.WHERE_FILTER_MODEL New class_dist {'[[0.]]': 31641, '[[1.]]': 31537}
([[1.]] - Address) SELECT count(*) FROM member WHERE address != 'Hartford'
([[1.]] - problem_id) SELECT problem_log_id ,  log_entry_date FROM problem_log WHERE problem_id = 10
([[1.]] - email_address) SELECT address_line_1 ,  address_line_2 FROM customers WHERE email_address  =  "vbogisich@example.org"
([[0.]] - beds) SELECT count(DISTINCT bedType) FROM Rooms;
([[1.]] - no_of_loans) SELECT state ,  acc_type ,  credit_score FROM customer WHERE no_of_loans  =  0
([[0.]] - product_type_description) SELECT count(*) FROM Accounts
([[0.]] - alt) SELECT name FROM races WHERE YEAR = 2017
([[1.]] - Founder) SELECT count(*) FROM manufacturers WHERE founder  =  'Andy'
([[0.]] - Client_ID) SELECT count(*) FROM BOOKINGS
([[1.]] - date_problem_reported) SELECT problem_id FROM problems WHERE date_problem_reported > "1978-06-26"


In [150]:
train_select_filter_queries, transformed_select_filter_input = transform_with_resampling(simple_queries_train, TrainingType.SELECT_FILTER_MODEL)
print_some_queries(train_select_filter_queries)

TrainingType.SELECT_FILTER_MODEL Old class_dist {'[[1.]]': 1833, '[[0.]]': 30479}
TrainingType.SELECT_FILTER_MODEL New class_dist {'[[1.]]': 29328, '[[0.]]': 30479}
([[1.]] - Other_Details) SELECT Name ,  Other_Details FROM Staff
([[0.]] - branch_ID) SELECT city ,  state FROM bank WHERE bname  =  'morningside'
([[1.]] - total) SELECT billing_state ,  COUNT(*) ,  SUM(total) FROM invoices WHERE billing_state  =  "CA";
([[1.]] - School_ID) SELECT count(DISTINCT school_id) FROM basketball_match
([[1.]] - *) SELECT count(*) FROM election
([[1.]] - salary) SELECT eid ,  salary FROM Employee WHERE name  =  'Mark Young'
([[1.]] - DEPT_ADDRESS) SELECT dept_address FROM department WHERE dept_name  =  'History'
([[0.]] - PROF_OFFICE) SELECT count(*) FROM employee
([[0.]] - MediaTypeId) SELECT FirstName ,  LastName FROM EMPLOYEE WHERE City  =  "Calgary"
([[0.]] - departure_date) SELECT name ,  distance FROM Aircraft WHERE aid  =  12


In [155]:
select_batch_size = 500
select_epochs = 400
select_steps_per_epoch = len(transformed_select_filter_input)/select_batch_size
select_filter_model = create_model(
    TrainingType.SELECT_FILTER_MODEL, 
    cell_type = CellType.GRU, 
    learning_rate = 0.0002, 
    stack_depth = 2, 
    encoding_size = 150, 
    dense_layer = [20, 7], 
    dropout = 0.1)
select_filter_model.fit((_ for _ in transformed_select_filter_input), steps_per_epoch=select_steps_per_epoch, epochs=select_epochs, verbose=1)

Epoch 1/400
Epoch 2/400
Epoch 3/400
Epoch 4/400
Epoch 5/400
Epoch 6/400
Epoch 7/400
Epoch 8/400
Epoch 9/400
Epoch 10/400
Epoch 11/400
Epoch 12/400
Epoch 13/400
Epoch 14/400
Epoch 15/400
Epoch 16/400
Epoch 17/400
Epoch 18/400
Epoch 19/400
Epoch 20/400
Epoch 21/400
Epoch 22/400
Epoch 23/400
Epoch 24/400
Epoch 25/400
Epoch 26/400
Epoch 27/400
Epoch 28/400
Epoch 29/400
Epoch 30/400
Epoch 31/400
Epoch 32/400
Epoch 33/400
Epoch 34/400
Epoch 35/400
Epoch 36/400
Epoch 37/400
Epoch 38/400
Epoch 39/400
Epoch 40/400
Epoch 41/400
Epoch 42/400
Epoch 43/400
Epoch 44/400
Epoch 45/400
Epoch 46/400
Epoch 47/400
Epoch 48/400
Epoch 49/400
Epoch 50/400
Epoch 51/400
Epoch 52/400
Epoch 53/400
Epoch 54/400
Epoch 55/400
Epoch 56/400
Epoch 57/400
Epoch 58/400
Epoch 59/400
Epoch 60/400
Epoch 61/400
Epoch 62/400
Epoch 63/400
Epoch 64/400
Epoch 65/400
Epoch 66/400
Epoch 67/400
Epoch 68/400
Epoch 69/400
Epoch 70/400
Epoch 71/400
Epoch 72/400
Epoch 73/400
Epoch 74/400
Epoch 75/400
Epoch 76/400
Epoch 77/400
Epoch 78

<tensorflow.python.keras.callbacks.History at 0x1a545a45640>

In [156]:
where_batch_size = 750
where_epochs = 400
where_steps_per_epoch = len(transformed_where_filter_input)/where_batch_size
where_filter_model = create_model(
    TrainingType.WHERE_FILTER_MODEL, 
    cell_type = CellType.GRU, 
    learning_rate = 0.0002, 
    stack_depth = 6, 
    encoding_size = 64, 
    dense_layer = [17, 7], 
    dropout = 0)
where_filter_model.fit((_ for _ in transformed_where_filter_input), steps_per_epoch=where_steps_per_epoch, epochs=where_epochs, verbose=1)

Epoch 1/400
Epoch 2/400
Epoch 3/400
Epoch 4/400
Epoch 5/400
Epoch 6/400
Epoch 7/400
Epoch 8/400
Epoch 9/400
Epoch 10/400
Epoch 11/400
Epoch 12/400
Epoch 13/400
Epoch 14/400
Epoch 15/400
Epoch 16/400
Epoch 17/400
Epoch 18/400
Epoch 19/400
Epoch 20/400
Epoch 21/400
Epoch 22/400
Epoch 23/400
Epoch 24/400
Epoch 25/400
Epoch 26/400
Epoch 27/400
Epoch 28/400
Epoch 29/400
Epoch 30/400
Epoch 31/400
Epoch 32/400
Epoch 33/400
Epoch 34/400
Epoch 35/400
Epoch 36/400
Epoch 37/400
Epoch 38/400
Epoch 39/400
Epoch 40/400
Epoch 41/400
Epoch 42/400
Epoch 43/400
Epoch 44/400
Epoch 45/400
Epoch 46/400
Epoch 47/400
Epoch 48/400
Epoch 49/400
Epoch 50/400
Epoch 51/400
Epoch 52/400
Epoch 53/400
Epoch 54/400
Epoch 55/400
Epoch 56/400
Epoch 57/400
Epoch 58/400
Epoch 59/400
Epoch 60/400
Epoch 61/400
Epoch 62/400
Epoch 63/400
Epoch 64/400
Epoch 65/400
Epoch 66/400
Epoch 67/400
Epoch 68/400
Epoch 69/400
Epoch 70/400
Epoch 71/400
Epoch 72/400
Epoch 73/400
Epoch 74/400
Epoch 75/400
Epoch 76/400
Epoch 77/400
Epoch 78

<tensorflow.python.keras.callbacks.History at 0x1a50ce2efd0>

In [163]:
select_func_batch_size = 750
select_func_epochs = 400
select_func_steps_per_epoch = len(transformed_select_func_input)/select_func_batch_size
select_func_model = create_model(
    TrainingType.SELECT_FUNC_MODEL, 
    cell_type = CellType.GRU, 
    learning_rate = 0.001, 
    stack_depth = 6, 
    encoding_size = 64, 
    dense_layer = [17, 7], 
    dropout = 0)
select_func_model.fit((_ for _ in transformed_select_func_input), steps_per_epoch = select_func_steps_per_epoch, epochs=select_func_epochs, verbose=1)

Epoch 1/400
Epoch 2/400
Epoch 3/400
Epoch 4/400
Epoch 5/400
Epoch 6/400
Epoch 7/400
Epoch 8/400
Epoch 9/400
Epoch 10/400
Epoch 11/400
Epoch 12/400
Epoch 13/400
Epoch 14/400
Epoch 15/400
Epoch 16/400
Epoch 17/400
Epoch 18/400
Epoch 19/400
Epoch 20/400
Epoch 21/400
Epoch 22/400
Epoch 23/400
Epoch 24/400
Epoch 25/400
Epoch 26/400
Epoch 27/400
Epoch 28/400
Epoch 29/400
Epoch 30/400
Epoch 31/400
Epoch 32/400
Epoch 33/400
Epoch 34/400
Epoch 35/400
Epoch 36/400
Epoch 37/400
Epoch 38/400
Epoch 39/400
Epoch 40/400
Epoch 41/400
Epoch 42/400
Epoch 43/400
Epoch 44/400
Epoch 45/400
Epoch 46/400
Epoch 47/400
Epoch 48/400
Epoch 49/400
Epoch 50/400
Epoch 51/400
Epoch 52/400
Epoch 53/400
Epoch 54/400
Epoch 55/400
Epoch 56/400
Epoch 57/400
Epoch 58/400
Epoch 59/400
Epoch 60/400
Epoch 61/400
Epoch 62/400
Epoch 63/400
Epoch 64/400
Epoch 65/400
Epoch 66/400
Epoch 67/400
Epoch 68/400
Epoch 69/400
Epoch 70/400
Epoch 71/400
Epoch 72/400
Epoch 73/400
Epoch 74/400
Epoch 75/400
Epoch 76/400
Epoch 77/400
Epoch 78

<tensorflow.python.keras.callbacks.History at 0x1a5b3373910>

In [164]:
where_cond_batch_size = 750
where_cond_epochs = 400
where_cond_steps_per_epoch = len(transformed_where_cond_input)/where_cond_batch_size
where_cond_model = create_model(
    TrainingType.WHERE_COND_MODEL, 
    cell_type = CellType.GRU, 
    learning_rate = 0.001, 
    stack_depth = 6, 
    encoding_size = 64, 
    dense_layer = [17, 7], 
    dropout = 0)
where_cond_model.fit((_ for _ in transformed_where_cond_input), steps_per_epoch = where_cond_steps_per_epoch, epochs=where_cond_epochs, verbose=1)

Epoch 1/400
Epoch 2/400
Epoch 3/400
Epoch 4/400
Epoch 5/400
Epoch 6/400
Epoch 7/400
Epoch 8/400
Epoch 9/400
Epoch 10/400
Epoch 11/400
Epoch 12/400
Epoch 13/400
Epoch 14/400
Epoch 15/400
Epoch 16/400
Epoch 17/400
Epoch 18/400
Epoch 19/400
Epoch 20/400
Epoch 21/400
Epoch 22/400
Epoch 23/400
Epoch 24/400
Epoch 25/400
Epoch 26/400
Epoch 27/400
Epoch 28/400
Epoch 29/400
Epoch 30/400
Epoch 31/400
Epoch 32/400
Epoch 33/400
Epoch 34/400
Epoch 35/400
Epoch 36/400
Epoch 37/400
Epoch 38/400
Epoch 39/400
Epoch 40/400
Epoch 41/400
Epoch 42/400
Epoch 43/400
Epoch 44/400
Epoch 45/400
Epoch 46/400
Epoch 47/400
Epoch 48/400
Epoch 49/400
Epoch 50/400
Epoch 51/400
Epoch 52/400
Epoch 53/400
Epoch 54/400
Epoch 55/400
Epoch 56/400
Epoch 57/400
Epoch 58/400
Epoch 59/400
Epoch 60/400
Epoch 61/400
Epoch 62/400
Epoch 63/400
Epoch 64/400
Epoch 65/400
Epoch 66/400
Epoch 67/400
Epoch 68/400
Epoch 69/400
Epoch 70/400
Epoch 71/400
Epoch 72/400
Epoch 73/400
Epoch 74/400
Epoch 75/400
Epoch 76/400
Epoch 77/400
Epoch 78

<tensorflow.python.keras.callbacks.History at 0x1a59ad57370>

In [159]:
def test_generator(data):
    for _ in data:
        yield _[0]

def test_ys(data):
    for _ in data:
        yield _[1]

In [160]:
def test_accuracy_binary(training_type: TrainingType, model):
    test_queries = []
    metric = []
    class_dist = {}
    if training_type == TrainingType.SELECT_FILTER_MODEL or training_type == TrainingType.WHERE_FILTER_MODEL:
        metrics = [
            keras.metrics.TruePositives(name='tp'),
            keras.metrics.FalsePositives(name='fp'),
            keras.metrics.TrueNegatives(name='tn'),
            keras.metrics.FalseNegatives(name='fn'), 
            keras.metrics.BinaryAccuracy(name='accuracy'),
            keras.metrics.Precision(name='precision'),
            keras.metrics.Recall(name='recall'),
            keras.metrics.AUC(name='auc'),
            keras.metrics.AUC(name='prc', curve='PR'), # precision-recall curve
        ]
    elif training_type == TrainingType.WHERE_COND_MODEL or training_type == TrainingType.SELECT_FUNC_MODEL:
        metrics = [
            tf.keras.metrics.CategoricalAccuracy(
                name='categorical accuracy'),
            tf.keras.metrics.Precision(name='precision 1', top_k=1),
            tf.keras.metrics.Precision(name='precision 3', top_k=3),
            tf.keras.metrics.Recall(name='recall 1', top_k=1),
            tf.keras.metrics.Recall(name='recall 2', top_k=3)
        ]        
    test_data = list(transform_input(simple_queries_test, test_queries, training_type, class_dist))
    print(f"{str(training_type)} Test class_dist", class_dist)
    actual_prediction = model.predict(test_generator(test_data))
    predictions = [1.0 if _ >= 0.5 else 0 for _ in actual_prediction]
    y_test = list([_[0][0] for _ in test_ys(test_data)])
    for metric in metrics:
        metric.update_state(y_test, predictions)
        print(f"{metric.name} =", metric.result().numpy())
    print("Total =", len(y_test))
    t = zip(y_test, predictions)
    matches = [(i, test_queries[i][1].col_name_actual, test_queries[i][0].question, test_queries[i][0].query) for i, _ in enumerate(t) if _[0] < 0.5 and _[1] >= 0.5]
    print("False Positive matches =", len(matches))
    for _ in matches:
        print(_)

In [161]:
test_accuracy_binary(TrainingType.SELECT_FILTER_MODEL, select_filter_model)

TrainingType.SELECT_FILTER_MODEL Test class_dist {'[[0.]]': 530, '[[1.]]': 40}
tp = 36.0
fp = 60.0
tn = 470.0
fn = 4.0
accuracy = 0.8877193
precision = 0.375
recall = 0.9
auc = 0.89339614
prc = 0.35156822
Total = 570
False Positive matches = 60
(11, 'city', 'What are the names of the airports in the city of Goroka?', "SELECT name FROM airports WHERE city  =  'Goroka'")
(23, 'dst_ap', 'Find the name, city, country, and altitude (or elevation) of the airports in the city of New York.', "SELECT name ,  city ,  country ,  elevation FROM airports WHERE city  =  'New York'")
(25, 'src_ap', 'Find the name, city, country, and altitude (or elevation) of the airports in the city of New York.', "SELECT name ,  city ,  country ,  elevation FROM airports WHERE city  =  'New York'")
(38, 'callsign', 'Find the name, city, country, and altitude (or elevation) of the airports in the city of New York.', "SELECT name ,  city ,  country ,  elevation FROM airports WHERE city  =  'New York'")
(43, 'dst_ap',

In [162]:
test_accuracy_binary(TrainingType.WHERE_FILTER_MODEL, where_filter_model)

TrainingType.WHERE_FILTER_MODEL Test class_dist {'[[0.]]': 548, '[[1.]]': 22}
tp = 20.0
fp = 136.0
tn = 412.0
fn = 2.0
accuracy = 0.75789475
precision = 0.12820514
recall = 0.90909094
auc = 0.83045787
prc = 0.122465596
Total = 570
False Positive matches = 136
(3, 'dst_ap', 'What are the names of the airports in the city of Goroka?', "SELECT name FROM airports WHERE city  =  'Goroka'")
(7, 'airline', 'What are the names of the airports in the city of Goroka?', "SELECT name FROM airports WHERE city  =  'Goroka'")
(12, 'country', 'What are the names of the airports in the city of Goroka?', "SELECT name FROM airports WHERE city  =  'Goroka'")
(23, 'dst_ap', 'Find the name, city, country, and altitude (or elevation) of the airports in the city of New York.', "SELECT name ,  city ,  country ,  elevation FROM airports WHERE city  =  'New York'")
(25, 'src_ap', 'Find the name, city, country, and altitude (or elevation) of the airports in the city of New York.', "SELECT name ,  city ,  country 