In [1]:
#import
import numpy as np
import torch.utils.data
import torch.cuda
import json
from transformers import BertTokenizer,BertModel
from torch import nn
from sklearn import metrics
from torch.optim.adam import Adam
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from typing import List



In [6]:
#utils
def read_train_datas(path):
    """
    :return: [[question, sel_col, conds:[col, op, start, end], conn_op],...]
    """
    with open(path, 'r', encoding='utf-8') as f:
        data_list = []
        for line in f:
            item = json.loads(line)
            question = item['question']
            sel = item['sql']['sel'][0]
            cond_conn_op = item['sql']['cond_conn_op']
            if item['sql'].get('conds') is not None:
                conds = item['sql']['conds']
                for i, cond in enumerate(conds):
                    value = cond[2]
                    start, end = value_start_end(question, value)
                    cond[2] = start
                    cond.append(end)
            else:
                conds = None
            data_list.append([question, sel, conds, cond_conn_op])
    return data_list


def read_predict_datas(path):
    """
    :param path: 预测数据路径
    :return: 预测数据
    """
    questions = []
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            item = json.loads(line)
            question = item['question']
            questions.append([question])
    return questions


def value_start_end(question, value):
    """
    get the start and end index of the value in the question
    """
    question_length = len(question)
    value_length = len(value)
    for i in range(question_length - value_length + 1):
        if question[i:value_length + i] == value:
            return i, i + value_length - 1
    return 0, 0


def get_columns():
    columns = ['基金代码', '基金名称', '成立时间', '基金类型', '基金规模', '销售状态', '是否可销售', '风险等级',
               '基金公司名称', '分红方式',
               '赎回状态', '是否支持定投', '净值同步日期', '净值', '成立以来涨跌幅', '昨日涨跌幅', '近一周涨跌幅',
               '近一个月涨跌幅', '近三个月涨跌幅', '近六个月涨跌幅',
               '近一年涨跌幅', '基金经理', '主题/概念', '一个月夏普率', '一年夏普率', '三个月夏普率', '六个月夏普率',
               '成立以来夏普率', '投资市场', '板块', '行业',
               '晨星三年评级', '管理费率', '销售服务费率', '托管费率', '认购费率', '申购费率', '赎回费率', '分红年度',
               '权益登记日',
               '除息日', '派息日', '红利再投日', '每十份收益单位派息', '主投资产类型', '基金投资风格描述', '估值',
               '是否主动管理型基金', '投资', '跟踪指数',
               '是否新发', '重仓', '无']
    return columns


def get_cond_op_dict():
    cond_op_dict = {'>': 0, '<': 1, '==': 2, '!=': 3, 'like': 4, '>=': 5, '<=': 6, 'none': 7}
    return cond_op_dict


def get_conn_op_dict():
    conn_op_dict = {'none': 0, 'and': 1, 'or': 2}
    return conn_op_dict

In [7]:
#dataset
# sql查询条件
class Conditions(object):
    def __init__(self, cond_col=None, cond_op=None, cond_value=None):
        """
        example [cond_col cond_op cond_value]
        :param cond_col: sql的查询条件列
        :param cond_op: sql的查询条件操作符
        :param cond_value: sql的查询条件值
        """
        self.cond_col = cond_col
        self.cond_op = cond_op
        self.cond_value = cond_value


# label
class Label(object):
    def __init__(self, label_sel_col=None, label_conn_op=None, label_cond: List[Conditions] = None):
        """
        example [select label_sel_col from table where label_condition[0] label_conn_op label_condition[1] label_conn_op ...]
        :param label_sel_col:
        :param label_conn_op:
        :param label_cond:
        """
        self.label_sel_col = label_sel_col
        self.label_conn_op = label_conn_op
        self.label_cond = label_cond


class InputFeatures(object):
    def __init__(self, model_path=None, question_length=128, max_length=512, input_ids=None, attention_mask=None,
                 token_type_ids=None, cls_idx=None, label=None):
        if model_path is not None:
            self.tokenizer = BertTokenizer.from_pretrained(model_path)
        self.question_length = question_length
        self.max_length = max_length
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.token_type_ids = token_type_ids
        self.cls_idx = cls_idx
        self.label = label

    def encode_columns(self, columns: List):
        """
        列编码
        :param columns: 列
        :return: 编码后的列，及序列号（用于列与列之间的区分）
        """
        columns_encode = []
        segment_ids = []
        i = 1
        for column in columns:
            encod = self.tokenizer.encode(column)
            seg = [i] * len(encod)
            columns_encode.extend(encod)
            segment_ids.extend(seg)
            i = 1 - i  # 切换 0 和 1
        return torch.tensor(columns_encode), torch.tensor(segment_ids)

    def get_cls_idx(self, columns):
        """
        获取列标记符的位置
        :param columns: 列
        :return:
        """
        cls_idx = []
        start = self.question_length
        for i in range(len(columns)):
            cls_idx.append(int(start))
            # 加上特殊标记的长度（例如 [CLS] 和 [SEP]）
            start += len(columns[i]) + 2
        return cls_idx

    def encode_question_with_columns(self, que_length, max_length, question, columns_encode, columns_segment_id):
        """
        编码
        :param que_length: 问题长度
        :param max_length: text长度
        :param question:  问题
        :param columns_encode:  编码的列
        :param columns_segment_id 编码的列的序列
        :return: 编码后的text
        """

        # 编码问题，需要填充，否则会出现长度不一致异常
        question_encoding = self.tokenizer.encode(question, add_special_tokens=True, padding='max_length',
                                                  max_length=que_length, truncation=True)

        # 合并编码后的张量，保证张量类型(dtype)为int或long, bert的embedding的要求
        input_ids = torch.cat([torch.tensor(question_encoding), columns_encode], dim=0)
        token_type_ids = torch.cat([torch.zeros(que_length, dtype=torch.long), columns_segment_id], dim=0)
        padding_length = max_length - len(input_ids)
        attention_mask = torch.cat([torch.ones(len(input_ids)), torch.zeros(padding_length)], dim=0)
        input_ids = torch.cat([input_ids, torch.zeros(padding_length, dtype=torch.long)], dim=0)
        token_type_ids = torch.cat([token_type_ids, torch.zeros(padding_length, dtype=torch.long)], dim=0)

        return input_ids, attention_mask, token_type_ids

    def list_features(self, datas):
        """
        输入特征
        :param datas: 数据
        :param que_length: 问题长度
        :param max_length: text长度
        :return: 特征信息
        """
        list_features = []
        columns = get_columns()
        cls_idx = self.get_cls_idx(columns)
        columns_encode, columns_segment_id = self.encode_columns(get_columns())
        for data in datas:
            # if contain label data
            label = None
            if len(data) > 1:
                label = Label(label_sel_col=[data[1]], label_conn_op=[data[3]],
                              label_cond=[Conditions(cond[0], cond[1], cond[2:4]) for cond in data[2]] if data[
                                                                                                              2] is not None else None)
            question = data[0]
            # 编码(question+columns)
            input_ids, attention_mask, token_type_ids = self.encode_question_with_columns(self.question_length,
                                                                                          self.max_length,
                                                                                          question, columns_encode,
                                                                                          columns_segment_id)
            list_features.append(
                InputFeatures(question_length=self.question_length, max_length=self.max_length, input_ids=input_ids,
                              attention_mask=attention_mask, token_type_ids=token_type_ids, cls_idx=cls_idx,
                              label=label))
        return list_features


class Dataset(torch.utils.data.Dataset):
    def __init__(self, features: List[InputFeatures]):
        self.features = features

    def __len__(self):
        return len(self.features)

    def __getitem__(self, item):
        feature = self.features[item]
        input_ids = np.array(feature.input_ids)
        attention_mask = np.array(feature.attention_mask)
        token_type_ids = np.array(feature.token_type_ids)
        cls_idx = np.array(feature.cls_idx)
        if feature.label is not None:
            label: Label = feature.label
            label_sel_col = np.array(label.label_sel_col)
            label_conn_op = np.array(label.label_conn_op)
            label_cond = np.array(label.label_cond)
            if label_cond.any() is None or label_cond.size == 0:
                # 初始化一维数组，保证纬度一致
                # 52对应‘无’这一列
                label_cond_col = np.array([len(get_columns()) - 1], dtype=np.int32)
                # 7对应‘none’操作符
                label_cond_op = np.array([len(get_cond_op_dict()) - 1], dtype=np.int32)
                label_cond_value = np.array([[0, 0]], dtype=np.int32)
            else:
                # 转化成一维数组，保证纬度一致
                label_cond_col = np.array([[item.cond_col] for item in label_cond]).ravel()[:np.prod(1)].reshape(
                    1)
                label_cond_op = np.array([[item.cond_op] for item in label_cond]).ravel()[:np.prod(1)].reshape(
                    1)
                label_cond_value = np.array([item.cond_value for item in label_cond]).ravel()[:np.prod((1, 2))].reshape(
                    (1, 2))
            # 打印样本信息
#             print(f"Sample {item}:")
#             print(f"input_ids shape: {input_ids.shape}")
#             print(f"attention_mask shape: {attention_mask.shape}")
#             print(f"token_type_ids shape: {token_type_ids.shape}")
#             print(f"label_sel_col shape: {label_sel_col.shape}")
#             print(f"label_conn_op shape: {label_conn_op.shape}")
#             print(f"label_cond_col shape: {label_cond_col.shape}")
#             print(f"label_cond_op shape: {label_cond_op.shape}")
#             print(f"label_cond_value shape: {label_cond_value.shape}")
            return input_ids, attention_mask, token_type_ids, cls_idx, label_sel_col, label_conn_op, label_cond_col, label_cond_op, label_cond_value
        else:
            return input_ids, attention_mask, token_type_ids, cls_idx



In [4]:
#model
class ColClassifierModel(nn.Module):
    def __init__(self, model_path, hidden_size, cond_op_length, dropout=0.5):
        super(ColClassifierModel, self).__init__()
        self.bert = BertModel.from_pretrained(model_path)
        self.dropout = nn.Dropout(dropout)
        # todo 可以不止一列
        self.sel_col_classifier = nn.Linear(hidden_size, 1)
        # todo 条件不止一列
        self.cond_col_classifier = nn.Linear(hidden_size, 1)
        # out classes需要纬度必须大于label中size(classes)，否则会出现Assertion `t >= 0 && t < n_classes` failed.
        self.cond_op_classifier = nn.Linear(hidden_size, cond_op_length)

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, cls_idx=None):
        # 输出最后一层隐藏状态以及池化层
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        dropout_output = self.dropout(outputs.pooler_output)
        dropout_hidden_state = self.dropout(outputs.last_hidden_state)

        """
        提取列特征信息，从dim=1即第二维中（列标记符号索引所在纬度）提取dropout_hidden_state对应该纬度的信息。
        前提需要将cls_idx张量shape扩展成与dropout_hidden_state一致
        """
        # cls_cols = dropout_hidden_state.gather(dim=1, index=cls_idx.unsqueeze(-1).expand(
        #     dropout_hidden_state.shape[0], -1, dropout_hidden_state.shape[-1]))
        # 简化写法
        cls_cols = dropout_hidden_state[:, cls_idx[0], :]

        out_sel_col = self.sel_col_classifier(cls_cols).squeeze(-1)
        out_cond_col = self.cond_col_classifier(cls_cols).squeeze(-1)

        out_cond_op = self.cond_op_classifier(dropout_output)

        return out_sel_col, out_cond_col, out_cond_op


class ValueClassifierModel(nn.Module):
    def __init__(self, model_path, hidden_size, conn_op_length, question_length, cond_value_length=2, dropout=0.5):
        super(ValueClassifierModel, self).__init__()
        self.bert = BertModel.from_pretrained(model_path)
        self.dropout = nn.Dropout(dropout)
        # todo 最大条件值数量
        self.cond_values_classifier = nn.Linear(hidden_size, cond_value_length)
        self.conn_op_classifier = nn.Linear(hidden_size, conn_op_length)
        self.question_length = question_length

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None):
        # 输出最后一层隐藏状态以及池化层
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask,
                            token_type_ids=token_type_ids)
        dropout_output = self.dropout(outputs.pooler_output)
        dropout_hidden_state = self.dropout(outputs.last_hidden_state)

        out_conn_op = self.conn_op_classifier(dropout_output)

        # 提取问题特征信息
        cond_values = dropout_hidden_state[:, 1:int(self.question_length), :]

        out_cond_values = self.cond_values_classifier(cond_values)

        return out_conn_op, out_cond_values

In [None]:
#train
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
def train(model: ColClassifierModel or ValueClassifierModel, model_save_path, train_dataset: Dataset,
          val_dataset: Dataset, batch_size, lr, epochs):
    # DataLoader根据batch_size获取数据，训练时选择打乱样本
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    # 是否使用gpu
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optim = Adam(model.parameters(), lr=lr)
    if use_cuda:
        model = model.to(device)
        criterion = criterion.to(device)
    best_val_avg_acc = 0
    for epoch in range(epochs):
        total_loss_train = 0
        model.train()
        # 训练进度
        for input_ids, attention_mask, token_type_ids, cls_idx, label_sel_col, label_conn_op, label_cond_col, label_cond_op, label_cond_value in tqdm(
                train_loader):
            # model要求输入的矩阵(hidden_size,sequence_size),需要把第二纬度去除.squeeze(1)
            input_ids = input_ids.squeeze(1).to(device)
            attention_mask = attention_mask.to(device)
            token_type_ids = token_type_ids.squeeze(1).to(device)
            if type(model) is ColClassifierModel:
                label_sel_col = label_sel_col.squeeze(-1).to(device)
                label_cond_col = label_cond_col.squeeze(-1).to(device)
                label_cond_op = label_cond_op.squeeze(-1).to(device)
                # 模型输出
                out_sel_col, out_cond_col, out_cond_op = model(input_ids, attention_mask, token_type_ids, cls_idx)
                # 计算损失
                loss_sel_col = criterion(out_sel_col, label_sel_col)
                loss_cond_col = criterion(out_cond_col, label_cond_col)
                loss_cond_op = criterion(out_cond_op, label_cond_op)
                # todo 损失比例
                total_loss_train = loss_sel_col + loss_cond_col + loss_cond_op
            if type(model) is ValueClassifierModel:
                label_conn_op = label_conn_op.squeeze(-1).to(device)
                label_cond_values = label_cond_value.squeeze(1).to(device)
                # 模型输出
                out_conn_op, out_cond_values = model(input_ids, attention_mask, token_type_ids)
                # 计算损失
                lost_conn_op = criterion(out_conn_op, label_conn_op)
                lost_cond_values = criterion(out_cond_values, label_cond_values)
                # todo 损失比例
                total_loss_train = lost_conn_op + lost_cond_values
            # 模型更新
            model.zero_grad()
            optim.zero_grad()
            total_loss_train.backward()
            optim.step()
        # 模型验证
        val_avg_acc = 0
        out_all_sel_col = []
        out_all_cond_col = []
        out_all_cond_op = []
        label_all_sel_col = []
        label_all_cond_col = []
        label_all_cond_op = []
        out_all_conn_op = []
        out_all_cond_values = []
        label_all_conn_op = []
        label_all_cond_values = []
        # 验证无需梯度计算
        model.eval()
        with torch.no_grad():
            # 使用当前epoch训练好的模型验证
            for input_ids, attention_mask, token_type_ids, cls_idx, label_sel_col, label_conn_op, label_cond_col, label_cond_op, label_cond_value in val_loader:
                input_ids = input_ids.squeeze(1).to(device)
                attention_mask = attention_mask.to(device)
                token_type_ids = token_type_ids.squeeze(1).to(device)
                if type(model) is ColClassifierModel:
                    label_sel_col = label_sel_col.squeeze(-1).to(device)
                    label_cond_col = label_cond_col.squeeze(-1).to(device)
                    label_cond_op = label_cond_op.squeeze(-1).to(device)
                    # 模型输出
                    out_sel_col, out_cond_col, out_cond_op = model(input_ids, attention_mask, token_type_ids, cls_idx)
                    out_all_sel_col.append(out_sel_col.argmax(dim=1).cpu().numpy())
                    out_all_cond_col.append(out_cond_col.argmax(dim=1).cpu().numpy())
                    out_all_cond_op.append(out_cond_op.argmax(dim=1).cpu().numpy())
                    label_all_sel_col.append(label_sel_col.cpu().numpy())
                    label_all_cond_col.append(label_cond_col.cpu().numpy())
                    label_all_cond_op.append(label_cond_op.cpu().numpy())
                if type(model) is ValueClassifierModel:
                    label_conn_op = label_conn_op.squeeze(-1).to(device)
                    # reshape(-1)需要转成一维数组才能计算准确率
                    label_cond_values = label_cond_value.squeeze(1).to(device).reshape(-1)
                    # 模型输出
                    out_conn_op, out_cond_values = model(input_ids, attention_mask, token_type_ids)
                    out_all_conn_op.append(out_conn_op.argmax(dim=1).cpu().numpy())
                    out_all_cond_values.append(out_cond_values.argmax(dim=1).reshape(-1).cpu().numpy())
                    label_all_conn_op.append(label_conn_op.cpu().numpy())
                    label_all_cond_values.append(label_cond_values.cpu().numpy())

        if type(model) is ColClassifierModel:
            val_sel_col_acc = metrics.accuracy_score(np.concatenate(out_all_sel_col), np.concatenate(label_all_sel_col))
            val_cond_col_acc = metrics.accuracy_score(np.concatenate(out_all_cond_col),
                                                      np.concatenate(label_all_cond_col))
            val_cond_op_acc = metrics.accuracy_score(np.concatenate(out_all_cond_op), np.concatenate(label_all_cond_op))
            # todo 准确率计算逻辑
            val_avg_acc = (val_sel_col_acc + val_cond_col_acc + val_cond_op_acc) / 3
        if type(model) is ValueClassifierModel:
            val_conn_op_acc = metrics.accuracy_score(np.concatenate(out_all_conn_op), np.concatenate(label_all_conn_op))
            val_cond_values_acc = metrics.accuracy_score(np.concatenate(out_all_cond_values),
                                                         np.concatenate(label_all_cond_values))
            val_avg_acc = (val_conn_op_acc + val_cond_values_acc) / 2
        # save model
        if val_avg_acc > best_val_avg_acc:
            best_val_avg_acc = val_avg_acc
            torch.save(model.state_dict(), model_save_path)
            print(f'''best model | Val Accuracy: {best_val_avg_acc: .3f}''')
        print(
            f'''Epochs: {epoch + 1} 
              | Train Loss: {total_loss_train: .3f} ]
              | Val Accuracy: {val_avg_acc: .3f}''')


def test(model, model_save_path, test_dataset, batch_size):
    # 加载最佳模型权重
    model.load_state_dict(torch.load(model_save_path))
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")

    if use_cuda:
        model = model.to(device)

    total_acc_test = 0
    model.eval()
    with torch.no_grad():
        for test_input, test_label in test_dataloader:
            test_label = test_label.to(device)
            attention_mask = test_input['attention_mask'].to(device)
            input_ids = test_input['input_ids'].squeeze(1).to(device)
            output = model(input_ids, attention_mask)
            acc = (output.argmax(dim=1) == test_label).sum().item()
            total_acc_test += acc
    print(f'Test Accuracy: {total_acc_test / len(test_dataset): .3f}')


if __name__ == '__main__':
    hidden_size = 768
    batch_size = 64
    learn_rate = 2e-5
    epochs = 3
    question_length = 128
    max_length = 512
    train_data_path = '/kaggle/input/bert-nl2sql-datas/waic_nl2sql_train.jsonl'
    pretrain_model_path = '/kaggle/input/bert-nl2sql-chinese-model-hgd'
    save_column_model_path = '/kaggle/working/classifier-column-model.pkl'
    save_value_model_path = '/kaggle/working/classifier-value-model.pkl'
    # 加载数据
    label_datas = read_train_datas(train_data_path)
    # 提取特征数据
    list_input_features = InputFeatures(pretrain_model_path, question_length, max_length).list_features(label_datas)
    # 初始化dataset
    dateset = Dataset(list_input_features)
    # 创建模型
    colModel = ColClassifierModel(pretrain_model_path, hidden_size, len(get_cond_op_dict()))
    valueModel = ValueClassifierModel(pretrain_model_path, hidden_size, len(get_conn_op_dict()), question_length)
    # 分割数据集
    total_size = len(label_datas)
    train_size = int(0.8 * total_size)
    val_size = int(0.1 * total_size)
    test_size = total_size - train_size - val_size
    # 分割数据集
    train_dataset, val_dataset, test_dataset = random_split(dateset, [train_size, val_size, test_size])
    print('train column model begin')
    train(colModel, save_column_model_path, train_dataset, val_dataset, batch_size, learn_rate, epochs)
    print('train column model finish')
    print('train value model begin')
    train(valueModel, save_value_model_path, train_dataset, val_dataset, batch_size, learn_rate,
          epochs)
    print('train value model finish')

train column model begin


  2%|▏         | 18/982 [00:17<13:25,  1.20it/s]

In [9]:
#predict
def predict(questions, pretrain_model_path, column_model_path, value_model_path, hidden_size, batch_size,
            question_length, max_length):
    # 创建模型
    col_model = ColClassifierModel(pretrain_model_path, hidden_size, len(get_cond_op_dict()))
    value_model = ValueClassifierModel(pretrain_model_path, hidden_size, len(get_conn_op_dict()), question_length)
    # 提取特征数据（不含label的数据）
    input_features = InputFeatures(pretrain_model_path, question_length, max_length).list_features(questions)
    dataset = Dataset(input_features)
    # 是否使用gpu
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")
    if use_cuda:
        col_model = col_model.to(device)
        col_model.load_state_dict(torch.load(column_model_path, map_location=torch.device(device)))
        value_model.load_state_dict(torch.load(column_model_path, map_location=torch.device(device)))
    # 预测不用打乱顺序shuffle=False
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    # 预测
    pre_all_sel_col = []
    pre_all_cond_col = []
    pre_all_cond_op = []
    pre_all_conn_op = []
    pre_all_cond_values = []
    for input_ids, attention_mask, token_type_ids, cls_idx in tqdm(dataloader):
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        out_sel_col, out_cond_col, out_cond_op = col_model(input_ids, attention_mask, token_type_ids, cls_idx)

        # 取预测结果最大值，torch.argmax找到指定纬度最大值所对应的索引（是索引，不是值）
        pre_sel_col = torch.argmax(out_sel_col.data, dim=1).cpu().numpy()
        pre_cond_col = torch.argmax(out_cond_col.data, dim=1).cpu().numpy()
        pre_cond_op = torch.argmax(out_cond_op.data, dim=1).cpu().numpy()

        pre_all_sel_col.append(pre_sel_col)
        pre_all_cond_col.append(pre_cond_col)
        pre_all_cond_op.append(pre_cond_op)

    for input_ids, attention_mask, token_type_ids, _ in tqdm(dataloader):
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        out_conn_op, out_cond_values = value_model(input_ids, attention_mask, token_type_ids)

        # 取预测结果最大值，torch.argmax找到指定维度最大值所对应的索引（是索引，不是值）
        pre_conn_op = torch.argmax(out_conn_op.data, dim=1).cpu().numpy()
        pre_cond_values = torch.argmax(out_cond_values.data, dim=1).cpu().numpy()

        pre_all_conn_op.append(pre_conn_op)
        pre_all_cond_values.append(pre_cond_values)

    print("pre_all_sel_col data:", pre_all_sel_col)
    print("pre_all_cond_col data:", pre_all_cond_col)
    print("pre_all_cond_op data:", pre_all_cond_op)
    print("pre_all_cond_op data:", pre_all_conn_op)
    print("pre_all_cond_op data:", pre_all_cond_values)


if __name__ == '__main__':
    hidden_size = 768
    batch_size = 24
    question_length = 128
    max_length = 512
    predict_question_path = '/kaggle/input/bert-nl2sql-predict-datas/waic_nl2sql_testa_public.jsonl'
    pretrain_model_path = '/kaggle/input/bert-nl2sql-chinese-model-hgd'
    column_model_path = '/kaggle/input/bert-nl2sql-result-model/classifier-column-model.pkl'
    questions = read_predict_datas(predict_question_path)
    predict(questions, pretrain_model_path, column_model_path, hidden_size, batch_size, question_length, max_length)

100%|██████████| 497/497 [00:53<00:00,  9.26it/s]


pre_all_sel_col data: [array([13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
       13, 13, 13, 13, 13, 13, 13]), array([13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
       13, 13, 13, 13, 13, 13, 13]), array([13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
       13, 13, 13, 13, 13, 13, 13]), array([13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 25, 25, 25,
       25, 25, 25, 25, 25, 25, 25]), array([25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
       25, 25, 25, 25, 25, 25, 25]), array([25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
       25, 25, 25, 25, 25, 25, 25]), array([25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
       25, 25, 25, 25, 25, 25, 25]), array([25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 10, 10, 10, 10, 10, 10,
       10, 10, 10, 10, 10, 10, 10]), array([10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
      