In [5]:
import json
from utils.schema import *
from utils.reverse_logic import *
from utils.reverse_sql import *
from utils.reverse_middle import *
from utils.parse import *
import os
import random
import copy

In [6]:
with open('raw_data/tables.json','r') as f:
    db = json.load(f)
db_dict = {}
schemas, db_names, tables = get_schemas_from_json('raw_data/tables.json')
for item in db:
    db_dict[item['db_id']] = item

db_reverse = {}
db_c_ori2pre = {}
db_t_ori2pre = {}
for k,v in db_dict.items():
    db_reverse[k] = {}
    column = []
    table = []
    db_c_ori2pre[k] = {}
    db_t_ori2pre[k] = {}
    for o,c in zip(v['column_names'],v['column_names_original']):
        column.append([o[0],o[1],c[1]])
        db_c_ori2pre[k][c[1]] = o[1]
    db_reverse[k]['column'] = column
    for o,c in zip(v['table_names'],v['table_names_original']):
        table.append([o,c])
        db_t_ori2pre[k][c] = o
    db_reverse[k]['table'] = table
db_aug = {}
for k,v in db_dict.items():
    col_temp = {}
    for index, item in enumerate(v['column_names']):
        if item[0] not in col_temp.keys():
            if item[1] in v['primary_keys']:
                col_temp[item[0]] = [(index,item[1],v['column_types'][index],'primary')]
            else:
                col_temp[item[0]] = [(index,item[1],v['column_types'][index],'None')]
        else:
            if item[1] in v['primary_keys']:
                col_temp[item[0]].append((index,item[1],v['column_types'][index],'primary'))
            else:
                col_temp[item[0]].append((index,item[1],v['column_types'][index],'None'))
    db_aug[k] = col_temp

In [7]:
def is_number(s):
    try:
        float(s)
        return True
    except ValueError:
        return False
db_value = {}
for k,v in db_dict.items():
    col_value = {}
    db_id = v['db_id']
    db_file = os.path.join('raw_data/database', db_id, db_id + '.sqlite')
    if not os.path.exists(db_file):
        raise ValueError('[ERROR]: database file %s not found ...' % (db_file))
    conn = sqlite3.connect(db_file)
    conn.text_factory = lambda b: b.decode(errors='ignore')
    conn.execute('pragma foreign_keys=ON')
    for i, (tab_id, col_name) in enumerate(db_dict[db_id]['column_names_original']):
        if i == 0 : # ignore * and special token 'id'
            continue
        tab_name = db_dict[db_id]['table_names_original'][tab_id]
        cursor = conn.execute("SELECT DISTINCT \"%s\" FROM \"%s\";" % (col_name, tab_name))
        cell_values = cursor.fetchall()
        cell_values = [str(each[0]) for each in cell_values]
        cell_values = [[str(float(each))] if is_number(each) else each.lower().split() for each in cell_values]
        col_value[i] = cell_values
    db_value[k] = col_value

In [8]:
random.seed(999)
string_no_agg = ['min','max','avg','sum']
value_no_agg = ['count']
time_no_agg = ['avg','sum','count']
def replace_select(middle,db_id,num = 1,agg = [''],distinct = [False]):
    table_idxs = middle['from']['table']
    col_ori = []
    col_aug = db_aug[db_id]
    for item in middle['select']:
        col_ori.append((item[0],item[2]))
    replaced = []
    result = []
    step = 0
    for index in range(num):
        flag = True
        while flag:
            if step >= 20:
                return (middle ,replaced, False)
            tab_index = random.randint(0,len(table_idxs)-1)
            col_idxs = col_aug[table_idxs[tab_index]]
            col_index = random.randint(0,len(col_idxs)-1)
            if num > 1:
                if col_index == 0:
                    step += 1
                    continue
            if col_index == 0:
                distinct[index] = False
            if (col_idxs[col_index][0],agg[index]) in col_ori:
                step += 1
                continue
            if agg[index] != '':
                if agg[index] in string_no_agg and col_idxs[col_index][2] == 'text':
                    step += 1
                    continue
                if agg[index] in value_no_agg and col_idxs[col_index][2] == 'number' :
                    step += 1
                    continue
                if agg[index] in time_no_agg and col_idxs[col_index][2] == 'time' :
                    step += 1
                    continue
            flag = False
            result.append([col_idxs[col_index][0],table_idxs[tab_index],agg[index],distinct[index]])
            replaced.append(col_idxs[col_index][1])
            col_ori.append((col_idxs[col_index][0],agg[index]))
    middle['select'] = result
    return (middle ,replaced,True)
def add_select(middle,db_id,num = 1,agg = [''],distinct = [False]):
    table_idxs = middle['from']['table']
    col_ori = []
    col_aug = db_aug[db_id]
    for item in middle['select']:
        col_ori.append((item[0],item[2]))
    added = []
    step = 0
    result = []
    for index in range(num):
        flag = True
        while flag:
            if step >= 20:
                return (middle ,added, False)
            tab_index = random.randint(0,len(table_idxs)-1)
            col_idxs = col_aug[table_idxs[tab_index]]
            col_index = random.randint(0,len(col_idxs)-1)
            if num > 1:
                if col_index == 0:
                    step += 1
                    continue
            if col_index == 0:
                distinct[index] = False
            if (col_idxs[col_index][0],agg[index]) in col_ori:
                step += 1
                continue
            if agg[index] != '':
                if agg[index] in string_no_agg and col_idxs[col_index][2] == 'text':
                    step += 1
                    continue
                if agg[index] in value_no_agg and col_idxs[col_index][2] == 'number' :
                    step += 1
                    continue
                if agg[index] in time_no_agg and col_idxs[col_index][2] == 'time' :
                    step += 1
                    continue
            flag = False
            result.append([col_idxs[col_index][0],table_idxs[tab_index],agg[index],distinct[index]])
            added.append(col_idxs[col_index][1])
            col_ori.append((col_idxs[col_index][1],agg[index]))
    middle['select'].extend(result)   
    return (middle ,added, True)
def replace_count(middle,db_id,col = False):
    table_idxs = middle['from']['table']
    col_ori = []
    col_aug = db_aug[db_id]
    for item in middle['select']:
        col_ori.append((item[0],item[2]))
    added = []
    result = []
    index = 1
    flag = True
    step = 0
    
    while flag:
        if step >= 20:
            return (middle ,added, False)
        if col:
            tab_index = random.randint(0,len(table_idxs)-1)
            col_idxs = col_aug[table_idxs[tab_index]]
            col_index = random.randint(0,len(col_idxs)-1)
            if col_idxs[col_index][2] != 'text':
                step += 1
                continue
            if col_index == 0:
                step += 1
                continue
            if (col_idxs[col_index][0],'count') in col_ori:
                step += 1
                continue

            flag = False
            result.append([col_idxs[col_index][0],table_idxs[tab_index],'count',False])
            added.append(col_idxs[col_index][1])
        else:
            if (0,'count') in col_ori:
                return (middle,added,False)
            result.append([0,-1,'count',False])
            flag = False
    middle['select'].extend(result)
    return (middle ,added, True)
def replace_agg(middle,db_id,num = 1,theagg = ['']):
    table_idxs = middle['from']['table']
    col_ori = []
    col_db = db_dict[db_id]
    for item in middle['select']:
        if item[0] != 0:
            col_ori.append((item[0],item[2]))
    
    replaced = []
    result = copy.deepcopy(middle['select'])
    index = 0
    step = 0
    if len(col_ori) == 0 or len(col_ori) < num:
        return (middle,replaced,False)
    for index in range(num):
        flag = True
        while flag:
            if step >= 20:
                return (middle ,replaced, False)
            col_index = random.randint(0,len(result)-1)
            if (result[col_index][0],result[col_index][2]) not in col_ori:
                step += 1
                continue
            ttype = col_db['column_types'][result[col_index][0]]
            if ttype == 'text':
                agg = random.sample(['count'],1)[0]
            elif ttype == 'number':
                agg = random.sample(['sum','max','min','avg'],1)[0]
            else:
                agg = random.sample(['max','min'],1)[0]
            if (result[col_index][0],agg) in col_ori:
                step += 1
                continue
            if index <= len(theagg)-1 and theagg[0] != '':
                agg = theagg[index]
                index += 1
            flag = False
            result[col_index][2] = str(agg)
            replaced.append(agg)
    middle['select'] = result
    return (middle,replaced,True)
#[10, 1, '', 'one', '!=', '"Banking"', 'None']
def replace_where(middle,db_id,num = 1):
    table_idxs = middle['from']['table']
    where_ori = []
    col_db = db_dict[db_id]
    col_value = db_value[db_id]
    col_aug = db_aug[db_id]
    for item in middle['where']:
        where_ori.append((item[0],item[4]))
    result = []
    replaced = []
    index = 0
    step = 0
    for index in range(num):
        flag = True
        while flag:
            if step >= 20:
                return (middle ,replaced, False)
            tab_index = random.randint(0,len(table_idxs)-1)
            col_idxs = col_aug[table_idxs[tab_index]]
            col_index = random.randint(0,len(col_idxs)-1)
            if col_idxs[col_index][0] == 0:
                step += 1
                continue
            ttype = col_idxs[col_index][2]
            if ttype == 'text':
                op = random.sample(['=','!=','LIKE'],1)[0]
            elif ttype == 'number':
                op = random.sample(['=','!=','>','<','>=','<='],1)[0]
            else:
                op = random.sample(['=','!=','>','<','>=','<='],1)[0]
            value = ' '.join(random.sample(col_value[col_idxs[col_index][0]],1)[0])
            if (col_idxs[col_index][0],op) in where_ori:
                step += 1
                continue
            flag = False
            AND = 'one'
            if index > 0:
                AND = 'AND'
            if '"' not in value:
                value = '"' + value + '"'
            where_ori.append((col_idxs[col_index][0],op))
            result.append([col_idxs[col_index][0],table_idxs[tab_index],'',AND,op,value,'None'])
            replaced.append((col_idxs[col_index][1],op,value))
    middle['where'] = result
    return (middle,replaced,True)
def add_where(middle,db_id,num = 1):
    table_idxs = middle['from']['table']
    where_ori = []
    col_db = db_dict[db_id]
    col_value = db_value[db_id]
    col_aug = db_aug[db_id]
    
    if middle['where'] != None:
        for item in middle['where']:
            where_ori.append((item[0],item[4]))
        result = copy.deepcopy(middle['where'])
    else:
        result = []
    added = []
    index = len(where_ori)
    
    step = 0
    for _ in range(num):
        flag = True
        while flag:
            if step >= 20:
                return (middle ,added, False)
            tab_index = random.randint(0,len(table_idxs)-1)
            col_idxs = col_aug[table_idxs[tab_index]]
            col_index = random.randint(0,len(col_idxs)-1)
            
            if col_idxs[col_index][0] == 0:
                step += 1
                continue
            ttype = col_idxs[col_index][2]
            if ttype == 'text':
                op = random.sample(['=','!=','LIKE'],1)[0]
            elif ttype == 'number':
                op = random.sample(['=','!=','>','<','>=','<='],1)[0]
            else:
                op = random.sample(['=','!=','>','<','>=','<='],1)[0]
            value = ' '.join(random.sample(col_value[col_idxs[col_index][0]],1)[0])
            if (col_idxs[col_index][0],op) in where_ori:
                step += 1
                continue
            flag = False
            AND = 'one'
            if index > 0:
                AND = 'AND'
            index += 1
            if '"' not in value:
                value = '"' + value + '"'
            where_ori.append((col_idxs[col_index][0],op))
            result.append([col_idxs[col_index][0],table_idxs[tab_index],'',AND,op,value,'None'])
            added.append((col_idxs[col_index][1],op,value))
    middle['where'] = result
    return (middle,added,True)
#[10, 1, '', 'one', '!=', '"Banking"', 'None']
def replace_op(middle,db_id,num = 1):
    table_idxs = middle['from']['table']
    where_ori = []
    col_db = db_dict[db_id]
    col_value = db_value[db_id]
    col_aug = db_aug[db_id]
    replaced = []
    if middle['where'] == None:
        return (middle ,replaced, False)
    for item in middle['where']:
        where_ori.append((item[0],item[4]))
    if len(where_ori) < num:
        return (middle ,replaced, False)
    result = copy.deepcopy(middle['where'])
    index = 0
    step = 0
    for index in range(num):
        flag = True
        while flag:
            if step >= 20:
                return (middle ,replaced, False)
            col = random.sample(result,1)[0]
            example_index = result.index(col)
            col_index,tab_index = col[0],col[1]
            if col_index == 0:
                step += 1
                continue
            ttype = col_db['column_types'][col_index]
            if ttype == 'text':
                op = random.sample(['=','!=','LIKE'],1)[0]
            elif ttype == 'number':
                op = random.sample(['=','!=','>','<','>=','<='],1)[0]
            else:
                op = random.sample(['=','!=','>','<','>=','<='],1)[0]
            value = ' '.join(random.sample(col_value[col_index],1)[0])
            if (col_index,op) in where_ori:
                step += 1
                continue
            flag = False
            where_ori.append((col_index,op))
            result[example_index][4] = op
            if '"' not in value:
                value = '"' + value + '"'
            result[example_index][5] = value
            replaced.append((op,value))
    middle['where'] = result
    return (middle,replaced,True)
def add_orderby(middle,db_id,asc = 'ASC',limit = ''):
    table_idxs = middle['from']['table']
    where_ori = []
    col_db = db_dict[db_id]
    col_value = db_value[db_id]
    col_aug = db_aug[db_id]
    added = []
    if middle['orderBy'] != None:
        return (middle ,added, False)
    if len(middle['select']) == 1:
        if middle['select'][0][0] == 0 and middle['select'][0][2] == 'count':
            return (middle ,added, False)
    flag = True
    step = 0
    while flag:
        if step >= 20:
            return (middle ,added, False)
        tab_index = random.randint(0,len(table_idxs)-1)
        col_idxs = col_aug[table_idxs[tab_index]]
        col_index = random.randint(0,len(col_idxs)-1)
        if col_idxs[col_index][0] == 0:
            step += 1
            continue
        ttype = col_idxs[col_index][2]
        if ttype == 'text':
            step += 1
            continue
        
        flag = False
        middle['orderBy'] = [[col_idxs[col_index][0],table_idxs[tab_index],'',limit,asc]]
        added.append((col_idxs[col_index][1]))
    return (middle,added,True)
def add_having(middle,db_id,op,value):
    table_idxs = middle['from']['table']
    where_ori = []
    col_db = db_dict[db_id]
    col_value = db_value[db_id]
    col_aug = db_aug[db_id]
    added = []
    if middle['having'] != None:
        return (middle ,added, False)
    if middle['groupBy'] == None:
        return (middle ,added, False)
    flag = True
    step = 0
    while flag:
        if step >= 20:
            return (middle ,added, False)
        
        flag = False
        middle['having'] = [[0, -1, 'count', 'one', op, value, 'None']]
    return (middle,added,True)
def replace_asc(middle,db_id,asc = 'ASC'):
    table_idxs = middle['from']['table']
    where_ori = []
    col_db = db_dict[db_id]
    col_value = db_value[db_id]
    col_aug = db_aug[db_id]
    added = []
    if middle['orderBy'] == None:
        return (middle ,added, False)
    if middle['orderBy'][0][4] == asc:
        return (middle ,added, False)
    
    middle['orderBy'][0][4] = asc
    return (middle,added,True)
def add_groupby(middle,db_id):
    table_idxs = middle['from']['table']
    where_ori = []
    col_db = db_dict[db_id]
    col_value = db_value[db_id]
    col_aug = db_aug[db_id]
    added = []
    if middle['groupBy'] != None:
        return (middle ,added, False)
    flag = True
    step = 0
    while flag:
        if step >= 20:
            return (middle ,added, False)
        tab_index = random.randint(0,len(table_idxs)-1)
        col_idxs = col_aug[table_idxs[tab_index]]
        col_index = random.randint(0,len(col_idxs)-1)
        if col_idxs[col_index][0] == 0:
            step += 1
            continue
        ttype = col_idxs[col_index][2]
        if ttype != 'text':
            step += 1
            continue
        
        flag = False
        middle['groupBy'] = [[col_idxs[col_index][0],table_idxs[tab_index]]]
        added.append((col_idxs[col_index][1]))
    return (middle,added,True)

In [10]:
with open('preprocessed/context_ready.txt','r') as f:
    template_ori = f.readlines()
with open('preprocessed/alldata.json','r') as f:
    data_all = json.load(f)
# question_ori = [''] * len(data_all)
with open('preprocessed/final_generation.json','r') as f:
    question_ori = json.load(f)
template = []
index = 0
for item in template_ori:
    temp = {}
    line = item.strip()
    if len(line.split(' ||| ')) == 2:
        question,opt = line.split(' ||| ')
        constrain = None
    if len(line.split(' ||| ')) == 3:
        question,opt,constrain = line.split(' ||| ')
    temp['question'] = question
    temp['opt'] = eval(opt)
    if constrain != None:
        temp['constrain'] = eval(constrain)
    else:
        temp['constrain'] = None
    template.append(temp)
    index += 1

In [11]:
data_single = []
for index,item in enumerate(data_all):
    temp = {}
    temp['question'] = [question_ori[index]]
    if item['question'].count('[CLS]') == 0:
        temp['question'].append(item['question'])
    temp['middle'] = item['middle']
    temp['query'] = item['query']
    temp['db_id'] = item['db_id']
    data_single.append(temp)
data_db = {}
for item in data_single:
    if item['db_id'] not in data_db.keys():
        data_db[item['db_id']] = [item]
    else:
        data_db[item['db_id']].append(item)

In [12]:
def check_constrain(cur,template):
    if template['constrain'] != None:
        if 'compent' in template['constrain'].keys():
            for item in template['constrain']['compent']:
                if cur['middle'][item] == None:
                    return False
        if 'no_compent' in template['constrain'].keys():
            for item in template['constrain']['no_compent']:
                 if cur['middle'][item] != None:
                    return False
        if 'num_select' in template['constrain'].keys():
            if len(cur['middle']['select']) != template['constrain']['num_select']:
                return False
        if 'num_where' in template['constrain'].keys():
            if cur['middle']['where'] == None:
                return False
            if len(cur['middle']['where']) != template['constrain']['num_where']:
                return False

    return True

In [13]:
col_replace = ['{column0}','{column1}','{column2}']
agg_replace = ['{agg0}','{agg1}','{agg2}']
op_replace = ['{op0}','{op1}','{op2}']
value_replace = ['{value0}','{value1}','{value2}']
agg_trans = {'max':["maximum","maximum number","maximum amount","the largest","the highest"],\
 'min':["minimum","minimum number","minimum number","the smallest","the lowest",],\
 'count':["the number"],\
 'avg':["the average","the average number","the mean"],\
 'sum':["the total number","total count","sum of","in total","the number of all","the total amount"]}
op_trans = {
    "=": ["is", "was", "are", "equal"],
    ">": ["more than", "higher than", "above", "bigger than", "over", "above"],
    "<": ["less than", "lower than", "below", "small than", "under", "below"],
    ">=": ["no less than", "no lower than", "no below", "no small than", 
     "no under", "at least", "or more", "no below", "not less than", "not lower than", "not below",
      "not small than",  "not before", "not under", "not below"],
    "<=": ["no more than", "no higher than", "no above", "no bigger than",
     "no over", "no above", "at most", "or less", "not more than", "not higher than", "not above",
      "not bigger than",  "not after", "not over", "not above"],
    "!=": ["does not equal to", "is not", "that are not", "were not", "that was not", "is non", "excluding", "not including",
      "except for", "ignore", "outside", "other than"],
    "LIKE": ["has the substring", "contains the word", "contains", "containing", "have substring", "starts with", "have the letter", "includes the substring"]
}
def generation_one(cur,template):
    one = {}
    one['db_id'] = cur['db_id']
    flag = False
    col_index = 0
    agg_index = 0
    op_index = 0
    value_index = 0
    db_id = cur['db_id']
    one['question'] = template['question']
    if check_constrain(cur,template):
        middle = cur['middle']
        
        if 'replace_select' in template['opt'].keys():
            num = 1
            agg = ['']
            distinct = False
            if template['opt']['replace_select'] != None:
                num = len(template['opt']['replace_select'])
                agg = template['opt']['replace_select']
                
                distinct = [False] * num
            middle,col,isok = replace_select(middle,db_id,num ,agg,distinct)

            if isok == False:
                return (one,False)
            for idx in range(num):
                one['question'] = one['question'].replace(col_replace[col_index],col[idx])
                col_index += 1
                    
        if 'add_select' in template['opt'].keys():
            num = 1
            agg = ['']
            distinct = False
            if template['opt']['add_select'] != None:
                num = len(template['opt']['add_select'])
                agg = template['opt']['add_select']
                
                distinct = [False] * num
            middle,col,isok = add_select(middle,db_id,num ,agg,distinct)

            if isok == False:
                return (one,False)
            for idx in range(num):
                one['question'] = one['question'].replace(col_replace[col_index],col[idx])
                col_index += 1

        if 'replace_agg' in template['opt'].keys():
            num = 1
            if template['opt']['replace_agg'] != None:
                #todo
                pass
                
            middle,col,isok = replace_agg(middle,db_id,num)
            if isok == False:
                return (one,False)
            for _ in range(num):
                one['question'] = one['question'].replace(agg_replace[agg_index],random.sample(agg_trans[col[agg_index]],1)[0])
                agg_index += 1
        
        if 'replace_count' in template['opt'].keys():
            needcol = False
            num = 1
            if template['opt']['replace_count'] != None:
                needcol = True
            middle,col,isok = replace_count(middle,db_id,needcol)
            if isok == False:
                return (one,False)
            if needcol:
                for idx in range(num):
                    one['question'] = one['question'].replace(col_replace[col_index],col[idx])
                    col_index += 1

                
        if 'replace_where' in template['opt'].keys():
            num = 1
            agg = ['']
            distinct = False
            if template['opt']['replace_where'] != None:
                #todo
                pass
 
            middle,col,isok = replace_where(middle,db_id,num)
            if isok == False:
                return (one,False)
            for _ in range(num):
                one['question'] = one['question'].replace(col_replace[col_index],col[0][0])
                one['question'] = one['question'].replace(op_replace[col_index],random.sample(op_trans[col[0][1]],1)[0])
                one['question'] = one['question'].replace(value_replace[value_index],col[0][2])
                col_index += 1
                op_index += 1
                value_index += 1
                
        if 'add_where' in template['opt'].keys():
            num = 1
            if template['opt']['add_where'] != None:
                #todo
                pass
 
            middle,col,isok = add_where(middle,db_id,num)
            if isok == False:
                return (one,False)
            for _ in range(num):
                one['question'] = one['question'].replace(col_replace[col_index],col[0][0])
                one['question'] = one['question'].replace(op_replace[col_index],random.sample(op_trans[col[0][1]],1)[0])
                one['question'] = one['question'].replace(value_replace[value_index],col[0][2])
                col_index += 1
                op_index += 1
                value_index += 1
        
        if 'replace_op' in template['opt'].keys():
            num = 1
            if template['opt']['replace_op'] != None:
                #todo
                pass
 
            middle,col,isok = replace_op(middle,db_id,num)
            if isok == False:
                return (one,False)
            for _ in range(num):
                one['question'] = one['question'].replace(op_replace[col_index],random.sample(op_trans[col[0][0]],1)[0])
                one['question'] = one['question'].replace(value_replace[value_index],col[0][1])
                op_index += 1
                value_index += 1
       
        
        if 'add_orderby' in template['opt'].keys():
            asc = 'ASC'
            limit = ''
            if template['opt']['add_orderby'] != None:
                asc = template['opt']['add_orderby'][0][0]
                limit = template['opt']['add_orderby'][0][1]
 
            middle,col,isok = add_orderby(middle,db_id,asc,limit)
            if isok == False:
                return (one,False)
            one['question'] = one['question'].replace(col_replace[col_index],col[0])
            col_index += 1
        
        
        
        if 'replace_asc' in template['opt'].keys():
            asc = 'ASC'
            limit = ''
            if template['opt']['replace_asc'] != None:
                asc = template['opt']['replace_asc']

            middle,col,isok = replace_asc(middle,db_id,asc)
            if isok == False:
                return (one,False)
            
        if 'add_groupby' in template['opt'].keys():
            if template['opt']['add_groupby'] != None:
                #todo
                pass
 
            middle,col,isok = add_groupby(middle,db_id)
            if isok == False:
                return (one,False)
            one['question'] = one['question'].replace(col_replace[col_index],col[0])
            col_index += 1
        
        if 'add_having' in template['opt'].keys():
            asc = 'ASC'
            limit = ''
            if template['opt']['add_having'] != None:
                op = template['opt']['add_having'][0][0]
                value = template['opt']['add_having'][0][1]
 
            middle,col,isok = add_having(middle,db_id,op,value)
            if isok == False:
                return (one,False)
            
    else:
        return (one,False)
#     one['middle'] = middle
    one['middle'] = middle
    return (one,True)

In [14]:
def generation_multi(single,num,turn = 5):
    data_multi = []
#     first = random.sample(single,1)[0]
    first = copy.deepcopy(single)
    temp = copy.deepcopy(first)
    temp['relate'] = 0
    temp['question'] = random.sample(temp['question'],1)[0]
    data_multi.append(temp)
    index = 1
    index1 = 1
    total = 0
    his = [first['query']]
    relate = 1
    therelate = 0
    while total < turn:
        if random.random() > 0.4 and therelate < 3:
            therelate += 1
            flag = True
            while flag:
                if  index > num:
                    return (data_multi,False)
                cur_temp = random.sample(template,1)[0]
                one,isok = generation_one(first,cur_temp)
                
                if not isok:
                    index += 1
                    continue
                one['query'] = translate_sql(one['middle'],0,one['db_id'],db_reverse)
                if one['query'] in his:
#                 print(get_struct_similarity(one['query'], his))
                # if 1 in get_struct_similarity(one['query'], his):
                    # print('chongfu')
                    index += 1
                    continue

#                 print(one)
                temp = copy.deepcopy(one)
                temp['relate'] = relate
                data_multi.append(temp)
                his.append(one['query'])
#                 print(total)
                total += 1
                flag = False
                if random.random() > 0.05:
                    first = one
                    relate = 1
                else:
                    relate = relate + 1
        else:
            therelate = 0
            flag = True
            while flag:
                if  index1 > num:
                    flag = False
                second = copy.deepcopy(random.sample(data_db[first['db_id']],1)[0])
#                 print(get_struct_similarity(one['query'], his))
                # if second['query'] == first['query']:
                if second['query'] in his:
                    # print('chongfu')
                    index1 += 1
                    continue
                temp = copy.deepcopy(second)
                temp['relate'] = 0
                temp['question'] = random.sample(temp['question'],1)[0]
                data_multi.append(temp)
                his.append(second['query'])
                total += 1
                flag = False
                if random.random() > 0.05:
                    first = second
                    relate = 1
                else:
                    relate = relate + 1
    return (data_multi,True)

In [15]:
from tqdm import tqdm
session = []
for _ in range(10):
    for item in tqdm(data_single):
        second = copy.deepcopy(item)
        try:
            one,_ = generation_multi(second,1000,4)
        except:
    #         print(item)
            pass
        session.append(one)

100%|██████████| 15281/15281 [00:09<00:00, 1561.61it/s]
100%|██████████| 15281/15281 [00:09<00:00, 1620.37it/s]
100%|██████████| 15281/15281 [00:09<00:00, 1580.90it/s]
100%|██████████| 15281/15281 [00:09<00:00, 1542.60it/s]
100%|██████████| 15281/15281 [00:08<00:00, 1738.99it/s]
100%|██████████| 15281/15281 [00:10<00:00, 1425.53it/s]
100%|██████████| 15281/15281 [00:08<00:00, 1742.75it/s]
100%|██████████| 15281/15281 [00:10<00:00, 1430.48it/s]
100%|██████████| 15281/15281 [00:08<00:00, 1740.96it/s]
100%|██████████| 15281/15281 [00:11<00:00, 1383.69it/s]


In [16]:
with open('raw_data/tables.json','r') as f:
    table = json.load(f)
db = {}
for item in table:
    temp = {}
    temp['column_names'] = item['column_names']
    temp['column_types'] = item['column_types']
    temp['table_names'] = item['table_names']
    temp['column_names_original'] = item['column_names_original']
    temp['table_names_original'] = item['table_names_original']
    
    table2columns = [[] for _ in range(len(item['table_names']))] # from table id to column ids list
    for col_id, col in enumerate(item['column_names']):
        if col_id == 0: continue
        table2columns[col[0]].append(col_id)
    for col,typ in zip(item['column_names'],item['column_types']):
        col.append(typ)
    temp['table2columns'] = table2columns
    db[item['db_id']] = temp
db_reverse = {}
db_c_ori2pre = {}
db_t_ori2pre = {}
for k,v in db.items():
    db_reverse[k] = {}
    column = []
    table = []
    db_c_ori2pre[k] = {}
    db_t_ori2pre[k] = {}
    for o,c in zip(v['column_names'],v['column_names_original']):
        column.append([o[0],o[1],c[1]])
        db_c_ori2pre[k][c[1]] = o[1]
    db_reverse[k]['column'] = column
    for o,c in zip(v['table_names'],v['table_names_original']):
        table.append([o,c])
        db_t_ori2pre[k][c] = o
    db_reverse[k]['table'] = table

In [18]:
from grakel.kernels import WeisfeilerLehman, VertexHistogram
import sqlparse
from sqlparse.tokens import Token

def similarity_compute_(s1,s2):
    s1_ = s1.split(" ")
    s2_ = s2.split(" ")
    l_min = min(len(s1_),len(s2_))
    l_max = max(len(s1_),len(s2_))
    return sum([1 if s1_[i] in s2_ else 0 for i in range(l_min)]) / l_max


def similarity_compute(l1,l2):
    assert len(l1)==len(l2)
    l = sum([0 if l1[i]=='' and l2[i]=='' else 1 for i in range(len(l1))])
    return sum([similarity_compute_(l1[i],l2[i]) if l1[i]!='' and l2[i]!='' else 0 for i in range(len(l1))]) / l


def get_wl_score(samples_list):
    gk = WeisfeilerLehman(n_iter=5, base_graph_kernel=VertexHistogram, normalize=True)
    K_samples = gk.fit_transform(samples_list)
    return K_samples[0][1:]

def _pprint_tree(sql_node_label, sql_depth_label, sql_graph, tokens, max_depth=None, depth=0, f=None, _pre=''):  # Deepth First Search
    """Pretty-print the object tree."""
    token_count = len(tokens)
    for idx, token in enumerate(tokens):
        if token.ttype in [Token.Text.Whitespace, Token.Name]: #, Token.Punctuation, Token.Name, Token.Wildcard]:
            continue
        if token.value not in sql_node_label.keys():
            sql_node_label[token.value] = len(sql_node_label.items())
        sql_graph.append((sql_node_label[sql_depth_label[depth-1]], sql_node_label[token.value]))
        sql_depth_label[depth] = token.value
        cls = token._get_repr_name()
        value = token._get_repr_value()

        last = idx == (token_count - 1)
        pre = '`- ' if last else '|- '

        q = '"' if value.startswith("'") and value.endswith("'") else "'"
        # print("{_pre}{pre}{idx} {cls} {q}{value}{q}"
        #         .format(**locals()), file=f)
        # print("{_pre}{pre}{depth} {value}"
        #         .format(**locals()))
        if token.is_group and (max_depth is None or depth < max_depth):
            parent_pre = '   ' if last else '|  '
            _pprint_tree(sql_node_label, sql_depth_label, sql_graph, token.tokens,max_depth, depth + 1, f, _pre + parent_pre)

    return sql_node_label, sql_depth_label, sql_graph
    
def get_struct_similarity(curt_sql, hist_sql):
    sql_list = [curt_sql] + hist_sql
    sql_node_label = {'root':0}
    samples_list_ = []
    samples_list = []
    for sl in sql_list:
        sample = []
        sql_depth_label = {-1:'root'}
        sql_graph = []
        sql_node_label, sql_depth_label, sql_graph = _pprint_tree(sql_node_label, sql_depth_label, sql_graph, sqlparse.parse(sl)[0].tokens)
        sample.append(set(sql_graph))

        sql_edge_label = dict([(v,0) for v in set(sql_graph)])
        sample.append(sql_edge_label)
        samples_list_.append(sample)
    for sl in samples_list_:
        sl.insert(1,  dict([(idx,sql_node_label[v]) for idx, v in enumerate(sql_node_label.keys())]))
        # sl.insert(1,  dict([(idx, 0) for idx, v in enumerate(sql_node_label.keys())]))
        samples_list.append(sl)
    # print(sql_depth_label)
    # print(samples_list)
    # print(sql_node_label)
    wl_score = get_wl_score(samples_list)
    return list(wl_score)


# score = get_struct_similarity(sql, hist_sql)
# print(score)
def compute_content(cur_label,his_label):
    result = []
    for item in his_label:
        result.append(similarity_compute(cur_label,item))
    return result

In [19]:
schemas, db_names, tables = get_schemas_from_json('raw_data/tables.json')
from tqdm import tqdm
data_final = []
for idx,item in enumerate(tqdm(session)):
    last = []
    last_sql = []
    last_label = []
    last_clabel = []
    last_tlabel = []
    relate = []
    theturn = []
    flag = True
    for turn in item:
        if flag == False:
            continue
        temp = {}
        last.append(turn['question'])
        last_sql.append(turn['query'])
        relate.append(turn['relate'])
        
        db_id = turn['db_id']
        schema = schemas[db_id]
        table = tables[db_id]
        schema = Schema(schema, table)
        try:
            sql_label = get_sql(schema, turn['query'])
            col_label = get_label(sql_label,len(table['column_names_original']))
            table_label = get_table_label(sql_label,len(table['table_names']))

            label = col_label + table_label
            last_label.append(label)
            last_clabel.append(col_label)
            last_tlabel.append(table_label)

            temp['col_label'] = last_clabel[::-1][:5]
            temp['table_label'] = last_tlabel[::-1][:5]
            temp['sql'] = sql_label
            temp['db_id'] = turn['db_id']
            temp['query'] = turn['query']
            temp['question'] = ' [CLS] '.join(last[::-1][:5])
            temp['struct'] = []
            for sql in last_sql[::-1][:5]:
                temp['struct'].append(get_struct_similarity(temp['query'], [sql])[0])

            temp['content'] = compute_content(col_label, last_clabel[::-1][:5])
            temp['sim'] = [(temp['struct'][i]+temp['content'][i])/2 for i in range(len(temp['content']))]
            temp['relate'] = relate[::-1][:5]
            assert len(temp['content']) == temp['question'].count('[CLS]')+1
            theturn.append(temp)
        except:
            flag = False
            print(idx)
    if flag:
        data_final.append(theturn)

  0%|          | 243/152810 [00:17<2:32:37, 16.66it/s]

239


  1%|          | 1371/152810 [01:41<1:54:50, 21.98it/s]

1367


  1%|          | 1395/152810 [01:43<2:08:13, 19.68it/s]

1391


  2%|▏         | 2593/152810 [03:10<3:27:46, 12.05it/s]

2590


  2%|▏         | 2599/152810 [03:11<3:03:08, 13.67it/s]

2597


  2%|▏         | 2655/152810 [03:15<3:15:11, 12.82it/s]

2653


  2%|▏         | 2712/152810 [03:19<2:24:45, 17.28it/s]

2710


  2%|▏         | 2748/152810 [03:22<2:49:37, 14.74it/s]

2745


  2%|▏         | 2868/152810 [03:31<2:41:33, 15.47it/s]

2866


  2%|▏         | 2960/152810 [03:38<2:37:22, 15.87it/s]

2957


  3%|▎         | 3941/152810 [05:15<2:27:57, 16.77it/s]

3939


  4%|▍         | 5850/152810 [07:31<2:22:51, 17.15it/s]

5846


  4%|▍         | 6417/152810 [08:13<2:20:35, 17.35it/s]

6415


  5%|▍         | 7325/152810 [09:19<2:29:39, 16.20it/s]

7323


  5%|▍         | 7354/152810 [09:21<2:40:13, 15.13it/s]

7351


  5%|▌         | 8033/152810 [10:11<2:31:53, 15.89it/s]

8030


  5%|▌         | 8077/152810 [10:15<2:42:06, 14.88it/s]

8073


  6%|▌         | 8418/152810 [10:40<1:56:17, 20.69it/s]

8412
8414
8415


  6%|▌         | 8865/152810 [11:13<2:56:42, 13.58it/s]

8862


  6%|▌         | 9088/152810 [11:30<2:15:23, 17.69it/s]

9085
9086


  7%|▋         | 11228/152810 [14:03<1:48:33, 21.74it/s]

11223
11226
11227


  7%|▋         | 11237/152810 [14:04<2:06:02, 18.72it/s]

11232


  7%|▋         | 11243/152810 [14:04<1:56:52, 20.19it/s]

11239


  9%|▉         | 13645/152810 [17:00<2:16:47, 16.96it/s] 

13642
13643


  9%|▉         | 13651/152810 [17:01<2:26:35, 15.82it/s]

13649


  9%|▉         | 13658/152810 [17:01<2:17:37, 16.85it/s]

13654


  9%|▉         | 13665/152810 [17:01<2:10:08, 17.82it/s]

13661
13663


  9%|▉         | 14267/152810 [17:44<2:55:22, 13.17it/s]

14265


  9%|▉         | 14316/152810 [17:48<2:27:01, 15.70it/s]

14312
14313


  9%|▉         | 14448/152810 [17:57<2:02:19, 18.85it/s]

14444


 10%|▉         | 15094/152810 [18:39<2:08:57, 17.80it/s]

15091
15092


 11%|█         | 16656/152810 [20:33<1:39:09, 22.89it/s]

16653


 12%|█▏        | 18028/152810 [22:12<2:47:52, 13.38it/s]

18026


 12%|█▏        | 18121/152810 [22:18<2:03:29, 18.18it/s]

18116


 13%|█▎        | 20026/152810 [25:01<2:08:50, 17.18it/s]

20023


 14%|█▍        | 21141/152810 [26:20<1:58:48, 18.47it/s]

21139


 14%|█▍        | 21162/152810 [26:21<2:30:37, 14.57it/s]

21160


 15%|█▍        | 22434/152810 [27:51<2:02:56, 17.67it/s]

22432


 15%|█▍        | 22609/152810 [28:04<1:55:47, 18.74it/s]

22604


 15%|█▌        | 23137/152810 [28:42<2:35:33, 13.89it/s]

23134


 15%|█▌        | 23315/152810 [28:55<2:08:13, 16.83it/s]

23311


 15%|█▌        | 23357/152810 [28:58<2:25:41, 14.81it/s]

23355


 16%|█▌        | 23695/152810 [29:22<2:09:52, 16.57it/s]

23693


 16%|█▌        | 24146/152810 [29:56<2:15:19, 15.85it/s]

24143


 17%|█▋        | 25941/152810 [32:03<2:32:12, 13.89it/s]

25938


 17%|█▋        | 26512/152810 [32:45<2:16:20, 15.44it/s]

26508


 17%|█▋        | 26528/152810 [32:46<1:59:27, 17.62it/s]

26523


 19%|█▊        | 28418/152810 [35:01<1:39:57, 20.74it/s]

28414


 19%|█▉        | 28918/152810 [35:39<1:57:08, 17.63it/s]

28915
28916


 19%|█▉        | 28929/152810 [35:39<1:47:09, 19.27it/s]

28926
28929


 19%|█▉        | 28938/152810 [35:40<1:53:51, 18.13it/s]

28936


 19%|█▉        | 29564/152810 [36:24<2:41:32, 12.72it/s]

29562


 19%|█▉        | 29592/152810 [36:26<2:49:43, 12.10it/s]

29589


 19%|█▉        | 29596/152810 [36:26<2:22:05, 14.45it/s]

29593


 20%|█▉        | 29873/152810 [36:44<1:45:34, 19.41it/s]

29870


 20%|█▉        | 30374/152810 [37:16<1:45:45, 19.30it/s]

30373


 21%|██▏       | 32600/152810 [40:00<2:16:25, 14.68it/s] 

32596


 22%|██▏       | 33275/152810 [40:51<2:13:39, 14.91it/s]

33272


 22%|██▏       | 33308/152810 [40:54<2:11:40, 15.13it/s]

33307


 22%|██▏       | 33399/152810 [41:01<2:08:04, 15.54it/s]

33397


 23%|██▎       | 35591/152810 [44:06<1:52:42, 17.33it/s]

35587


 23%|██▎       | 35631/152810 [44:08<1:53:22, 17.22it/s]

35628


 24%|██▍       | 36338/152810 [44:59<1:49:03, 17.80it/s]

36335


 24%|██▍       | 36423/152810 [45:05<1:55:41, 16.77it/s]

36419
36420


 25%|██▍       | 37599/152810 [46:30<1:42:55, 18.66it/s]

37596


 25%|██▌       | 38596/152810 [47:43<1:56:03, 16.40it/s]

38591


 26%|██▌       | 38979/152810 [48:11<2:04:13, 15.27it/s]

38977


 26%|██▌       | 39425/152810 [48:44<2:10:32, 14.48it/s]

39422


 26%|██▌       | 39740/152810 [49:09<1:57:08, 16.09it/s]

39736
39737


 27%|██▋       | 41789/152810 [51:36<1:39:15, 18.64it/s]

41786
41787
41789


 27%|██▋       | 41809/152810 [51:38<1:42:33, 18.04it/s]

41805


 27%|██▋       | 41867/152810 [51:42<2:28:24, 12.46it/s]

41866


 29%|██▉       | 44202/152810 [54:28<1:39:59, 18.10it/s]

44197


 29%|██▉       | 44218/152810 [54:30<1:44:30, 17.32it/s]

44213
44217


 29%|██▉       | 44876/152810 [55:15<2:08:19, 14.02it/s]

44874


 30%|██▉       | 45154/152810 [55:34<1:36:00, 18.69it/s]

45151


 30%|███       | 46010/152810 [56:31<1:44:33, 17.02it/s]

46008


 31%|███▏      | 47852/152810 [58:46<1:56:34, 15.01it/s]

47850


 31%|███▏      | 47879/152810 [58:48<1:52:38, 15.52it/s]

47877


 32%|███▏      | 48499/152810 [59:34<1:49:49, 15.83it/s]

48496


 32%|███▏      | 48556/152810 [59:39<1:47:12, 16.21it/s]

48553


 32%|███▏      | 48682/152810 [59:48<1:40:27, 17.28it/s]

48678


 32%|███▏      | 48801/152810 [59:57<1:36:55, 17.88it/s]

48800


 32%|███▏      | 49508/152810 [1:01:16<1:49:48, 15.68it/s]

49504


 33%|███▎      | 50248/152810 [1:02:09<2:03:55, 13.79it/s]

50246


 34%|███▎      | 51205/152810 [1:03:17<1:53:21, 14.94it/s]

51203


 34%|███▍      | 51703/152810 [1:03:51<1:21:43, 20.62it/s]

51701


 34%|███▍      | 51830/152810 [1:04:00<2:01:29, 13.85it/s]

51827


 34%|███▍      | 51997/152810 [1:04:12<1:55:35, 14.54it/s]

51993


 35%|███▍      | 52879/152810 [1:05:15<1:47:22, 15.51it/s]

52876


 35%|███▍      | 53169/152810 [1:05:36<1:34:26, 17.58it/s]

53166


 35%|███▍      | 53213/152810 [1:05:39<1:53:20, 14.65it/s]

53211


 35%|███▌      | 53917/152810 [1:06:30<1:37:10, 16.96it/s]

53915


 36%|███▌      | 54257/152810 [1:07:01<1:49:16, 15.03it/s] 

54255


 36%|███▌      | 54708/152810 [1:07:34<1:39:27, 16.44it/s]

54704


 36%|███▌      | 55021/152810 [1:07:58<1:39:19, 16.41it/s]

55017
55018
55019


 36%|███▌      | 55059/152810 [1:08:01<1:45:46, 15.40it/s]

55056


 37%|███▋      | 57075/152810 [1:10:24<1:25:43, 18.61it/s]

57070
57071
57074


 37%|███▋      | 57083/152810 [1:10:25<1:26:50, 18.37it/s]

57081
57084


 37%|███▋      | 57157/152810 [1:10:30<2:04:48, 12.77it/s]

57156


 39%|███▉      | 59480/152810 [1:13:16<1:37:55, 15.88it/s]

59477


 39%|███▉      | 59485/152810 [1:13:17<1:26:38, 17.95it/s]

59482
59485


 39%|███▉      | 59500/152810 [1:13:18<1:47:08, 14.52it/s]

59497


 39%|███▉      | 59505/152810 [1:13:18<1:32:36, 16.79it/s]

59501


 39%|███▉      | 60119/152810 [1:14:00<2:15:56, 11.36it/s]

60118


 39%|███▉      | 60129/152810 [1:14:01<2:04:42, 12.39it/s]

60125
60126


 40%|███▉      | 60432/152810 [1:14:22<1:13:54, 20.83it/s]

60427


 40%|███▉      | 60489/152810 [1:14:25<1:16:23, 20.14it/s]

60486


 40%|███▉      | 60938/152810 [1:14:55<1:24:54, 18.04it/s]

60934


 40%|████      | 61349/152810 [1:15:24<1:46:07, 14.36it/s]

61346


 41%|████      | 62506/152810 [1:16:50<1:02:53, 23.93it/s]

62504


 41%|████      | 62590/152810 [1:16:55<1:28:24, 17.01it/s]

62586


 42%|████▏     | 63872/152810 [1:18:29<1:47:41, 13.76it/s]

63869


 42%|████▏     | 63896/152810 [1:18:31<1:55:19, 12.85it/s]

63894


 43%|████▎     | 66194/152810 [1:21:43<1:25:49, 16.82it/s]

66190


 44%|████▍     | 67837/152810 [1:23:43<1:33:36, 15.13it/s]

67834


 45%|████▍     | 68160/152810 [1:24:05<1:25:00, 16.60it/s]

68157
68159


 45%|████▌     | 69201/152810 [1:25:20<1:11:33, 19.47it/s]

69197
69198


 46%|████▌     | 69539/152810 [1:25:45<1:35:20, 14.56it/s]

69536


 46%|████▌     | 69760/152810 [1:26:01<1:35:18, 14.52it/s]

69757


 46%|████▌     | 69988/152810 [1:26:18<1:23:33, 16.52it/s]

69984


 47%|████▋     | 72352/152810 [1:29:09<1:08:01, 19.71it/s]

72348
72351


 47%|████▋     | 72359/152810 [1:29:10<1:08:58, 19.44it/s]

72355
72356


 47%|████▋     | 72369/152810 [1:29:10<1:10:22, 19.05it/s]

72364


 48%|████▊     | 72961/152810 [1:29:54<1:05:09, 20.43it/s]

72957


 49%|████▉     | 74761/152810 [1:32:03<1:08:47, 18.91it/s]

74757
74759


 49%|████▉     | 74776/152810 [1:32:03<1:17:35, 16.76it/s]

74773


 49%|████▉     | 75408/152810 [1:32:48<1:36:27, 13.37it/s]

75406
75407


 49%|████▉     | 75418/152810 [1:32:49<1:46:14, 12.14it/s]

75415


 49%|████▉     | 75429/152810 [1:32:50<1:35:22, 13.52it/s]

75427
75429


 50%|████▉     | 76223/152810 [1:33:41<1:06:34, 19.17it/s]

76219


 51%|█████     | 78261/152810 [1:36:09<1:19:59, 15.53it/s]

78257
78258


 52%|█████▏    | 79117/152810 [1:37:13<1:14:52, 16.41it/s]

79115


 52%|█████▏    | 79214/152810 [1:37:20<1:18:52, 15.55it/s]

79210


 52%|█████▏    | 79274/152810 [1:37:24<1:17:40, 15.78it/s]

79271


 52%|█████▏    | 79977/152810 [1:38:42<2:07:26,  9.53it/s]

79974
79975


 54%|█████▎    | 81767/152810 [1:40:57<1:26:07, 13.75it/s] 

81764


 54%|█████▍    | 82267/152810 [1:41:34<1:04:42, 18.17it/s]

82263


 55%|█████▍    | 83559/152810 [1:43:06<1:17:43, 14.85it/s]

83555
83557


 56%|█████▌    | 84821/152810 [1:44:37<56:06, 20.19it/s]  

84816
84817
84818
84819
84820


 56%|█████▌    | 85273/152810 [1:45:11<1:11:43, 15.69it/s]

85269


 56%|█████▌    | 85494/152810 [1:45:29<1:28:53, 12.62it/s]

85491


 56%|█████▌    | 85583/152810 [1:45:35<1:14:59, 14.94it/s]

85579
85580


 57%|█████▋    | 87632/152810 [1:48:01<51:57, 20.91it/s]  

87628
87630


 59%|█████▉    | 90046/152810 [1:50:55<53:52, 19.42it/s]  

90042


 59%|█████▉    | 90051/152810 [1:50:55<59:03, 17.71it/s]  

90048


 59%|█████▉    | 90055/152810 [1:50:56<1:02:04, 16.85it/s]

90051
90054


 59%|█████▉    | 90852/152810 [1:51:52<48:51, 21.13it/s]  

90847


 60%|█████▉    | 90992/152810 [1:52:01<45:52, 22.46it/s]  

90989


 60%|█████▉    | 90998/152810 [1:52:01<50:43, 20.31it/s]

90996


 60%|█████▉    | 91014/152810 [1:52:02<50:30, 20.39it/s]  

91011


 60%|█████▉    | 91300/152810 [1:52:20<54:30, 18.81it/s]  

91296


 60%|██████    | 91929/152810 [1:53:03<1:05:33, 15.48it/s]

91925


 61%|██████    | 93059/152810 [1:54:27<45:19, 21.98it/s]  

93053


 61%|██████    | 93068/152810 [1:54:28<42:15, 23.56it/s]

93063


 62%|██████▏   | 94351/152810 [1:55:59<56:41, 17.19it/s]  

94349


 62%|██████▏   | 94645/152810 [1:56:21<1:06:03, 14.67it/s]

94643


 63%|██████▎   | 95628/152810 [1:57:58<57:16, 16.64it/s]  

95625
95626


 64%|██████▍   | 97547/152810 [2:00:14<48:03, 19.17it/s]  

97544


 64%|██████▍   | 98080/152810 [2:00:52<57:11, 15.95it/s]  

98076
98077


 64%|██████▍   | 98398/152810 [2:01:15<1:08:48, 13.18it/s]

98395


 65%|██████▍   | 98722/152810 [2:01:37<51:40, 17.44it/s]  

98719


 65%|██████▍   | 99013/152810 [2:01:58<53:15, 16.83it/s]  

99009


 65%|██████▌   | 99718/152810 [2:02:50<1:02:19, 14.20it/s]

99715


 65%|██████▌   | 99980/152810 [2:03:09<1:05:54, 13.36it/s]

99976


 65%|██████▌   | 100064/152810 [2:03:15<59:09, 14.86it/s]  

100062


 66%|██████▌   | 100104/152810 [2:03:18<48:41, 18.04it/s]  

100100
100101
100104


 66%|██████▌   | 100607/152810 [2:03:56<56:53, 15.29it/s]  

100603
100604


 66%|██████▌   | 100755/152810 [2:04:07<1:07:04, 12.93it/s]

100752


 67%|██████▋   | 102918/152810 [2:06:42<51:57, 16.01it/s]  

102915


 67%|██████▋   | 102923/152810 [2:06:42<47:15, 17.59it/s]

102920


 67%|██████▋   | 102928/152810 [2:06:43<47:49, 17.38it/s]

102926


 67%|██████▋   | 102992/152810 [2:06:47<1:07:00, 12.39it/s]

102990


 69%|██████▉   | 105323/152810 [2:09:34<54:12, 14.60it/s]  

105321
105324


 69%|██████▉   | 105969/152810 [2:10:20<1:03:17, 12.33it/s]

105967


 69%|██████▉   | 106126/152810 [2:10:31<41:11, 18.89it/s]  

106123


 69%|██████▉   | 106132/152810 [2:10:31<39:52, 19.51it/s]

106129


 69%|██████▉   | 106142/152810 [2:10:32<44:19, 17.55it/s]

106140


 70%|██████▉   | 106278/152810 [2:10:40<39:45, 19.50it/s]  

106276


 70%|██████▉   | 106779/152810 [2:11:13<49:42, 15.43it/s]  

106776
106779


 70%|███████   | 107210/152810 [2:11:43<52:38, 14.44it/s]  

107206


 70%|███████   | 107356/152810 [2:11:54<53:02, 14.28it/s]  

107354


 70%|███████   | 107417/152810 [2:11:58<47:28, 15.94it/s]  

107414


 71%|███████   | 108343/152810 [2:13:08<36:25, 20.35it/s]  

108338


 71%|███████   | 108349/152810 [2:13:09<33:57, 21.82it/s]

108345


 71%|███████   | 108362/152810 [2:13:09<38:47, 19.09it/s]

108359


 71%|███████   | 108823/152810 [2:13:40<46:54, 15.63it/s]  

108819


 71%|███████   | 108857/152810 [2:13:43<52:53, 13.85it/s]  

108854


 71%|███████▏  | 109248/152810 [2:14:12<50:45, 14.30it/s]  

109246


 72%|███████▏  | 109295/152810 [2:14:15<51:49, 13.99it/s]  

109294


 72%|███████▏  | 109623/152810 [2:14:40<53:03, 13.56it/s]  

109620


 72%|███████▏  | 109680/152810 [2:14:45<47:12, 15.22it/s]  

109677


 72%|███████▏  | 109798/152810 [2:14:54<53:01, 13.52it/s]  

109796


 72%|███████▏  | 109836/152810 [2:14:56<50:26, 14.20it/s]

109833


 72%|███████▏  | 110631/152810 [2:16:22<46:24, 15.15it/s]  

110628


 74%|███████▍  | 112828/152810 [2:19:09<36:52, 18.07it/s]   

112825


 74%|███████▍  | 112956/152810 [2:19:18<39:57, 16.62it/s]

112951


 75%|███████▍  | 114121/152810 [2:20:43<38:49, 16.61it/s]  

114118


 75%|███████▍  | 114295/152810 [2:20:56<40:43, 15.76it/s]  

114291


 75%|███████▌  | 115043/152810 [2:21:50<40:41, 15.47it/s]  

115041


 75%|███████▌  | 115259/152810 [2:22:06<47:01, 13.31it/s]  

115256


 76%|███████▌  | 115385/152810 [2:22:15<37:20, 16.70it/s]  

115382


 76%|███████▌  | 115832/152810 [2:22:49<38:02, 16.20it/s]  

115828


 77%|███████▋  | 118192/152810 [2:25:40<29:40, 19.45it/s]  

118190


 77%|███████▋  | 118210/152810 [2:25:41<28:14, 20.42it/s]

118207


 77%|███████▋  | 118293/152810 [2:25:47<42:16, 13.61it/s]

118290


 79%|███████▉  | 120609/152810 [2:28:36<32:57, 16.28it/s]  

120605


 79%|███████▉  | 120620/152810 [2:28:36<33:49, 15.86it/s]

120617


 79%|███████▉  | 120627/152810 [2:28:37<27:27, 19.53it/s]

120625


 79%|███████▉  | 121254/152810 [2:29:21<41:39, 12.63it/s]

121250


 79%|███████▉  | 121281/152810 [2:29:23<39:34, 13.28it/s]

121279


 79%|███████▉  | 121421/152810 [2:29:32<28:56, 18.08it/s]

121419
121420


 80%|███████▉  | 121688/152810 [2:29:50<32:09, 16.13it/s]

121685


 80%|███████▉  | 122062/152810 [2:30:15<34:11, 14.99it/s]

122059


 80%|████████  | 122415/152810 [2:30:39<32:18, 15.68it/s]

122413


 81%|████████  | 123312/152810 [2:31:45<34:47, 14.13it/s]  

123309


 81%|████████  | 124104/152810 [2:32:43<31:58, 14.96it/s]  

124100


 82%|████████▏ | 124661/152810 [2:33:24<31:38, 14.83it/s]

124658


 82%|████████▏ | 124904/152810 [2:33:43<37:09, 12.52it/s]

124901


 82%|████████▏ | 124995/152810 [2:33:50<34:51, 13.30it/s]

124993


 82%|████████▏ | 125207/152810 [2:34:06<30:25, 15.12it/s]

125205


 83%|████████▎ | 127180/152810 [2:36:54<22:16, 19.18it/s]  

127178


 84%|████████▍ | 128109/152810 [2:38:00<22:22, 18.40it/s]

128106


 85%|████████▍ | 129403/152810 [2:39:34<23:16, 16.76it/s]

129399
129400


 85%|████████▍ | 129576/152810 [2:39:47<24:50, 15.59it/s]

129572


 85%|████████▌ | 130163/152810 [2:40:30<29:24, 12.83it/s]

130161


 85%|████████▌ | 130324/152810 [2:40:41<23:50, 15.72it/s]

130321


 85%|████████▌ | 130543/152810 [2:40:58<27:12, 13.64it/s]

130539


 86%|████████▌ | 130667/152810 [2:41:06<19:23, 19.04it/s]

130662


 86%|████████▌ | 130696/152810 [2:41:08<25:26, 14.49it/s]

130693


 86%|████████▌ | 131427/152810 [2:42:03<22:41, 15.70it/s]

131423


 87%|████████▋ | 133475/152810 [2:44:31<17:21, 18.56it/s]

133471
133472


 87%|████████▋ | 133484/152810 [2:44:32<17:35, 18.31it/s]

133480


 87%|████████▋ | 133491/152810 [2:44:32<17:27, 18.44it/s]

133488
133489


 89%|████████▉ | 135885/152810 [2:47:24<17:46, 15.88it/s]

135881
135883


 89%|████████▉ | 135893/152810 [2:47:24<17:36, 16.02it/s]

135890
135893


 89%|████████▉ | 135907/152810 [2:47:25<16:55, 16.65it/s]

135904


 89%|████████▉ | 135914/152810 [2:47:26<14:49, 18.99it/s]

135911


 89%|████████▉ | 136688/152810 [2:48:20<14:00, 19.18it/s]

136685


 90%|████████▉ | 136840/152810 [2:48:29<13:26, 19.80it/s]

136837


 91%|█████████ | 139007/152810 [2:51:00<11:37, 19.79it/s]

139005


 91%|█████████ | 139384/152810 [2:51:27<13:59, 15.99it/s]

139381


 92%|█████████▏| 140242/152810 [2:52:31<12:34, 16.67it/s]

140239


 92%|█████████▏| 140367/152810 [2:52:40<11:40, 17.77it/s]

140364


 92%|█████████▏| 140489/152810 [2:52:49<14:31, 14.14it/s]

140486


 92%|█████████▏| 141341/152810 [2:54:17<10:37, 17.98it/s]

141337


 95%|█████████▍| 144564/152810 [2:58:05<08:40, 15.83it/s]

144562


 95%|█████████▍| 144682/152810 [2:58:14<09:09, 14.78it/s]

144680


 95%|█████████▍| 144857/152810 [2:58:27<08:16, 16.01it/s]

144853


 95%|█████████▌| 145606/152810 [2:59:22<06:23, 18.80it/s]

145603


 96%|█████████▌| 146706/152810 [3:00:44<06:46, 15.01it/s]

146703


 96%|█████████▌| 146813/152810 [3:00:51<07:36, 13.15it/s]

146811


 97%|█████████▋| 148188/152810 [3:02:39<04:38, 16.63it/s]  

148184


 97%|█████████▋| 148755/152810 [3:03:21<03:19, 20.38it/s]

148752


 97%|█████████▋| 148760/152810 [3:03:22<03:50, 17.55it/s]

148758


 99%|█████████▉| 151169/152810 [3:06:14<01:40, 16.35it/s]

151166


 99%|█████████▉| 151174/152810 [3:06:15<01:40, 16.28it/s]

151171


 99%|█████████▉| 151181/152810 [3:06:15<01:59, 13.67it/s]

151178


 99%|█████████▉| 151814/152810 [3:07:00<01:25, 11.67it/s]

151811


 99%|█████████▉| 151845/152810 [3:07:02<01:02, 15.32it/s]

151841
151842


 99%|█████████▉| 151970/152810 [3:07:11<00:42, 19.68it/s]

151967


 99%|█████████▉| 151992/152810 [3:07:12<00:47, 17.05it/s]

151988


100%|█████████▉| 152252/152810 [3:07:27<00:28, 19.85it/s]

152249


100%|█████████▉| 152596/152810 [3:07:50<00:15, 13.88it/s]

152593


100%|██████████| 152810/152810 [3:08:05<00:00, 13.54it/s]


In [20]:
total = 0
temp = {}
final = []
for turn in data_final:
    flag = True
    for idx,item in enumerate(turn):
        if item['question'] in temp.keys() and idx != 0:
            flag = False
            continue
        temp[item['question']] = total
        total += 1
    if flag:
        final.append(turn)

In [21]:
import numpy as np 
for item in final:
    temp = []
    for iitem in item:
        temp.append(iitem['sim'][::-1][:-1] + [0]*(5-len(iitem['sim'][:-1])))
    temp = (np.array(temp).T + np.array(temp) + np.eye(5)).tolist()
    for idx,iitem in enumerate(item):
        iitem['final_sim'] = temp[idx]

In [28]:
label_dict = {'':0}
for item in final:
    for iitem in item:
        for col in iitem['col_label'][0]:
            if col not in label_dict.keys():
                label_dict[col] = 1
            else:
                label_dict[col] += 1
count = 0
thelabel = []
finallabel = {}
convert = {}
index = 0
label_dict = dict(sorted(label_dict.items(), key=lambda x: x[1], reverse=True))
for k,v in label_dict.items():
    count += 1
    thelabel.append(k)
    finallabel[k] = v
    convert[k] = index
    index += 1
    if count == 383:
        break
theall = []
for item in final:
    flag = True
    for iitem in item:
        for col in iitem['col_label'][0]:
            if col not in thelabel:
                flag = False
    if flag:
        theall.append(item)

In [29]:
len(theall)

64609

In [32]:
count

323

In [33]:
convert

{'': 0,
 'SELECT COUNT': 1,
 'GROUP_BY': 2,
 'SELECT': 3,
 'GROUP_BY ORDER_BY ASC LIMIT': 4,
 'GROUP_BY ORDER_BY DESC LIMIT': 5,
 'WHERE OR =': 6,
 'ORDER_BY COUNT DESC LIMIT': 7,
 'WHERE =': 8,
 'SELECT WHERE >=': 9,
 'SELECT COUNT HAVING COUNT >=': 10,
 'SELECT GROUP_BY': 11,
 'SELECT GROUP_BY ORDER_BY DESC LIMIT': 12,
 'SELECT WHERE =': 13,
 'SELECT AVG': 14,
 'WHERE >=': 15,
 'HAVING COUNT >': 16,
 'SELECT COUNT GROUP_BY': 17,
 'SELECT SUM GROUP_BY': 18,
 'SELECT MIN': 19,
 'SELECT MIN SELECT': 20,
 'SELECT SELECT COUNT': 21,
 'WHERE LIKE': 22,
 'SELECT EXCEPT SELECT': 23,
 'EXCEPT WHERE =': 24,
 'ORDER_BY DESC LIMIT': 25,
 'SELECT ORDER_BY DESC': 26,
 'SELECT ORDER_BY DESC LIMIT': 27,
 'WHERE IN': 28,
 'OP_SEL SELECT': 29,
 'WHERE !=': 30,
 'SELECT ORDER_BY COUNT DESC LIMIT': 31,
 'HAVING COUNT >=': 32,
 'SELECT COUNT ORDER_BY DESC LIMIT': 33,
 'EXCEPT SELECT': 34,
 'SELECT COUNT ORDER_BY COUNT DESC LIMIT': 35,
 'SELECT SELECT': 36,
 'ORDER_BY ASC LIMIT': 37,
 'SELECT ORDER_BY ASC

In [41]:
def get_position_ids(ex, shuffle=True):
    # cluster columns with their corresponding table and randomly shuffle tables and columns
    # [CLS] q1 q2 ... [SEP] * t1 c1 c2 c3 t2 c4 c5 ... [SEP]
    db, table_word_len, column_word_len = ex['db'], ex['table_word_len'], ex['column_word_len']
    table_num, column_num = len(db['table_names']), len(db['column_names'])
    question_position_id = list(range(len(ex['question_id'])))
    start = len(question_position_id)
    table_position_id, column_position_id = [None] * table_num, [None] * column_num
    column_position_id[0] = list(range(start, start + column_word_len[0]))
    start += column_word_len[0] # special symbol * first
    table_idxs = list(range(table_num))
    if shuffle:
        random.shuffle(table_idxs)
    for idx in table_idxs:
        col_idxs = db['table2columns'][idx]
        table_position_id[idx] = list(range(start, start + table_word_len[idx]))
        start += table_word_len[idx]
        if shuffle:
            random.shuffle(col_idxs)
        for col_id in col_idxs:
            column_position_id[col_id] = list(range(start, start + column_word_len[col_id]))
            start += column_word_len[col_id]
    position_id = question_position_id + list(chain.from_iterable(table_position_id)) + \
        list(chain.from_iterable(column_position_id)) + [start]
    assert len(position_id) == len(ex['input_id'])
    return position_id

In [43]:
from transformers import AutoTokenizer
from itertools import chain
import numpy as np
t = AutoTokenizer.from_pretrained("google/electra-large-discriminator")
def get_input(item, db, t):
    s = re.sub('([.,!?()])', r' \1 ', item['question'])
    s = re.sub('\s{2,}', ' ', s)
    input_token = '[CLS]'
    question = [q.lower() for q in s.split()]
    question = ['[CLS]' if q == '[cls]' else q for q in question]
    sim_label = item['final_sim']
    sim_mask = [True if item != -1 else False for item in item['final_sim']]
    rtd_label = []
    ssf_label = []
    context_label = []
    question_id = [t.cls_token_id] # map token to id
    question_mask_plm = []
    question_subword_len = []
    input_token += ' ' + ' '.join(question)
    for w in question:
        toks = t.convert_tokens_to_ids(t.tokenize(w))
        question_id.extend(toks)
        question_subword_len.append(len(toks))
    question_mask_plm = [0] + [1] * (len(question_id) - 1) + [0]


    question_id.append(t.sep_token_id)
    input_token += ' [SEP]'
    rtd_label = [0]*len(question_id)

    table = [['table'] + t.lower().split() for t in db['table_names']]
    table_id, table_mask_plm, table_subword_len = [], [], []
    table_word_len = []
    for tab_ids, s in enumerate(table):
        l = 0
        for w in s:
            input_token += ' ' + w
            toks = t.convert_tokens_to_ids(t.tokenize(w))
            table_id.extend(toks)
            table_subword_len.append(len(toks))
            l += len(toks)
        if len(item['table_label'][0][tab_ids]) == 0:
            rtd_label.extend([1]*l)
        else:
            rtd_label.extend([0]*l)
        table_word_len.append(l)
    table_mask_plm = [1] * len(table_id)

    if len(item['col_label']) == 1:
        column = [[db['column_types'][idx].lower()] + c.lower().split() for idx, (_, c, _) in enumerate(db['column_names'])]
    else:
        column = [[db['column_types'][idx].lower()] + c.lower().split() + item['col_label'][1][idx].split() for idx, (_, c, _) in enumerate(db['column_names'])]
    col_index = 0
    column_id, column_mask_plm, column_subword_len = [], [], []
    column_word_len = []
    for col_ids, s in enumerate(column):
        l = 0
        ssf_label.append(convert[item['col_label'][0][col_ids]])
        for w in s:
            input_token += ' ' + w
            toks = t.convert_tokens_to_ids(t.tokenize(w))
            column_id.extend(toks)
            column_subword_len.append(len(toks))
            l += len(toks)
        if len(item['col_label'][0][col_ids]) == 0:
            rtd_label.extend([1]*l)
        else:
            rtd_label.extend([0]*l)
        column_word_len.append(l)

    rtd_label.append(0)
    column_mask_plm = [1] * len(column_id) + [0]
    
    input_token += ' [SEP]'
    column_id.append(t.sep_token_id)
    
    question_mask_plm = question_mask_plm + [0] * (len(table_id) + len(column_id))
    table_mask_plm = [0] * len(question_id) + table_mask_plm + [0] * len(column_id)
    column_mask_plm = [0] * (len(question_id) + len(table_id)) + column_mask_plm

    input_id = question_id + table_id + column_id
    
   
    
    assert len(input_id) == len(question_mask_plm) == len(table_mask_plm) == len(column_mask_plm) == len(rtd_label)
    assert len(ssf_label) == len(item['col_label'][0])
    assert len(sim_label) == 5
    result = {}
    result['db_id'] = item['db_id']
    result['input_id'] = input_id
    result['query'] = item['query']
    result['sql'] = item['sql']
    result['content'] = item['content']
    result['db_id'] = item['db_id']
    result['sim_label'] = sim_label
    result['sim_mask'] = sim_mask
    result['ssf_label'] = ssf_label
    result['question_mask_plm'] = question_mask_plm
    result['rtd_label'] = rtd_label
    temp = {}
    temp['db'] = db
    temp['table_word_len'] = table_word_len
    temp['column_word_len'] = column_word_len
    temp['question_id'] = question_id
    temp['input_id'] = input_id

    result['column_mask_plm'] = column_mask_plm
    result['column_word_len'] = column_word_len
    result['position_ids'] = get_position_ids(temp, shuffle=True)
    result['input_token'] = input_token
    result['db_id'] = item['db_id']
    return result

Downloading:   0%|          | 0.00/27.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/668 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

In [44]:
from tqdm import tqdm
alltask = []
for item in tqdm(theall):
    temp = []
    for iitem in item:
        one = get_input(iitem,db[iitem['db_id']],t)
        temp.append(one)
    alltask.append(temp)

100%|██████████| 64609/64609 [46:16<00:00, 23.27it/s]  


In [45]:
from tqdm import tqdm
with open('final_data/electra_all.json','w') as f:
    json.dump(alltask,f)

In [46]:
final = []
for item in alltask:
    flag = True
    for iitem in item:
        if len(iitem['input_id']) > 256:
            flag = False
    if flag:
        final.append(item)

In [47]:
db_data = {}
for item in final:
    if item[0]['db_id'] not in db_data.keys():
        db_data[item[0]['db_id']] = [item]
    else:
        db_data[item[0]['db_id']].append(item)
keys = []
for key in db_data.keys():
    keys.append(key)
weight = [len(v) for k, v in db_data.items()]

In [48]:
final = []
for i in range(12000):
    thekeys = random.choices(keys,weights=weight,k=8)
    batch = []
    for key in thekeys:
        temp = random.sample(db_data[key],1)[0]
        batch.append(temp)
    for item in batch:
        for iitem in item:
            final.append(iitem)

In [49]:
from tqdm import tqdm
with open('final_data/alltask_final.txt','w') as f:
    for item in tqdm(final):
        f.write(json.dumps(item))
        f.write('\n')

100%|██████████| 480000/480000 [00:45<00:00, 10586.93it/s]


In [50]:
len(final)

480000