# SPARTA (Semantic Parsing And Relational Table Aware)

This is a term project in `Unstructured Text Analysis` class.   
We implement the deep learning model for converting Korean language to SQL query. 

- github: https://github.com/TooTouch/SPARTA

<br>
<img src='https://user-images.githubusercontent.com/37654013/119700897-bec2c100-be8e-11eb-9d61-36de1ca66d5a.png'>
<br>

**Team Members**
- Hoonsang Yoon 
- Jaehyuk Heo 
- Jungwoo Choi
- Jeongseob Kim

**Information**
- Korea University [DSBA Lab](http://dsba.korea.ac.kr/)
- Advisor: [Pilsung Kang](http://dsba.korea.ac.kr/professor/)



In [21]:
%reload_ext autoreload
%autoreload 2

import pandas as pd
import numpy as np
from konlpy.tag import Mecab

import os 
import sys
sys.path.append('../TabularSemanticParsing')
# sys.path.extend(['../SQLova','../TabularSemanticParsing'])

In [22]:
sys.path.insert(0, '../TabularSemanticParsing')

In [43]:
# Bridge Model Define
from src.semantic_parser.learn_framework import EncoderDecoderLFramework
from src.semantic_parser.ensemble_configs import model_dirs as ensemble_model_dirs
from src.eval.wikisql.lib.dbengine import DBEngine

from src.demos.demos import Text2SQLWrapper
from src.data_processor.path_utils import get_model_dir, get_checkpoint_path
from src.data_processor.schema_graph import SchemaGraph
from src.data_processor.vocab_processor import build_vocab
from src.data_processor.data_processor import preprocess
from src.data_processor.processors.data_processor_wikisql import generate_sql_q, generate_sql_q1
 
import src.utils.utils as utils
import src.eval.eval_tools as eval_tools

import src.data_processor.processor_utils as data_utils
import src.data_processor.data_loader as data_loader
import src.common.ops as ops
import random

import ast

# parser = argparse.ArgumentParser()
# args_bridge = parser.parse_args(args=[])
from src.parse_args import args as args_bridge

args_bridge.data_dir="../TabularSemanticParsing/data/ko_from_table"
args_bridge.db_dir="../TabularSemanticParsing/data/ko_from_table"
args_bridge.dataset_name="wikisql"
args_bridge.model="bridge"
args_bridge.model_id = 2
args_bridge.question_split=True
args_bridge.query_split=False
args_bridge.question_only=True
args_bridge.normalize_variables=False
args_bridge.denormalize_sql=True
args_bridge.omit_from_clause=True
args_bridge.table_shuffling=True
args_bridge.use_graph_encoding=False
args_bridge.use_typed_field_markers=False
args_bridge.use_lstm_encoder=True
args_bridge.use_meta_data_encoding=True
args_bridge.use_picklist=True
args_bridge.no_anchor_text=False
args_bridge.anchor_text_match_threshold=0.85
args_bridge.top_k_picklist_matches=2
args_bridge.atomic_value_copy=False
args_bridge.process_sql_in_execution_order=False
args_bridge.sql_consistency_check=False
args_bridge.share_vocab=False
args_bridge.sample_ground_truth=False
args_bridge.save_nn_weights_for_visualizations=True
args_bridge.vocab_min_freq=0
args_bridge.text_vocab_min_freq=0
args_bridge.program_vocab_min_freq=0
args_bridge.max_in_seq_len=512
args_bridge.max_out_seq_len=60

args_bridge.num_steps=10000
args_bridge.curriculum_interval=0
args_bridge.num_peek_steps=400
args_bridge.num_accumulation_steps=3
args_bridge.save_best_model_only=True
args_bridge.train_batch_size=8
args_bridge.dev_batch_size=8
args_bridge.encoder_input_dim=768
args_bridge.encoder_hidden_dim=512
args_bridge.decoder_input_dim=512
args_bridge.num_rnn_layers=1
args_bridge.num_const_attn_layers=0
args_bridge.use_oracle_tables=False
args_bridge.num_random_tables_added=0
args_bridge.use_additive_features=False
args_bridge.schema_augmentation_factor=1
args_bridge.random_field_order=False
args_bridge.data_augmentation_factor=1
args_bridge.augment_with_wikisql=False
args_bridge.num_values_per_field=0
args_bridge.pretrained_transformer="bert-base-multilingual-cased"
args_bridge.fix_pretrained_transformer_parameters=False
args_bridge.bert_finetune_rate=0.00005
args_bridge.learning_rate=0.0003
args_bridge.learning_rate_scheduler="inverse-square"
args_bridge.trans_learning_rate_scheduler="inverse-square"
args_bridge.warmup_init_lr=0.0003
args_bridge.warmup_init_ft_lr=0
args_bridge.num_warmup_steps=3000
args_bridge.emb_dropout_rate=0.3
args_bridge.pretrained_lm_dropout_rate=0
args_bridge.rnn_layer_dropout_rate=0.1
args_bridge.rnn_weight_dropout_rate=0
args_bridge.cross_attn_dropout_rate=0
args_bridge.cross_attn_num_heads=8
args_bridge.res_input_dropout_rate=0.2
args_bridge.res_layer_dropout_rate=0
args_bridge.ff_input_dropout_rate=0.4
args_bridge.ff_hidden_dropout_rate=0.0
args_bridge.grad_norm=0.3
args_bridge.decoding_algorithm="beam-search"
args_bridge.beam_size=8
args_bridge.bs_alpha=1.0
args_bridge.data_parallel=False

args_bridge.model_dir = '../TabularSemanticParsing/model/wikisql.bridge.lstm.meta.ts.ko_from_table.bs_8.ppl-0.85.2.dn.no_from.feat.bert-base-multilingual-cased.xavier-768-512-512-8-3-0.0003-inv-sqr-0.0003-3000-5e-05-inv-sqr-0.0-3000-0.3-0.3-0.0-0.0-1-8-0.1-0.0-res-0.2-0.0-ff-0.4-0.0.test'

sp = EncoderDecoderLFramework(args_bridge).cuda()

sp.load_checkpoint(get_checkpoint_path(args_bridge))
sp.eval()

dataset = data_loader.load_processed_data(args_bridge)

split = 'test'
if args_bridge.dataset_name == 'wikisql':
    engine_path = os.path.join(args_bridge.data_dir, '{}.db'.format(split))
    engine = DBEngine(engine_path)
else:
    engine = None

* text vocab size = 119547
* program vocab size = 99

pretrained_transformer = bert-base-multilingual-cased
fix_pretrained_transformer_parameters = False

share_vocab is False
bridge module created
=> loading checkpoint '../TabularSemanticParsing/model/wikisql.bridge.lstm.meta.ts.ko_from_table.bs_8.ppl-0.85.2.dn.no_from.feat.bert-base-multilingual-cased.xavier-768-512-512-8-3-0.0003-inv-sqr-0.0003-3000-5e-05-inv-sqr-0.0-3000-0.3-0.3-0.0-0.0-1-8-0.1-0.0-res-0.2-0.0-ff-0.4-0.0.test/model-best.8.tar'
loading preprocessed data: ../TabularSemanticParsing/data/ko_from_table/wikisql.bridge.question-split.ppl-0.85.2.dn.no_from.bert.multilingual.TabularSemanticParsing.pkl


# Table information

In [57]:
question_idx = 205

In [58]:
dataset['test'][question_idx].text

'얼마나 많은 팀을 뛰었는지 week 6'

In [59]:
# Question with question_idx
question = 'week 6에 얼마나 많은 팀이 뛰었을까?'
dataset[split][question_idx].text = question

examples = dataset[split][question_idx:question_idx+1]
# examples = [question]
sp.schema_graphs = dataset['schema']

table_id = examples[0].db_name

# define all tables
with open('../TabularSemanticParsing/data/ko_from_table/test.tables.jsonl', 'r') as json_file:
    tables = list(json_file)
    
for i, table in enumerate(tables):
    tables[i] = ast.literal_eval(table)

print('{} {} examples loaded'.format(len(examples), split))

1 test examples loaded


In [60]:
pred_restored_cache = sp.load_pred_restored_cache()
pred_restored_cache_size = sum(len(v) for v in pred_restored_cache.values())

out_dict = sp.inference(examples, restore_clause_order=args_bridge.process_sql_in_execution_order,
                        pred_restored_cache=pred_restored_cache,
                        check_schema_consistency_=args_bridge.sql_consistency_check,
                        engine=engine, inline_eval=True, verbose=True)

100%|██████████| 1/1 [00:00<00:00,  3.61it/s]


In [61]:
pred_sql = [out_dict['pred_decoded'][0][-1]]
target_table = [table for table in tables if table["id"] == table_id]

In [62]:
generate_sql_q(pred_sql, target_table)

['SELECT count(Opponent) FROM 1-11391448-2 WHERE Week = 6']

In [63]:
sel = pred_sql[0]['sel']
agg = pred_sql[0]['agg']
conds = pred_sql[0]['conds']

engine.execute(table_id, sel, agg, conds)

[1]