In [1]:
import json

queries = []
labels = []
sqls = []
db_ids = []
with open("sparc/train.json","r") as f:
    interactions = json.load(f)
    for interaction in interactions:
        db_id = interaction["database_id"]
        turns = interaction["interaction"]
        if len(turns)>1:
            for i in range(1,len(turns)):
                queries.append(" <s> ".join([ele["utterance"] for ele in turns[0:i]]))
                sqls.append(turns[i]["sql"])
                labels.append(turns[i]["utterance"])
                db_ids.append(db_id)


In [2]:
import attr
import torch
from seq2struct.utils import registry

@attr.s
class PreprocessConfig:
    config = attr.ib()
    config_args = attr.ib()


class Preprocessor:
    def __init__(self, config):
        self.config = config
        self.model_preproc = registry.instantiate(
            registry.lookup('model', config['model']).Preproc,
            config['model'])
@attr.s
class InferConfig:
    config = attr.ib()
    config_args = attr.ib()
    logdir = attr.ib()
    section = attr.ib()
    beam_size = attr.ib()
    output = attr.ib()
    step = attr.ib()
    use_heuristic = attr.ib(default=False)
    mode = attr.ib(default="infer")
    limit = attr.ib(default=None)
    output_history = attr.ib(default=False)


class Inferer:
    def __init__(self, config):
        self.config = config
        if torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
            torch.set_num_threads(1)

        # 0. Construct preprocessors
        self.model_preproc = registry.instantiate(
            registry.lookup('model', config['model']).Preproc,
            config['model'])
        self.model_preproc.load()

    def load_model(self, logdir, step):
        '''Load a model (identified by the config used for construction) and return it'''
        # 1. Construct model
        model = registry.construct('model', self.config['model'], preproc=self.model_preproc, device=self.device)
        model.to(self.device)
        model.eval()
        model.visualize_flag = False

        # 2. Restore its parameters
        saver = saver_mod.Saver({"model": model})
        last_step = saver.restore(logdir, step=step, map_location=self.device, item_keys=["model"])

        if not last_step:
            raise Exception('Attempting to infer on untrained model')
        return model




In [3]:
import _jsonnet
from seq2struct import datasets
from seq2struct import models
from seq2struct.utils import registry
from seq2struct.utils import vocab

exp_config = json.loads(_jsonnet.evaluate_file("experiments/sparc-configs/gap-run.jsonnet"))
model_config_file = exp_config["model_config"]
model_config_args = json.dumps(exp_config["model_config_args"])
preprocess_config = PreprocessConfig(model_config_file, model_config_args)

config = json.loads(_jsonnet.evaluate_file(preprocess_config.config, tla_codes={'args': preprocess_config.config_args}))

preprocessor = Preprocessor(config)

data = registry.construct('dataset', config['data']["train"])

# test = preprocessor.model_preproc.dec_preproc.grammar.parse(sqls[0],"train")




===========================

In [4]:
import sys
import os
from seq2struct import beam_search
from seq2struct import datasets
from seq2struct import models
from seq2struct import optimizers
from seq2struct.utils import registry
from seq2struct.utils import saver as saver_mod

from seq2struct.models.spider import spider_beam_search

exp_config = json.loads(_jsonnet.evaluate_file("experiments/sparc-configs/gap-run.jsonnet"))
model_config_file = exp_config["model_config"]
model_config_args = json.dumps(exp_config["model_config_args"])

infer_output_path = "{}/{}-step{}.infer".format(
                exp_config["eval_output"],
                exp_config["eval_name"],
                38100)

infer_config = InferConfig(
                model_config_file,
                model_config_args,
                exp_config["logdir"],
                exp_config["eval_section"],
                exp_config["eval_beam_size"],
                infer_output_path,
                38100,
                use_heuristic=exp_config["eval_use_heuristic"]
            )

if infer_config.config_args:
    config = json.loads(_jsonnet.evaluate_file(infer_config.config, tla_codes={'args': infer_config.config_args}))
else:
    config = json.loads(_jsonnet.evaluate_file(infer_config.config))

if 'model_name' in config:
    infer_config.logdir = os.path.join(infer_config.logdir, config['model_name'])

output_path = infer_config.output.replace('__LOGDIR__', infer_config.logdir)


inferer = Inferer(config)
model = inferer.load_model(infer_config.logdir, infer_config.step)


file data/sparc-bart-final/sparc,nl2code-1115,output_from=true,fs=2,emb=bart,cvlink/enc/config.json not found


data/sparc-bart-final/sparc,nl2code-1115,output_from=true,fs=2,emb=bart,cvlink/enc
Parameter containing:
tensor([[-0.0370,  0.1117,  0.1829,  ...,  0.2054,  0.0578, -0.0750],
        [ 0.0055, -0.0049, -0.0069,  ..., -0.0030,  0.0038,  0.0087],
        [-0.0448,  0.4604, -0.0604,  ...,  0.1073,  0.0310,  0.0477],
        ...,
        [-0.0138,  0.0278, -0.0467,  ...,  0.0455, -0.0265,  0.0125],
        [-0.0043,  0.0153, -0.0567,  ...,  0.0496,  0.0108, -0.0099],
        [ 0.0053,  0.0324, -0.0179,  ..., -0.0085,  0.0223, -0.0020]],
       requires_grad=True)
Updated the model with ./pretrained_checkpoint/pytorch_model.bin
Parameter containing:
tensor([[-3.8313e-02,  1.2050e-01,  1.7760e-01,  ...,  1.9729e-01,
          5.9443e-02, -6.9929e-02],
        [ 4.5650e-03, -2.3032e-03, -8.4326e-03,  ..., -3.5686e-03,
          4.7121e-03,  8.4110e-03],
        [-4.5997e-02,  4.6710e-01, -6.5000e-02,  ...,  1.0271e-01,
          2.5631e-02,  4.7501e-02],
        ...,
        [-2.5866e-03, -3.

In [5]:
# 重写sqls
class TTT:
    def __init__(self,schema):
        self.schema = schema
rewrite_sqls = []
for i in range(len(sqls)):
    root = preprocessor.model_preproc.dec_preproc.grammar.parse(sqls[i],"train")
    ttt = TTT(data.schemas[db_ids[i]])
    r_sql = model.decoder.preproc.grammar.unparse(root, ttt)
    rewrite_sqls.append(r_sql)


In [37]:
with open("coco/sparc/train.json","w") as f:
    for i in range(len(sqls)):
        x = queries[i]+" </s> "+rewrite_sqls[i]
        label = labels[i]
        f.write(json.dumps({"x":x,"label":label})+"\n")

==============
数据增强





In [6]:
import copy
from random import choice
import sqlite3
from tqdm import tqdm

def do_change_select_agg(sql,schema):
    is_change = False
    selects = sql["select"][1]
    is_distinct = sql["select"][0]
    for select in selects:
        val_unit = select[1]
        col_unit = val_unit[1]
        if not is_distinct:
            if col_unit[2]==False:
                #no distinct
                agg = col_unit[0]
                column_id = col_unit[1]
                #=0, aka *, skip
                if column_id!=0:
                    tp = schema.columns[column_id].type
                    if tp=='text':
                        agg_candidates =[0,3]
                        if agg in agg_candidates:
                            agg_candidates.remove(agg)
                        new_agg = choice(agg_candidates)
                        col_unit[0] = new_agg
                        is_change = True
    return is_change

def do_change_select_column(sql,schema):
    is_change = False
    selects = sql["select"][1]
    is_distinct = sql["select"][0]
    for select in selects:
        val_unit = select[1]
        col_unit = val_unit[1]

        if (not is_distinct) and col_unit[2]==False:
            #no distinct
            column_id = col_unit[1]

            if column_id!=0:
                tp = schema.columns[column_id].type
                to_replaces = [ele.id for ele in schema.columns[column_id].table.columns if tp==ele.type]
                if column_id in to_replaces:
                    to_replaces.remove(column_id)
                if 0 in to_replaces:
                    to_replaces.remove(0)

                if len(to_replaces)>0:
                    to_replace = choice(to_replaces)
                    col_unit[1]=to_replace
                    is_change = True
            else:
                table_id = None
                for table_unit in sql["from"]["table_units"]:
                    if table_unit[0] =="table_unit":
                        table_id = table_unit[1]
                if table_id is not None:
                    to_replaces = [ele.id for ele in schema.tables[table_id].columns if "text"==ele.type]
                    if 0 in to_replaces:
                        to_replaces.remove(0)
                    if len(to_replaces)>0:
                        to_replace = choice(to_replaces)
                        col_unit[1]=to_replace
                        is_change = True
    return is_change

def do_change_where_column(sql,schema):
    #建立sqlite链接
    conn = sqlite3.connect('sparc/database/{0}/{0}.sqlite'.format(schema.db_id))
    cursor = conn.cursor()
    is_change = False
    wheres = sql["where"]
    for where in wheres:
        if isinstance(where,list):
            cond_unit = where
            column_id = cond_unit[2][1][1]
            tp = schema.columns[column_id].type
            if tp == "number":
                to_replaces = [ele.id for ele in schema.columns[column_id].table.columns if ele.type==tp]
                if 0 in to_replaces:
                    to_replaces.remove(0)
                    if len(to_replaces)>0:
                        cond_unit[2][1][1] = choice(to_replaces)
                        is_change = True
            if tp == "text":
                to_replaces = [ele.id for ele in schema.columns[column_id].table.columns if ele.type==tp]
                if column_id in to_replaces:
                    to_replaces.remove(column_id)
                if 0 in to_replaces:
                    to_replaces.remove(0)
                if len(to_replaces)>0:
                    to_replace = choice(to_replaces)
                    #随机选取
                    try:
                        cursor.execute("select {} from {} ORDER BY RANDOM() limit 2".format(schema.columns[to_replace].orig_name,schema.columns[to_replace].table.orig_name))
                    except:
                        return False
                    c_result = cursor.fetchall()
                    vals = [ele[0] for ele in c_result]
                    if vals is not None and len(vals) > 0:
                        if not isinstance(cond_unit[3],dict):
                            orig_val = cond_unit[3]
                            if orig_val is None:
                                return False
                            if isinstance(orig_val,str):
                                if len(orig_val)>0 and orig_val[0] == "\"":
                                    orig_val = orig_val[1:]
                                if len(orig_val)>0 and orig_val[-1] == "\"":
                                    orig_val = orig_val[:-1]
                            if orig_val in vals:
                                vals = vals.remove(orig_val)
                            if vals is not None and len(vals)>0:
                                v_to_replace = choice(vals)
                                cond_unit[2][1][1] = to_replace
                                cond_unit[3] = v_to_replace
                                is_change=True
    return is_change

rewrite_sqls = []
for i in tqdm(range(len(sqls))):
    _cur = {}

    root = preprocessor.model_preproc.dec_preproc.grammar.parse(sqls[i],"dev")
    ttt = TTT(data.schemas[db_ids[i]])
    gold_sql = model.decoder.preproc.grammar.unparse(root, ttt)
    _cur["gold"] = gold_sql
    # _cur["select_agg"] = []
    # _cur["select_agg_ast"] = []
    _cur["select_column"] = []
    _cur["select_column_ast"] = []
    _cur["where_column"] = []
    _cur["where_column_ast"] = []


    #5次select agg changed
    # for iii in range(5):
    #     changed_sql = copy.deepcopy(sqls[i])
    #     is_changed = do_change_select_agg(changed_sql,data.schemas[db_ids[i]])
    #     if is_changed:
    #         root = preprocessor.model_preproc.dec_preproc.grammar.parse(changed_sql,"dev")
    #         ttt = TTT(data.schemas[db_ids[i]])
    #         r_sql = model.decoder.preproc.grammar.unparse(root, ttt)
    #         _cur["select_agg"].append(r_sql)
    #         _cur["select_agg_ast"].append(changed_sql)
    for iii in range(5):
        changed_sql = copy.deepcopy(sqls[i])
        is_changed = do_change_select_column(changed_sql,data.schemas[db_ids[i]])
        if is_changed:
            root = preprocessor.model_preproc.dec_preproc.grammar.parse(changed_sql,"dev")
            ttt = TTT(data.schemas[db_ids[i]])
            r_sql = model.decoder.preproc.grammar.unparse(root, ttt)
            _cur["select_column"].append(r_sql)
            _cur["select_column_ast"].append(changed_sql)
    for iii in range(10):
        changed_sql = copy.deepcopy(sqls[i])
        is_changed = do_change_where_column(changed_sql,data.schemas[db_ids[i]])
        if is_changed:
            try:
                root = preprocessor.model_preproc.dec_preproc.grammar.parse(changed_sql,"dev")
                ttt = TTT(data.schemas[db_ids[i]])
                r_sql = model.decoder.preproc.grammar.unparse(root, ttt)
                _cur["where_column"].append(r_sql)
                _cur["where_column_ast"].append(changed_sql)
            except:
                ignored = 0
    rewrite_sqls.append(_cur)




100%|██████████| 5995/5995 [00:25<00:00, 236.78it/s]


In [7]:
with open("coco/sparc/train_aug.json","w") as f:
    for i in range(len(sqls)):
        record = {
            "db_id":db_ids[i],
            "query":queries[i],
            "gold_sql":rewrite_sqls[i]["gold"],
            "rewrite_sqls":{
                # "select_agg" : rewrite_sqls[i]["select_agg"],
                # "select_agg_ast" : rewrite_sqls[i]["select_agg_ast"],
                "select_column" : rewrite_sqls[i]["select_column"],
                "select_column_ast" : rewrite_sqls[i]["select_column_ast"],
                "where_column" : rewrite_sqls[i]["where_column"],
                "where_column_ast" : rewrite_sqls[i]["where_column_ast"]
            }
        }
        f.write(json.dumps(record)+"\n")
