In [1]:
#import
import numpy as np
import torch.utils.data
import torch.cuda
import json
from transformers import BertTokenizer,BertModel
from torch import nn
from sklearn import metrics
from torch.optim.adam import Adam
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from typing import List



In [2]:
#util
def read_train_datas(path, question_length):
    """
    :param path 数据路径
    :param question_length 问题长度
    :return: [[question, agg, conn_op, cond_ops, cond_vals],...], cond_vals:[[val_start_idx,val_end_idx],...]
    """
    column_length = len(get_columns())
    with open(path, 'r', encoding='utf-8') as f:
        data_list = []
        for line in f:
            item = json.loads(line)
            # question
            question = item['question']
            # agg
            sel = item['sql']['sel']
            agg_op = item['sql']['agg']
            agg = [get_agg_dict()['none']] * column_length
            for i in range(len(sel)):
                sel_col_item = sel[i]
                agg_op_item = agg_op[i]
                agg[sel_col_item] = agg_op_item
            # conn_op
            conn_op = item['sql']['cond_conn_op']
            # cond_ops & cond_vals
            cond_ops = [get_cond_op_dict()['none']] * column_length
            cond_vals = [0] * question_length
            if item['sql'].get('conds') is not None:
                conds = item['sql']['conds']
                for i, cond in enumerate(conds):
                    cond_col_item = cond[0]
                    cond_op_item = cond[1]
                    cond_ops[cond_col_item] = cond_op_item
                    value = cond[2]
                    cond_vals = fill_value_start_end(cond_vals, question, value)
            data_list.append([question, agg, conn_op, cond_ops, cond_vals])
    return data_list


def read_predict_datas(path):
    """
    :param path: 预测数据路径
    :return: 预测数据
    """
    questions = []
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            item = json.loads(line)
            question = item['question']
            questions.append([question])
    return questions


def fill_value_start_end(cond_vals, question, value):
    """
    fill [1] by the value in the question
    """
    question_length = len(question)
    value_length = len(value)
    for i in range(question_length - value_length + 1):
        if question[i:value_length + i] == value:
            cond_vals[i: value_length + i] = [1] * value_length
    return cond_vals


def get_columns():
    columns = ['基金代码', '基金名称', '成立时间', '基金类型', '基金规模', '销售状态', '是否可销售', '风险等级',
               '基金公司名称', '分红方式',
               '赎回状态', '是否支持定投', '净值同步日期', '净值', '成立以来涨跌幅', '昨日涨跌幅', '近一周涨跌幅',
               '近一个月涨跌幅', '近三个月涨跌幅', '近六个月涨跌幅',
               '近一年涨跌幅', '基金经理', '主题/概念', '一个月夏普率', '一年夏普率', '三个月夏普率', '六个月夏普率',
               '成立以来夏普率', '投资市场', '板块', '行业',
               '晨星三年评级', '管理费率', '销售服务费率', '托管费率', '认购费率', '申购费率', '赎回费率', '分红年度',
               '权益登记日',
               '除息日', '派息日', '红利再投日', '每十份收益单位派息', '主投资产类型', '基金投资风格描述', '估值',
               '是否主动管理型基金', '投资', '跟踪指数',
               '是否新发', '重仓', '无']
    return columns


def get_cond_op_dict():
    cond_op_dict = {'>': 0, '<': 1, '==': 2, '!=': 3, 'like': 4, '>=': 5, '<=': 6, 'none': 7}
    return cond_op_dict


def get_conn_op_dict():
    conn_op_dict = {'none': 0, 'and': 1, 'or': 2}
    return conn_op_dict


def get_agg_dict():
    agg_dict = {'': 0, 'AVG': 1, 'MAX': 2, 'MIN': 3, 'COUNT': 4, 'SUM': 5, 'none': 6}
    return agg_dict


def get_key(dict, value):
    """
    根据字典的value获取key
    :param dict: 字典
    :param value: 值
    :return: key
    """
    return [k for k, v in dict.items() if v == value]


def get_values_name(question, cond_vals):
    """
    cond_vals的值如[0,1,1,1,1,0,0,0,1,1,1,0,0,0]所示
    根据cond_vals中为1的值找到question对应下标的内容
    返回找到的内容列表，连续为1的内容作为返回列表的一个元素
    """
    question = question
    result = []
    cur_start_idx = 0
    valid = False

    for idx, val in enumerate(cond_vals):
        if val == 1:
            if not valid:
                cur_start_idx = idx
                valid = True
        else:
            if valid:
                valid = False
                if idx > cur_start_idx:
                    vals = question[cur_start_idx:idx]
                    result.append(vals)

    return result

In [3]:
#dataset
# label
class Label(object):
    def __init__(self, label_agg: List = None, label_conn_op=None, label_cond_ops: List = None,
                 label_cond_vals: List = None):
        """
        训练标签信息
        :param label_agg: 聚合函数
        :param label_conn_op: 连接操作符
        :param label_cond_ops: 条件操作符
        :param label_cond_vals: 条件值
        """
        self.label_agg = label_agg
        self.label_conn_op = label_conn_op
        self.label_cond_ops = label_cond_ops
        self.label_cond_vals = label_cond_vals


class InputFeatures(object):
    def __init__(self, model_path=None, question_length=128, max_length=512, input_ids=None, attention_mask=None,
                 token_type_ids=None, cls_idx=None, label: Label = None):
        if model_path is not None:
            self.tokenizer = BertTokenizer.from_pretrained(model_path)
        self.question_length = question_length
        self.max_length = max_length
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.token_type_ids = token_type_ids
        self.cls_idx = cls_idx
        self.label = label

    def encode_expression(self, expressions: List):
        """
        表达式编码
        :param expressions: 表达式（列名或条件表达式）
        :return: 编码后的列，及序列号（用于列与列之间的区分）
        """
        encodings = self.tokenizer.batch_encode_plus(expressions)
        expressions_encode = encodings["input_ids"]
        segment_ids = encodings["token_type_ids"]
        segment_ids = [[elem if j % 2 == 0 else 1 for elem in row] for j, row in enumerate(segment_ids)]
        expressions_encode = [item for sublist in expressions_encode for item in sublist]
        segment_ids = [item for sublist in segment_ids for item in sublist]
        return torch.tensor(expressions_encode), torch.tensor(segment_ids)

    def get_cls_idx(self, expressions):
        """
        获取表达式标记符的位置
        :param expressions: 表达式
        :return:
        """
        cls_idx = []
        start = self.question_length
        for i in range(len(expressions)):
            cls_idx.append(int(start))
            # 加上特殊标记的长度（例如 [CLS] 和 [SEP]）
            start += len(expressions[i]) + 2
        return cls_idx

    def encode_question_with_expressions(self, que_length, max_length, question, expressions_encode,
                                         expressions_segment_id):
        """
        编码
        :param que_length: 问题长度
        :param max_length: text长度
        :param question:  问题
        :param expressions_encode:  编码的列
        :param expressions_segment_id 编码的列的序列
        :return: 编码后的text
        """

        # 编码问题，需要填充，否则会出现长度不一致异常
        question_encoding = self.tokenizer.encode(question, add_special_tokens=True, padding='max_length',
                                                  max_length=que_length, truncation=True)

        # 合并编码后的张量，保证张量类型(dtype)为int或long, bert的embedding的要求
        input_ids = torch.cat([torch.tensor(question_encoding), expressions_encode], dim=0)
        token_type_ids = torch.cat([torch.zeros(que_length, dtype=torch.long), expressions_segment_id], dim=0)
        padding_length = max_length - len(input_ids)
        attention_mask = torch.cat([torch.ones(len(input_ids)), torch.zeros(padding_length)], dim=0)
        input_ids = torch.cat([input_ids, torch.zeros(padding_length, dtype=torch.long)], dim=0)
        token_type_ids = torch.cat([token_type_ids, torch.zeros(padding_length, dtype=torch.long)], dim=0)

        return input_ids, attention_mask, token_type_ids

    def list_features(self, datas):
        """
        输入特征
        :param datas: 数据
        :return: 特征信息
        """
        list_features = []
        columns = get_columns()
        cls_idx = self.get_cls_idx(columns)
        expressions_encode, expressions_segment_id = self.encode_expression(columns)
        for data in datas:
            question = data[0]
            # if contain label data
            label = None
            if len(data) > 1:
                label = Label(label_agg=data[1], label_conn_op=data[2], label_cond_ops=data[3], label_cond_vals=data[4])
            # 编码(question+expressions)
            input_ids, attention_mask, token_type_ids = self.encode_question_with_expressions(self.question_length,
                                                                                              self.max_length,
                                                                                              question,
                                                                                              expressions_encode,
                                                                                              expressions_segment_id)
            list_features.append(
                InputFeatures(question_length=self.question_length, max_length=self.max_length, input_ids=input_ids,
                              attention_mask=attention_mask, token_type_ids=token_type_ids, cls_idx=cls_idx,
                              label=label))
        return list_features


class Dataset(torch.utils.data.Dataset):
    def __init__(self, features: List[InputFeatures]):
        self.features = features

    def __len__(self):
        return len(self.features)

    def __getitem__(self, item):
        feature = self.features[item]
        input_ids = np.array(feature.input_ids)
        attention_mask = np.array(feature.attention_mask)
        token_type_ids = np.array(feature.token_type_ids)
        cls_idx = np.array(feature.cls_idx)
        if feature.label is not None:
            label: Label = feature.label
            label_agg = np.array(label.label_agg)
            label_conn_op = np.array(label.label_conn_op)
            label_cond_ops = np.array(label.label_cond_ops)
            label_cond_vals = np.array([np.array(val) for val in label.label_cond_vals])
            return input_ids, attention_mask, token_type_ids, cls_idx, label_agg, label_conn_op, label_cond_ops, label_cond_vals
        else:
            return input_ids, attention_mask, token_type_ids, cls_idx



In [4]:
#model
class ColClassifierModel(nn.Module):
    def __init__(self, model_path, hidden_size, agg_length, conn_op_length, cond_op_length, dropout=0.5):
        super(ColClassifierModel, self).__init__()
        self.bert = BertModel.from_pretrained(model_path)
        self.dropout = nn.Dropout(dropout)
        # out classes需要纬度必须大于label中size(classes)，否则会出现Assertion `t >= 0 && t < n_classes` failed.
        self.agg_classifier = nn.Linear(hidden_size, agg_length)
        self.cond_ops_classifier = nn.Linear(hidden_size, cond_op_length)
        self.conn_op_classifier = nn.Linear(hidden_size, conn_op_length)

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, cls_idx=None):
        # 输出最后一层隐藏状态以及池化层
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        dropout_output = self.dropout(outputs.pooler_output)
        dropout_hidden_state = self.dropout(outputs.last_hidden_state)

        """
        提取列特征信息，从dim=1即第二维中（列标记符号索引所在纬度）提取dropout_hidden_state对应该纬度的信息。
        前提需要将cls_idx张量shape扩展成与dropout_hidden_state一致
        """
        # cls_cols = dropout_hidden_state.gather(dim=1, index=cls_idx.unsqueeze(-1).expand(
        #     dropout_hidden_state.shape[0], -1, dropout_hidden_state.shape[-1]))
        # 简化写法
        cls_cols = dropout_hidden_state[:, cls_idx[0], :]

        out_agg = self.agg_classifier(cls_cols)
        out_cond_ops = self.cond_ops_classifier(cls_cols)

        out_conn_op = self.conn_op_classifier(dropout_output)

        return out_agg, out_cond_ops, out_conn_op


class ValueClassifierModel(nn.Module):
    def __init__(self, model_path, hidden_size, question_length, cond_value_length=2, dropout=0.5):
        super(ValueClassifierModel, self).__init__()
        self.bert = BertModel.from_pretrained(model_path)
        self.dropout = nn.Dropout(dropout)
        self.cond_vals_classifier = nn.Linear(hidden_size, cond_value_length)
        self.question_length = question_length

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None):
        # 输出最后一层隐藏状态
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        hidden_state = outputs.last_hidden_state

        # 提取问题特征信息
        cond_values = hidden_state[:, 1:self.question_length + 1, :]

        out_cond_vals = self.cond_vals_classifier(cond_values)

        return out_cond_vals


In [None]:
#train
def train(model: ColClassifierModel or ValueClassifierModel, model_save_path, train_dataset: Dataset,
          val_dataset: Dataset, batch_size, lr, epochs):
    # DataLoader根据batch_size获取数据，训练时选择打乱样本
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    # 是否使用gpu
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optim = Adam(model.parameters(), lr=lr)
    if use_cuda:
        model = model.to(device)
        criterion = criterion.to(device)
    best_val_avg_acc = 0
    for epoch in range(epochs):
        total_loss_train = 0
        model.train()
        # 训练进度
        for input_ids, attention_mask, token_type_ids, cls_idx, label_agg, label_conn_op, label_cond_ops, label_cond_vals in tqdm(
                train_loader):
            # model要求输入的矩阵(hidden_size,sequence_size),需要把第二纬度去除.squeeze(1)
            input_ids = input_ids.squeeze(1).to(device)
            attention_mask = attention_mask.to(device)
            token_type_ids = token_type_ids.squeeze(1).to(device)
            if type(model) is ColClassifierModel:
                # reshape(-1)合并一二纬度
                label_agg = label_agg.to(device).reshape(-1)
                label_conn_op = label_conn_op.to(device)
                label_cond_ops = label_cond_ops.to(device).reshape(-1)
                # 模型输出
                out_agg, out_cond_ops, out_conn_op = model(input_ids, attention_mask, token_type_ids, cls_idx)
                out_agg = out_agg.to(device).reshape(-1, out_agg.size(2))
                out_cond_ops = out_cond_ops.to(device).reshape(-1, out_cond_ops.size(2))
                out_conn_op = out_conn_op.to(device)
                # 计算损失
                loss_agg = criterion(out_agg, label_agg)
                loss_conn_op = criterion(out_conn_op, label_conn_op)
                loss_cond_ops = criterion(out_cond_ops, label_cond_ops)
                # 损失比例
                total_loss_train = loss_agg + loss_conn_op + loss_cond_ops

            if type(model) is ValueClassifierModel:
                label_cond_vals = label_cond_vals.to(device).reshape(-1)
                # 模型输出
                out_cond_vals = model(input_ids, attention_mask, token_type_ids)
                # 计算损失
                out_cond_vals = out_cond_vals.reshape(-1, out_cond_vals.size(2))
                lost_cond_vals = criterion(out_cond_vals, label_cond_vals)
                total_loss_train = lost_cond_vals

            # 模型更新
            model.zero_grad()
            optim.zero_grad()
            total_loss_train.backward()
            optim.step()
        # 模型验证
        val_avg_acc = 0
        out_all_agg = []
        out_all_conn_op = []
        out_all_cond_ops = []
        out_all_cond_vals = []
        label_all_agg = []
        label_all_conn_op = []
        label_all_cond_ops = []
        label_all_cond_vals = []
        # 验证无需梯度计算
        model.eval()
        with torch.no_grad():
            # 使用当前epoch训练好的模型验证
            for input_ids, attention_mask, token_type_ids, cls_idx, label_agg, label_conn_op, label_cond_ops, label_cond_vals in val_loader:
                input_ids = input_ids.squeeze(1).to(device)
                attention_mask = attention_mask.to(device)
                token_type_ids = token_type_ids.squeeze(1).to(device)
                if type(model) is ColClassifierModel:
                    label_agg = label_agg.to(device).reshape(-1)
                    label_conn_op = label_conn_op.to(device)
                    label_cond_ops = label_cond_ops.to(device).reshape(-1)
                    # 模型输出
                    out_agg, out_cond_ops, out_conn_op = model(input_ids, attention_mask, token_type_ids, cls_idx)
                    out_agg = out_agg.argmax(dim=2).to(device).reshape(-1)
                    out_cond_ops = out_cond_ops.argmax(dim=2).to(device).reshape(-1)
                    out_conn_op = out_conn_op.argmax(dim=1).to(device)
                    out_all_agg.append(out_agg.cpu().numpy())
                    out_all_conn_op.append(out_conn_op.cpu().numpy())
                    out_all_cond_ops.append(out_cond_ops.cpu().numpy())
                    label_all_agg.append(label_agg.cpu().numpy())
                    label_all_conn_op.append(label_conn_op.cpu().numpy())
                    label_all_cond_ops.append(label_cond_ops.cpu().numpy())
                if type(model) is ValueClassifierModel:
                    label_cond_vals = label_cond_vals.to(device).reshape(-1)
                    # 模型输出
                    out_cond_vals = model(input_ids, attention_mask, token_type_ids)
                    out_cond_vals = out_cond_vals.argmax(dim=2).to(device).reshape(-1)
                    out_all_cond_vals.append(out_cond_vals.cpu().numpy())
                    label_all_cond_vals.append(label_cond_vals.cpu().numpy())

        if type(model) is ColClassifierModel:
            val_agg_acc = metrics.accuracy_score(np.concatenate(out_all_agg, axis=0),
                                                 np.concatenate(label_all_agg, axis=0))
            val_conn_op_acc = metrics.accuracy_score(np.concatenate(out_all_conn_op, axis=0),
                                                     np.concatenate(label_all_conn_op, axis=0))
            val_cond_ops_acc = metrics.accuracy_score(np.concatenate(out_all_cond_ops, axis=0),
                                                      np.concatenate(label_all_cond_ops, axis=0))
            # 准确率计算逻辑
            val_avg_acc = (val_agg_acc + val_conn_op_acc + val_cond_ops_acc) / 3
        if type(model) is ValueClassifierModel:
            val_cond_vals_acc = metrics.accuracy_score(np.concatenate(out_all_cond_vals, axis=0),
                                                       np.concatenate(label_all_cond_vals, axis=0))

            val_avg_acc = val_cond_vals_acc
        # save model
        if val_avg_acc > best_val_avg_acc:
            best_val_avg_acc = val_avg_acc
            torch.save(model.state_dict(), model_save_path)
            print(f'''best model | Val Accuracy: {best_val_avg_acc: .3f}''')
        print(
            f'''Epochs: {epoch + 1} 
              | Train Loss: {total_loss_train.item(): .3f} 
              | Val Accuracy: {val_avg_acc: .3f}''')


if __name__ == '__main__':
    hidden_size = 768
    batch_size = 48
    learn_rate = 2e-5
    epochs = 3
    question_length = 128
    max_length = 512
    train_data_path = '/kaggle/input/bert-nl2sql-datas/waic_nl2sql_train.jsonl'
    pretrain_model_path = '/kaggle/input/bert-nl2sql-chinese-model-hgd'
    save_column_model_path = '/kaggle/working/classifier-column-model.pkl'
    save_value_model_path = '/kaggle/working/classifier-value-model.pkl'
    # 加载数据
    label_datas = label_datas = read_train_datas(train_data_path, question_length)
    # 提取特征数据
    col_model_features = InputFeatures(pretrain_model_path, question_length, max_length).list_features(label_datas)
    # 初始化dataset
    col_model_dateset = Dataset(col_model_features)
    # 创建模型
    col_model = ColClassifierModel(pretrain_model_path, hidden_size, len(get_agg_dict()), len(get_conn_op_dict()),
                                   len(get_cond_op_dict()))
    # 分割数据集
    total_size = len(label_datas)
    train_size = int(0.8 * total_size)
    val_size = int(0.1 * total_size)
    test_size = total_size - train_size - val_size
    # 分割数据集
    col_model_train_dataset, col_model_val_dataset, col_model_test_dataset = random_split(col_model_dateset,
                                                                                          [train_size, val_size,
                                                                                           test_size])
    print('train column model begin')
    train(col_model, save_column_model_path, col_model_train_dataset, col_model_val_dataset, batch_size, learn_rate,
          epochs)
    print('train column model finish')
    value_model_features = InputFeatures(pretrain_model_path, question_length, max_length).list_features(label_datas)
    value_model_dateset = Dataset(value_model_features)
    value_model = ValueClassifierModel(pretrain_model_path, hidden_size, question_length)
    value_model_train_dataset, value_model_val_dataset, value_model_test_dataset = random_split(value_model_dateset,
                                                                                                [train_size, val_size,
                                                                                                 test_size])
    print('train value model begin')
    train(value_model, save_value_model_path, value_model_train_dataset, value_model_val_dataset, batch_size,
          learn_rate,
          epochs)
    print('train value model finish')

In [5]:
#predict
def predict(questions, predict_result_path, pretrain_model_path, column_model_path, value_model_path, hidden_size,
            batch_size, question_length, max_length, table_name='table_name'):
    # 创建模型
    col_model = ColClassifierModel(pretrain_model_path, hidden_size, len(get_agg_dict()), len(get_conn_op_dict()),
                                   len(get_cond_op_dict()))
    value_model = ValueClassifierModel(pretrain_model_path, hidden_size, question_length)
    # 提取特征数据（不含label的数据）
    input_features = InputFeatures(pretrain_model_path, question_length, max_length).list_features(questions)
    dataset = Dataset(input_features)
    # 预测不用打乱顺序shuffle=False
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    # 是否使用gpu
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")
    if use_cuda:
        col_model = col_model.to(device)
        value_model = value_model.to(device)
        col_model.load_state_dict(torch.load(column_model_path, map_location=torch.device(device)))
        value_model.load_state_dict(torch.load(value_model_path, map_location=torch.device(device)))
    # 预测
    pre_all_agg = []
    pre_all_conn_op = []
    pre_all_cond_ops = []
    pre_all_cond_vals = []
    for input_ids, attention_mask, token_type_ids, cls_idx in tqdm(dataloader):
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        out_agg, out_cond_ops, out_conn_op = col_model(input_ids, attention_mask, token_type_ids, cls_idx)

        # 取预测结果最大值，torch.argmax找到指定纬度最大值所对应的索引（是索引，不是值）
        pre_agg = torch.argmax(out_agg, dim=2).cpu().numpy()
        pre_cond_ops = torch.argmax(out_cond_ops, dim=2).cpu().numpy()
        pre_conn_op = torch.argmax(out_conn_op, dim=1).cpu().numpy()

        pre_all_agg.extend(pre_agg)
        pre_all_cond_ops.extend(pre_cond_ops)
        pre_all_conn_op.extend(pre_conn_op)
    for input_ids, attention_mask, token_type_ids, cls_idx in tqdm(dataloader):
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        out_cond_vals = value_model(input_ids, attention_mask, token_type_ids)
        pre_cond_vals = torch.argmax(out_cond_vals, dim=2).cpu().numpy()
        pre_all_cond_vals.extend(pre_cond_vals)

    with open(predict_result_path, 'w', encoding='utf-8') as wf:
        for question, agg, conn_op, cond_ops, cond_vals in zip(questions, pre_all_agg, pre_all_conn_op,
                                                               pre_all_cond_ops, pre_all_cond_vals):
            sel_col = np.where(np.array(agg) != get_agg_dict()['none'])[0]
            agg = agg[agg != get_agg_dict()['none']]
            cond_col = np.where(np.array(cond_ops) != get_cond_op_dict()['none'])[0]
            cond_op = cond_ops[cond_ops != get_cond_op_dict()['none']]
            sel_col_name = [get_columns()[idx_col] for idx_col in sel_col]
            cond_vals_name = get_values_name(question[0], cond_vals)
            conds = [[int(item_cond_col), int(item_cond_op), item_cond_val_name] for
                     item_cond_col, item_cond_op, item_cond_val_name in zip(cond_col, cond_op, cond_vals_name)]
            sql_dict = {"question": question, "table_id": table_name,
                        "sql": {"sel": list(map(int, sel_col)),
                                "agg": list(map(int, agg)),
                                "limit": 0,
                                "orderby": [],
                                "asc_desc": 0,
                                "cond_conn_op": int(conn_op),
                                'conds': conds},
                        "keywords": {"sel_cols": sel_col_name, "values": cond_vals_name}}
            sql_json = json.dumps(sql_dict, ensure_ascii=False)
            wf.write(sql_json + '\n')


if __name__ == '__main__':
    hidden_size = 768
    batch_size = 24
    question_length = 128
    max_length = 512
    predict_question_path = '/kaggle/input/bert-nl2sql-predict-datas/waic_nl2sql_testa_public.jsonl'
    predict_result_path = '/kaggle/working/predict.jsonl'
    pretrain_model_path = '/kaggle/input/bert-nl2sql-chinese-model-hgd'
    column_model_path = '/kaggle/input/bert-nl2sql-result-model/classifier-column-model.pkl'
    value_model_path = '/kaggle/input/bert-nl2sql-result-model/classifier-value-model.pkl'
    questions = read_predict_datas(predict_question_path)
    predict(questions, predict_result_path, pretrain_model_path, column_model_path, value_model_path, hidden_size,
            batch_size, question_length, max_length)

100%|██████████| 497/497 [00:54<00:00,  9.16it/s]
100%|██████████| 497/497 [00:53<00:00,  9.26it/s]
