In [1]:
import torch
from sqlnet.utils import *
from sqlnet.model.sqlbert import SQLBert, BertAdam, BertTokenizer
from torch.optim import Adam
from sqlnet.lookahead import Lookahead
import time

import argparse



In [21]:
parser = argparse.ArgumentParser()

parser.add_argument('--gpu', action='store_true', help='Whether use gpu')
parser.add_argument('--batch_size', type=int, default=12)

parser.add_argument('--data_dir', type=str, default='../data/')
parser.add_argument('--bert_model_dir', type=str, default='../model/chinese-bert_chinese_wwm_pytorch/')
parser.add_argument('--restore_model_path', type=str, default='../model/saved_bert_model')
parser.add_argument('--result_path', type=str,
                    default='./result.json',
                    help='Output path of prediction result')
parser.add_argument('--local_eval', action='store_true')

args = parser.parse_args(args=[])

class nl2sql():
    def __init__(self, 
                 bert_model_dir = args.bert_model_dir, 
                 best_model_path = args.restore_model_path, ):
        self.bert_model_dir = bert_model_dir
        self.best_model_path = best_model_path
        self.model_init()
    
    def model_init(self,):
        self.tokenizer = BertTokenizer.from_pretrained(self.bert_model_dir, do_lower_case=True)
        self.model = SQLBert.from_pretrained(self.bert_model_dir)
        print("Loading from %s" % self.bert_model_dir)
        self.model.load_state_dict(torch.load(self.best_model_path))
        print("Loaded model from %s" % self.best_model_path)
        self.model.eval()
    
    def text2sql(self, sql_pred, table_id, table):
        op_sql_dict = {0:">", 1:"<", 2:"==", 3:"!="}
        agg_sql_dict = {0:"", 1:"AVG", 2:"MAX", 3:"MIN", 4:"COUNT", 5:"SUM"}
        conn_sql_dict = {0:"", 1:"and", 2:"or"}
        agg = sql_pred['agg']
        cond_conn_op = sql_pred['cond_conn_op']
        sel = sql_pred['sel']
        conds = sql_pred['conds']

        res = []
        for idx, i in enumerate(agg):

            tmp_conds = []
            for jdx, j in enumerate(conds):
                if j[1] == 2 or j[1] == 3:
                    # op_sql is '==' or '!='
                    tmp_conds.append(str(table['header'][j[0]] + ' ' + op_sql_dict[j[1]] + ' \'' + j[2] + '\''))
                else:
                    tmp_conds.append(str(table['header'][j[0]] + ' ' + op_sql_dict[j[1]] + ' ' + j[2]))
            tmp_where = (' ' + conn_sql_dict[cond_conn_op]+' ').join(tmp_conds)
            if agg_sql_dict[i] != '':
                tmp_res = 'select ' + agg_sql_dict[i] + '(' + table['header'][sel[idx]] + ') from ' + table_id + ' where ' + tmp_where
            else:
                tmp_res = 'select ' + table['header'][sel[idx]] + ' from ' + table_id + ' where ' + tmp_where
            res.append(tmp_res)
        return res

    def predict(self, sql, table):
        '''
        tmp_sql_demo: [{'question': 'PE2011大于11或者EPS2011大于11的公司有哪些', 'table_id': '69d4941c334311e9aefd542696d6e445'}]
        tmp_table_demo: {table_id: dict_keys(['rows', 'name', 'title', 'header', 'common', 'id', 'types'])}
        '''
        sql_data = [sql.copy()]
        table_data = {table['id']: table.copy()}
        
        batch_size=1
        perm = list(range(len(sql_data)))
        pred_record = []
        
        # for st in tqdm(range(len(sql_data) // batch_size + 1)):  # delete +1 
        for st in tqdm(range(len(sql_data) // batch_size)):
            if st * batch_size == len(perm):
                    break
            ed = (st + 1) * batch_size if (st + 1) * batch_size < len(perm) else len(perm)
            st = st * batch_size
            with torch.no_grad():
                    if isinstance(self.model, SQLBert):
                            q_seq, col_seq, col_num, raw_q_seq, table_ids, header_type = to_batch_seq_test(sql_data, table_data, perm, st, ed, tokenizer=self.tokenizer)
                            bert_inputs, q_lens, sel_col_nums, where_col_nums = gen_batch_bert_seq(self.tokenizer, q_seq, col_seq, header_type)
                            score = self.model.forward(bert_inputs, return_logits=False)
                            sql_preds = self.model.gen_query(score, q_seq, col_seq, sql_data, table_data, perm, st, ed)
                    else:
                            q_seq, col_seq, col_num, raw_q_seq, table_ids, header_type = to_batch_seq_test(sql_data, table_data, perm, st, ed)
                            score = self.model.forward(q_seq, col_seq, col_num)
                            sql_preds = self.model.gen_query(score, q_seq, col_seq, raw_q_seq)
                    sql_preds = post_process(sql_preds, sql_data, table_data, perm, st, ed)
            for sql_pred in sql_preds:
                    sql_pred = eval(str(sql_pred))
                    pred_record.append(sql_pred)
                    
        return pred_record[0]
    
    def infer(self, sql, table):
        tmp = self.predict(sql, table)
        res = self.text2sql(tmp, table['id'], table)
        
        return res
        
        


In [22]:
table_data = {'rows': [['600340.SH',
    '华夏幸福',
    17.49,
    1.54,
    2.03,
    2.67,
    11.36,
    8.61,
    6.56,
    'None',
    'None',
    5.8,
    '推荐'],
   ['000402.SZ',
    '金融街',
    6.53,
    0.67,
    0.78,
    0.91,
    9.8,
    8.41,
    7.2,
    '10.66',
    '-38.7',
    1.0,
    '谨慎推荐'],
   ['600823.SH',
    '世茂股份',
    11.79,
    1.01,
    1.13,
    1.39,
    11.66,
    10.4,
    8.47,
    '22.09',
    '-46.6',
    1.3,
    '无'],
   ['600716.SH',
    '凤凰股份',
    5.54,
    0.32,
    0.45,
    0.66,
    17.51,
    12.45,
    8.39,
    '7.46',
    '-25.8',
    2.5,
    '谨慎推荐'],
   ['000608.SZ',
    '阳光股份',
    4.79,
    0.23,
    0.29,
    0.32,
    20.76,
    16.31,
    14.75,
    '6.71',
    '-28.7',
    1.4,
    '谨慎推荐'],
   ['002285.SZ',
    '世联地产',
    15.07,
    0.48,
    0.84,
    1.05,
    31.27,
    18.04,
    14.34,
    'None',
    'None',
    3.6,
    '无']],
  'name': 'Table_69d4941c334311e9aefd542696d6e445',
  'title': '66 表3：2012年6月12日非住宅开发重点覆盖公司估值表 ',
  'header': ['证券代码',
   '公司名称',
   '股价',
   'EPS2011',
   'EPS2012E',
   'EPS2013E',
   'PE2011',
   'PE2012E',
   'PE2013E',
   'NAV',
   '折价率',
   'PB2012Q1',
   '评级'],
  'common': '资料来源：wind',
  'id': '69d4941c334311e9aefd542696d6e445',
  'types': ['text',
   'text',
   'real',
   'real',
   'real',
   'real',
   'real',
   'real',
   'real',
   'real',
   'real',
   'real',
   'text']}

sql_data = {'question': 'PE2011大于11或者EPS2011大于11的公司有哪些', 'table_id': '69d4941c334311e9aefd542696d6e445'}

In [23]:

ns = nl2sql()
ns.infer(sql_data, table_data)

Loading from ../model/chinese-bert_chinese_wwm_pytorch/
Loaded model from ../model/saved_bert_model


100%|██████████████████████████████████████████████████████| 1/1 [00:00<00:00,  8.45it/s]


['select 公司名称 from 69d4941c334311e9aefd542696d6e445 where EPS2011 > 11 or PE2011 > 11']

In [26]:
pwd

'/home/jasoncheung/project/work/nl2sql/code'

draft.py     result.json         start_train_bert.sh  train_bert.py
infer.ipynb  [0m[01;34msqlnet[0m/             test_bert.py         train.ipynb
README.md    start_test_bert.sh  test_ensemble.py     Untitled.ipynb
