In [1]:
%reload_ext autoreload
%autoreload 2

import torch
import json
import pandas as pd
import numpy as np
from IPython.display import display
import os 
import sys
sys.path.append('../')

from train import construct_hyper_param, \
                  get_data, get_models, get_wemb_bert, \
                  sort_and_generate_pr_w, generate_sql_q, \
                  tokenize_corenlp_direct_version
from sqlova.utils.utils_wikisql import *
from sqlova.utils.utils import topk_multi_dim
from sqlnet.dbengine import DBEngine

import corenlp
from konlpy.tag import Mecab

import argparse

In [2]:
parser = argparse.ArgumentParser()
args = construct_hyper_param(parser, notebook=True)

# gpu
# args.gpu = 1

# GPU_NUM = args.gpu # select gpu number
# device = torch.device(f'cuda:{GPU_NUM}' if torch.cuda.is_available() else 'cpu')
# torch.cuda.set_device(device) # change allocation of current GPU
# print ('Current cuda device ', torch.cuda.current_device()) # check

device = 'cpu'

In [3]:
def inference(inputs, sql, datadir, bert_config, model_bert, tokenizer):
    nlu, nlu_t, hds, tb = inputs
    
    
    engine = DBEngine(os.path.join(datadir, 'test.db'))
    
    # prediction
    wemb_n, wemb_h, l_n, l_hpu, l_hs, \
    nlu_tt, t_to_tt_idx, tt_to_t_idx = get_wemb_bert(bert_config, 
                                                     model_bert, 
                                                     tokenizer, 
                                                     nlu_t, 
                                                     hds, 
                                                     args.max_seq_length,
                                                     num_out_layers_n=args.num_target_layers, 
                                                     num_out_layers_h=args.num_target_layers,
                                                     device=device)

    prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(wemb_n, 
                                                                                    l_n, 
                                                                                    wemb_h, 
                                                                                    l_hpu,
                                                                                    l_hs, 
                                                                                    engine, 
                                                                                    tb,
                                                                                    nlu_t, 
                                                                                    nlu_tt,
                                                                                    tt_to_t_idx, 
                                                                                    nlu,
                                                                                    beam_size=args.beam_size,
                                                                                    device=device)

    pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i)

    pr_sql_q1 = generate_sql_q(pr_sql_i, tb)
    pr_ans, _ = engine.execute_return_query(tb[0]['id'], pr_sc[0], pr_sa[0], pr_sql_i[0]['conds'])
    
    # ground truth
    g_sql_q1 = generate_sql_q(sql, tb)

    g_ans, _ = engine.execute_return_query(tb[0]['id'], 
                                           sql[0]['sel'], 
                                           sql[0]['agg'], 
                                           sql[0]['conds'])
    
    # print results
    print()
    print('='*30)
    print('Logical Form')
    print('='*30)
    print('[PRED]: ',pr_sql_q1[0])
    print('[TRUE]: ',g_sql_q1[0])
    print()
    print('='*30)
    print('Execution')
    print('='*30)
    print('[PRED]: ',pr_ans[0])
    print('[TRUE]: ',g_ans[0])
    

# Testset Performance

In [4]:
lognames = ['ko_token','ko_from_table']
result_lst = []

for name in lognames:
    file = json.load(open(f'../logs/{name}/dev_performance.json','r'))
    result_lst.append(file)

In [5]:
pd.DataFrame(np.array([list(result_lst[i].values()) for i in range(len(lognames))]), 
             columns=result_lst[0].keys(),
             index=lognames).round(3)

Unnamed: 0,loss,acc_sc,acc_sa,acc_wn,acc_wc,acc_wo,acc_wvi,acc_wv,acc_lx,acc_x
ko_token,0.0,0.898,0.87,0.919,0.862,0.878,0.0,0.872,0.702,0.781
ko_from_table,0.0,0.869,0.868,0.913,0.866,0.88,0.0,0.874,0.702,0.773


# Load data & model

In [4]:
args.datadir = '../data/ko_token'
args.logdir = '../logs/ko_token'
args.bert_name = 'bert-base-multilingual-cased'

In [5]:
train_data, train_table, dev_data, dev_table, test_data, test_table, _, _, test_loader = get_data(args.datadir, args)


# To start from the pre-trained models, un-comment following lines.
path_model_bert = os.path.join(args.logdir, 'model_bert_best.pt')
path_model = os.path.join(args.logdir, 'model_best.pt')
model, model_bert, tokenizer, bert_config = get_models(args, 
                                                       trained=True,
                                                       path_model_bert=path_model_bert, 
                                                       path_model=path_model, 
                                                       device=device)

BERT: pretrained bert-base-multilingual-cased
BERT: learning rate: 1e-05
BERT: Fine-tune BERT: False
Seq-to-SQL: the number of final BERT layers to be used: 2
Seq-to-SQL: the size of hidden dimension = 100
Seq-to-SQL: LSTM encoding layer size = 2
Seq-to-SQL: dropout rate = 0.3
Seq-to-SQL: learning rate = 0.001


In [8]:
idx = 1039

tb = [train_table[train_data[idx]['table_id']]]
hds = [tb[0]['header']]
nlu_t = [train_data[idx]['question_tok']]
nlu = [train_data[idx]['question']]
sql = [train_data[idx]['sql']]

display(pd.DataFrame(tb[0]['rows'], columns=tb[0]['header']))

print('Question: ',nlu[0])

Unnamed: 0,No. in series,No. in season,Title,Directed by,Written by,Featured character(s),Original air date,U.S. viewers (million)
0,104/105,1/2,""" LA X """,Jack Bender,Damon Lindelof & Carlton Cuse,Various,"February2,2010",12.09
1,106,3,""" What Kate Does """,Paul Edwards,Edward Kitsis & Adam Horowitz,Kate,"February9,2010",11.05
2,107,4,""" The Substitute """,Tucker Gates,Elizabeth Sarnoff & Melinda Hsu Taylor,Locke,"February16,2010",9.82
3,108,5,""" Lighthouse """,Jack Bender,Carlton Cuse & Damon Lindelof,Jack,"February23,2010",9.95
4,109,6,""" Sundown """,Bobby Roth,Paul Zbyszewski & Graham Roland,Sayid,"March2,2010",9.29
5,110,7,""" Dr. Linus """,Mario Van Peebles,Edward Kitsis & Adam Horowitz,Ben,"March9,2010",9.49
6,111,8,""" Recon """,Jack Bender,Elizabeth Sarnoff & Jim Galasso,Sawyer,"March16,2010",8.87
7,112,9,""" Ab Aeterno """,Tucker Gates,Melinda Hsu Taylor & Greggory Nations,Richard,"March23,2010",9.31
8,113,10,""" The Package """,Paul Edwards,Paul Zbyszewski & Graham Roland,Sun & Jin,"March30,2010",10.13
9,114,11,""" Happily Ever After """,Jack Bender,Carlton Cuse & Damon Lindelof,Desmond,"April6,2010",9.55


Question:  에피소드에서 Jack & Locke이 나오는 에피소드 번호는 무엇인가요?


# Inference

In [9]:
idx = 100

tb = [test_table[test_data[idx]['table_id']]]
hds = [tb[0]['header']]
nlu_t = [test_data[idx]['question_tok']]
nlu = [test_data[idx]['question']]
sql = [test_data[idx]['sql']]

display(pd.DataFrame(tb[0]['rows'], columns=tb[0]['header']))

print('Question: ',nlu[0])

Unnamed: 0,Pick #,CFL Team,Player,Position,College
0,25,Hamilton Tiger-Cats (via Montreal via Hamilton),Robert Pavlovic,TE,South Carolina
1,26,Edmonton Eskimos,Micheal Jean-louis,DL,Laval
2,27,Edmonton Eskimos (via Winnipeg),Calvin McCarty,RB,Western Washington
3,28,Saskatchewan Roughriders,Ryan Ackerman,OL,Regina
4,29,Toronto Argonauts,Eric Maranda,LB,Laval
5,30,Toronto Argonauts (via Calgary),Steve Schmidt,TE,San Diego State
6,31,Montreal Alouettes,James Judges,DE,Buffalo


Question:  Calvin McCarty은 무슨 college에서 뛰었나요?


In [10]:
inference(inputs=[nlu, nlu_t, hds, tb], 
          sql=sql, 
          datadir=args.datadir, 
          bert_config=bert_config, 
          model_bert=model_bert, 
          tokenizer=tokenizer)


Logical Form
[PRED]:  SELECT (College) FROM 1-10812403-4 WHERE Player = calvin mccarty
[TRUE]:  SELECT (College) FROM 1-10812403-4 WHERE Player = Calvin McCarty

Execution
[PRED]:  western washington
[TRUE]:  western washington


# Step by Step

In [6]:
args.datadir = '../data/ko_token'
args.logdir = '../logs/ko_token'
args.bert_name = 'bert-base-multilingual-cased'
args.bS = 4

In [7]:
train_data, train_table, dev_data, dev_table, test_data, test_table, _, _, test_loader = get_data(args.datadir, args)


# To start from the pre-trained models, un-comment following lines.
path_model_bert = os.path.join(args.logdir, 'model_bert_best.pt')
path_model = os.path.join(args.logdir, 'model_best.pt')
model, model_bert, tokenizer, bert_config = get_models(args, 
                                                       trained=True,
                                                       path_model_bert=path_model_bert, 
                                                       path_model=path_model, 
                                                       device=device)

BERT: pretrained bert-base-multilingual-cased
BERT: learning rate: 1e-05
BERT: Fine-tune BERT: False
Seq-to-SQL: the number of final BERT layers to be used: 2
Seq-to-SQL: the size of hidden dimension = 100
Seq-to-SQL: LSTM encoding layer size = 2
Seq-to-SQL: dropout rate = 0.3
Seq-to-SQL: learning rate = 0.001


In [8]:
model.eval()
model_bert.eval()
print()




# Top K

In [76]:
def beam_search(wemb_n, l_n, wemb_hpu, l_hpu, l_hs, engine, tb,
                nlu_t, nlu_wp_t, wp_to_wh_index, nlu,
                beam_size=4):
    # sc
    s_sc = model.scp(wemb_n, l_n, wemb_hpu, l_hpu, l_hs, show_p_sc=False)
    prob_sc = F.softmax(s_sc, dim=-1)
    bS, mcL = s_sc.shape

    # minimum_hs_length = min(l_hs)
    # beam_size = minimum_hs_length if beam_size > minimum_hs_length else beam_size

    # sa
    if mcL < beam_size:
        beam_size_mcL = mcL
    else:
        beam_size_mcL = beam_size
        
    # Construct all possible sc_sa_score
    prob_sc_sa = torch.zeros([bS, beam_size_mcL, model.n_agg_ops]).to(device)
    prob_sca = torch.zeros_like(prob_sc_sa).to(device)

    # get the top-k indices.  pr_sc_beam = [B, beam_size]
    
    pr_sc_beam = pred_sc_beam(s_sc, beam_size_mcL)

    # calculate and predict s_sa.
    for i_beam in range(beam_size_mcL):
        pr_sc = list( array(pr_sc_beam)[:,i_beam] )
        s_sa = model.sap(wemb_n, l_n, wemb_hpu, l_hpu, l_hs, pr_sc, show_p_sa=False)
        prob_sa = F.softmax(s_sa, dim=-1)
        prob_sc_sa[:, i_beam, :] = prob_sa

        prob_sc_selected = prob_sc[range(bS), pr_sc] # [B]
        prob_sca[:,i_beam,:] =  (prob_sa.t() * prob_sc_selected).t()
        # [mcL, B] * [B] -> [mcL, B] (element-wise multiplication)
        # [mcL, B] -> [B, mcL]

    # Calculate the dimension of tensor
    # tot_dim = len(prob_sca.shape)

    # First flatten to 1-d
    if np.prod(prob_sca.shape[1:]) < beam_size:
        beam_size_sca = np.prod(prob_sca.shape[1:])
    else:
        beam_size_sca = beam_size
    idxs_s, values_s = topk_multi_dim(torch.tensor(prob_sca), n_topk=beam_size_sca, batch_exist=True)
    # Now as sc_idx is already sorted, re-map them properly.

    idxs_s = remap_sc_idx(idxs_s, pr_sc_beam) # [sc_beam_idx, sa_idx] -> [sc_idx, sa_idx]
    idxs_arr = array(idxs_s)
    # [B, beam_size, remainig dim]
    # idxs[b][0] gives first probable [sc_idx, sa_idx] pairs.
    # idxs[b][1] gives of second.

    # Calculate prob_sca, a joint probability
    beam_idx_sca = [0] * bS
    beam_meet_the_final = [False] * bS

    while True:
        pr_sc = idxs_arr[range(bS),beam_idx_sca,0]
        pr_sa = idxs_arr[range(bS),beam_idx_sca,1]

        # map index properly
        check = check_sc_sa_pairs(tb, pr_sc, pr_sa)

        if sum(check) == bS:
            break
        else:
            for b, check1 in enumerate(check):
                if not check1: # wrong pair
                    beam_idx_sca[b] += 1
                    if beam_idx_sca[b] >= beam_size_sca:
                        beam_meet_the_final[b] = True
                        beam_idx_sca[b] -= 1
                else:
                    beam_meet_the_final[b] = True

        if sum(beam_meet_the_final) == bS:
            break


    # Now pr_sc, pr_sa are properly predicted.
    pr_sc_best = list(pr_sc)
    pr_sa_best = list(pr_sa)

    # Now, Where-clause beam search.
    s_wn = model.wnp(wemb_n, l_n, wemb_hpu, l_hpu, l_hs, show_p_wn=False)
    prob_wn = F.softmax(s_wn, dim=-1).detach().to('cpu').numpy()

    # Found "executable" most likely 4(=max_num_of_conditions) where-clauses.
    # wc
    s_wc = model.wcp(wemb_n, l_n, wemb_hpu, l_hpu, l_hs, show_p_wc=False, penalty=True)
    prob_wc = F.sigmoid(s_wc).detach().to('cpu').numpy()
    # pr_wc_sorted_by_prob = pred_wc_sorted_by_prob(s_wc)

    # get max_wn # of most probable columns & their prob.
    pr_wn_max = [model.max_wn]*bS
    pr_wc_max = pred_wc(pr_wn_max, s_wc) # if some column do not have executable where-claouse, omit that column
    prob_wc_max = zeros([bS, model.max_wn])
    for b, pr_wc_max1 in enumerate(pr_wc_max):
        prob_wc_max[b,:] = prob_wc[b,pr_wc_max1]

    # get most probable max_wn where-clouses
    # wo
    s_wo_max = model.wop(wemb_n, l_n, wemb_hpu, l_hpu, l_hs, wn=pr_wn_max, wc=pr_wc_max, show_p_wo=False)
    prob_wo_max = F.softmax(s_wo_max, dim=-1).detach().to('cpu').numpy()
    # [B, max_wn, n_cond_op]

    pr_wvi_beam_op_list = []
    prob_wvi_beam_op_list = []
    for i_op  in range(model.n_cond_ops-1):
        pr_wo_temp = [ [i_op]*model.max_wn ]*bS
        # wv
        s_wv = model.wvp(wemb_n, l_n, wemb_hpu, l_hpu, l_hs, wn=pr_wn_max, wc=pr_wc_max, wo=pr_wo_temp, show_p_wv=False)
        prob_wv = F.softmax(s_wv, dim=-2).detach().to('cpu').numpy()

        # prob_wv
        pr_wvi_beam, prob_wvi_beam = pred_wvi_se_beam(model.max_wn, s_wv, beam_size)
        pr_wvi_beam_op_list.append(pr_wvi_beam)
        prob_wvi_beam_op_list.append(prob_wvi_beam)
        # pr_wvi_beam = [B, max_wn, k_logit**2 [st, ed] paris]

        # pred_wv_beam

    # Calculate joint probability of where-clause
    # prob_w = [batch, wc, wo, wv] = [B, max_wn, n_cond_op, n_pairs]
    n_wv_beam_pairs = prob_wvi_beam.shape[2]
    prob_w = zeros([bS, model.max_wn, model.n_cond_ops-1, n_wv_beam_pairs])
    for b in range(bS):
        for i_wn in range(model.max_wn):
            for i_op in range(model.n_cond_ops-1): # do not use final one
                for i_wv_beam in range(n_wv_beam_pairs):
                    # i_wc = pr_wc_max[b][i_wn] # already done
                    p_wc = prob_wc_max[b, i_wn]
                    p_wo = prob_wo_max[b, i_wn, i_op]
                    p_wv = prob_wvi_beam_op_list[i_op][b, i_wn, i_wv_beam]

                    prob_w[b, i_wn, i_op, i_wv_beam] = p_wc * p_wo * p_wv

    # Perform execution guided decoding
    conds_max = []
    prob_conds_max = []

    if 4 < beam_size:
        beam_size_w = 4
    else:
        beam_size_w = beam_size
        
    idxs_w, values_w = topk_multi_dim(torch.tensor(prob_w), n_topk=beam_size_w, batch_exist=True)
    # idxs = [B, i_wc_beam, i_op, i_wv_pairs]

    # Construct conds1
    for b, idxs1 in enumerate(idxs_w):
        conds_max1 = []
        prob_conds_max1 = []
        for i_wn, idxs11 in enumerate(idxs1):
            i_wc = pr_wc_max[b][idxs11[0]]
            i_op = idxs11[1]
            wvi = pr_wvi_beam_op_list[i_op][b][idxs11[0]][idxs11[2]]

            # get wv_str
            temp_pr_wv_str, _ = convert_pr_wvi_to_string([[wvi]], [nlu_t[b]], [nlu_wp_t[b]], [wp_to_wh_index[b]], [nlu[b]])
            merged_wv11 = merge_wv_t1_eng(temp_pr_wv_str[0][0], nlu[b])
            conds11 = [i_wc, i_op, merged_wv11]

            prob_conds11 = prob_w[b, idxs11[0], idxs11[1], idxs11[2] ]

            # test execution
            pr_ans = engine.execute(tb[b]['id'], pr_sc[b], pr_sa[b], [conds11])
            if bool(pr_ans):
                # pr_ans is not empty!
                conds_max1.append(conds11)
                prob_conds_max1.append(prob_conds11)

        conds_max.append(conds_max1)
        prob_conds_max.append(prob_conds_max1)

        # May need to do more exhuastive search?
        # i.e. up to.. getting all executable cases.

    # Calculate total probability to decide the number of where-clauses
    pr_sql_i = []
    prob_wn_w = []
    pr_wn_based_on_prob = []     

    for b, prob_wn1 in enumerate(prob_wn):
        max_executable_wn1 = len( conds_max[b] )

        prob_wn_w1 = []
        prob_wn_w1.append(prob_wn1[0])  # wn=0 case.
        for i_wn in range(max_executable_wn1):
            prob_wn_w11 = prob_wn1[i_wn+1] * prob_conds_max[b][i_wn]
            prob_wn_w1.append(prob_wn_w11)
        pr_wn_based_on_prob.append(argmax(prob_wn_w1))
        prob_wn_w.append(prob_wn_w1)

        pr_sql_i1 = {'agg': pr_sa_best[b], 'sel': pr_sc_best[b], 'conds': conds_max[b][:pr_wn_based_on_prob[b]]}
        pr_sql_i.append(pr_sql_i1)


    #=================
    # top k 
    #=================
    
    beam_size_sca = idxs_arr.shape[1]
    pr_sc_topk = np.zeros((bS, beam_size_sca), dtype=np.int)
    for i in range(beam_size_sca):
        pr_sc_i = idxs_arr[range(bS),i,0]
        pr_sa_i = idxs_arr[range(bS),i,1]

        # map index properly

        check = check_sc_sa_pairs(tb, pr_sc_i, pr_sa_i)

        for b, check_b in enumerate(check):
            if check_b: # wrong pair
                pr_sc_topk[b][i] += 1

    values_s = values_s * pr_sc_topk

    total_prob = np.zeros((bS, beam_size_sca, 5))

    for b, prob_wn1 in enumerate(prob_wn):
        max_executable_wn1 = len( conds_max[b] )
        prob_wn_w1 = []
        prob_wn_w1.append(prob_wn1[0])  # wn=0 case.

        for sc_sa_idx, i_sc_sa in enumerate(values_s[b]):
            total_prob[b][sc_sa_idx][0] = i_sc_sa * prob_wn1[0] # wn=0 case.

            for w_idx, i_wn in enumerate(range(max_executable_wn1)):

                prob_wn_w11 = prob_wn1[i_wn+1] * prob_conds_max[b][i_wn]

                total_prob[b][sc_sa_idx][i_wn+1] = i_sc_sa * prob_wn_w11

    if np.prod(total_prob.shape[1:]) < beam_size:
        beam_size_total = np.prod(total_prob.shape[1:])
    else:
        beam_size_total = beam_size
    idx_total, values_total = topk_multi_dim(torch.tensor(total_prob), n_topk=beam_size_total, batch_exist=True)

    # top k sql
    pr_sql_topk = []
    for b, idx in enumerate(idx_total):
        pr_sql_topk_i = []
        for k, idx_k in enumerate(idx):
            pr_sql_k = {'agg': idxs_arr[b][idx_k[0]][1], 
                        'sel': idxs_arr[b][idx_k[0]][0], 
                        'conds': conds_max[b][:idx_k[1]]}

            pr_sql_topk_i.append(pr_sql_k)

        pr_sql_topk.append(pr_sql_topk_i)
        
    return pr_sql_topk
    

In [77]:

def topk_acc(bS, tb, topk_x_acc, topk_lx_acc, 
             g_sc, g_sa, g_wn, g_wc, g_wo, g_wv, sql_i, pr_sql_topk_i):
    
    cnt_x = np.zeros(bS, dtype=np.int)
    cnt_lx = np.zeros(bS, dtype=np.int)

    for i in range(beam_size):
        pr_sql_b = np.array(pr_sql_topk_i)[:,i]

        pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_b)
        pr_sc = [pr_sql['sel'] for pr_sql in pr_sql_b]
        pr_sa = [pr_sql['agg'] for pr_sql in pr_sql_b]
        pr_wn = [len(pr_sql['conds']) for pr_sql in pr_sql_b]

        # where value index is None
        g_wvi = None
        pr_wvi = None

        cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, \
        cnt_wc1_list, cnt_wo1_list, \
        cnt_wvi1_list, cnt_wv1_list = get_cnt_sw_list(g_sc, g_sa, g_wn, g_wc, g_wo, g_wvi,
                                                      pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi,
                                                      sql_i, pr_sql_b,
                                                      mode='test')

        # Logical Form
        cnt_lx1_list = get_cnt_lx_list(cnt_sc1_list, cnt_sa1_list, cnt_wn1_list, cnt_wc1_list,
                                       cnt_wo1_list, cnt_wv1_list)
        cnt_lx += cnt_lx1_list

        cnt_lx_k = sum(cnt_lx > 0)

        if f'Top-{i+1} lx' in topk_lx_acc.keys():
            topk_lx_acc[f'Top-{i+1} lx'] += cnt_lx_k

        # Execution 
        cnt_x1_list, _, _ = get_cnt_x_list(engine, tb, sql_i, pr_sql_b)
        cnt_x += cnt_x1_list

        cnt_x_k = sum(cnt_x > 0)

        if f'Top-{i+1} x' in topk_x_acc.keys():
            topk_x_acc[f'Top-{i+1} x'] += cnt_x_k
            
    return topk_lx_acc, topk_x_acc

In [78]:
def inference_topk(engine, t, data_table, beam_size, k_lst, topk_lx_acc, topk_x_acc):

    # get fields
    nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, data_table, no_hs_t=True, no_sql_t=True)


    # prediction
    wemb_n, wemb_hpu, l_n, l_hpu, l_hs, \
    nlu_wp_t, t_to_tt_idx, wp_to_wh_index = get_wemb_bert(bert_config, 
                                                         model_bert, 
                                                         tokenizer, 
                                                         nlu_t, 
                                                         hds, 
                                                         args.max_seq_length,
                                                         num_out_layers_n=args.num_target_layers, 
                                                         num_out_layers_h=args.num_target_layers,
                                                         device=device)

    # beam search
    pr_sql_topk_i = beam_search(wemb_n, l_n, wemb_hpu, l_hpu, l_hs, engine, tb,
                                nlu_t, nlu_wp_t, wp_to_wh_index, nlu,
                                beam_size=beam_size)

    
    

    g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i)
    
    topk_lx_acc, topk_x_acc = topk_acc(len(t), tb, topk_x_acc, topk_lx_acc, 
                                       g_sc, g_sa, g_wn, g_wc, g_wo, g_wv, sql_i, pr_sql_topk_i)

    
    return topk_lx_acc, topk_x_acc

In [79]:
engine = DBEngine(os.path.join(args.datadir, 'test.db'))
data_table = test_table

beam_size = 64
k_lst = [1,2,3,4,5,10]

# top k
topk_x_acc = dict([(f'Top-{k} x',0) for k in k_lst])
topk_lx_acc = dict([(f'Top-{k} lx',0) for k in k_lst])

In [80]:
for i in range(2):
    t = test_loader.dataset[i:i+4]
    
    topk_lx_acc, topk_x_acc = inference_topk(engine, t, data_table, beam_size, k_lst, 
                                             topk_lx_acc, topk_x_acc)
    

In [81]:
topk_lx_acc

{'Top-1 lx': 7,
 'Top-2 lx': 7,
 'Top-3 lx': 7,
 'Top-4 lx': 7,
 'Top-5 lx': 7,
 'Top-10 lx': 8}

In [82]:
topk_x_acc

{'Top-1 x': 7,
 'Top-2 x': 7,
 'Top-3 x': 7,
 'Top-4 x': 7,
 'Top-5 x': 7,
 'Top-10 x': 8}