In [1]:
#import
import numpy as np
import pandas as pd
import torch.utils.data
import torch.cuda
import json
from transformers import BertTokenizer,BertModel
from torch import nn
from sklearn import metrics
from torch.optim.adam import Adam
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from typing import List



In [2]:
#util
def read_train_datas(path, question_length, columns):
    """
    :param path 数据路径
    :param question_length 问题长度
    :param columns 列
    :return: [[question, agg, conn_op, cond_ops, cond_vals],...], cond_vals:[[val_start_idx,val_end_idx],...]
    """
    column_length = len(columns)
    with open(path, 'r', encoding='utf-8') as f:
        data_list = []
        for line in f:
            item = json.loads(line)
            # question
            question = item['question']
            # agg
            sel = item['sql']['sel']
            agg_op = item['sql']['agg']
            agg = [get_agg_dict()['none']] * column_length
            for i in range(len(sel)):
                sel_col_item = sel[i]
                agg_op_item = agg_op[i]
                agg[sel_col_item] = agg_op_item
            # conn_op
            conn_op = item['sql']['cond_conn_op']
            # cond_cols & cond_ops & cond_vals
            # +1 默认初始化为不存在的列, question_length需要大于column_length
            cond_cols = [column_length + 1] * question_length
            cond_ops = [get_cond_op_dict()['none']] * question_length
            cond_vals = [0] * question_length
            if item['sql'].get('conds') is not None:
                conds = item['sql']['conds']
                for idx, cond in enumerate(conds):
                    cond_cols[idx] = cond[0]
                    cond_ops[idx] = cond[1]
                    value = cond[2]
                    cond_vals = fill_value_start_end(cond_vals, question, value, idx)
            data_list.append([question, agg, conn_op, cond_cols, cond_ops, cond_vals])
    return data_list


def read_predict_datas(path):
    """
    :param path: 预测数据路径
    :return: 预测数据
    """
    questions = []
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            item = json.loads(line)
            question = item['question']
            questions.append([question])
    return questions


def get_columns(table_path):
    columns = pd.read_table(table_path, header=2)
    return columns.columns.__array__()


def get_cond_op_dict():
    cond_op_dict = {'>': 0, '<': 1, '==': 2, '!=': 3, 'like': 4, '>=': 5, '<=': 6, 'none': 7}
    return cond_op_dict


def get_conn_op_dict():
    conn_op_dict = {'none': 0, 'and': 1, 'or': 2}
    return conn_op_dict


def get_agg_dict():
    agg_dict = {'': 0, 'AVG': 1, 'MAX': 2, 'MIN': 3, 'COUNT': 4, 'SUM': 5, 'none': 6}
    return agg_dict


def get_key(dict, value):
    """
    根据字典的value获取key
    :param dict: 字典
    :param value: 值
    :return: key
    """
    return [k for k, v in dict.items() if v == value]


def fill_value_start_end(cond_vals, question, value, idx):
    """
    :param cond_vals 待填充待值
    :param question 问题
    :param value 待匹配待值
    :param 下标
    fill [1] by the value in the question
    结果类似[0,0,0,0,1,1,1,1,0,0,0,2,2,2,2,3,3,3,0,0,0,4,4,4,0,0,0,0,0]
    """
    question_length = len(question)
    value_length = len(value)
    for i in range(question_length - value_length + 1):
        if question[i:value_length + i] == value:
            cond_vals[i: value_length + i] = [idx + 1] * value_length
    return cond_vals


def count_values(cond_vals):
    """
   cond_vals的值如[0,0,0,0,1,1,1,1,0,0,0,2,2,2,2,3,3,3,0,0,0,4,4,4,0,0,0,0,0]所示
   统计出现>0的数量，连续的>0只统计一次
   """
    count = 0
    pre_value = None
    for idx, val in enumerate(cond_vals):
        if idx > 0:
            pre_value = cond_vals[idx - 1]
        if val > 0 and val != pre_value:
            count = count + 1
    return count


def get_values_name(question, cond_vals):
    """
    cond_vals的值如[0,0,0,0,1,1,1,1,0,0,0,2,2,2,2,3,3,3,0,0,0,4,4,4,0,0,0,0,0]所示
    根据cond_vals中为1的值找到question对应下标的内容
    返回找到的内容列表，连续为1的内容作为返回列表的一个元素
    """
    result = []
    start_of_segment = None
    for idx, current_value in enumerate(cond_vals):
        previous_value = cond_vals[idx - 1] if idx > 0 else None
        # 检测新的连续段的开始
        if current_value > 0 and current_value != previous_value and start_of_segment is None:
            start_of_segment = idx
        # 当段结束时，添加到结果并更新段的起始位置
        elif current_value != previous_value and start_of_segment is not None:
            segment = question[start_of_segment:idx]
            result.append(segment)
            start_of_segment = None if current_value == 0 else idx
    return result


In [3]:
#dataset
# label
class Label(object):
    def __init__(self, label_agg: List = None, label_conn_op=None, label_cond_cols: List = None,
                 label_cond_ops: List = None, label_cond_vals: List = None):
        """
        训练标签信息
        :param label_agg: 聚合函数
        :param label_conn_op: 连接操作符
        :param label_cond_cols: 条件操列
        :param label_cond_ops: 条件操作符
        :param label_cond_vals: 条件值
        """
        self.label_agg = label_agg
        self.label_conn_op = label_conn_op
        self.label_cond_ops = label_cond_ops
        self.label_cond_cols = label_cond_cols
        self.label_cond_vals = label_cond_vals


class InputFeatures(object):
    def __init__(self, model_path=None, question_length=128, max_length=512, input_ids=None, attention_mask=None,
                 token_type_ids=None, cls_idx=None, label: Label = None):
        if model_path is not None:
            self.tokenizer = BertTokenizer.from_pretrained(model_path)
        self.question_length = question_length
        self.max_length = max_length
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.token_type_ids = token_type_ids
        self.cls_idx = cls_idx
        self.label = label

    def encode_expression(self, expressions: List):
        """
        表达式编码
        :param expressions: 表达式（列名或条件表达式）
        :return: 编码后的列，及序列号（用于列与列之间的区分）
        """
        encodings = self.tokenizer.batch_encode_plus(expressions)
        expressions_encode = encodings["input_ids"]
        segment_ids = encodings["token_type_ids"]
        segment_ids = [[elem if j % 2 == 0 else 1 for elem in row] for j, row in enumerate(segment_ids)]
        expressions_encode = [item for sublist in expressions_encode for item in sublist]
        segment_ids = [item for sublist in segment_ids for item in sublist]
        return torch.tensor(expressions_encode), torch.tensor(segment_ids)

    def get_cls_idx(self, expressions):
        """
        获取表达式标记符的位置
        :param expressions: 表达式
        :return:
        """
        cls_idx = []
        start = self.question_length
        for i in range(len(expressions)):
            cls_idx.append(int(start))
            # 加上特殊标记的长度（例如 [CLS] 和 [SEP]）
            start += len(expressions[i]) + 2
        return cls_idx

    def encode_question_with_expressions(self, que_length, max_length, question, expressions_encode,
                                         expressions_segment_id):
        """
        编码
        :param que_length: 问题长度
        :param max_length: text长度
        :param question:  问题
        :param expressions_encode:  编码的列
        :param expressions_segment_id 编码的列的序列
        :return: 编码后的text
        """

        # 编码问题，需要填充，否则会出现长度不一致异常
        question_encoding = self.tokenizer.encode(question, add_special_tokens=True, padding='max_length',
                                                  max_length=que_length, truncation=True)

        # 合并编码后的张量，保证张量类型(dtype)为int或long, bert的embedding的要求
        input_ids = torch.cat([torch.tensor(question_encoding), expressions_encode], dim=0)
        token_type_ids = torch.cat([torch.zeros(que_length, dtype=torch.long), expressions_segment_id], dim=0)
        padding_length = max_length - len(input_ids)
        attention_mask = torch.cat([torch.ones(len(input_ids)), torch.zeros(padding_length)], dim=0)
        input_ids = torch.cat([input_ids, torch.zeros(padding_length, dtype=torch.long)], dim=0)
        token_type_ids = torch.cat([token_type_ids, torch.zeros(padding_length, dtype=torch.long)], dim=0)

        return input_ids, attention_mask, token_type_ids

    def list_features(self, columns, datas):
        """
        输入特征
        :param columns 列
        :param datas: 数据
        :return: 特征信息
        """
        list_features = []
        cls_idx = self.get_cls_idx(columns)
        expressions_encode, expressions_segment_id = self.encode_expression(columns)
        for data in datas:
            question = data[0]
            # if contain label data
            label = None
            if len(data) > 1:
                label = Label(label_agg=data[1], label_conn_op=data[2], label_cond_cols=data[3], label_cond_ops=data[4],
                              label_cond_vals=data[5])
            # 编码(question+expressions)
            input_ids, attention_mask, token_type_ids = self.encode_question_with_expressions(self.question_length,
                                                                                              self.max_length,
                                                                                              question,
                                                                                              expressions_encode,
                                                                                              expressions_segment_id)
            list_features.append(
                InputFeatures(question_length=self.question_length, max_length=self.max_length, input_ids=input_ids,
                              attention_mask=attention_mask, token_type_ids=token_type_ids, cls_idx=cls_idx,
                              label=label))
        return list_features


class Dataset(torch.utils.data.Dataset):
    def __init__(self, features: List[InputFeatures]):
        self.features = features

    def __len__(self):
        return len(self.features)

    def __getitem__(self, item):
        feature = self.features[item]
        input_ids = np.array(feature.input_ids)
        attention_mask = np.array(feature.attention_mask)
        token_type_ids = np.array(feature.token_type_ids)
        cls_idx = np.array(feature.cls_idx)
        if feature.label is not None:
            label: Label = feature.label
            label_agg = np.array(label.label_agg)
            label_conn_op = np.array(label.label_conn_op)
            label_cond_cols = np.array(label.label_cond_cols)
            label_cond_ops = np.array(label.label_cond_ops)
            label_cond_vals = np.array([np.array(val) for val in label.label_cond_vals])
            return input_ids, attention_mask, token_type_ids, cls_idx, label_agg, label_conn_op, label_cond_cols, label_cond_ops, label_cond_vals
        else:
            return input_ids, attention_mask, token_type_ids, cls_idx



In [4]:
#model
class ColClassifierModel(nn.Module):
    def __init__(self, model_path, hidden_size, agg_length, conn_op_length, dropout=0.5):
        super(ColClassifierModel, self).__init__()
        self.bert = BertModel.from_pretrained(model_path)
        self.dropout = nn.Dropout(dropout)
        # out classes需要纬度必须大于label中size(classes)，否则会出现Assertion `t >= 0 && t < n_classes` failed.
        self.agg_classifier = nn.Linear(hidden_size, agg_length)
        self.conn_op_classifier = nn.Linear(hidden_size, conn_op_length)

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, cls_idx=None):
        # 输出最后一层隐藏状态以及池化层
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        dropout_output = self.dropout(outputs.pooler_output)
        dropout_hidden_state = self.dropout(outputs.last_hidden_state)

        """
        提取列特征信息，从dim=1即第二维中（列标记符号索引所在纬度）提取dropout_hidden_state对应该纬度的信息。
        前提需要将cls_idx张量shape扩展成与dropout_hidden_state一致
        """
        # cls_cols = dropout_hidden_state.gather(dim=1, index=cls_idx.unsqueeze(-1).expand(
        #     dropout_hidden_state.shape[0], -1, dropout_hidden_state.shape[-1]))
        # 简化写法
        cls_cols = dropout_hidden_state[:, cls_idx[0], :]

        out_agg = self.agg_classifier(cls_cols)

        out_conn_op = self.conn_op_classifier(dropout_output)

        return out_agg, out_conn_op


class CondClassifierModel(nn.Module):
    def __init__(self, model_path, hidden_size, question_length, dropout=0.5):
        super(CondClassifierModel, self).__init__()
        self.bert = BertModel.from_pretrained(model_path)
        self.dropout = nn.Dropout(dropout)
        # question_length为条件最多个数
        self.cond_cols_classifier = nn.Linear(hidden_size, question_length)
        self.cond_ops_classifier = nn.Linear(hidden_size, question_length)
        self.cond_vals_classifier = nn.Linear(hidden_size, question_length)
        self.cond_count_classifier = nn.Linear(hidden_size, question_length)
        self.question_length = question_length

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None):
        # 输出最后一层隐藏状态
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        dropout_output = self.dropout(outputs.pooler_output)
        hidden_state = outputs.last_hidden_state

        out_cond_count = self.cond_count_classifier(dropout_output)

        # 提取问题特征信息
        cond_values = hidden_state[:, 1:self.question_length + 1, :]

        out_cond_cols = self.cond_cols_classifier(cond_values)
        out_cond_ops = self.cond_ops_classifier(cond_values)
        out_cond_vals = self.cond_vals_classifier(cond_values)

        return out_cond_cols, out_cond_ops, out_cond_vals, out_cond_count


In [5]:
#train
def train(model: ColClassifierModel or CondClassifierModel, model_save_path, train_dataset: Dataset,
          val_dataset: Dataset, batch_size, lr, epochs):
    # DataLoader根据batch_size获取数据，训练时选择打乱样本
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    # 是否使用gpu
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optim = Adam(model.parameters(), lr=lr)
    if use_cuda:
        model = model.to(device)
        criterion = criterion.to(device)
    best_val_avg_acc = 0
    for epoch in range(epochs):
        total_loss_train = 0
        model.train()
        # 训练进度
        for input_ids, attention_mask, token_type_ids, cls_idx, label_agg, label_conn_op, label_cond_cols, label_cond_ops, label_cond_vals in tqdm(
                train_loader):
            # model要求输入的矩阵(hidden_size,sequence_size),需要把第二纬度去除.squeeze(1)
            input_ids = input_ids.squeeze(1).to(device)
            attention_mask = attention_mask.to(device)
            token_type_ids = token_type_ids.squeeze(1).to(device)
            if type(model) is ColClassifierModel:
                # reshape(-1)合并一二纬度
                label_agg = label_agg.to(device).reshape(-1)
                label_conn_op = label_conn_op.to(device)
                # 模型输出
                out_agg, out_conn_op = model(input_ids, attention_mask, token_type_ids, cls_idx)
                out_agg = out_agg.to(device).reshape(-1, out_agg.size(2))
                out_conn_op = out_conn_op.to(device)
                # 计算损失
                loss_agg = criterion(out_agg, label_agg)
                loss_conn_op = criterion(out_conn_op, label_conn_op)
                # 损失比例
                total_loss_train = loss_agg + loss_conn_op

            if type(model) is CondClassifierModel:
                label_cond_cols = label_cond_cols.to(device).reshape(-1)
                label_cond_ops = label_cond_ops.to(device).reshape(-1)
                label_cond_vals = label_cond_vals.to(device)
                label_cond_count = [count_values(label_cond_val) for label_cond_val in label_cond_vals]
                label_cond_count = torch.tensor(label_cond_count).reshape(-1).to(device)
                label_cond_vals = label_cond_vals.reshape(-1)
                # 模型输出
                out_cond_cols, out_cond_ops, out_cond_vals, out_cond_count = model(input_ids, attention_mask,
                                                                                   token_type_ids)
                # 计算损失
                out_cond_cols = out_cond_cols.reshape(-1, out_cond_cols.size(2))
                out_cond_ops = out_cond_ops.reshape(-1, out_cond_ops.size(2))
                out_cond_vals = out_cond_vals.reshape(-1, out_cond_vals.size(2))
                lost_cond_cols = criterion(out_cond_cols, label_cond_cols)
                lost_cond_ops = criterion(out_cond_ops, label_cond_ops)
                lost_cond_vals = criterion(out_cond_vals, label_cond_vals)
                lost_cond_count = criterion(out_cond_count, label_cond_count)
                total_loss_train = ((lost_cond_cols + lost_cond_vals + lost_cond_ops) * 0.1 + lost_cond_count * 0.9)

            # 模型更新
            model.zero_grad()
            optim.zero_grad()
            total_loss_train.backward()
            optim.step()
        # 模型验证
        val_avg_acc = 0
        out_all_agg = []
        out_all_conn_op = []
        out_all_cond_cols = []
        out_all_cond_ops = []
        out_all_cond_vals = []
        out_all_cond_count = []
        label_all_agg = []
        label_all_conn_op = []
        label_all_cond_cols = []
        label_all_cond_ops = []
        label_all_cond_vals = []
        label_all_cond_count = []
        # 验证无需梯度计算
        model.eval()
        with torch.no_grad():
            # 使用当前epoch训练好的模型验证
            for input_ids, attention_mask, token_type_ids, cls_idx, label_agg, label_conn_op, label_cond_cols, label_cond_ops, label_cond_vals in val_loader:
                input_ids = input_ids.squeeze(1).to(device)
                attention_mask = attention_mask.to(device)
                token_type_ids = token_type_ids.squeeze(1).to(device)
                if type(model) is ColClassifierModel:
                    label_agg = label_agg.to(device).reshape(-1)
                    label_conn_op = label_conn_op.to(device)
                    # 模型输出
                    out_agg, out_conn_op = model(input_ids, attention_mask, token_type_ids, cls_idx)
                    out_agg = out_agg.argmax(dim=2).to(device).reshape(-1)
                    out_conn_op = out_conn_op.argmax(dim=1).to(device)
                    out_all_agg.append(out_agg.cpu().numpy())
                    out_all_conn_op.append(out_conn_op.cpu().numpy())
                    label_all_agg.append(label_agg.cpu().numpy())
                    label_all_conn_op.append(label_conn_op.cpu().numpy())
                if type(model) is CondClassifierModel:
                    label_cond_cols = label_cond_cols.to(device).reshape(-1)
                    label_cond_ops = label_cond_ops.to(device).reshape(-1)
                    label_cond_vals = label_cond_vals.to(device)
                    label_count_value = [count_values(label_cond_val) for label_cond_val in label_cond_vals]
                    label_cond_vals = label_cond_vals.reshape(-1)
                    # 模型输出
                    out_cond_cols, out_cond_ops, out_cond_vals, out_cond_count = model(input_ids, attention_mask,
                                                                                       token_type_ids)
                    out_cond_cols = out_cond_cols.argmax(dim=2).to(device)
                    out_cond_ops = out_cond_ops.argmax(dim=2).to(device)
                    out_cond_vals = out_cond_vals.argmax(dim=2).to(device)
                    out_cond_count = out_cond_count.argmax(dim=1).to(device)
                    out_cond_cols = out_cond_cols.reshape(-1)
                    out_cond_ops = out_cond_ops.reshape(-1)
                    out_cond_vals = out_cond_vals.reshape(-1)
                    out_all_cond_cols.append(out_cond_cols.cpu().numpy())
                    out_all_cond_ops.append(out_cond_ops.cpu().numpy())
                    out_all_cond_vals.append(out_cond_vals.cpu().numpy())
                    out_all_cond_count.extend(out_cond_count.cpu().numpy())
                    label_all_cond_cols.append(label_cond_cols.cpu().numpy())
                    label_all_cond_ops.append(label_cond_ops.cpu().numpy())
                    label_all_cond_vals.append(label_cond_vals.cpu().numpy())
                    label_all_cond_count.extend(label_count_value)

        if type(model) is ColClassifierModel:
            val_agg_acc = metrics.accuracy_score(np.concatenate(out_all_agg, axis=0),
                                                 np.concatenate(label_all_agg, axis=0))
            val_conn_op_acc = metrics.accuracy_score(np.concatenate(out_all_conn_op, axis=0),
                                                     np.concatenate(label_all_conn_op, axis=0))
            print(f'val_agg_acc: {val_agg_acc}')
            print(f'val_conn_op_acc: {val_conn_op_acc}')
            # 准确率计算逻辑
            val_avg_acc = (val_agg_acc + val_conn_op_acc) / 2
        if type(model) is CondClassifierModel:
            val_cond_cols_acc = metrics.accuracy_score(np.concatenate(out_all_cond_cols, axis=0),
                                                       np.concatenate(label_all_cond_cols, axis=0))
            val_cond_ops_acc = metrics.accuracy_score(np.concatenate(out_all_cond_ops, axis=0),
                                                      np.concatenate(label_all_cond_ops, axis=0))
            val_cond_vals_acc = metrics.accuracy_score(np.concatenate(out_all_cond_vals, axis=0),
                                                       np.concatenate(label_all_cond_vals, axis=0))
            val_cond_count_acc = metrics.accuracy_score(label_all_cond_count, out_all_cond_count)
            print(f'val_cond_cols_acc: {val_cond_cols_acc}')
            print(f'val_cond_ops_acc: {val_cond_ops_acc}')
            print(f'val_cond_vals_acc: {val_cond_vals_acc}')
            print(f'val_cond_count_acc: {val_cond_count_acc}')
            val_avg_acc = (val_cond_cols_acc + val_cond_ops_acc + val_cond_vals_acc) / 3 * 0.1 + val_cond_count_acc * 0.9
            # save model
        if val_avg_acc > best_val_avg_acc:
            best_val_avg_acc = val_avg_acc
            torch.save(model.state_dict(), model_save_path)
            print(f'''best model | Val Accuracy: {best_val_avg_acc: .4f}''')
        print(
            f'''Epochs: {epoch + 1} 
              | Train Loss: {total_loss_train.item(): .4f} 
              | Val Accuracy: {val_avg_acc: .4f}''')


if __name__ == '__main__':
    hidden_size = 768
    batch_size = 12
    learn_rate = 2e-5
    epochs = 50
    question_length = 128
    max_length = 512
    table_path = '/kaggle/input/bert-nl2sql-train-datas/table.xlsx'
    train_data_path = '/kaggle/input/bert-nl2sql-train-datas/train.jsonl'
    pretrain_model_path = '/kaggle/input/bert-nl2sql-chinese-model-hgd'
    save_column_model_path = '/kaggle/working/classifier-column-model.pkl'
    save_value_model_path = '/kaggle/working/classifier-value-model.pkl'
    # 读取列
    columns = get_columns(table_path)
    # 加载数据
    label_datas = read_train_datas(train_data_path, question_length, columns)
    # 提取特征数据
    model_features = InputFeatures(pretrain_model_path, question_length, max_length).list_features(columns, label_datas)
    # 初始化dataset
    model_dateset = Dataset(model_features)
    # 创建模型
    col_model = ColClassifierModel(pretrain_model_path, hidden_size, len(get_agg_dict()), len(get_conn_op_dict()))
    # 分割数据集
    total_size = len(label_datas)
    train_size = int(0.8 * total_size)
    val_size = int(0.1 * total_size)
    test_size = total_size - train_size - val_size
    # 分割数据集
    model_train_dataset, model_val_dataset, model_test_dataset = random_split(model_dateset,
                                                                              [train_size, val_size,
                                                                               test_size])
    print('train column model begin')
    train(col_model, save_column_model_path, model_train_dataset, model_val_dataset, batch_size, learn_rate,
          epochs)
    print('train column model finish')
    cond_model = CondClassifierModel(pretrain_model_path, hidden_size, question_length)
    print('train value model begin')
    train(cond_model, save_value_model_path, model_train_dataset, model_val_dataset, batch_size,
          learn_rate,
          epochs)
    print('train value model finish')

train column model begin


100%|██████████| 62/62 [00:10<00:00,  5.85it/s]


best model | Val Accuracy:  0.7845
Epochs: 1 
              | Train Loss:  0.5195 
              | Val Accuracy:  0.7845


100%|██████████| 62/62 [00:10<00:00,  6.10it/s]


best model | Val Accuracy:  0.9748
Epochs: 2 
              | Train Loss:  0.6749 
              | Val Accuracy:  0.9748


100%|██████████| 62/62 [00:10<00:00,  6.10it/s]


best model | Val Accuracy:  0.9856
Epochs: 3 
              | Train Loss:  0.2378 
              | Val Accuracy:  0.9856


100%|██████████| 62/62 [00:10<00:00,  6.05it/s]


Epochs: 4 
              | Train Loss:  0.1318 
              | Val Accuracy:  0.9856


100%|██████████| 62/62 [00:10<00:00,  6.05it/s]


Epochs: 5 
              | Train Loss:  0.1655 
              | Val Accuracy:  0.9856


100%|██████████| 62/62 [00:10<00:00,  6.04it/s]


Epochs: 6 
              | Train Loss:  0.1817 
              | Val Accuracy:  0.9856


100%|██████████| 62/62 [00:10<00:00,  6.04it/s]


Epochs: 7 
              | Train Loss:  0.1318 
              | Val Accuracy:  0.9856


100%|██████████| 62/62 [00:10<00:00,  6.04it/s]


Epochs: 8 
              | Train Loss:  0.1322 
              | Val Accuracy:  0.9856


100%|██████████| 62/62 [00:10<00:00,  6.05it/s]


Epochs: 9 
              | Train Loss:  0.1191 
              | Val Accuracy:  0.9856


100%|██████████| 62/62 [00:10<00:00,  6.04it/s]


Epochs: 10 
              | Train Loss:  0.1342 
              | Val Accuracy:  0.9856


100%|██████████| 62/62 [00:10<00:00,  6.05it/s]


Epochs: 11 
              | Train Loss:  0.1704 
              | Val Accuracy:  0.9856


100%|██████████| 62/62 [00:10<00:00,  6.05it/s]


Epochs: 12 
              | Train Loss:  0.1338 
              | Val Accuracy:  0.9856


100%|██████████| 62/62 [00:10<00:00,  6.06it/s]


best model | Val Accuracy:  0.9883
Epochs: 13 
              | Train Loss:  0.0967 
              | Val Accuracy:  0.9883


100%|██████████| 62/62 [00:10<00:00,  6.06it/s]


Epochs: 14 
              | Train Loss:  0.1007 
              | Val Accuracy:  0.9834


100%|██████████| 62/62 [00:10<00:00,  6.06it/s]


best model | Val Accuracy:  0.9896
Epochs: 15 
              | Train Loss:  0.1037 
              | Val Accuracy:  0.9896


100%|██████████| 62/62 [00:10<00:00,  6.05it/s]


best model | Val Accuracy:  0.9899
Epochs: 16 
              | Train Loss:  0.1007 
              | Val Accuracy:  0.9899


100%|██████████| 62/62 [00:10<00:00,  6.05it/s]


best model | Val Accuracy:  0.9916
Epochs: 17 
              | Train Loss:  0.0615 
              | Val Accuracy:  0.9916


100%|██████████| 62/62 [00:10<00:00,  6.06it/s]


best model | Val Accuracy:  0.9935
Epochs: 18 
              | Train Loss:  0.0420 
              | Val Accuracy:  0.9935


100%|██████████| 62/62 [00:10<00:00,  6.02it/s]


Epochs: 19 
              | Train Loss:  0.0425 
              | Val Accuracy:  0.9933


100%|██████████| 62/62 [00:10<00:00,  6.07it/s]


Epochs: 20 
              | Train Loss:  0.0451 
              | Val Accuracy:  0.9895


100%|██████████| 62/62 [00:10<00:00,  6.05it/s]


best model | Val Accuracy:  0.9963
Epochs: 21 
              | Train Loss:  0.0373 
              | Val Accuracy:  0.9963


100%|██████████| 62/62 [00:10<00:00,  6.06it/s]


best model | Val Accuracy:  0.9967
Epochs: 22 
              | Train Loss:  0.0254 
              | Val Accuracy:  0.9967


100%|██████████| 62/62 [00:10<00:00,  6.06it/s]


Epochs: 23 
              | Train Loss:  0.0146 
              | Val Accuracy:  0.9965


100%|██████████| 62/62 [00:10<00:00,  6.06it/s]


best model | Val Accuracy:  0.9969
Epochs: 24 
              | Train Loss:  0.0317 
              | Val Accuracy:  0.9969


100%|██████████| 62/62 [00:10<00:00,  6.06it/s]


best model | Val Accuracy:  0.9977
Epochs: 25 
              | Train Loss:  0.0428 
              | Val Accuracy:  0.9977


100%|██████████| 62/62 [00:10<00:00,  6.05it/s]


best model | Val Accuracy:  0.9984
Epochs: 26 
              | Train Loss:  0.0105 
              | Val Accuracy:  0.9984


100%|██████████| 62/62 [00:10<00:00,  6.06it/s]


best model | Val Accuracy:  0.9993
Epochs: 27 
              | Train Loss:  0.0088 
              | Val Accuracy:  0.9993


100%|██████████| 62/62 [00:10<00:00,  6.05it/s]


best model | Val Accuracy:  0.9994
Epochs: 28 
              | Train Loss:  0.0045 
              | Val Accuracy:  0.9994


100%|██████████| 62/62 [00:10<00:00,  6.06it/s]


Epochs: 29 
              | Train Loss:  0.0124 
              | Val Accuracy:  0.9994


100%|██████████| 62/62 [00:10<00:00,  6.04it/s]


Epochs: 30 
              | Train Loss:  0.0165 
              | Val Accuracy:  0.9993


100%|██████████| 62/62 [00:10<00:00,  6.05it/s]


Epochs: 31 
              | Train Loss:  0.0014 
              | Val Accuracy:  0.9994


100%|██████████| 62/62 [00:10<00:00,  6.05it/s]


Epochs: 32 
              | Train Loss:  0.0011 
              | Val Accuracy:  0.9994


100%|██████████| 62/62 [00:10<00:00,  6.05it/s]


Epochs: 33 
              | Train Loss:  0.0021 
              | Val Accuracy:  0.9993


100%|██████████| 62/62 [00:10<00:00,  6.05it/s]


Epochs: 34 
              | Train Loss:  0.0086 
              | Val Accuracy:  0.9994


100%|██████████| 62/62 [00:10<00:00,  6.05it/s]


Epochs: 35 
              | Train Loss:  0.0042 
              | Val Accuracy:  0.9994


100%|██████████| 62/62 [00:10<00:00,  6.04it/s]


Epochs: 36 
              | Train Loss:  0.0125 
              | Val Accuracy:  0.9994


100%|██████████| 62/62 [00:10<00:00,  6.05it/s]


Epochs: 37 
              | Train Loss:  0.0025 
              | Val Accuracy:  0.9994


100%|██████████| 62/62 [00:10<00:00,  6.06it/s]


Epochs: 38 
              | Train Loss:  0.0017 
              | Val Accuracy:  0.9993


100%|██████████| 62/62 [00:10<00:00,  6.05it/s]


Epochs: 39 
              | Train Loss:  0.0012 
              | Val Accuracy:  0.9994


100%|██████████| 62/62 [00:10<00:00,  6.06it/s]


best model | Val Accuracy:  0.9995
Epochs: 40 
              | Train Loss:  0.0018 
              | Val Accuracy:  0.9995


100%|██████████| 62/62 [00:10<00:00,  6.05it/s]


Epochs: 41 
              | Train Loss:  0.0006 
              | Val Accuracy:  0.9995


100%|██████████| 62/62 [00:10<00:00,  6.06it/s]


Epochs: 42 
              | Train Loss:  0.0005 
              | Val Accuracy:  0.9995


100%|██████████| 62/62 [00:10<00:00,  6.06it/s]


Epochs: 43 
              | Train Loss:  0.0004 
              | Val Accuracy:  0.9995


100%|██████████| 62/62 [00:10<00:00,  6.05it/s]


Epochs: 44 
              | Train Loss:  0.0013 
              | Val Accuracy:  0.9994


100%|██████████| 62/62 [00:10<00:00,  6.05it/s]


Epochs: 45 
              | Train Loss:  0.0005 
              | Val Accuracy:  0.9993


100%|██████████| 62/62 [00:10<00:00,  6.06it/s]


Epochs: 46 
              | Train Loss:  0.0031 
              | Val Accuracy:  0.9994


100%|██████████| 62/62 [00:10<00:00,  6.05it/s]


Epochs: 47 
              | Train Loss:  0.0004 
              | Val Accuracy:  0.9994


100%|██████████| 62/62 [00:10<00:00,  6.04it/s]


Epochs: 48 
              | Train Loss:  0.0123 
              | Val Accuracy:  0.9994


100%|██████████| 62/62 [00:10<00:00,  6.05it/s]


Epochs: 49 
              | Train Loss:  0.0006 
              | Val Accuracy:  0.9995


100%|██████████| 62/62 [00:10<00:00,  6.06it/s]


Epochs: 50 
              | Train Loss:  0.0004 
              | Val Accuracy:  0.9995
train column model finish
train value model begin


100%|██████████| 62/62 [00:13<00:00,  4.62it/s]


val_cond_cols_acc: 0.9868376358695652
val_cond_ops_acc: 0.9834408967391305
val_cond_vals_acc: 0.9065896739130435
val_cond_count_acc: 0.5652173913043478
best model | Val Accuracy:  0.6046
Epochs: 1 
              | Train Loss:  1.0502 
              | Val Accuracy:  0.6046


100%|██████████| 62/62 [00:13<00:00,  4.61it/s]


val_cond_cols_acc: 0.9868376358695652
val_cond_ops_acc: 0.9868376358695652
val_cond_vals_acc: 0.9065047554347826
val_cond_count_acc: 0.7934782608695652
best model | Val Accuracy:  0.8101
Epochs: 2 
              | Train Loss:  0.6234 
              | Val Accuracy:  0.8101


100%|██████████| 62/62 [00:13<00:00,  4.61it/s]


val_cond_cols_acc: 0.9868376358695652
val_cond_ops_acc: 0.9868376358695652
val_cond_vals_acc: 0.9325747282608695
val_cond_count_acc: 0.967391304347826
best model | Val Accuracy:  0.9675
Epochs: 3 
              | Train Loss:  0.7179 
              | Val Accuracy:  0.9675


100%|██████████| 62/62 [00:13<00:00,  4.63it/s]


val_cond_cols_acc: 0.9868376358695652
val_cond_ops_acc: 0.9868376358695652
val_cond_vals_acc: 0.9412364130434783
val_cond_count_acc: 0.9782608695652174
best model | Val Accuracy:  0.9776
Epochs: 4 
              | Train Loss:  0.0944 
              | Val Accuracy:  0.9776


100%|██████████| 62/62 [00:13<00:00,  4.62it/s]


val_cond_cols_acc: 0.9868376358695652
val_cond_ops_acc: 0.9868376358695652
val_cond_vals_acc: 0.9462466032608695
val_cond_count_acc: 0.967391304347826
Epochs: 5 
              | Train Loss:  0.0928 
              | Val Accuracy:  0.9680


100%|██████████| 62/62 [00:13<00:00,  4.62it/s]


val_cond_cols_acc: 0.9868376358695652
val_cond_ops_acc: 0.9868376358695652
val_cond_vals_acc: 0.9487092391304348
val_cond_count_acc: 0.967391304347826
Epochs: 6 
              | Train Loss:  0.1434 
              | Val Accuracy:  0.9681


100%|██████████| 62/62 [00:13<00:00,  4.61it/s]


val_cond_cols_acc: 0.9868376358695652
val_cond_ops_acc: 0.9868376358695652
val_cond_vals_acc: 0.9496433423913043
val_cond_count_acc: 0.967391304347826
Epochs: 7 
              | Train Loss:  0.0573 
              | Val Accuracy:  0.9681


100%|██████████| 62/62 [00:13<00:00,  4.63it/s]


val_cond_cols_acc: 0.9868376358695652
val_cond_ops_acc: 0.9868376358695652
val_cond_vals_acc: 0.9500679347826086
val_cond_count_acc: 0.967391304347826
Epochs: 8 
              | Train Loss:  0.0531 
              | Val Accuracy:  0.9681


100%|██████████| 62/62 [00:13<00:00,  4.62it/s]


val_cond_cols_acc: 0.9873471467391305
val_cond_ops_acc: 0.9868376358695652
val_cond_vals_acc: 0.9519361413043478
val_cond_count_acc: 0.967391304347826
Epochs: 9 
              | Train Loss:  0.0814 
              | Val Accuracy:  0.9682


100%|██████████| 62/62 [00:13<00:00,  4.63it/s]


val_cond_cols_acc: 0.9920176630434783
val_cond_ops_acc: 0.9880264945652174
val_cond_vals_acc: 0.9546535326086957
val_cond_count_acc: 0.967391304347826
Epochs: 10 
              | Train Loss:  0.0592 
              | Val Accuracy:  0.9685


100%|██████████| 62/62 [00:13<00:00,  4.62it/s]


val_cond_cols_acc: 0.9926970108695652
val_cond_ops_acc: 0.9901494565217391
val_cond_vals_acc: 0.9579653532608695
val_cond_count_acc: 0.9782608695652174
best model | Val Accuracy:  0.9785
Epochs: 11 
              | Train Loss:  0.0301 
              | Val Accuracy:  0.9785


100%|██████████| 62/62 [00:13<00:00,  4.62it/s]


val_cond_cols_acc: 0.9926120923913043
val_cond_ops_acc: 0.9911684782608695
val_cond_vals_acc: 0.9619565217391305
val_cond_count_acc: 0.967391304347826
Epochs: 12 
              | Train Loss:  0.0511 
              | Val Accuracy:  0.9688


100%|██████████| 62/62 [00:13<00:00,  4.62it/s]


val_cond_cols_acc: 0.9932914402173914
val_cond_ops_acc: 0.991593070652174
val_cond_vals_acc: 0.9650135869565217
val_cond_count_acc: 0.967391304347826
Epochs: 13 
              | Train Loss:  0.0180 
              | Val Accuracy:  0.9690


100%|██████████| 62/62 [00:13<00:00,  4.62it/s]


val_cond_cols_acc: 0.9938009510869565
val_cond_ops_acc: 0.9918478260869565
val_cond_vals_acc: 0.9675611413043478
val_cond_count_acc: 0.9782608695652174
best model | Val Accuracy:  0.9789
Epochs: 14 
              | Train Loss:  0.0203 
              | Val Accuracy:  0.9789


100%|██████████| 62/62 [00:13<00:00,  4.63it/s]


val_cond_cols_acc: 0.9943953804347826
val_cond_ops_acc: 0.9922724184782609
val_cond_vals_acc: 0.969344429347826
val_cond_count_acc: 0.967391304347826
Epochs: 15 
              | Train Loss:  0.0180 
              | Val Accuracy:  0.9692


100%|██████████| 62/62 [00:13<00:00,  4.62it/s]


val_cond_cols_acc: 0.9955842391304348
val_cond_ops_acc: 0.9928668478260869
val_cond_vals_acc: 0.9710427989130435
val_cond_count_acc: 0.967391304347826
Epochs: 16 
              | Train Loss:  0.0164 
              | Val Accuracy:  0.9693


100%|██████████| 62/62 [00:13<00:00,  4.63it/s]


val_cond_cols_acc: 0.9957540760869565
val_cond_ops_acc: 0.9937160326086957
val_cond_vals_acc: 0.9724864130434783
val_cond_count_acc: 0.967391304347826
Epochs: 18 
              | Train Loss:  0.0237 
              | Val Accuracy:  0.9694


100%|██████████| 62/62 [00:13<00:00,  4.65it/s]


val_cond_cols_acc: 0.996688179347826
val_cond_ops_acc: 0.9943953804347826
val_cond_vals_acc: 0.9748641304347826
val_cond_count_acc: 0.9782608695652174
best model | Val Accuracy:  0.9793
Epochs: 19 
              | Train Loss:  0.0147 
              | Val Accuracy:  0.9793


100%|██████████| 62/62 [00:13<00:00,  4.62it/s]


val_cond_cols_acc: 0.9966032608695652
val_cond_ops_acc: 0.9942255434782609
val_cond_vals_acc: 0.9762228260869565
val_cond_count_acc: 0.967391304347826
Epochs: 20 
              | Train Loss:  0.0338 
              | Val Accuracy:  0.9696


 85%|████████▌ | 53/62 [00:11<00:01,  4.56it/s]

val_cond_cols_acc: 0.9971127717391305
val_cond_ops_acc: 0.9949898097826086
val_cond_vals_acc: 0.9739300271739131
val_cond_count_acc: 0.9782608695652174
best model | Val Accuracy:  0.9793
Epochs: 21 
              | Train Loss:  0.0329 
              | Val Accuracy:  0.9793


100%|██████████| 62/62 [00:13<00:00,  4.62it/s]


val_cond_cols_acc: 0.9969429347826086
val_cond_ops_acc: 0.9946501358695652
val_cond_vals_acc: 0.9738451086956522
val_cond_count_acc: 0.967391304347826
Epochs: 22 
              | Train Loss:  0.1015 
              | Val Accuracy:  0.9695


100%|██████████| 62/62 [00:13<00:00,  4.64it/s]


val_cond_cols_acc: 0.9970278532608695
val_cond_ops_acc: 0.9949898097826086
val_cond_vals_acc: 0.9794497282608695
val_cond_count_acc: 0.967391304347826
Epochs: 23 
              | Train Loss:  0.0127 
              | Val Accuracy:  0.9697


100%|██████████| 62/62 [00:13<00:00,  4.64it/s]


val_cond_cols_acc: 0.9971976902173914
val_cond_ops_acc: 0.9949898097826086
val_cond_vals_acc: 0.9807235054347826
val_cond_count_acc: 0.967391304347826
Epochs: 24 
              | Train Loss:  0.0138 
              | Val Accuracy:  0.9697


100%|██████████| 62/62 [00:13<00:00,  4.62it/s]


val_cond_cols_acc: 0.9971976902173914
val_cond_ops_acc: 0.9949898097826086
val_cond_vals_acc: 0.9819123641304348
val_cond_count_acc: 0.967391304347826
Epochs: 25 
              | Train Loss:  0.0123 
              | Val Accuracy:  0.9698


100%|██████████| 62/62 [00:13<00:00,  4.62it/s]


val_cond_cols_acc: 0.9972826086956522
val_cond_ops_acc: 0.9949898097826086
val_cond_vals_acc: 0.9831012228260869
val_cond_count_acc: 0.967391304347826
Epochs: 26 
              | Train Loss:  0.0112 
              | Val Accuracy:  0.9698


100%|██████████| 62/62 [00:13<00:00,  4.60it/s]


val_cond_cols_acc: 0.9972826086956522
val_cond_ops_acc: 0.9949898097826086
val_cond_vals_acc: 0.9842900815217391
val_cond_count_acc: 0.967391304347826
Epochs: 27 
              | Train Loss:  0.0134 
              | Val Accuracy:  0.9699


100%|██████████| 62/62 [00:13<00:00,  4.61it/s]


val_cond_cols_acc: 0.9976222826086957
val_cond_ops_acc: 0.9949898097826086
val_cond_vals_acc: 0.9840353260869565
val_cond_count_acc: 0.9782608695652174
best model | Val Accuracy:  0.9797
Epochs: 28 
              | Train Loss:  0.0109 
              | Val Accuracy:  0.9797


100%|██████████| 62/62 [00:13<00:00,  4.60it/s]


val_cond_cols_acc: 0.9972826086956522
val_cond_ops_acc: 0.9949898097826086
val_cond_vals_acc: 0.9866677989130435
val_cond_count_acc: 0.967391304347826
Epochs: 29 
              | Train Loss:  0.0142 
              | Val Accuracy:  0.9700


100%|██████████| 62/62 [00:13<00:00,  4.62it/s]


val_cond_cols_acc: 0.9975373641304348
val_cond_ops_acc: 0.9949898097826086
val_cond_vals_acc: 0.9866677989130435
val_cond_count_acc: 0.967391304347826
Epochs: 30 
              | Train Loss:  0.0049 
              | Val Accuracy:  0.9700


100%|██████████| 62/62 [00:13<00:00,  4.58it/s]


val_cond_cols_acc: 0.9976222826086957
val_cond_ops_acc: 0.9949898097826086
val_cond_vals_acc: 0.9885360054347826
val_cond_count_acc: 0.967391304347826
Epochs: 31 
              | Train Loss:  0.0127 
              | Val Accuracy:  0.9700


100%|██████████| 62/62 [00:13<00:00,  4.62it/s]


val_cond_cols_acc: 0.9976222826086957
val_cond_ops_acc: 0.9949898097826086
val_cond_vals_acc: 0.9885360054347826
val_cond_count_acc: 0.967391304347826
Epochs: 32 
              | Train Loss:  0.0205 
              | Val Accuracy:  0.9700


100%|██████████| 62/62 [00:13<00:00,  4.61it/s]


val_cond_cols_acc: 0.9977921195652174
val_cond_ops_acc: 0.9949898097826086
val_cond_vals_acc: 0.9893851902173914
val_cond_count_acc: 0.967391304347826
Epochs: 33 
              | Train Loss:  0.1037 
              | Val Accuracy:  0.9701


100%|██████████| 62/62 [00:13<00:00,  4.63it/s]


val_cond_cols_acc: 0.997452445652174
val_cond_ops_acc: 0.9946501358695652
val_cond_vals_acc: 0.9898947010869565
val_cond_count_acc: 0.9782608695652174
best model | Val Accuracy:  0.9798
Epochs: 34 
              | Train Loss:  0.0050 
              | Val Accuracy:  0.9798


100%|██████████| 62/62 [00:13<00:00,  4.61it/s]


val_cond_cols_acc: 0.9973675271739131
val_cond_ops_acc: 0.9950747282608695
val_cond_vals_acc: 0.9903192934782609
val_cond_count_acc: 0.967391304347826
Epochs: 35 
              | Train Loss:  0.0168 
              | Val Accuracy:  0.9701


100%|██████████| 62/62 [00:13<00:00,  4.61it/s]


val_cond_cols_acc: 0.998046875
val_cond_ops_acc: 0.9950747282608695
val_cond_vals_acc: 0.9897248641304348
val_cond_count_acc: 0.9782608695652174
best model | Val Accuracy:  0.9799
Epochs: 36 
              | Train Loss:  0.0115 
              | Val Accuracy:  0.9799


100%|██████████| 62/62 [00:13<00:00,  4.61it/s]


val_cond_cols_acc: 0.9983865489130435
val_cond_ops_acc: 0.9950747282608695
val_cond_vals_acc: 0.9907438858695652
val_cond_count_acc: 0.9782608695652174
best model | Val Accuracy:  0.9799
Epochs: 37 
              | Train Loss:  0.0118 
              | Val Accuracy:  0.9799


100%|██████████| 62/62 [00:13<00:00,  4.61it/s]


val_cond_cols_acc: 0.9979619565217391
val_cond_ops_acc: 0.9951596467391305
val_cond_vals_acc: 0.9912533967391305
val_cond_count_acc: 0.9782608695652174
best model | Val Accuracy:  0.9799
Epochs: 38 
              | Train Loss:  0.0062 
              | Val Accuracy:  0.9799


 87%|████████▋ | 54/62 [00:11<00:01,  4.52it/s]

val_cond_cols_acc: 0.9978770380434783
val_cond_ops_acc: 0.9950747282608695
val_cond_vals_acc: 0.9914232336956522
val_cond_count_acc: 0.967391304347826
Epochs: 39 
              | Train Loss:  0.0130 
              | Val Accuracy:  0.9701


100%|██████████| 62/62 [00:13<00:00,  4.64it/s]


val_cond_cols_acc: 0.9977072010869565
val_cond_ops_acc: 0.9949898097826086
val_cond_vals_acc: 0.9901494565217391
val_cond_count_acc: 0.9782608695652174
Epochs: 40 
              | Train Loss:  0.0062 
              | Val Accuracy:  0.9799


100%|██████████| 62/62 [00:13<00:00,  4.62it/s]


val_cond_cols_acc: 0.9979619565217391
val_cond_ops_acc: 0.9946501358695652
val_cond_vals_acc: 0.9833559782608695
val_cond_count_acc: 0.9782608695652174
Epochs: 41 
              | Train Loss:  0.0120 
              | Val Accuracy:  0.9796


100%|██████████| 62/62 [00:13<00:00,  4.62it/s]


val_cond_cols_acc: 0.9978770380434783
val_cond_ops_acc: 0.9950747282608695
val_cond_vals_acc: 0.9904042119565217
val_cond_count_acc: 0.9782608695652174
Epochs: 42 
              | Train Loss:  0.0097 
              | Val Accuracy:  0.9799


100%|██████████| 62/62 [00:13<00:00,  4.63it/s]


val_cond_cols_acc: 0.9978770380434783
val_cond_ops_acc: 0.9952445652173914
val_cond_vals_acc: 0.9912533967391305
val_cond_count_acc: 0.967391304347826
Epochs: 43 
              | Train Loss:  0.0024 
              | Val Accuracy:  0.9701


100%|██████████| 62/62 [00:13<00:00,  4.63it/s]


val_cond_cols_acc: 0.9979619565217391
val_cond_ops_acc: 0.9951596467391305
val_cond_vals_acc: 0.9918478260869565
val_cond_count_acc: 0.9782608695652174
best model | Val Accuracy:  0.9799
Epochs: 44 
              | Train Loss:  0.0055 
              | Val Accuracy:  0.9799


100%|██████████| 62/62 [00:13<00:00,  4.62it/s]


val_cond_cols_acc: 0.998046875
val_cond_ops_acc: 0.994735054347826
val_cond_vals_acc: 0.9922724184782609
val_cond_count_acc: 0.9782608695652174
best model | Val Accuracy:  0.9799
Epochs: 45 
              | Train Loss:  0.0030 
              | Val Accuracy:  0.9799


100%|██████████| 62/62 [00:13<00:00,  4.61it/s]


val_cond_cols_acc: 0.998046875
val_cond_ops_acc: 0.9949048913043478
val_cond_vals_acc: 0.9920176630434783
val_cond_count_acc: 0.9782608695652174
Epochs: 46 
              | Train Loss:  0.0032 
              | Val Accuracy:  0.9799


100%|██████████| 62/62 [00:13<00:00,  4.63it/s]


val_cond_cols_acc: 0.9981317934782609
val_cond_ops_acc: 0.9950747282608695
val_cond_vals_acc: 0.9917629076086957
val_cond_count_acc: 0.9782608695652174
Epochs: 47 
              | Train Loss:  0.0060 
              | Val Accuracy:  0.9799


100%|██████████| 62/62 [00:13<00:00,  4.63it/s]


val_cond_cols_acc: 0.9979619565217391
val_cond_ops_acc: 0.9946501358695652
val_cond_vals_acc: 0.9926970108695652
val_cond_count_acc: 0.9782608695652174
best model | Val Accuracy:  0.9799
Epochs: 48 
              | Train Loss:  0.0294 
              | Val Accuracy:  0.9799


100%|██████████| 62/62 [00:13<00:00,  4.61it/s]


val_cond_cols_acc: 0.9978770380434783
val_cond_ops_acc: 0.9953294836956522
val_cond_vals_acc: 0.9925271739130435
val_cond_count_acc: 0.9782608695652174
best model | Val Accuracy:  0.9800
Epochs: 49 
              | Train Loss:  0.0037 
              | Val Accuracy:  0.9800


100%|██████████| 62/62 [00:13<00:00,  4.64it/s]


val_cond_cols_acc: 0.9979619565217391
val_cond_ops_acc: 0.9948199728260869
val_cond_vals_acc: 0.9921875
val_cond_count_acc: 0.9782608695652174
Epochs: 50 
              | Train Loss:  0.0072 
              | Val Accuracy:  0.9799
train value model finish


In [11]:
#predict
def predict(columns, questions, predict_result_path, pretrain_model_path, column_model_path, value_model_path,
            hidden_size, batch_size, question_length, max_length, table_name='table_name'):
    # 创建模型
    col_model = ColClassifierModel(pretrain_model_path, hidden_size, len(get_agg_dict()), len(get_conn_op_dict()))
    cond_model = CondClassifierModel(pretrain_model_path, hidden_size, question_length)
    # 提取特征数据（不含label的数据）
    input_features = InputFeatures(pretrain_model_path, question_length, max_length).list_features(columns, questions)
    dataset = Dataset(input_features)
    # 预测不用打乱顺序shuffle=False
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    # 是否使用gpu
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")
    if use_cuda:
        col_model = col_model.to(device)
        cond_model = cond_model.to(device)
        col_model.load_state_dict(torch.load(column_model_path, map_location=torch.device(device)))
        cond_model.load_state_dict(torch.load(value_model_path, map_location=torch.device(device)))
    # 预测
    pre_all_agg = []
    pre_all_conn_op = []
    pre_all_cond_cols = []
    pre_all_cond_ops = []
    pre_all_cond_vals = []
    pre_all_cond_counts = []
    for input_ids, attention_mask, token_type_ids, cls_idx in tqdm(dataloader):
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        out_agg, out_conn_op = col_model(input_ids, attention_mask, token_type_ids, cls_idx)
        # 取预测结果最大值，torch.argmax找到指定纬度最大值所对应的索引（是索引，不是值）
        pre_agg = torch.argmax(out_agg, dim=2).cpu().numpy()
        pre_conn_op = torch.argmax(out_conn_op, dim=1).cpu().numpy()
        pre_all_agg.extend(pre_agg)
        pre_all_conn_op.extend(pre_conn_op)
    for input_ids, attention_mask, token_type_ids, cls_idx in tqdm(dataloader):
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        out_cond_cols, out_cond_ops, out_cond_vals, out_cond_count = cond_model(input_ids, attention_mask,
                                                                                token_type_ids)
        pre_cond_cols = torch.argmax(out_cond_cols, dim=2).cpu().numpy()
        pre_cond_ops = torch.argmax(out_cond_ops, dim=2).cpu().numpy()
        pre_cond_vals = torch.argmax(out_cond_vals, dim=2).cpu().numpy()
        pre_cond_count = torch.argmax(out_cond_count, dim=1).cpu().numpy()
        pre_all_cond_cols.extend(pre_cond_cols)
        pre_all_cond_ops.extend(pre_cond_ops)
        pre_all_cond_vals.extend(pre_cond_vals)
        pre_all_cond_counts.extend(pre_cond_count)

    with open(predict_result_path, 'w', encoding='utf-8') as wf:
        for question, agg, conn_op, cond_cols, cond_ops, cond_vals, cond_counts in zip(questions, pre_all_agg,
                                                                                       pre_all_conn_op,
                                                                                       pre_all_cond_cols,
                                                                                       pre_all_cond_ops,
                                                                                       pre_all_cond_vals,
                                                                                       pre_all_cond_counts):
            sel_col = np.where(np.array(agg) != get_agg_dict()['none'])[0]
            agg = agg[agg != get_agg_dict()['none']]
            cond_col = cond_cols[cond_cols <= len(columns)]
            cond_op = cond_ops[cond_ops != get_cond_op_dict()['none']]
            sel_col_name = [columns[idx_col] for idx_col in sel_col]
            cond_vals_name = get_values_name(question[0], cond_vals)
            print(f'cond_col: {cond_col}')
            print(f'cond_op: {cond_op}')
            print(f'cond_vals_name: {cond_vals_name}')
            print(f'cond_counts: {cond_counts}')
            conds = [[int(cond_col[idx]), int(cond_op[idx]), cond_vals_name[idx]] for
                     idx in range(cond_counts)]
            sql_dict = {"question": question, "table_id": table_name,
                        "sql": {"sel": list(map(int, sel_col)),
                                "agg": list(map(int, agg)),
                                "limit": 0,
                                "orderby": [],
                                "asc_desc": 0,
                                "cond_conn_op": int(conn_op),
                                'conds': conds},
                        "keywords": {"sel_cols": sel_col_name, "values": cond_vals_name}}
            sql_json = json.dumps(sql_dict, ensure_ascii=False)
            wf.write(sql_json + '\n')


if __name__ == '__main__':
    hidden_size = 768
    batch_size = 12
    question_length = 128
    max_length = 512
    table_path = '/kaggle/input/bert-nl2sql-train-datas/table.xlsx'
    predict_question_path = '/kaggle/input/bert-nl2sql-train-datas/train_test.jsonl'
    predict_result_path = '/kaggle/working/predict.jsonl'
    pretrain_model_path = '/kaggle/input/bert-nl2sql-chinese-model-hgd'
    column_model_path = '/kaggle/input/bert-nl2sql-result-model/classifier-column-model.pkl'
    value_model_path = '/kaggle/input/bert-nl2sql-result-model/classifier-value-model.pkl'
    columns = get_columns(table_path)
    questions = read_predict_datas(predict_question_path)
    predict(columns, questions, predict_result_path, pretrain_model_path, column_model_path, value_model_path,
            hidden_size, batch_size, question_length, max_length)

100%|██████████| 1/1 [00:00<00:00, 37.13it/s]
100%|██████████| 1/1 [00:00<00:00, 36.77it/s]

cond_col: [0 3]
cond_op: [4 4]
cond_vals_name: ['2024-01-13', '机组']
cond_counts: 2
cond_col: [0 3]
cond_op: [4 4]
cond_vals_name: ['2024-01-13', '机组#']
cond_counts: 2
cond_col: [0 3]
cond_op: [4 4]
cond_vals_name: ['2024-01-13', '全部机', '的']
cond_counts: 2
cond_col: [0 0 3 3]
cond_op: [4 6 4 4]
cond_vals_name: ['2023-01-03', '2023-01-30', '清能#1机组']
cond_counts: 3
cond_col: [0 0 3 3]
cond_op: [4 6 4]
cond_vals_name: ['2023-01-03', '2', '023-01-30', '机组#1的出']
cond_counts: 3



