In [4]:
import time
import numpy as np
import random

def write_table2sql(table, engine, sql=None):

    def select_col_agg(mask):
        """
        select col agg pair
        :return:
        """
        col_num = len(table['header'])
        sel_idx = np.argmax(np.random.rand(col_num) * mask)
        sel_type = table['types'][sel_idx]
        if sel_type == 'text':
            sel_agg = random.sample([0, 3], 1)
        else:
            sel_agg = random.sample([0,1,2,3,4,5], 1)
        sel_agg = sel_agg[0]
        return sel_idx, sel_agg

    def select_cond_op(type):
        if type == 'text':
            return 0
        else:
            flag = random.randint(0, 2)
            return flag

    datas = []

    for j in range(1):
        data = {}
        sql = {}
        agg = None
        sel = None
        conds = []
        data['table_id'] = table['id']
        mask = np.asarray([1] * len(table['header']))
        ret = None
        # make sure at least one condition
        cnt = 0
        while(1):
            cnt += 1
            col_num = len(table['header'])
            sel_idx = np.argmax(np.random.rand(col_num))
            sel_type = table['types'][sel_idx]
            cond_op = select_cond_op(sel_type)
            rows = table['rows']
            if len(rows) == 0:
                return []
            row_num = len(rows)
            select_row = random.randint(0, row_num-1)
            cond_value = rows[select_row][sel_idx]
            if len(str(cond_value).split()) > 20 or str(cond_value) == '':
                continue
            conds.append([sel_idx, cond_op, cond_value])
            
            start = time.time()
            ret = engine.execute(table['id'], 0, 0, conds, ret_rows=True)
            
            if time.time() - start > 1:
                mask[sel_idx] = -1
                break
            
            if len(ret) != 0:
                mask[sel_idx] = -1
                break

            conds.pop()

        if len(ret) != 0:
            
            for i in range(min(3, len(ret[0])-1)):
                col_num = len(table['header'])
                sel_idx = np.argmax(np.random.rand(col_num) * mask)
                sel_type = table['types'][sel_idx]
                cond_op = select_cond_op(sel_type)
                rows = ret
                row_num = len(rows)
                select_row = random.randint(0, row_num-1)

                cond_value = list(rows[select_row])[sel_idx]
                conds.append([sel_idx, cond_op, cond_value])
                ret = engine.execute(table['id'], 0, 0, conds, ret_rows=True)
                # result doesn't change
                if len(ret) == row_num:
                    conds.pop()
                    break

                if len(str(cond_value).split()) > 20 or str(cond_value) == '':
                    conds.pop()
                    break
                mask[sel_idx] = -1
                if len(ret) == 0:
                    break

        sel_idx, sel_agg = select_col_agg(mask)
        sel = sel_idx
        agg = sel_agg
        sql['agg'] = agg
        sql['sel'] = sel
        sql['conds'] = conds
        data['sql'] = sql

        question = sql2qst(sql, table)
        data['question'] = question
        datas.append(data)

    return datas

In [5]:
op_sql_dict = {0: "=", 1: ">", 2: "<", 3: "OP"}
agg_sql_dict = {0: "", 1: "MAX", 2: "MIN", 3: "COUNT", 4: "SUM", 5: "AVG"}

agg_str_dict = {0: "What is ", 1: "What is the maximum of ", 2: "What is the minimum ", 3: "What is the number of ", 4: "What is the sum of ", 5: "What is the average of "}
op_str_dict = {0: "is", 1: "is more than", 2: "is less than", 3: ""}

def sql2qst(sql, table):
    select_index = sql['sel']
    aggregation_index = sql['agg']
    conditions = sql['conds']

    # select part
    select_part = ""
    select_str = table['header'][select_index]
    agg_str = agg_str_dict[aggregation_index]
    select_part += '{}{}'.format(agg_str, select_str)

    # where part
    where_part = []
    for col_index, op, val in conditions:
        cond_col = table['header'][col_index]
        where_part.append('{} {} {}'.format(cond_col, op_str_dict[op], val))
    # print('where part:', where_part)
    final_question = "{} that {}".format(select_part, ' and '.join(where_part))
    # print('final question:', final_question)
    return final_question

In [6]:
import records
from sqlalchemy import *
import re, time
from babel.numbers import parse_decimal, NumberFormatError


schema_re = re.compile(r'\((.+)\)') # group (.......) dfdf (.... )group
num_re = re.compile(r'[-+]?\d*\.\d+|\d+') # ? zero or one time appear of preceding character, * zero or several time appear of preceding character.
# Catch something like -34.34, .4543,
# | is 'or'

agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
cond_ops = ['=', '>', '<', 'OP']

class DBEngine:

    def __init__(self, fdb):
        self.db = create_engine('sqlite:///{}'.format(fdb))
        self.conn = self.db.connect()
        self.table_id = ''
        self.schema_str = ''

    def execute_query(self, table_id, query, *args, **kwargs):
        return self.execute(table_id, query.sel_index, query.agg_index, query.conditions, *args, **kwargs)

    def execute(self, table_id, select_index, aggregation_index, conditions, lower=True, ret_rows=False):
        if not table_id.startswith('table'):
            table_id = 'table_{}'.format(table_id.replace('-', '_'))

        start = time.time()
        if table_id != self.table_id:
            self.table_id = table_id
            table_info = self.conn.execute('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).fetchall()[0].sql.replace('\n','')
            self.schema_str = schema_re.findall(table_info)[0]

        schema = {}
        for tup in self.schema_str.split(', '):
            c, t = tup.split()
            schema[c] = t
        select = 'col{}'.format(select_index)
        agg = agg_ops[aggregation_index]
        if agg:
            select = '{}({})'.format(agg, select)
        if ret_rows is True:
            select = '*'
        where_clause = []
        where_map = {}
        for col_index, op, val in conditions:
            if lower and (isinstance(val, str) or isinstance(val, str)):
                val = val.lower()
            if schema['col{}'.format(col_index)] == 'real' and not isinstance(val, (int, float)):
                try:
                    # print('!!!!!!value of val is: ', val, 'type is: ', type(val))
                    # val = float(parse_decimal(val)) # somehow it generates error.
                    val = float(parse_decimal(val, locale='en_US'))
                    # print('!!!!!!After: val', val)

                except NumberFormatError as e:
                    try:
                        val = float(num_re.findall(val)[0]) # need to understand and debug this part.
                    except:
                        # Although column is of number, selected one is not number. Do nothing in this case.
                        pass
            where_clause.append('col{} {} :col{}'.format(col_index, cond_ops[op], col_index))
            where_map['col{}'.format(col_index)] = val
        where_str = ''
        if where_clause:
            where_str = 'WHERE ' + ' AND '.join(where_clause)
        query = 'SELECT {} FROM {} {}'.format(select, table_id, where_str)

        out = self.conn.execute(query, **where_map)

        if ret_rows is False:
            return [o[0] for o in out]
        return [o for o in out]
    def execute_return_query(self, table_id, select_index, aggregation_index, conditions, lower=True):
        if not table_id.startswith('table'):
            table_id = 'table_{}'.format(table_id.replace('-', '_'))
        table_info = self.db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql.replace('\n','')
        schema_str = schema_re.findall(table_info)[0]
        schema = {}
        for tup in schema_str.split(', '):
            c, t = tup.split()
            schema[c] = t
        select = 'col{}'.format(select_index)
        agg = agg_ops[aggregation_index]
        if agg:
            select = '{}({})'.format(agg, select)
        where_clause = []
        where_map = {}
        for col_index, op, val in conditions:
            if lower and (isinstance(val, str) or isinstance(val, str)):
                val = val.lower()
            if schema['col{}'.format(col_index)] == 'real' and not isinstance(val, (int, float)):
                try:
                    # print('!!!!!!value of val is: ', val, 'type is: ', type(val))
                    # val = float(parse_decimal(val)) # somehow it generates error.
                    val = float(parse_decimal(val, locale='en_US'))
                    # print('!!!!!!After: val', val)

                except NumberFormatError as e:
                    val = float(num_re.findall(val)[0])
            where_clause.append('col{} {} :col{}'.format(col_index, cond_ops[op], col_index))
            where_map['col{}'.format(col_index)] = val
        where_str = ''
        if where_clause:
            where_str = 'WHERE ' + ' AND '.join(where_clause)
        query = 'SELECT {} AS result FROM {} {}'.format(select, table_id, where_str)
        #print query
        out = self.db.query(query, **where_map)


#         return [o.result for o in out], query
        return [o[0] for o in out], query
    def show_table(self, table_id):
        if not table_id.startswith('table'):
            table_id = 'table_{}'.format(table_id.replace('-', '_'))
        rows = self.db.query('select * from ' +table_id)
        print(rows.dataset)

In [7]:
import json

class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        else:
            return super(NpEncoder, self).default(obj)

import numpy as np
with open('train.tables.jsonl', 'r') as fr:
    tables = []
    for table in fr:
        table = json.loads(table)
        if '' in table['header']:
            continue
        tables.append(table)
       
engine = DBEngine('train.db')
with open("train_augment.jsonl","w") as f:
    for i in range(1000):
        probs = np.random.rand(len(tables))
        table_i = tables[np.argmax(probs)]
        data = write_table2sql(table_i, engine)
        if len(data) == 0:
            print('couldnt find a valid sql!')
        for js in data:
            js["phase"] = js["table_id"][0]
        agg_str = ['', 'max ', 'min ', 'count ', 'sum ', 'avg ']
        op_str = ['=', '>', '<']

        js1 = {}
        sql_str = ''
        sql_str += 'select '
        sql_str += agg_str[js['sql']['agg']]
        sql_str += table_i['header'][js['sql']['sel']].lower() + ' '
        sql_str += 'where '
        for j in range(len(js['sql']['conds'])):
            sql_str += table_i['header'][js['sql']['conds'][j][0]].lower() + ' '
            sql_str += op_str[js['sql']['conds'][j][1]] + ' '
            sql_str += str(js['sql']['conds'][j][2]).lower()
            if len(js['sql']['conds']) > 1 and j != len(js['sql']['conds']) - 1:
                sql_str += ' and '
        src = sql_str.split(' ')
        trg = js['question'].lower().split(' ')
        while (trg[-1] == ''):
            trg = trg[:-1]
        if trg[-1][-1] == '?':
            trg[-1] = trg[-1][:-1]
            trg += ['?']
        js['src'] = src
        js['trg'] = trg
        f.write(json.dumps(js, cls=NpEncoder) + '\n')

    print('finished!')

finished!
